Commit 870c0a0e authored by syz's avatar syz
Browse files

Fixed bug in discrete_cmap + changed cmap arg name

parent 4e149f7e
......@@ -115,7 +115,7 @@ def cbar_for_line_plot(axis, num_steps, discrete_ticks=True, **kwargs):
Whether or not to have the ticks match the number of number of steps. Default = True
"""
cmap = get_cmap_object(kwargs.pop('cmap', None))
cmap = discrete_cmap(num_steps, base_cmap=cmap.name)
cmap = discrete_cmap(num_steps, cmap=cmap.name)
sm = make_scalar_mappable(0, num_steps - 1, cmap=cmap, **kwargs)
......@@ -261,20 +261,20 @@ def cmap_hot_desaturated():
return cmap_from_rgba('hot_desaturated', hot_desaturated, 255)
def discrete_cmap(num_bins, base_cmap=default_cmap):
def discrete_cmap(num_bins, cmap=None):
"""
Create an N-bin discrete colormap from the specified input map
Create an N-bin discrete colormap from the specified input map specified
Parameters
----------
num_bins : unsigned int
Number of discrete bins
base_cmap : matplotlib.colors.LinearSegmentedColormap object
cmap : matplotlib.colors.Colormap object
Base color map to discretize
Returns
-------
new_cmap : String or matplotlib.colors.LinearSegmentedColormap object
new_cmap : matplotlib.colors.LinearSegmentedColormap object
Discretized color map
Notes
......@@ -283,16 +283,17 @@ def discrete_cmap(num_bins, base_cmap=default_cmap):
https://gist.github.com/jakevdp/91077b0cae40f8f8244a
"""
if base_cmap is None:
base_cmap = default_cmap.name
if cmap is None:
cmap = default_cmap.name
elif isinstance(base_cmap, type(default_cmap)):
base_cmap = base_cmap.name
elif not isinstance(cmap, str):
# could not figure out a better type check
cmap = cmap.name
if type(base_cmap) == str:
return plt.get_cmap(base_cmap, num_bins)
if type(cmap) == str:
return plt.get_cmap(cmap, num_bins)
return base_cmap
return cmap
def rainbow_plot(axis, x_vec, y_vec, num_steps=32, **kwargs):
......@@ -376,7 +377,7 @@ def plot_line_family(axis, x_vec, line_family, line_names=None, label_prefix='',
if show_cbar:
# put back the cmap parameter:
kwargs.update({'cmap': cmap})
cbar_for_line_plot(axis, num_lines, **kwargs)
_ = cbar_for_line_plot(axis, num_lines, **kwargs)
def plot_map(axis, data, stdevs=None, origin='lower', **kwargs):
......@@ -1062,7 +1063,7 @@ def plot_cluster_results_together(label_mat, mean_response, spec_val=None, cmap=
fig.colorbar(im, cax=cax, ticks=np.arange(num_clusters),
cmap=discrete_cmap(num_clusters, base_cmap=plt.cm.viridis))
ax_map.axis('tight')"""
pcol0 = ax_map.pcolor(label_mat, cmap=discrete_cmap(num_clusters, base_cmap=cmap))
pcol0 = ax_map.pcolor(label_mat, cmap=discrete_cmap(num_clusters, cmap=cmap))
fig.colorbar(pcol0, ax=ax_map, ticks=np.arange(num_clusters))
ax_map.axis('tight')
ax_map.set_aspect('auto')
......@@ -1137,7 +1138,7 @@ def plot_cluster_results_separate(label_mat, cluster_centroids, max_centroids=4,
axes_handles = [fax1, fax2, fax3, fax4, fax5, fax6, fax7, fax8, fax9, fax10]
# First plot the labels map:
pcol0 = fax1.pcolor(label_mat, cmap=discrete_cmap(cluster_centroids.shape[0], base_cmap=cmap))
pcol0 = fax1.pcolor(label_mat, cmap=discrete_cmap(cluster_centroids.shape[0], cmap=cmap))
fig501.colorbar(pcol0, ax=fax1, ticks=np.arange(cluster_centroids.shape[0]))
fax1.axis('tight')
fax1.set_aspect('auto')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment