jupyter_utils.py 12.6 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
"""
Created on 11/11/16 10:08 AM
@author: Suhas Somnath, Chris Smith
"""

import matplotlib.pyplot as plt
from IPython.display import display
import ipywidgets as widgets
import numpy as np

from .plot_utils import single_img_cbar_plot


def simple_ndim_visualizer(data_mat, pos_dim_names, pos_dim_units_old, spec_dim_names, spec_dim_units_old,
                           pos_ref_vals=None, spec_ref_vals=None, pos_plot_2d=True, spec_plot_2d=True, spec_xdim=None,
                           pos_xdim=None):
    """
    Generates a simple visualizer for visualizing simple datasets (up to 4 dimensions). The visualizer will ONLY work
    within the context of a jupyter notebook!

    The visualizer consists of two panels - spatial map and spectrograms. slider widgets will be generated to slice
    dimensions. The data matrix can be real, complex or compound valued

    Parameters
    ----------
    data_mat : numpy.array object
        Data to be visualized
    pos_dim_names : list of strings
        Names of the position dimensions
    pos_dim_units_old : list of strings
        Units for the position dimension
    spec_dim_names : list of strings
        Names of the spectroscopic dimensions
    spec_dim_units_old : list of strings
        Units for the spectroscopic dimensions
    pos_ref_vals : dictionary, optional
        Dictionary of names and reference values for each of the position dimensions.
        Default - linear distribution for each dimension
    spec_ref_vals : dictionary, optional
        Dictionary of names and reference values for each of the spectroscopic dimensions.
        Default - linear distribution for each dimension
    pos_plot_2d : bool, optional
        Whether or not to plot spatial data as 2D images. Default = True
    spec_plot_2d : bool, optional
        Whether or not to plot spectral data as 2D images. Default = True
    spec_xdim : str, optional
        Name of dimension with respect to which the spectral data will be plotted for 1D plots
    pos_xdim : str, optional
        Name of dimension with respect to which the position data will be plotted for 1D plots
    """
    def check_data_type(data_mat):
        if data_mat.dtype.names is not None:
            return 2, list(data_mat.dtype.names), None
        if data_mat.dtype in [np.complex64, np.complex128, np.complex]:
            return 1, ['Real','Imaginary', 'Amplitude','Phase'], [np.real, np.imag, np.abs, np.angle]
        else:
            return 0, None, None

    def get_clims(data, stdev=2):
        avg = np.mean(data)
        std = np.std(data)
        return (avg -stdev*std, avg + stdev*std)

    def get_slice_string(slice_dict, dim_names, values_dict, units_dict):
        slice_str = ''
        for cur_name in dim_names:
            if cur_name in dim_names:
                slice_str += '{} = {} {}\n'.format(cur_name,
                                                 values_dict[cur_name][slice_dict[cur_name]],
                                                 units_dict[cur_name])
        slice_str = slice_str[:-1]
        return slice_str

    def get_slicing_tuple(slice_dict):
        slice_list = []
        for dim_name in pos_dim_names+spec_dim_names:
            cur_slice = slice(None)
            if slice_dict[dim_name] is not None:
                cur_slice = slice(slice_dict[dim_name], slice_dict[dim_name]+1)
            slice_list.append(cur_slice)
        return tuple(slice_list)

    def naive_slice(data_mat, slice_dict):
        return np.squeeze(data_mat[get_slicing_tuple(slice_dict)])

    def get_spatmap_slice_dict(slice_dict={}):
        spatmap_slicing = {}
        for name in pos_dim_names:
            spatmap_slicing[name] = None
        for ind, name in enumerate(spec_dim_names):
            spatmap_slicing[name] = slice_dict.get(name, data_mat.shape[ind + len(pos_dim_names)] // 2)
        return spatmap_slicing

    def get_spgram_slice_dict(slice_dict={}):
        spgram_slicing = {}
        for ind, name in enumerate(pos_dim_names):
            spgram_slicing[name] = slice_dict.get(name, data_mat.shape[ind] // 2)
        for name in spec_dim_names:
            spgram_slicing[name] = None
        return spgram_slicing

    def update_image(img_handle, data_mat, slice_dict, twoD=True):
        if twoD:
            img_handle.set_data(naive_slice(data_mat, slice_dict))
        else:
            y_mat = naive_slice(data_mat, slice_dict)
            if y_mat.ndim > 1:
                if y_mat.shape[0] != len(img_handle):
                    y_mat = y_mat.T
            for line_handle, y_vec in zip(img_handle, y_mat):
                line_handle.set_ydata(y_vec)
            img_handle[0].get_axes().set_ylim([np.min(y_mat), np.max(y_mat)])

    # ###########################################################################

    pos_plot_2d = pos_plot_2d and len(pos_dim_names) > 1
    spec_plot_2d = spec_plot_2d and len(spec_dim_names) > 1

    if not spec_plot_2d and spec_xdim is None:
        # Take the largest dimension you can find:
        spec_xdim = spec_dim_names[np.argmax(data_mat.shape[len(pos_dim_names):])]

    if not pos_plot_2d and pos_xdim is None:
        # Take the largest dimension you can find:
        pos_xdim = pos_dim_names[np.argmax(data_mat.shape[:len(pos_dim_names)])]

    if pos_ref_vals is None:
        spec_ref_vals = {}
        for ind, name in enumerate(pos_dim_names):
            spec_ref_vals[name] = np.arange(data_mat.shape[ind + len(pos_dim_names)])

    if spec_ref_vals is None:
        pos_ref_vals = {}
        for ind, name in enumerate(pos_dim_names):
            pos_ref_vals[name] = np.arange(data_mat.shape[ind])

    pos_dim_units = {}
    spec_dim_units = {}
    for name, unit in zip(pos_dim_names, pos_dim_units_old):
        pos_dim_units[name] = unit
    for name, unit in zip(spec_dim_names, spec_dim_units_old):
        spec_dim_units[name] = unit

    data_type, data_names, data_funcs  = check_data_type(data_mat)

    sub_data = data_mat
    component_name = 'Real'

    if data_type == 1:
        sub_data = data_funcs[0](data_mat)
        component_name = data_names[0]
    elif data_type == 2:
        component_name = data_names[0]
        sub_data = data_mat[component_name]

    component_title = 'Component: ' + component_name

    clims = get_clims(sub_data)

    spatmap_slicing = get_spatmap_slice_dict()
    current_spatmap = naive_slice(sub_data, spatmap_slicing)
    spgram_slicing = get_spgram_slice_dict()
    current_spgram = naive_slice(sub_data, spgram_slicing)

    # print(current_spatmap.shape, current_spgram.shape)

    fig, axes = plt.subplots(ncols=2, figsize=(14,7))
    # axes[0].hold(True)
    spec_titles = get_slice_string(spatmap_slicing, spec_dim_names, spec_ref_vals, spec_dim_units)
    axes[0].set_title('Spatial Map for\n' + component_title + '\n' + spec_titles)
    if pos_plot_2d:
        img_spat, cbar_spat = single_img_cbar_plot(axes[0], current_spatmap,
                                                                 x_size=data_mat.shape[1], y_size=data_mat.shape[0],
                                                                 clim=clims)
        axes[0].set_xlabel(pos_dim_names[1] + ' (' + pos_dim_units_old[1] + ')')
        axes[0].set_ylabel(pos_dim_names[0] + ' (' + pos_dim_units_old[0] + ')')
        main_vert_line = axes[0].axvline(x=spgram_slicing[pos_dim_names[1]], color='k')
        main_hor_line = axes[0].axhline(y=spgram_slicing[pos_dim_names[0]], color='k')
    else:
        axes[0].set_xlabel(pos_xdim + ' (' + pos_dim_units[pos_xdim] + ')')
        if current_spatmap.shape[0] != pos_ref_vals[pos_xdim].size:
            current_spatmap = current_spatmap.T
        img_spat = axes[0].plot(pos_ref_vals[pos_xdim], current_spatmap)
        if current_spatmap.ndim > 1:
            other_pos_dim = pos_dim_names.copy()
            other_pos_dim.remove(pos_xdim)
            other_pos_dim = other_pos_dim[0]
            axes[0].legend(pos_ref_vals[other_pos_dim])

    pos_titles = get_slice_string(spgram_slicing, pos_dim_names, pos_ref_vals, pos_dim_units)
    axes[1].set_title('Spectrogram for\n' + component_title + '\n' + pos_titles)
    if spec_plot_2d:
        axes[1].set_xlabel(spec_dim_names[1] + ' (' + spec_dim_units_old[1] + ')')
        axes[1].set_ylabel(spec_dim_names[0] + ' (' + spec_dim_units_old[0] + ')')
        img_spec, cbar_spec = single_img_cbar_plot(axes[1], current_spgram,
                                                                 x_size=data_mat.shape[len(pos_dim_names) + 1],
                                                                 y_size=data_mat.shape[len(pos_dim_names)],
                                                                 cbar_label=component_name, clim=clims)
    else:
        axes[1].set_xlabel(spec_xdim + ' (' + spec_dim_units[spec_xdim] + ')')
        if current_spgram.shape[0] != spec_ref_vals[spec_xdim].size:
            current_spgram = current_spgram.T
        img_spec = axes[1].plot(spec_ref_vals[spec_xdim], current_spgram)
        if current_spgram.ndim > 1:
            other_spec_dim = spec_dim_names.copy()
            other_spec_dim.remove(spec_xdim)
            other_spec_dim = other_spec_dim[0]
            axes[1].legend(spec_ref_vals[other_spec_dim])

    fig.tight_layout()

    slice_dict = {}
    for dim_ind, dim_name in enumerate(pos_dim_names):
        slice_dict[dim_name] = (0, sub_data.shape[dim_ind] -1, 1)
    for dim_ind, dim_name in enumerate(spec_dim_names):
        slice_dict[dim_name] = (0, sub_data.shape[dim_ind + len(pos_dim_names)] - 1, 1)
    if data_type > 0:
        slice_dict['component'] = data_names

    # stupid and hacky way of doing this:
    global_vars = {'sub_data': sub_data, 'component_title':component_title}

    def update_plots(**kwargs):
        component_name = kwargs.get('component', None)
        if component_name is not None:
            if component_name != slice_dict['component']:
                # update the data and title:
                if data_type == 1:
                    func_ind = data_names.index(component_name)
                    sub_data = data_funcs[func_ind](data_mat)
                elif data_type == 2:
                    sub_data = data_mat[component_name]
                component_title = 'Component: ' + component_name
                # sub data and component_title here are now local, update gobal vars!
                global_vars.update({'sub_data': sub_data, 'component_title': component_title})

                clims = get_clims(sub_data)
                update_image(img_spat, sub_data, spatmap_slicing, twoD=pos_plot_2d)
                if pos_plot_2d:
                    img_spat.set_clim(clims)
                update_image(img_spec, sub_data, spgram_slicing, twoD=spec_plot_2d)
                if spec_plot_2d:
                    img_spec.set_clim(clims)

                spec_titles = get_slice_string(spatmap_slicing, spec_dim_names, spec_ref_vals, spec_dim_units)
                axes[0].set_title('Spatial Map for\n' + component_title + '\n' + spec_titles)
                pos_titles = get_slice_string(spgram_slicing, pos_dim_names, pos_ref_vals, pos_dim_units)
                axes[1].set_title('Spectrogram for\n' + component_title + '\n' + pos_titles)
                # print('Updated component!')

        # Check to see if spectrogram needs to be updated:
        update_spgram = False
        for dim_name in pos_dim_names:
            if kwargs[dim_name] != slice_dict[dim_name]:
                update_spgram = True
                break
        if update_spgram:
            # print('updating spectrogam + crosshairs')
            spgram_slicing.update(get_spgram_slice_dict(slice_dict=kwargs))
            update_image(img_spec, global_vars['sub_data'], spgram_slicing, twoD=spec_plot_2d)
            pos_titles = get_slice_string(spgram_slicing, pos_dim_names, pos_ref_vals, pos_dim_units)
            axes[1].set_title('Spectrogram for\n' + global_vars['component_title'] + '\n' + pos_titles)
            if pos_plot_2d:
                main_vert_line.set_xdata(spgram_slicing[pos_dim_names[1]])
                main_hor_line.set_ydata(spgram_slicing[pos_dim_names[0]])

        update_spatmap = False
        for dim_name in spec_dim_names:
            if kwargs[dim_name] != slice_dict[dim_name]:
                update_spatmap = True
                break
        if update_spatmap:
            # print('updating spatial map')
            spatmap_slicing.update(get_spatmap_slice_dict(slice_dict=kwargs))
            update_image(img_spat, global_vars['sub_data'], spatmap_slicing, twoD=pos_plot_2d)
            spec_titles = get_slice_string(spatmap_slicing, spec_dim_names, spec_ref_vals, spec_dim_units)
            axes[0].set_title('Spatial Map for\n' + global_vars['component_title'] + '\n' + spec_titles)

        slice_dict.update(kwargs)
        display(fig)

    widgets.interact(update_plots, **slice_dict);