Commit 905b1ae2 authored by Somnath, Suhas's avatar Somnath, Suhas Committed by CompPhysChris
Browse files

Implemented a module-wide default colormap and an easy function to get a colormap

parent 620d3693
......@@ -24,6 +24,29 @@ from ..io.hdf_utils import reshape_to_Ndims, get_formatted_labels, get_data_desc
if sys.version_info.major == 3:
unicode = str
default_cmap = plt.cm.virdis
def get_cmap_object(cmap):
"""
Get the matplotlib.colors.LinearSegmentedColormap object regardless of the input
Parameters
----------
cmap : String, or matplotlib.colors.LinearSegmentedColormap object (Optional)
Requested color map
Returns
-------
cmap : matplotlib.colors.LinearSegmentedColormap object
Requested / Default colormap object
"""
if cmap is None:
return default_cmap
elif isinstance(cmap, str):
return plt.get_cmap(cmap)
return cmap
def set_tick_font_size(axes, font_size):
"""
Sets the font size of the ticks in the provided axes
......@@ -64,7 +87,7 @@ def cmap_jet_white_center():
Returns
-------
white_jet : matplotlib.colors.LinearSegmentedColormap object
color map object that can be used in place of plt.cm.viridis
color map object that can be used in place of the default colormap
"""
# For red - central column is like brightness
# For blue - last column is like brightness
......@@ -172,7 +195,7 @@ def cmap_hot_desaturated():
return cmap_from_rgba('hot_desaturated', hot_desaturated, 255)
def discrete_cmap(num_bins, base_cmap=None):
def discrete_cmap(num_bins, base_cmap=default_cmap):
"""
Create an N-bin discrete colormap from the specified input map
......@@ -195,9 +218,9 @@ def discrete_cmap(num_bins, base_cmap=None):
"""
if base_cmap is None:
base_cmap = 'viridis'
base_cmap = default_cmap.name
if type(base_cmap) == type(plt.cm.viridis):
elif isinstance(base_cmap, type(default_cmap)):
base_cmap = base_cmap.name
if type(base_cmap) == str:
......@@ -233,7 +256,8 @@ def _add_loop_parameters(axes, switching_coef_vec):
return axes
def rainbow_plot(ax, ao_vec, ai_vec, num_steps=32, cmap=plt.cm.viridis, **kwargs):
def rainbow_plot(ax, ao_vec, ai_vec, num_steps=32, cmap=default_cmap, **kwargs):
"""
Plots the input against the output waveform (typically loops).
The color of the curve changes as a function of time using the jet colorscheme
......@@ -251,6 +275,8 @@ def rainbow_plot(ax, ao_vec, ai_vec, num_steps=32, cmap=plt.cm.viridis, **kwargs
cmap : matplotlib.colors.LinearSegmentedColormap object
Colormap to be used
"""
cmap = get_cmap_object(cmap)
pts_per_step = int(len(ai_vec) / num_steps)
for step in range(num_steps - 1):
ax.plot(ao_vec[step * pts_per_step:(step + 1) * pts_per_step],
......@@ -266,7 +292,7 @@ def rainbow_plot(ax, ao_vec, ai_vec, num_steps=32, cmap=plt.cm.viridis, **kwargs
def plot_line_family(axis, x_axis, line_family, line_names=None, label_prefix='Line', label_suffix='',
cmap=plt.cm.viridis, y_offset=0, **kwargs):
cmap=default_cmap, y_offset=0, **kwargs):
"""
Plots a family of lines with a sequence of colors
......@@ -289,6 +315,8 @@ def plot_line_family(axis, x_axis, line_family, line_names=None, label_prefix='L
y_offset : (optional) number
quantity by which the lines are offset from each other vertically (useful for spectra)
"""
cmap = get_cmap_object(cmap)
num_lines = line_family.shape[0]
if line_names is None:
......@@ -468,7 +496,8 @@ def plot_loops(excit_wfm, datasets, line_colors=[], dataset_names=[], evenly_spa
###############################################################################
def plot_complex_map_stack(map_stack, num_comps=4, title='Eigenvectors', xlabel='UDVS Step', stdevs=2):
def plot_complex_map_stack(map_stack, num_comps=4, title='Eigenvectors', xlabel='UDVS Step', stdevs=2,
cmap=default_cmap):
"""
Plots the provided spectrograms from SVD V vector
......@@ -484,11 +513,15 @@ def plot_complex_map_stack(map_stack, num_comps=4, title='Eigenvectors', xlabel=
Label for x axis
stdevs : int
Number of standard deviations to consider for plotting
cmap : String, or matplotlib.colors.LinearSegmentedColormap object (Optional)
Requested color map
Returns
---------
fig, axes
"""
cmap = get_cmap_object(cmap)
fig201, axes201 = plt.subplots(2, num_comps, figsize=(4 * num_comps, 8))
fig201.subplots_adjust(hspace=0.4, wspace=0.4)
fig201.canvas.set_window_title(title)
......@@ -501,7 +534,7 @@ def plot_complex_map_stack(map_stack, num_comps=4, title='Eigenvectors', xlabel=
for func, lab, ax in zip(funcs, labels, axes):
amp_mean = np.mean(func(cur_map))
amp_std = np.std(func(cur_map))
ax.imshow(func(cur_map), cmap='inferno',
ax.imshow(func(cur_map), cmap=cmap,
vmin=amp_mean - stdevs * amp_std,
vmax=amp_mean + stdevs * amp_std)
ax.set_title('Eigenvector: %d - %s' % (index + 1, lab))
......@@ -707,7 +740,7 @@ def plot_map_stack(map_stack, num_comps=9, stdevs=2, color_bar_mode=None, evenly
return fig202, axes202
def plot_cluster_h5_group(h5_group, centroids_together=True, cmap=None):
def plot_cluster_h5_group(h5_group, centroids_together=True, cmap=default_cmap):
"""
Plots the cluster labels and mean response for each cluster
......@@ -776,7 +809,7 @@ def plot_cluster_h5_group(h5_group, centroids_together=True, cmap=None):
###############################################################################
def plot_cluster_results_together(label_mat, mean_response, spec_val=None, cmap=None,
def plot_cluster_results_together(label_mat, mean_response, spec_val=None, cmap=default_cmap,
spec_label='Spectroscopic Value', resp_label='Response',
pos_labels=('X', 'Y'), pos_ticks=None):
"""
......@@ -814,11 +847,7 @@ def plot_cluster_results_together(label_mat, mean_response, spec_val=None, cmap=
axes : 1D array_like of axes objects
Axes of the individual plots within `fig`
"""
if cmap is None:
cmap = plt.cm.viridis
else:
if isinstance(cmap, str):
cmap = plt.get_cmap(cmap)
cmap = get_cmap_object(cmap)
if isinstance(cmap, str):
cmap = plt.get_cmap(cmap)
......@@ -900,7 +929,7 @@ def plot_cluster_results_together(label_mat, mean_response, spec_val=None, cmap=
###############################################################################
def plot_cluster_results_separate(label_mat, cluster_centroids, max_centroids=4, cmap=None,
def plot_cluster_results_separate(label_mat, cluster_centroids, max_centroids=4, cmap=default_cmap,
spec_val=None, x_label='Excitation (a.u.)', y_label='Response (a.u.)'):
"""
Plots the provided labels mat and centroids from clustering
......@@ -928,11 +957,7 @@ def plot_cluster_results_separate(label_mat, cluster_centroids, max_centroids=4,
fig
"""
if cmap is None:
cmap = plt.cm.viridis
else:
if isinstance(cmap, str):
cmap = plt.get_cmap(cmap)
cmap = get_cmap_object(cmap)
if max_centroids < 5:
......@@ -1128,7 +1153,7 @@ def plot_1d_spectrum(data_vec, freq, title, figure_path=None):
###############################################################################
def plot_2d_spectrogram(mean_spectrogram, freq, title, cmap=None, figure_path=None, **kwargs):
def plot_2d_spectrogram(mean_spectrogram, freq, title, figure_path=None, **kwargs):
"""
Plots the position averaged spectrogram
......@@ -1140,8 +1165,6 @@ def plot_2d_spectrogram(mean_spectrogram, freq, title, cmap=None, figure_path=No
BE frequency that serves as the X axis of the plot
title : String
Plot group name
cmap : matplotlib.colors.LinearSegmentedColormap object
color map. Default = plt.cm.viridis
figure_path : String / Unicode
Absolute path of the file to write the figure to
......@@ -1157,22 +1180,17 @@ def plot_2d_spectrogram(mean_spectrogram, freq, title, cmap=None, figure_path=No
print('2D:', mean_spectrogram.shape, freq.shape)
return
"""cmap = kwargs.get('cmap')
kwargs.pop('cmap')"""
if cmap is None: # unpack from kwargs instead
col_map = plt.cm.viridis # overriding default
freq *= 1E-3 # to kHz
fig, ax = plt.subplots(nrows=2, ncols=1, sharex=True)
# print(mean_spectrogram.shape)
# print(freq.shape)
ax[0].imshow(np.abs(mean_spectrogram), interpolation='nearest', cmap=col_map,
ax[0].imshow(np.abs(mean_spectrogram), interpolation='nearest',
extent=[freq[0], freq[-1], mean_spectrogram.shape[0], 0], **kwargs)
ax[0].set_title('Amplitude')
# ax[0].set_xticks(freq)
# ax[0].set_ylabel('UDVS Step')
ax[0].axis('tight')
ax[1].imshow(np.angle(mean_spectrogram), interpolation='nearest', cmap=col_map,
ax[1].imshow(np.angle(mean_spectrogram), interpolation='nearest',
extent=[freq[0], freq[-1], mean_spectrogram.shape[0], 0], **kwargs)
ax[1].set_title('Phase')
ax[1].set_xlabel('Frequency (kHz)')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment