Commit 905b1ae2 authored by Somnath, Suhas's avatar Somnath, Suhas Committed by CompPhysChris
Browse files

Implemented a module-wide default colormap and an easy function to get a colormap

parent 620d3693
...@@ -24,6 +24,29 @@ from ..io.hdf_utils import reshape_to_Ndims, get_formatted_labels, get_data_desc ...@@ -24,6 +24,29 @@ from ..io.hdf_utils import reshape_to_Ndims, get_formatted_labels, get_data_desc
if sys.version_info.major == 3: if sys.version_info.major == 3:
unicode = str unicode = str
default_cmap = plt.cm.virdis
def get_cmap_object(cmap):
"""
Get the matplotlib.colors.LinearSegmentedColormap object regardless of the input
Parameters
----------
cmap : String, or matplotlib.colors.LinearSegmentedColormap object (Optional)
Requested color map
Returns
-------
cmap : matplotlib.colors.LinearSegmentedColormap object
Requested / Default colormap object
"""
if cmap is None:
return default_cmap
elif isinstance(cmap, str):
return plt.get_cmap(cmap)
return cmap
def set_tick_font_size(axes, font_size): def set_tick_font_size(axes, font_size):
""" """
Sets the font size of the ticks in the provided axes Sets the font size of the ticks in the provided axes
...@@ -64,7 +87,7 @@ def cmap_jet_white_center(): ...@@ -64,7 +87,7 @@ def cmap_jet_white_center():
Returns Returns
------- -------
white_jet : matplotlib.colors.LinearSegmentedColormap object white_jet : matplotlib.colors.LinearSegmentedColormap object
color map object that can be used in place of plt.cm.viridis color map object that can be used in place of the default colormap
""" """
# For red - central column is like brightness # For red - central column is like brightness
# For blue - last column is like brightness # For blue - last column is like brightness
...@@ -172,7 +195,7 @@ def cmap_hot_desaturated(): ...@@ -172,7 +195,7 @@ def cmap_hot_desaturated():
return cmap_from_rgba('hot_desaturated', hot_desaturated, 255) return cmap_from_rgba('hot_desaturated', hot_desaturated, 255)
def discrete_cmap(num_bins, base_cmap=None): def discrete_cmap(num_bins, base_cmap=default_cmap):
""" """
Create an N-bin discrete colormap from the specified input map Create an N-bin discrete colormap from the specified input map
...@@ -195,9 +218,9 @@ def discrete_cmap(num_bins, base_cmap=None): ...@@ -195,9 +218,9 @@ def discrete_cmap(num_bins, base_cmap=None):
""" """
if base_cmap is None: if base_cmap is None:
base_cmap = 'viridis' base_cmap = default_cmap.name
if type(base_cmap) == type(plt.cm.viridis): elif isinstance(base_cmap, type(default_cmap)):
base_cmap = base_cmap.name base_cmap = base_cmap.name
if type(base_cmap) == str: if type(base_cmap) == str:
...@@ -233,7 +256,8 @@ def _add_loop_parameters(axes, switching_coef_vec): ...@@ -233,7 +256,8 @@ def _add_loop_parameters(axes, switching_coef_vec):
return axes return axes
def rainbow_plot(ax, ao_vec, ai_vec, num_steps=32, cmap=plt.cm.viridis, **kwargs):
def rainbow_plot(ax, ao_vec, ai_vec, num_steps=32, cmap=default_cmap, **kwargs):
""" """
Plots the input against the output waveform (typically loops). Plots the input against the output waveform (typically loops).
The color of the curve changes as a function of time using the jet colorscheme The color of the curve changes as a function of time using the jet colorscheme
...@@ -251,6 +275,8 @@ def rainbow_plot(ax, ao_vec, ai_vec, num_steps=32, cmap=plt.cm.viridis, **kwargs ...@@ -251,6 +275,8 @@ def rainbow_plot(ax, ao_vec, ai_vec, num_steps=32, cmap=plt.cm.viridis, **kwargs
cmap : matplotlib.colors.LinearSegmentedColormap object cmap : matplotlib.colors.LinearSegmentedColormap object
Colormap to be used Colormap to be used
""" """
cmap = get_cmap_object(cmap)
pts_per_step = int(len(ai_vec) / num_steps) pts_per_step = int(len(ai_vec) / num_steps)
for step in range(num_steps - 1): for step in range(num_steps - 1):
ax.plot(ao_vec[step * pts_per_step:(step + 1) * pts_per_step], ax.plot(ao_vec[step * pts_per_step:(step + 1) * pts_per_step],
...@@ -266,7 +292,7 @@ def rainbow_plot(ax, ao_vec, ai_vec, num_steps=32, cmap=plt.cm.viridis, **kwargs ...@@ -266,7 +292,7 @@ def rainbow_plot(ax, ao_vec, ai_vec, num_steps=32, cmap=plt.cm.viridis, **kwargs
def plot_line_family(axis, x_axis, line_family, line_names=None, label_prefix='Line', label_suffix='', def plot_line_family(axis, x_axis, line_family, line_names=None, label_prefix='Line', label_suffix='',
cmap=plt.cm.viridis, y_offset=0, **kwargs): cmap=default_cmap, y_offset=0, **kwargs):
""" """
Plots a family of lines with a sequence of colors Plots a family of lines with a sequence of colors
...@@ -289,6 +315,8 @@ def plot_line_family(axis, x_axis, line_family, line_names=None, label_prefix='L ...@@ -289,6 +315,8 @@ def plot_line_family(axis, x_axis, line_family, line_names=None, label_prefix='L
y_offset : (optional) number y_offset : (optional) number
quantity by which the lines are offset from each other vertically (useful for spectra) quantity by which the lines are offset from each other vertically (useful for spectra)
""" """
cmap = get_cmap_object(cmap)
num_lines = line_family.shape[0] num_lines = line_family.shape[0]
if line_names is None: if line_names is None:
...@@ -468,7 +496,8 @@ def plot_loops(excit_wfm, datasets, line_colors=[], dataset_names=[], evenly_spa ...@@ -468,7 +496,8 @@ def plot_loops(excit_wfm, datasets, line_colors=[], dataset_names=[], evenly_spa
############################################################################### ###############################################################################
def plot_complex_map_stack(map_stack, num_comps=4, title='Eigenvectors', xlabel='UDVS Step', stdevs=2): def plot_complex_map_stack(map_stack, num_comps=4, title='Eigenvectors', xlabel='UDVS Step', stdevs=2,
cmap=default_cmap):
""" """
Plots the provided spectrograms from SVD V vector Plots the provided spectrograms from SVD V vector
...@@ -484,11 +513,15 @@ def plot_complex_map_stack(map_stack, num_comps=4, title='Eigenvectors', xlabel= ...@@ -484,11 +513,15 @@ def plot_complex_map_stack(map_stack, num_comps=4, title='Eigenvectors', xlabel=
Label for x axis Label for x axis
stdevs : int stdevs : int
Number of standard deviations to consider for plotting Number of standard deviations to consider for plotting
cmap : String, or matplotlib.colors.LinearSegmentedColormap object (Optional)
Requested color map
Returns Returns
--------- ---------
fig, axes fig, axes
""" """
cmap = get_cmap_object(cmap)
fig201, axes201 = plt.subplots(2, num_comps, figsize=(4 * num_comps, 8)) fig201, axes201 = plt.subplots(2, num_comps, figsize=(4 * num_comps, 8))
fig201.subplots_adjust(hspace=0.4, wspace=0.4) fig201.subplots_adjust(hspace=0.4, wspace=0.4)
fig201.canvas.set_window_title(title) fig201.canvas.set_window_title(title)
...@@ -501,7 +534,7 @@ def plot_complex_map_stack(map_stack, num_comps=4, title='Eigenvectors', xlabel= ...@@ -501,7 +534,7 @@ def plot_complex_map_stack(map_stack, num_comps=4, title='Eigenvectors', xlabel=
for func, lab, ax in zip(funcs, labels, axes): for func, lab, ax in zip(funcs, labels, axes):
amp_mean = np.mean(func(cur_map)) amp_mean = np.mean(func(cur_map))
amp_std = np.std(func(cur_map)) amp_std = np.std(func(cur_map))
ax.imshow(func(cur_map), cmap='inferno', ax.imshow(func(cur_map), cmap=cmap,
vmin=amp_mean - stdevs * amp_std, vmin=amp_mean - stdevs * amp_std,
vmax=amp_mean + stdevs * amp_std) vmax=amp_mean + stdevs * amp_std)
ax.set_title('Eigenvector: %d - %s' % (index + 1, lab)) ax.set_title('Eigenvector: %d - %s' % (index + 1, lab))
...@@ -707,7 +740,7 @@ def plot_map_stack(map_stack, num_comps=9, stdevs=2, color_bar_mode=None, evenly ...@@ -707,7 +740,7 @@ def plot_map_stack(map_stack, num_comps=9, stdevs=2, color_bar_mode=None, evenly
return fig202, axes202 return fig202, axes202
def plot_cluster_h5_group(h5_group, centroids_together=True, cmap=None): def plot_cluster_h5_group(h5_group, centroids_together=True, cmap=default_cmap):
""" """
Plots the cluster labels and mean response for each cluster Plots the cluster labels and mean response for each cluster
...@@ -776,7 +809,7 @@ def plot_cluster_h5_group(h5_group, centroids_together=True, cmap=None): ...@@ -776,7 +809,7 @@ def plot_cluster_h5_group(h5_group, centroids_together=True, cmap=None):
############################################################################### ###############################################################################
def plot_cluster_results_together(label_mat, mean_response, spec_val=None, cmap=None, def plot_cluster_results_together(label_mat, mean_response, spec_val=None, cmap=default_cmap,
spec_label='Spectroscopic Value', resp_label='Response', spec_label='Spectroscopic Value', resp_label='Response',
pos_labels=('X', 'Y'), pos_ticks=None): pos_labels=('X', 'Y'), pos_ticks=None):
""" """
...@@ -814,11 +847,7 @@ def plot_cluster_results_together(label_mat, mean_response, spec_val=None, cmap= ...@@ -814,11 +847,7 @@ def plot_cluster_results_together(label_mat, mean_response, spec_val=None, cmap=
axes : 1D array_like of axes objects axes : 1D array_like of axes objects
Axes of the individual plots within `fig` Axes of the individual plots within `fig`
""" """
if cmap is None: cmap = get_cmap_object(cmap)
cmap = plt.cm.viridis
else:
if isinstance(cmap, str):
cmap = plt.get_cmap(cmap)
if isinstance(cmap, str): if isinstance(cmap, str):
cmap = plt.get_cmap(cmap) cmap = plt.get_cmap(cmap)
...@@ -900,7 +929,7 @@ def plot_cluster_results_together(label_mat, mean_response, spec_val=None, cmap= ...@@ -900,7 +929,7 @@ def plot_cluster_results_together(label_mat, mean_response, spec_val=None, cmap=
############################################################################### ###############################################################################
def plot_cluster_results_separate(label_mat, cluster_centroids, max_centroids=4, cmap=None, def plot_cluster_results_separate(label_mat, cluster_centroids, max_centroids=4, cmap=default_cmap,
spec_val=None, x_label='Excitation (a.u.)', y_label='Response (a.u.)'): spec_val=None, x_label='Excitation (a.u.)', y_label='Response (a.u.)'):
""" """
Plots the provided labels mat and centroids from clustering Plots the provided labels mat and centroids from clustering
...@@ -928,11 +957,7 @@ def plot_cluster_results_separate(label_mat, cluster_centroids, max_centroids=4, ...@@ -928,11 +957,7 @@ def plot_cluster_results_separate(label_mat, cluster_centroids, max_centroids=4,
fig fig
""" """
if cmap is None: cmap = get_cmap_object(cmap)
cmap = plt.cm.viridis
else:
if isinstance(cmap, str):
cmap = plt.get_cmap(cmap)
if max_centroids < 5: if max_centroids < 5:
...@@ -1128,7 +1153,7 @@ def plot_1d_spectrum(data_vec, freq, title, figure_path=None): ...@@ -1128,7 +1153,7 @@ def plot_1d_spectrum(data_vec, freq, title, figure_path=None):
############################################################################### ###############################################################################
def plot_2d_spectrogram(mean_spectrogram, freq, title, cmap=None, figure_path=None, **kwargs): def plot_2d_spectrogram(mean_spectrogram, freq, title, figure_path=None, **kwargs):
""" """
Plots the position averaged spectrogram Plots the position averaged spectrogram
...@@ -1140,8 +1165,6 @@ def plot_2d_spectrogram(mean_spectrogram, freq, title, cmap=None, figure_path=No ...@@ -1140,8 +1165,6 @@ def plot_2d_spectrogram(mean_spectrogram, freq, title, cmap=None, figure_path=No
BE frequency that serves as the X axis of the plot BE frequency that serves as the X axis of the plot
title : String title : String
Plot group name Plot group name
cmap : matplotlib.colors.LinearSegmentedColormap object
color map. Default = plt.cm.viridis
figure_path : String / Unicode figure_path : String / Unicode
Absolute path of the file to write the figure to Absolute path of the file to write the figure to
...@@ -1157,22 +1180,17 @@ def plot_2d_spectrogram(mean_spectrogram, freq, title, cmap=None, figure_path=No ...@@ -1157,22 +1180,17 @@ def plot_2d_spectrogram(mean_spectrogram, freq, title, cmap=None, figure_path=No
print('2D:', mean_spectrogram.shape, freq.shape) print('2D:', mean_spectrogram.shape, freq.shape)
return return
"""cmap = kwargs.get('cmap')
kwargs.pop('cmap')"""
if cmap is None: # unpack from kwargs instead
col_map = plt.cm.viridis # overriding default
freq *= 1E-3 # to kHz freq *= 1E-3 # to kHz
fig, ax = plt.subplots(nrows=2, ncols=1, sharex=True) fig, ax = plt.subplots(nrows=2, ncols=1, sharex=True)
# print(mean_spectrogram.shape) # print(mean_spectrogram.shape)
# print(freq.shape) # print(freq.shape)
ax[0].imshow(np.abs(mean_spectrogram), interpolation='nearest', cmap=col_map, ax[0].imshow(np.abs(mean_spectrogram), interpolation='nearest',
extent=[freq[0], freq[-1], mean_spectrogram.shape[0], 0], **kwargs) extent=[freq[0], freq[-1], mean_spectrogram.shape[0], 0], **kwargs)
ax[0].set_title('Amplitude') ax[0].set_title('Amplitude')
# ax[0].set_xticks(freq) # ax[0].set_xticks(freq)
# ax[0].set_ylabel('UDVS Step') # ax[0].set_ylabel('UDVS Step')
ax[0].axis('tight') ax[0].axis('tight')
ax[1].imshow(np.angle(mean_spectrogram), interpolation='nearest', cmap=col_map, ax[1].imshow(np.angle(mean_spectrogram), interpolation='nearest',
extent=[freq[0], freq[-1], mean_spectrogram.shape[0], 0], **kwargs) extent=[freq[0], freq[-1], mean_spectrogram.shape[0], 0], **kwargs)
ax[1].set_title('Phase') ax[1].set_title('Phase')
ax[1].set_xlabel('Frequency (kHz)') ax[1].set_xlabel('Frequency (kHz)')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment