-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathplotting.py
More file actions
272 lines (231 loc) · 9.28 KB
/
plotting.py
File metadata and controls
272 lines (231 loc) · 9.28 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
"""Plotting utilities for visualization of results.
This module provides functions for creating standardized plots including contour plots
and line plots with consistent styling and formatting.
"""
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import TwoSlopeNorm
from matplotlib import rc
from matplotlib import colors
def set_matplotlib_style():
"""Sets up standard matplotlib styling for consistent plots.
Configures:
- Serif font family with Computer Modern
- LaTeX text rendering
- AMS math package for mathematical notation
"""
rc('font', **{'family': 'serif', 'serif': ['Computer Modern']})
rc('text', usetex=True)
rc('text.latex', preamble=r'\usepackage{amsmath}')
def standard_textbox(text, override_kwargs=None):
"""Creates standardized textbox parameters for plot annotations.
Args:
text: Text to display in the textbox.
override_kwargs: Dictionary of parameters to override defaults.
Returns:
dict: Textbox parameters for matplotlib text function.
"""
props = dict(
boxstyle='round',
facecolor=(0.93, 0.93, 0.93),
alpha=1.0,
linewidth=0.5,
edgecolor='gray'
)
txtbox_args = {
'x': 0.05,
'y': 0.92,
's': text,
'ha': 'left',
'va': 'top',
'bbox': props,
'fontsize': 20
}
if override_kwargs is not None:
txtbox_args.update(override_kwargs)
return txtbox_args
def contour_plot(data, cmap="RdBu_r",
figsize=(5, 4), dpi=150,
lineplot_data=None,
txtbox_kwargs=None,
increasing_colorbar=True,
colorbar_kwargs=None,
overlay_kwargs=None,
vmin=None, vmax=None, center=True,
ax=None, save_file=None, add_colorbar=True,
colorbar_label=None, xlabel=None, ylabel=None, title=None,
return_plt=False, label_size=14, tick_size=12, **kwargs):
"""Creates a contour plot with standardized styling and formatting.
Args:
data: List of tuples (xs, ys, zs) or (xs, ys, zs, overlay) for each subplot.
cmap: Colormap to use for the contour plot.
figsize: Tuple of (width, height) for the figure.
dpi: Dots per inch for the figure.
lineplot_data: Optional list of (x, y, style_dict) tuples for line overlays.
txtbox_kwargs: Optional list of textbox parameters for each subplot.
increasing_colorbar: Whether colorbar values increase left to right.
colorbar_kwargs: Additional arguments for colorbar customization.
overlay_kwargs: Arguments for overlay contour plot if provided.
vmin: Minimum value for color scaling.
vmax: Maximum value for color scaling.
center: Whether to center the colormap at zero.
ax: Optional existing axes to plot on.
save_file: Optional path to save the figure.
add_colorbar: Whether to add a colorbar.
colorbar_label: Label for the colorbar.
xlabel: Label for x-axis.
ylabel: Label for y-axis.
title: Title for the plot.
return_plt: Whether to return the plot object instead of showing.
label_size: Font size for axis labels.
tick_size: Font size for tick labels.
**kwargs: Additional arguments passed to contourf.
Returns:
If return_plt is True, returns the contour plot object.
Otherwise, displays the plot.
"""
# Initialize default kwargs
colorbar_kwargs = colorbar_kwargs or {}
overlay_kwargs = overlay_kwargs or {}
# Standardize matplotlib settings
set_matplotlib_style()
# Create figure and axes
fig, axes = plt.subplots(1, len(data), figsize=figsize, dpi=dpi,
sharex=True, sharey=True)
# Determine color scaling
all_zs = np.concatenate([d[2] for d in data])
vmin = np.nanmin(all_zs) if vmin is None else vmin
vmax = np.nanmax(all_zs) if vmax is None else vmax
# Create normalization
if center:
abs_max = max(abs(vmin), abs(vmax))
norm = TwoSlopeNorm(vmin=-abs_max, vcenter=0, vmax=abs_max)
else:
norm = colors.Normalize(vmin=vmin, vmax=vmax)
# Plot contours
contours = []
for i, d in enumerate(data):
# Extract data
if len(d) == 3:
xs, ys, zs = d
overlay = None
elif len(d) == 4:
xs, ys, zs, overlay = d
ax = axes[i] if len(data) > 1 else axes
# Plot overlay if provided
if overlay is not None:
ax.contourf(xs, ys, overlay, **overlay_kwargs)
# Plot main contour
contour = ax.contourf(xs, ys, zs, cmap=cmap, norm=norm, **kwargs)
contour.set_edgecolor('face')
contours.append(contour)
# Add lineplot if provided
if lineplot_data is not None:
ax.plot(lineplot_data[i][0], lineplot_data[i][1], **lineplot_data[i][2])
# Set labels and title
if i == 0:
ax.set_ylabel(ylabel, fontsize=label_size)
ax.set_xlabel(xlabel, fontsize=label_size)
ax.set_title(title)
ax.tick_params(axis='both', which='major', labelsize=tick_size)
# Add textbox if provided
if txtbox_kwargs is not None:
ax.text(**txtbox_kwargs[i], transform=ax.transAxes)
# Add colorbar
if add_colorbar:
# Add a single colorbar for all subplots
contour_for_colorbar = contours[np.nanargmax([np.nanmax(d[2]) for d in data])] if increasing_colorbar else contours[np.nanargmin([np.nanmin(d[2]) for d in data])]
cbar = fig.colorbar(contour_for_colorbar, ax=axes,
cmap=cmap,
norm=norm,
orientation='horizontal',
fraction=0.1, pad=0.1, location='top',
**colorbar_kwargs)
if colorbar_label:
cbar.set_label(colorbar_label, fontsize=label_size, labelpad=10)
if not increasing_colorbar:
cbar.ax.invert_xaxis()
# Save figure if requested
if save_file is not None:
fig.savefig(save_file, bbox_inches='tight')
# Return or show
if return_plt:
return contour
else:
plt.show()
def line_plot(data, figsize=(4, 4), dpi=150,
txtbox_kwargs=None,
ax=None,
xlabel=None, ylabel=None, title=None,
plt_legend=False, label_size=14,
tick_size=12,
save_file=None, return_plt=False, **kwargs):
"""Creates a line plot with standardized styling and formatting.
Args:
data: List of tuples (xs, ys, style_dict) for each line to plot.
figsize: Tuple of (width, height) for the figure.
dpi: Dots per inch for the figure.
txtbox_kwargs: Optional textbox parameters.
ax: Optional existing axes to plot on.
xlabel: Label for x-axis.
ylabel: Label for y-axis.
title: Title for the plot.
plt_legend: Whether to show the legend.
label_size: Font size for axis labels.
tick_size: Font size for tick labels.
save_file: Optional path to save the figure.
return_plt: Whether to return the plot object instead of showing.
**kwargs: Additional arguments passed to plot.
Returns:
If return_plt is True, returns the axes object.
Otherwise, displays the plot.
"""
set_matplotlib_style()
# Create or use existing axes
if ax is None:
ax = plt.figure(figsize=figsize, dpi=dpi).add_subplot(111)
# Plot each line
for (xs, ys, subplot_kwargs) in data:
ax.plot(xs, ys, **subplot_kwargs, **kwargs)
# Set labels and title
ax.set_xlabel(xlabel, fontsize=label_size)
ax.set_ylabel(ylabel, fontsize=label_size)
ax.set_title(title)
ax.tick_params(axis='both', which='major', labelsize=tick_size)
# Add legend if requested
if plt_legend:
ax.legend()
# Add textbox if provided
if txtbox_kwargs is not None:
ax.text(**txtbox_kwargs, transform=ax.transAxes)
# Save figure if requested
if save_file is not None:
# Handle subplot case
if ax.get_gridspec() is not None:
fig = plt.figure(figsize=figsize, dpi=dpi)
new_ax = fig.add_subplot(111)
# Recreate plot
for (xs, ys, subplot_kwargs) in data:
new_ax.plot(xs, ys, **subplot_kwargs, **kwargs)
# Set labels and title
new_ax.set_xlabel(xlabel, fontsize=label_size)
new_ax.set_ylabel(ylabel, fontsize=label_size)
new_ax.set_title(title)
new_ax.tick_params(axis='both', which='major', labelsize=tick_size)
if plt_legend:
new_ax.legend()
if txtbox_kwargs is not None:
new_ax.text(**txtbox_kwargs, transform=new_ax.transAxes)
# Save and cleanup
fig.tight_layout()
fig.savefig(save_file, bbox_inches='tight')
plt.close(fig)
else:
# Save original figure
ax.figure.tight_layout()
ax.figure.savefig(save_file, bbox_inches='tight')
# Return or show
if return_plt:
return ax
else:
plt.show()