Commit 76067522 authored by syz's avatar syz
Browse files

Direct call to visualizer.

parent 4f930cb3
......@@ -13,6 +13,7 @@ import numpy as np
from .hdf_utils import checkIfMain, get_attr, get_data_descriptor, get_formatted_labels, \
get_dimensionality, get_sort_order, get_unit_values, reshape_to_Ndims
from .io_utils import transformToReal
from ..viz.jupyter_utils import simple_ndim_visualizer
class PycroDataset(h5py.Dataset):
......@@ -339,3 +340,26 @@ class PycroDataset(h5py.Dataset):
return transformToReal(data_slice), success
else:
return data_slice, success
def visualize(self, slice_dict=None, **kwargs):
"""
Interactive visualization of this dataset. Only available on jupyter notebooks
Parameters
----------
slice_dict : dictionary, optional
Slicing instructions
"""
# TODO: Robust implementation that allows slicing
if len(self.pos_dim_labels + self.spec_dim_labels) > 4:
raise NotImplementedError('Unable to support visualization of more than 4 dimensions. Try slicing')
data_mat = self.get_n_dim_form()
pos_dim_names = self.pos_dim_labels[::-1]
spec_dim_names = self.spec_dim_labels
pos_dim_units_old = get_attr(self.h5_pos_inds, 'units')
spec_dim_units_old = get_attr(self.h5_spec_inds, 'units')
pos_ref_vals = get_unit_values(self.h5_pos_inds, self.h5_pos_vals, is_spec=False)
spec_ref_vals = get_unit_values(self.h5_spec_inds, self.h5_spec_vals, is_spec=True)
simple_ndim_visualizer(data_mat, pos_dim_names, pos_dim_units_old, spec_dim_names, spec_dim_units_old,
pos_ref_vals=pos_ref_vals, spec_ref_vals=spec_ref_vals, **kwargs)
......@@ -9,10 +9,11 @@ Submodules
be_viz_utils
plot_utils
jupyter_utils
"""
from . import tests
from . import plot_utils
from . import be_viz_utils
__all__ = ['plot_utils', 'be_viz_utils']
__all__ = ['plot_utils', 'be_viz_utils', 'jupyter_utils']
......@@ -73,7 +73,7 @@ def simple_ndim_visualizer(data_mat, pos_dim_names, pos_dim_units_old, spec_dim_
def get_slicing_tuple(slice_dict):
slice_list = []
for dim_name in pos_dim_names+spec_dim_names:
for dim_name in pos_dim_names + spec_dim_names:
cur_slice = slice(None)
if slice_dict[dim_name] is not None:
cur_slice = slice(slice_dict[dim_name], slice_dict[dim_name]+1)
......@@ -141,7 +141,7 @@ def simple_ndim_visualizer(data_mat, pos_dim_names, pos_dim_units_old, spec_dim_
for name, unit in zip(spec_dim_names, spec_dim_units_old):
spec_dim_units[name] = unit
data_type, data_names, data_funcs = check_data_type(data_mat)
data_type, data_names, data_funcs = check_data_type(data_mat)
sub_data = data_mat
component_name = 'Real'
......@@ -170,8 +170,8 @@ def simple_ndim_visualizer(data_mat, pos_dim_names, pos_dim_units_old, spec_dim_
axes[0].set_title('Spatial Map for\n' + component_title + '\n' + spec_titles)
if pos_plot_2d:
img_spat, cbar_spat = single_img_cbar_plot(axes[0], current_spatmap,
x_size=data_mat.shape[1], y_size=data_mat.shape[0],
clim=clims)
x_size=data_mat.shape[1], y_size=data_mat.shape[0],
clim=clims)
axes[0].set_xlabel(pos_dim_names[1] + ' (' + pos_dim_units_old[1] + ')')
axes[0].set_ylabel(pos_dim_names[0] + ' (' + pos_dim_units_old[0] + ')')
main_vert_line = axes[0].axvline(x=spgram_slicing[pos_dim_names[1]], color='k')
......@@ -193,9 +193,9 @@ def simple_ndim_visualizer(data_mat, pos_dim_names, pos_dim_units_old, spec_dim_
axes[1].set_xlabel(spec_dim_names[1] + ' (' + spec_dim_units_old[1] + ')')
axes[1].set_ylabel(spec_dim_names[0] + ' (' + spec_dim_units_old[0] + ')')
img_spec, cbar_spec = single_img_cbar_plot(axes[1], current_spgram,
x_size=data_mat.shape[len(pos_dim_names) + 1],
y_size=data_mat.shape[len(pos_dim_names)],
cbar_label=component_name, clim=clims)
x_size=data_mat.shape[len(pos_dim_names) + 1],
y_size=data_mat.shape[len(pos_dim_names)],
cbar_label=component_name, clim=clims)
else:
axes[1].set_xlabel(spec_xdim + ' (' + spec_dim_units[spec_xdim] + ')')
if current_spgram.shape[0] != spec_ref_vals[spec_xdim].size:
......@@ -218,7 +218,7 @@ def simple_ndim_visualizer(data_mat, pos_dim_names, pos_dim_units_old, spec_dim_
slice_dict['component'] = data_names
# stupid and hacky way of doing this:
global_vars = {'sub_data': sub_data, 'component_title':component_title}
global_vars = {'sub_data': sub_data, 'component_title': component_title}
def update_plots(**kwargs):
component_name = kwargs.get('component', None)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment