This GitLab instance is undergoing maintenance and is operating in read-only mode.

You are on a read-only GitLab instance.
fitter.py 16.1 KB
Newer Older
1
# -*- coding: utf-8 -*-
2
"""
3
4
:class:`~pycroscopy.analysis.fitter.Fitter` - Abstract class that provides the
framework for building application-specific children classes
5

6
Created on Thu Aug 15 11:48:53 2019
7

8
@author: Suhas Somnath
9
"""
10
11
from __future__ import division, print_function, absolute_import, \
    unicode_literals
12
import numpy as np
13
14
15
from warnings import warn
import joblib
from scipy.optimize import least_squares
16

17
18
19
from pyUSID.processing.comp_utils import recommend_cpu_cores
from pyUSID.processing.process import Process
from pyUSID.io.usi_data import USIDataset
20

21
# TODO: All reading, holding operations should use Dask arrays
22

23

24
class Fitter(Process):
25

26
    def __init__(self, h5_main, proc_name, variables=None, **kwargs):
27
        """
28
        Creates a new instance of the abstract Fitter class
29

Chris Smith's avatar
Chris Smith committed
30
        Parameters
31
32
33
        ----------
        h5_main : h5py.Dataset or pyUSID.io.USIDataset object
            Main datasets whose one or dimensions will be reduced
34
35
        proc_name : str or unicode
            Name of the child process
36
37
        variables : str or list, optional
            List of spectroscopic dimension names that will be reduced
38
39
40
41
42
        h5_target_group : h5py.Group, optional. Default = None
            Location where to look for existing results and to place newly
            computed results. Use this kwarg if the results need to be written
            to a different HDF5 file. By default, this value is set to the
            parent group containing `h5_main`
43
44
45
        kwargs : dict
            Keyword arguments that will be passed on to
            pyUSID.processing.process.Process
46
        """
47

48
        super(Fitter, self).__init__(h5_main, proc_name, **kwargs)
49
50
51

        # Validate other arguments / kwargs here:
        if variables is not None:
52
53
54
55
56
57
58
            if isinstance(variables, str):
                variables = [variables]
            if not isinstance(variables, (list, tuple)):
                raise TypeError('variables should be a string / list or tuple'
                                'of strings. Provided object was of type: {}'
                                ''.format(type(variables)))
            if not all([dim in self.h5_main.spec_dim_labels for dim in variables]):
59
                raise ValueError('Provided dataset does not appear to have the'
60
61
62
63
                                 ' spectroscopic dimension(s): {} that need '
                                 'to be fitted: {}'
                                 ''.format(self.h5_main.spec_dim_labels,
                                           variables))
64
65
66
67
68
69
70
71
72
73
74
75
76
77

        # Variables specific to Fitter
        self._guess = None
        self._fit = None
        self._is_guess = True
        self._h5_guess = None
        self._h5_fit = None
        self.__set_up_called = False

        # Variables from Process:
        self.compute = self.set_up_guess
        self._unit_computation = super(Fitter, self)._unit_computation
        self._create_results_datasets = self._create_guess_datasets
        self._map_function = None
78

79
    def _read_guess_chunk(self):
80
        """
81
82
        Returns a chunk of guess dataset corresponding to the same pixels of
        the main dataset.
83
        """
84
85
        curr_pixels = self._get_pixels_in_current_batch()
        self._guess = self._h5_guess[curr_pixels, :]
86

87
88
        if self.verbose and self.mpi_rank == 0:
            print('Guess of shape: {}'.format(self._guess.shape))
89

90
    def _write_results_chunk(self):
91
        """
92
        Writes the guess or fit results into appropriate HDF5 datasets.
93
        """
94
95
96
        if self._is_guess:
            targ_dset = self._h5_guess
            source_dset = self._guess
97
        else:
98
99
            targ_dset = self._h5_fit
            source_dset = self._fit
100

101
        curr_pixels = self._get_pixels_in_current_batch()
102

103
104
105
106
107
108
109
        if self.verbose and self.mpi_rank == 0:
            print('Writing data of shape: {} and dtype: {} to position range: '
                  '{} in HDF5 dataset:{}'.format(source_dset.shape,
                                                 source_dset.dtype,
                                              [curr_pixels[0],curr_pixels[-1]],
                                                 targ_dset))
        targ_dset[curr_pixels, :] = source_dset
110

Somnath, Suhas's avatar
Somnath, Suhas committed
111
    def _create_guess_datasets(self):
112
        """
113
114
115
        Model specific call that will create the h5 group, empty guess dataset,
        corresponding spectroscopic datasets and also link the guess dataset
        to the spectroscopic datasets.
116
        """
117
118
        raise NotImplementedError('Please override the _create_guess_datasets '
                                  'specific to your model')
119

Somnath, Suhas's avatar
Somnath, Suhas committed
120
    def _create_fit_datasets(self):
Chris Smith's avatar
Chris Smith committed
121
        """
122
123
124
125
126
        Model specific call that will create the (empty) fit dataset, and
        link the fit dataset to the spectroscopic datasets.
        """
        raise NotImplementedError('Please override the _create_fit_datasets '
                                  'specific to your model')
Chris Smith's avatar
Chris Smith committed
127

128
    def _get_existing_datasets(self):
Chris Smith's avatar
Chris Smith committed
129
        """
130
        Gets existing Guess, Fit, status datasets, from the HDF5 group.
Chris Smith's avatar
Chris Smith committed
131

132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
        All other domain-specific datasets should be loaded in the classes that
        extend this class
        """
        self._h5_guess = USIDataset(self.h5_results_grp['Guess'])

        try:
            self._h5_status_dset = self.h5_results_grp[self._status_dset_name]
        except KeyError:
            warn('status dataset not created yet')
            self._h5_status_dset = None

        try:
            self._h5_fit = self.h5_results_grp['Fit']
            self._h5_fit = USIDataset(self._h5_fit)
        except KeyError:
            self._h5_fit = None
            if not self._is_guess:
                self._create_fit_datasets()

    def do_guess(self, *args, override=False, **kwargs):
152
        """
153
154
155
156
157
158
159
160
161
162
163
        Computes the Guess

        Parameters
        ----------
        args : list, optional
            List of arguments
        override : bool, optional
            If True, computes a fresh guess even if existing Guess was found
            Else, returns existing Guess dataset. Default = False
        kwargs : dict, optional
            Keyword arguments
164
165
166

        Returns
        -------
167
168
        USIDataset
            HDF5 dataset with the Guesses computed
169
        """
170
171
172
173
174
175
176
177
178
        if not self.__set_up_called:
            raise ValueError('Please call set_up_guess() before calling '
                             'do_guess()')
        self.h5_results_grp = super(Fitter, self).compute(override=override)
        # to be on the safe side, expect setup again
        self.__set_up_called = False
        return USIDataset(self.h5_results_grp['Guess'])

    def do_fit(self, *args, override=False, **kwargs):
179
        """
180
181
        Computes the Fit

182
183
        Parameters
        ----------
184
185
186
187
188
189
190
        args : list, optional
            List of arguments
        override : bool, optional
            If True, computes a fresh guess even if existing Fit was found
            Else, returns existing Fit dataset. Default = False
        kwargs : dict, optional
            Keyword arguments
191
192
193

        Returns
        -------
194
195
        USIDataset
            HDF5 dataset with the Fit computed
196
        """
197
198
199
200
201
202
203
204
205
206
207
208
        if not self.__set_up_called:
            raise ValueError('Please call set_up_guess() before calling '
                             'do_guess()')
        """
        Either delete or reset 'last_pixel' attribute to 0
        This value will be used for filling in the status dataset.
        """
        self.h5_results_grp.attrs['last_pixel'] = 0
        self.h5_results_grp = super(Fitter, self).compute(override=override)
        # to be on the safe side, expect setup again
        self.__set_up_called = False
        return USIDataset(self.h5_results_grp['Fit'])
209

Somnath, Suhas's avatar
Somnath, Suhas committed
210
    def _reformat_results(self, results, strategy='wavelet_peaks'):
211
        """
212
213
        Model specific restructuring / reformatting of the parallel compute
        results
Chris Smith's avatar
Chris Smith committed
214
215
216

        Parameters
        ----------
217
        results : list or array-like
Chris Smith's avatar
Chris Smith committed
218
            Results to be formatted for writing
Chris Smith's avatar
Chris Smith committed
219
        strategy : str
220
221
            The strategy used in the fit.  Determines how the results will be
            reformatted, if multiple strategies for guess / fit are available
Chris Smith's avatar
Chris Smith committed
222
223
224

        Returns
        -------
Chris Smith's avatar
Chris Smith committed
225
        results : numpy.ndarray
226
            Formatted array that is ready to be writen to the HDF5 file
Chris Smith's avatar
Chris Smith committed
227

228
229
        """
        return np.array(results)
230

231
    def set_up_guess(self, h5_partial_guess=None):
232
        """
233
        Performs necessary book-keeping before do_guess can be called
234

235
236
237
238
        Parameters
        ----------
        h5_partial_guess: h5py.Dataset or pyUSID.io.USIDataset, optional
            HDF5 dataset containing partial Guess. Not implemented
239
        """
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
        # TODO: h5_partial_guess needs to be utilized
        if h5_partial_guess is not None:
            raise NotImplementedError('Provided h5_partial_guess cannot be '
                                      'used yet. Ask developer to implement')

        # Set up the parms dict so everything necessary for checking previous
        # guess / fit is ready
        self._is_guess = True
        self._status_dset_name = 'completed_guess_positions'
        ret_vals = self._check_for_duplicates()
        self.duplicate_h5_groups, self.partial_h5_groups = ret_vals

        if self.verbose and self.mpi_rank == 0:
            print('Groups with Guess in:\nCompleted: {}\nPartial:{}'.format(
                self.duplicate_h5_groups, self.partial_h5_groups))

        self._unit_computation = super(Fitter, self)._unit_computation
        self._create_results_datasets = self._create_guess_datasets
        self.compute = self.do_guess
        self.__set_up_called = True

    def set_up_fit(self, h5_partial_fit=None, h5_guess=None):
262
        """
263
        Performs necessary book-keeping before do_fit can be called
264
265

        Parameters
Chris Smith's avatar
Chris Smith committed
266
        ----------
267
268
269
270
        h5_partial_fit: h5py.Dataset or pyUSID.io.USIDataset, optional
            HDF5 dataset containing partial Fit. Not implemented
        h5_guess: h5py.Dataset or pyUSID.io.USIDataset, optional
            HDF5 dataset containing completed Guess. Not implemented
271
        """
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
        # TODO: h5_partial_guess needs to be utilized
        if h5_partial_fit is not None or h5_guess is not None:
            raise NotImplementedError('Provided h5_partial_fit cannot be '
                                      'used yet. Ask developer to implement')
        self._is_guess = False

        self._map_function = None
        self._unit_computation = None
        self._create_results_datasets = self._create_fit_datasets

        # Case 1: Fit already complete or partially complete.
        # This is similar to a partial process. Leave as is
        self._status_dset_name = 'completed_fit_positions'
        ret_vals = self._check_for_duplicates()
        self.duplicate_h5_groups, self.partial_h5_groups = ret_vals
        if self.verbose and self.mpi_rank == 0:
            print('Checking on partial / completed fit datasets')
            print(
                'Completed results groups:\n{}\nPartial results groups:\n'
                '{}'.format(self.duplicate_h5_groups, self.partial_h5_groups))

        # Case 2: Fit neither partial / completed. Search for guess.
        # Most popular scenario:
        if len(self.duplicate_h5_groups) == 0 and len(
                self.partial_h5_groups) == 0:
            if self.verbose and self.mpi_rank == 0:
                print('No fit datasets found. Looking for Guess datasets')
            # Change status dataset name back to guess to check for status
            # on guesses:
            self._status_dset_name = 'completed_guess_positions'
            # Note that check_for_duplicates() will be against fit's parm_dict.
            # So make a backup of that
            fit_parms = self.parms_dict.copy()
            # Set parms_dict to an empty dict so that we can accept any Guess
            # dataset:
            self.parms_dict = dict()
            ret_vals = self._check_for_duplicates()
            guess_complete_h5_grps, guess_partial_h5_grps = ret_vals
            if self.verbose and self.mpi_rank == 0:
                print(
                    'Guess datasets search resulted in:\nCompleted: {}\n'
                    'Partial:{}'.format(guess_complete_h5_grps,
                                        guess_partial_h5_grps))
            # Now put back the original parms_dict:
            self.parms_dict.update(fit_parms)

            # Case 2.1: At least guess is completed:
            if len(guess_complete_h5_grps) > 0:
                # Just set the last group as the current results group
                self.h5_results_grp = guess_complete_h5_grps[-1]
                if self.verbose and self.mpi_rank == 0:
                    print('Guess found! Using Guess in:\n{}'.format(
                        self.h5_results_grp))
                # It will grab the older status default unless we set the
                # status dataset back to fit
                self._status_dset_name = 'completed_fit_positions'
                # Get handles to the guess dataset. Nothing else will be found
                self._get_existing_datasets()

            elif len(guess_complete_h5_grps) == 0 and len(
                    guess_partial_h5_grps) > 0:
                FileNotFoundError(
                    'Guess not yet completed. Please complete guess first')
335
                return
336
            else:
337
338
339
                FileNotFoundError(
                    'No Guess found. Please complete guess first')
                return
340

341
342
343
344
        # We want compute to call our own manual unit computation function:
        self._unit_computation = self._unit_compute_fit
        self.compute = self.do_fit
        self.__set_up_called = True
345

346
347
348
349
350
    def _unit_compute_fit(self, obj_func, obj_func_args=[],
                          solver_options={'jac': 'cs'}):
        """
        Performs least-squares fitting on self.data using self.guess for
        initial conditions.
351

352
        Results of the computation are captured in self._results
Chris Smith's avatar
Chris Smith committed
353

354
355
356
357
358
359
360
361
362
363
        Parameters
        ----------
        obj_func : callable
            Objective function to minimize on
        obj_func_args : list
            Arguments required by obj_func following the guess parameters
            (which should be the first argument)
        solver_options : dict, optional
            Keyword arguments passed onto scipy.optimize.least_squares
        """
364

365
366
        # At this point data has been read in. Read in the guess as well:
        self._read_guess_chunk()
367

368
369
370
371
        if self.verbose and self.mpi_rank == 0:
            print('_unit_compute_fit got:\nobj_func: {}\nobj_func_args: {}\n'
                  'solver_options: {}'.format(obj_func, obj_func_args,
                                              solver_options))
372

373
        # TODO: Generalize this bit. Use Parallel compute instead!
374

375
376
377
378
        if self.mpi_size > 1:
            if self.verbose:
                print('Rank {}: About to start serial computation'
                      '.'.format(self.mpi_rank))
379

380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
            self._results = list()
            for pulse_resp, pulse_guess in zip(self.data, self._guess):
                curr_results = least_squares(obj_func, pulse_guess,
                                             args=[pulse_resp] + obj_func_args,
                                             **solver_options)
                self._results.append(curr_results)
        else:
            cores = recommend_cpu_cores(self.data.shape[0],
                                        verbose=self.verbose)
            if self.verbose:
                print('Starting parallel fitting with {} cores'.format(cores))

            values = [joblib.delayed(least_squares)(obj_func, pulse_guess,
                                                    args=[pulse_resp] + obj_func_args,
                                                    **solver_options) for
                      pulse_resp, pulse_guess in zip(self.data, self._guess)]
            self._results = joblib.Parallel(n_jobs=cores)(values)

        if self.verbose and self.mpi_rank == 0:
            print(
                'Finished computing fits on {} objects. Results of length: {}'
                '.'.format(self.data.shape[0], len(self._results)))

        # What least_squares returns is an object that needs to be extracted
        # to get the coefficients. This is handled by the write function