feature_extraction.py 8.98 KB
Newer Older
1
2
3
4
5
6
7
# -*- coding: utf-8 -*-
"""
Created on Tue Oct  6 15:34:12 2015

@author: Numan Laanait -- nlaanait@gmail.com
"""

8
from __future__ import division, print_function, absolute_import
9
import warnings
10
import multiprocessing as mp
11
12
13
14
15
16

import h5py
import numpy as np
import skimage.feature


17
# TODO: Docstrings following numpy standard.
18

Unknown's avatar
Unknown committed
19
# Functions
20
def pickle_keypoints(keypoints):
21
22
    """
    Function to pickle cv2.sift keypoint objects
Unknown's avatar
Unknown committed
23
24
25
26
27
28
29
30
31

    Parameters
    ----------
    keypoints
        keypoint object to be pickled

    Returns
    -------

32
    """
33
34
    kpArray = np.array([])
    for point in keypoints:
35
36
        kpArray = np.append(kpArray, [point.pt[1], point.pt[0]])
    kpArray = np.reshape(kpArray, (int(kpArray.size / 2), 2))
37
38
39
40
    return kpArray


# Class to do feature extraction. This is a wrapper on scikit-image and openCV feature extraction detectors.
41
42
43
# TODO: Add support for opencV or implement sift.
# TODO: Add io operations for extracted features.
# TODO: Memory checking, since some of the features are quite large.
44
45

class FeatureExtractorParallel(object):
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
    """
    This is an Object used to contain a data set and has methods to perform
    feature extraction on the data set that are detector based.
    Begin by loading a detector for features and a computer vision library.

    Parameters
    ----------
    detector_name : (string)
        name of detector.
    lib : (string)
        computer vision library to use (opencv or skimage)

        The following can be used for:
        lib = opencv: SIFT, ORB, SURF
        lib = skimage: ORB, BRIEF, CENSURE

    """

64
65
66
67
68
69
70
    def __init__(self, detector_name, lib):
        self.data = []
        self.lib = lib

        try:
            if self.lib == 'opencv':
                pass
71
                #                detector = cv2.__getattribute__(detector_name)
72
73
74
75
76
77
78
79
80
81
            elif self.lib == 'skimage':
                self.detector = skimage.feature.__getattribute__(detector_name)
        except AttributeError:
            print('Error: The Library does not contain the specified detector')

    def clearData(self):
        del self.data
        self.data = []

    def loadData(self, dataset):
82
83
84
85
86
87
88
89
        """
        This is a Method that loads h5 Dataset to be corrected.

        Parameters
        ----------
        dataset : h5py.Dataset
            Dataset to be corrected
        """
90
91
92
93
94
        if not isinstance(dataset, h5py.Dataset):
            warnings.warn('Error: Data must be an h5 Dataset object')
        else:
            self.data = dataset
            dim = int(np.sqrt(self.data.shape[-1]))
95
            self.data = self.data.reshape(-1, dim, dim)
96
97

    def getData(self):
98
99
100
101
        """
        This is a Method that returns the loaded h5 Dataset.

        """
102
103
104
        return self.data

    def getFeatures(self, **kwargs):
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
        """
        This is a Method that returns features (keypoints and descriptors)
        that are obtained by using the FeatureExtractor.Detector object.

        Parameters
        ----------
        processors : int, optional
                    Number of processors to use, default = 1.
        mask : boolean, optional, default False.
            Whether to use

        Returns
        -------
        keypts :
            keypoints
        descs :
            descriptors

        """
124
125
126
127
128
        detector = self.detector
        dset = self.data
        lib = self.lib
        processes = kwargs.get('processors', 1)
        mask = kwargs.get('mask', False)
129
130
        origin = kwargs.get('origin', [0, 0])
        winSize = kwargs.get('window_size', 0)
131
132
133

        if mask:
            def mask_func(x, winSize):
134
                x[origin[0] - winSize / 2: origin[0] + winSize / 2,
Unknown's avatar
Unknown committed
135
                  origin[1] - winSize / 2: origin[1] + winSize / 2] = 2
136
137
                x = x - 1
                return x
138

139
            mask_ind = np.mask_indices(dset.shape[-1], mask_func, winSize)
140
            self.data = np.array([imp[mask_ind].reshape(winSize, winSize) for imp in dset])
141
142
143

        # detect and compute keypoints
        def detect(image):
144
145
            if lib == 'opencv':
                image = (image - image.mean()) / image.std()
146
147
148
149
150
                image = image.astype('uint8')
                k_obj, d_obj = detector.detectAndCompute(image, None)
                keypts, descs = pickle_keypoints(k_obj), pickle_keypoints(d_obj)

            elif lib == 'skimage':
151
                imp = (image - image.mean()) / np.std(image)
152
153
154
155
156
157
158
159
                imp[imp < 0] = 0
                imp.astype('float32')
                detector.detect_and_extract(imp)
                keypts, descs = detector.keypoints, detector.descriptors

            return keypts, descs

        # start pool of workers
Unknown's avatar
Unknown committed
160
        print('launching %i kernels...' % processes)
161
        pool = mp.Pool(processes)
Unknown's avatar
Unknown committed
162
        tasks = [imp for imp in self.data]
163
164
        chunk = int(self.data.shape[0] / processes)
        jobs = pool.imap(detect, tasks, chunksize=chunk)
165
166

        # get keypoints and descriptors
167
        results = []
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
        print('Extracting features...')
        try:
            for j in jobs:
                results.append(j)
        except ValueError:
            warnings.warn('ValueError something about 2d-image. Probably some of the detector input params are wrong.')

        keypts = [itm[0].astype('int') for itm in results]
        desc = [itm[1] for itm in results]

        # close the pool
        print('Closing down the kernels... \n')
        pool.close()

        return keypts, desc

184

185
class FeatureExtractorSerial(object):
186
187
188
189
190
191
192
    """
    This is an Object used to contain a data set and has methods to perform
    feature extraction on the data set that are detector based.
    Begin by loading a detector for features and a computer vision library.

    Parameters
    ----------
Chris Smith's avatar
Chris Smith committed
193
194
195
196
197
198
199
200
    detector_name : (string)
        name of detector.
    lib : (string)
        computer vision library to use (opencv or skimage)

        The following can be used for:
        lib = opencv: SIFT, ORB, SURF
        lib = skimage: ORB, BRIEF, CENSURE
201
202
203

    """

204
205
206
207
208
209
210
    def __init__(self, detector_name, lib):
        self.data = []
        self.lib = lib

        try:
            if self.lib == 'opencv':
                pass
211
                #                detector = cv2.__getattribute__(detector_name)
212
213
214
215
216
217
218
219
220
221
            elif self.lib == 'skimage':
                self.detector = skimage.feature.__getattribute__(detector_name)
        except AttributeError:
            print('Error: The Library does not contain the specified detector')

    def clearData(self):
        del self.data
        self.data = []

    def loadData(self, dataset):
222
223
224
225
226
227
228
229
        """
        This is a Method that loads h5 Dataset to be corrected.

        Parameters
        ----------
        dataset : h5py.Dataset

        """
230
231
232
233
234
        if not isinstance(dataset, h5py.Dataset):
            warnings.warn('Error: Data must be an h5 Dataset object')
        else:
            self.data = dataset
            dim = int(np.sqrt(self.data.shape[-1]))
235
            self.data = self.data.reshape(-1, dim, dim)
236
237

    def getData(self):
238
239
240
        """
        This is a Method that returns the loaded h5 Dataset.
        """
241
242
243
        return self.data

    def getFeatures(self, **kwargs):
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
        """
        This is a Method that returns features (keypoints and descriptors)
        that are obtained by using the FeatureExtractor.Detector object.

        Parameters
        ----------
        mask : boolean, optional
            Whether to use, default False.

        Returns
        -------
        keypts :
            descriptors
        descs :
            keypoints

        """
261
262
263
264
        detector = self.detector
        dset = self.data
        lib = self.lib
        mask = kwargs.get('mask', False)
265
266
        origin = kwargs.get('origin', [0, 0])
        winSize = kwargs.get('window_size', 0)
267
268
269

        if mask:
            def mask_func(x, winSize):
270
                x[origin[0] - winSize / 2: origin[0] + winSize / 2,
Unknown's avatar
Unknown committed
271
                  origin[1] - winSize / 2: origin[1] + winSize / 2] = 2
272
273
                x = x - 1
                return x
274

275
            mask_ind = np.mask_indices(dset.shape[-1], mask_func, winSize)
276
            self.data = np.array([imp[mask_ind].reshape(winSize, winSize) for imp in dset])
277
278
279

        # detect and compute keypoints
        def detect(image):
280
281
            if lib == 'opencv':
                image = (image - image.mean()) / image.std()
282
283
284
285
286
                image = image.astype('uint8')
                k_obj, d_obj = detector.detectAndCompute(image, None)
                keypts, descs = pickle_keypoints(k_obj), pickle_keypoints(d_obj)

            elif lib == 'skimage':
287
                imp = (image - image.mean()) / np.std(image)
288
289
290
291
292
293
294
295
296
297
298
299
300
301
                imp[imp < 0] = 0
                imp.astype('float32')
                detector.detect_and_extract(imp)
                keypts, descs = detector.keypoints, detector.descriptors

            return keypts, descs

        # start pool of workers
        results = [detect(imp) for imp in self.data]

        # get keypoints and descriptors
        keypts = [itm[0].astype('int') for itm in results]
        desc = [itm[1] for itm in results]

302
        return keypts, desc