Commit 16a055f5 authored by Somnath, Suhas's avatar Somnath, Suhas
Browse files

More convenient plotting of clustering results

parent 9d338870
......@@ -15,6 +15,7 @@ from matplotlib.colors import LinearSegmentedColormap
import numpy as np
from mpl_toolkits.axes_grid1 import make_axes_locatable
from ..analysis.utils.be_loop import loopFitFunction
from pycroscopy.io.hdf_utils import reshape_to_Ndims, get_formatted_labels
def set_tick_font_size(axes, font_size):
......@@ -124,13 +125,14 @@ def plotLoopFitNGuess(Vdc, ds_proj_loops, ds_guess, ds_fit, title=''):
###############################################################################
def rainbowPlot(ax, ao_vec, ai_vec, num_steps=32):
def rainbowPlot(ax, ao_vec, ai_vec, num_steps=32, cmap=plt.cm.jet, **kwargs):
"""
Plots the input against the output waveform (typically loops).
The color of the curve changes as a function of time using the jet colorscheme
Inputs:
---------
Parameters
----------
ax : axis handle
Axis to plot the curve
ao_vec : 1D float numpy array
......@@ -139,21 +141,60 @@ def rainbowPlot(ax, ao_vec, ai_vec, num_steps=32):
vector that forms the Y axis
num_steps : unsigned int (Optional)
Number of discrete color steps
cmap : matplotlib.colors.LinearSegmentedColormap object
Colormap to be used
"""
pts_per_step = int(len(ai_vec) / num_steps)
for step in xrange(num_steps - 1):
ax.plot(ao_vec[step * pts_per_step:(step + 1) * pts_per_step],
ai_vec[step * pts_per_step:(step + 1) * pts_per_step],
color=plt.cm.jet(255 * step / num_steps))
color=cmap(255 * step / num_steps), **kwargs)
# plot the remainder:
ax.plot(ao_vec[(num_steps - 1) * pts_per_step:],
ai_vec[(num_steps - 1) * pts_per_step:],
color=plt.cm.jet(255 * num_steps / num_steps))
color=cmap(255 * num_steps / num_steps), **kwargs)
"""
CS3=plt.contourf([[0,0],[0,0]], range(0,310),cmap=plt.cm.jet)
fig.colorbar(CS3)"""
def plot_line_family(ax, x_axis, line_family, line_names=None, label_prefix='Line', label_suffix='', cmap=plt.cm.jet, **kwargs):
"""
Plots a family of lines with a sequence of colors
Parameters
----------
ax : axis handle
Axis to plot the curve
x_axis : array-like
Values to plot against
line_family : 2D numpy array
family of curves arranged as [curve_index, features]
line_names : array-like
array of string or numbers that represent the identity of each curve in the family
label_prefix : string / unicode
prefix for the legend (before the index of the curve)
label_suffix : string / unicode
suffix for the legend (after the index of the curve)
cmap : matplotlib.colors.LinearSegmentedColormap object
Colormap to be used
"""
num_lines = line_family.shape[0]
if line_names is None:
line_names = ['{} {} {}'.format(label_prefix, line_ind, label_suffix) for line_ind in range(num_lines)]
else:
if len(line_names) != num_lines:
warn('Line names of different length compared to provided dataset')
line_names = ['{} {} {}'.format(label_prefix, line_ind, label_suffix) for line_ind in range(num_lines)]
for line_ind in xrange(num_lines):
ax.plot(x_axis, line_family[line_ind],
label=line_names[line_ind],
color=cmap(int(255 * line_ind / (num_lines - 1))), **kwargs)
def plot_map(axis, data, stdevs=2, show_colorbar=False, **kwargs):
"""
Plots a 2d map with a tight z axis, with or without color bars.
......@@ -621,12 +662,54 @@ def plotLoadingMaps(loadings, num_comps=4, stdevs=2, show_colorbar=True, **kwarg
return fig202, axes202
# TODO: The label and units for the main dataset itself are missing in most cases! - ie. I don't know that the data is 'Current' and 'nA'
def plot_cluster_results(h5_group, y_spec_label):
h5_labels = h5_group['Labels']
try:
h5_mean_resp = h5_group['Mean_Response']
except KeyError:
# old PySPM format:
h5_mean_resp = h5_group['Centroids']
# Reshape the mean response to N dimensions
mean_response, success = reshape_to_Ndims(h5_mean_resp)
# unfortunately, we cannot use the above function for the labels
# However, we will assume that the position values are linked to the labels:
h5_pos_vals = h5_labels.file[h5_labels.attrs['Position_Values']]
h5_pos_inds = h5_labels.file[h5_labels.attrs['Position_Indices']]
# Reshape the labels correctly:
pos_dims = []
for col in range(h5_pos_inds.shape[1]):
pos_dims.append(np.unique(h5_pos_inds[:, col]).size)
pos_ticks = [h5_pos_vals[:pos_dims[0], 0], h5_pos_vals[slice(0,None,pos_dims[0]), 1]]
# prepare the axes ticks for the map
pos_dims.reverse() # go from slowest to fastest
pos_dims = tuple(pos_dims)
label_mat = np.reshape(h5_labels.value, pos_dims)
# Figure out the correct units and labels for mean response:
h5_spec_vals = h5_mean_resp.file[h5_mean_resp.attrs['Spectroscopic_Values']]
x_spec_label = get_formatted_labels(h5_spec_vals)[0]
# Figure out the correct axes labels for label map:
pos_labels = get_formatted_labels(h5_pos_vals)
plotClusterResults(label_mat, mean_response, spec_val=np.squeeze(h5_spec_vals[0]),
spec_label=x_spec_label, resp_label=y_spec_label,
pos_labels=pos_labels, pos_ticks=pos_ticks)
###############################################################################
# TODO: Pull the spectroscopic value from the h5 dataset if 1D and nothing is specified
# TODO: Pull the name of the spectroscopic axis as well
def plotClusterResults(label_mat, mean_response, spec_val=None, cmap=plt.cm.jet,
spec_label='Spectroscopic Value', resp_label='Response'):
spec_label='Spectroscopic Value', resp_label='Response',
pos_labels=('X', 'Y'), pos_ticks=None):
"""
Plot the cluster labels and mean response for each cluster
......@@ -660,11 +743,7 @@ def plotClusterResults(label_mat, mean_response, spec_val=None, cmap=plt.cm.jet,
"""
def __plotCentroids(centroids, ax, spec_val, spec_label, y_label, cmap, title=None):
num_clusters = centroids.shape[0]
for clust in xrange(num_clusters):
ax.plot(spec_val, centroids[clust],
label='Cluster {}'.format(clust),
color=cmap(int(255 * clust / (num_clusters - 1))))
plot_line_family(ax, spec_val, centroids, label_prefix='Cluster', cmap=cmap)
ax.set_ylabel(y_label)
# ax.legend(loc='best')
if title:
......@@ -709,7 +788,19 @@ def plotClusterResults(label_mat, mean_response, spec_val=None, cmap=plt.cm.jet,
nx = len(np.unique(pos[:, 0]))
ny = len(np.unique(pos[:, 1]))
label_mat = label_mat[()].reshape(nx, ny)
im = ax_map.imshow(label_mat, interpolation='none')
ax_map.set_xlabel(pos_labels[0])
ax_map.set_ylabel(pos_labels[1])
if pos_ticks is not None:
x_ticks = np.linspace(0, label_mat.shape[1] - 1, 5, dtype=np.uint16)
y_ticks = np.linspace(0, label_mat.shape[0] - 1, 5, dtype=np.uint16)
ax_map.set_xticks(x_ticks)
ax_map.set_yticks(y_ticks)
ax_map.set_xticklabels(pos_ticks[0][x_ticks])
ax_map.set_yticklabels(pos_ticks[1][y_ticks])
divider = make_axes_locatable(ax_map)
cax = divider.append_axes("right", size="5%", pad=0.05) # space for colorbar
fig.colorbar(im, cax=cax)
......@@ -726,26 +817,29 @@ def plotClusterResults(label_mat, mean_response, spec_val=None, cmap=plt.cm.jet,
###############################################################################
def plotKMeansClusters(label_mat, cluster_centroids,
num_cluster=4):
max_centroids=4, x_label='', y_label=''):
"""
Plots the provided label mat and centroids
from KMeans clustering
Plots the provided labels mat and centroids from clustering
Parameters:
-------------
Parameters
----------
label_mat : 2D int numpy array
structured as [rows, cols]
cluster_centroids: 2D real numpy array
structured as [cluster,features]
num_cluster : int
Number of centroids to plot
max_centroids : unsigned int
Number of centroids to plot
x_label : String / unicode
X label for centroid plots
y_label : String / unicode
Y label for centroid plots
Returns:
---------
Returns
-------
fig
"""
if num_cluster < 5:
if max_centroids < 5:
fig501 = plt.figure(figsize=(20, 10))
fax1 = plt.subplot2grid((2, 4), (0, 0), colspan=2, rowspan=2)
......@@ -774,20 +868,19 @@ def plotKMeansClusters(label_mat, cluster_centroids,
fig501.tight_layout()
axes_handles = [fax1, fax2, fax3, fax4, fax5, fax6, fax7, fax8, fax9, fax10]
# Plot results
for ax, index in zip(axes_handles[0:num_cluster + 1], np.arange(num_cluster + 1)):
if index == 0:
im = ax.imshow(label_mat, interpolation='none')
ax.set_title('K-means Cluster Map')
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.05) # space for colorbar
plt.colorbar(im, cax=cax)
else:
# ax.plot(Vdc_vec, cluster_centroids[index-1,:], 'g-')
ax.plot(cluster_centroids[index - 1, :], 'g-')
ax.set_xlabel('Voltage (V)')
ax.set_ylabel('Current (arb.)')
ax.set_title('K-means Centroid: %d' % (index))
# First plot the labels map:
im = fax1.imshow(label_mat, interpolation='none')
fax1.set_title('K-means Cluster Map')
divider = make_axes_locatable(fax1)
cax = divider.append_axes("right", size="5%", pad=0.05) # space for colorbar
plt.colorbar(im, cax=cax)
# Plot results
for ax, index in zip(axes_handles[1: max_centroids + 1], np.arange(1, max_centroids + 1)):
ax.plot(cluster_centroids[index - 1, :], 'g-')
ax.set_xlabel(x_label)
ax.set_ylabel(y_label)
ax.set_title('Centroid: %d' % index)
fig501.subplots_adjust(hspace=0.60, wspace=0.60)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment