hylite.hyimage

Store and manipulate hyperspectral image data.

  1"""
  2Store and manipulate hyperspectral image data.
  3"""
  4
  5import os
  6import numpy as np
  7import matplotlib
  8import matplotlib.pyplot as plt
  9from matplotlib import path
 10from roipoly import MultiRoi
 11import imageio
 12import scipy as sp
 13import hylite
 14from hylite.hydata import HyData
 15from hylite.hylibrary import HyLibrary
 16
 17
 18
 19class HyImage( HyData ):
 20    """
 21    A class for hyperspectral image data. These can be individual scenes or hyperspectral orthoimages.
 22    """
 23
 24    def __init__(self, data, **kwds):
 25        """
 26        Args:
 27            data (ndarray): a numpy array such that data[x][y][band] gives each pixel value.
 28            **kwds:
 29                wav = A numpy array containing band wavelengths for this image.
 30                affine = an affine transform of the format returned by GDAL.GetGeoTransform().
 31                projection = string defining the project. Default is None.
 32                sensor = sensor name. Default is "unknown".
 33                header = path to associated header file. Default is None.
 34        """
 35
 36        #call constructor for HyData
 37        super().__init__(data, **kwds)
 38
 39        # special case - if dataset only has oneband, slice it so it still has
 40        # the format data[x,y,b].
 41        if not self.data is None:
 42            if len(self.data.shape) == 1:
 43                self.data = self.data[None, None, :] # single pixel image
 44            if len(self.data.shape) == 2:
 45                self.data = self.data[:, :, None] # single band iamge
 46
 47        #load any additional project information (specific to images)
 48        self.set_projection(kwds.get("projection",None))
 49        self.affine = kwds.get("affine",[0,1,0,0,0,1])
 50
 51        # wavelengths
 52        if 'wav' in kwds:
 53            self.set_wavelengths(kwds['wav'])
 54
 55        #special header formatting
 56        self.header['file type'] = 'ENVI Standard'
 57
 58    def copy(self,data=True):
 59        """
 60        Make a deep copy of this image instance.
 61
 62        Args:
 63            data (bool): True if a copy of the data should be made, otherwise only copy header.
 64
 65        Returns:
 66            a new HyImage instance.
 67        """
 68        if not data:
 69            return HyImage(None, header=self.header.copy(), projection=self.projection, affine=self.affine)
 70        else:
 71            return HyImage( self.data.copy(), header=self.header.copy(), projection=self.projection, affine=self.affine)
 72
 73    def T(self):
 74        """
 75        Return a transposed view of the data matrix (corresponding with the [y,x] indexing used by matplotlib, opencv etc.
 76        """
 77        return np.transpose(self.data, (1,0,2))
 78
 79    def xdim(self):
 80        """
 81        Return number of pixels in x (first dimension of data array)
 82        """
 83        return self.data.shape[0]
 84
 85    def ydim(self):
 86        """
 87        Return number of pixels in y (second dimension of data array)
 88        """
 89        return self.data.shape[1]
 90
 91    def aspx(self):
 92        """
 93        Return the aspect ratio of this image (width/height).
 94        """
 95        return self.ydim() / self.xdim()
 96
 97    def get_extent(self):
 98        """
 99        Returns the width and height of this image in world coordinates.
100
101        Returns:
102            tuple with (width, height).
103        """
104        return self.xdim * self.pixel_size[0], self.ydim * self.pixel_size[1]
105
106    def set_projection(self,proj):
107        """
108        Set this project to an existing osgeo.osr.SpatialReference or GDAL georeference string.
109
110        Args:
111            proj (str, osgeo.osr.SpatialReference): the project to use as osgeo.osr.SpatialReference or GDAL georeference string.
112        """
113        if proj is None:
114            self.projection = None
115        else:
116            try:
117                from osgeo.osr import SpatialReference
118            except:
119                assert False, "Error - GDAL must be installed to work with spatial projections in hylite."
120            if isinstance(proj, SpatialReference):
121                self.projection = proj
122            elif isinstance(proj, str):
123                self.projection = SpatialReference(proj)
124            else:
125                print("Invalid project %s" % proj)
126                raise
127
128    def set_projection_EPSG(self,EPSG):
129        """
130        Sets this image project using an EPSG code.
131
132        Args:
133            EPSG (str): string EPSG code that can be passed to SpatialReference.SetFromUserInput(...).
134        """
135
136        try:
137            from osgeo.osr import SpatialReference
138        except:
139            assert False, "Error - GDAL must be installed to work with spatial projections in hylite."
140
141        self.projection = SpatialReference()
142        self.projection.SetFromUserInput(EPSG)
143
144    def get_projection_EPSG(self):
145        """
146        Gets a string describing this projections EPSG code (if it is an EPSG project).
147
148        Returns:
149            an EPSG code string of the format "EPSG:XXXX".
150        """
151        if self.projection is None:
152            return None
153        else:
154            return "%s:%s" % (self.projection.GetAttrValue("AUTHORITY",0),self.projection.GetAttrValue("AUTHORITY",1))
155
156    def pix_to_world(self, px, py, proj=None):
157        """
158        Take pixel coordinates and return world coordinates
159
160        Args:
161            px (int): the pixel x-coord.
162            py (int): the pixel y-coord.
163            proj (str, osr.SpatialReference): the coordinate system to use. Default (None) uses the same system as this image. Otherwise
164                   an osr.SpatialReference can be passed (HyImage.project), or an EPSG string (e.g. get_projection_EPSG(...)).
165        Returns:
166            the world coordinates in the coordinate system defined by get_projection_EPSG(...).
167        """
168
169        try:
170            from osgeo import osr
171            import osgeo.gdal as gdal
172            from osgeo import ogr
173        except:
174            assert False, "Error - GDAL must be installed to work with spatial projections in hylite."
175
176        # parse project
177        if proj is None:
178            proj = self.projection
179        elif isinstance(proj, str) or isinstance(proj, int):
180            epsg = proj
181            if isinstance(epsg, str):
182                try:
183                    epsg = int(str.split(':')[1])
184                except:
185                    assert False, "Error - %s is an invalid EPSG code." % proj
186            proj = osr.SpatialReference()
187            proj.ImportFromEPSG(epsg)
188
189        # check we have all the required info
190        assert isinstance(proj, osr.SpatialReference), "Error - invalid spatial reference %s" % proj
191        assert (not self.affine is None) and (
192            not self.projection is None), "Error - project information is undefined."
193
194        #project to world coordinates in this images project/world coords
195        x,y = gdal.ApplyGeoTransform(self.affine, px, py)
196
197        #project to target coords (if different)
198        if not proj.IsSameGeogCS(self.projection):
199            P = ogr.Geometry(ogr.wkbPoint)
200            if proj.EPSGTreatsAsNorthingEasting():
201                P.AddPoint(x, y)
202            else:
203                P.AddPoint(y, x)
204            P.AssignSpatialReference(self.projection)  # tell the point what coordinates it's in
205            P.TransformTo(proj)  # reproject it to the out spatial reference
206            x, y = P.GetX(), P.GetY()
207
208            #do we need to transpose?
209            if proj.EPSGTreatsAsLatLong():
210                x,y=y,x #we want lon,lat not lat,lon
211        return x, y
212
213    def world_to_pix(self, x, y, proj = None):
214        """
215        Take world coordinates and return pixel coordinates
216
217        Args:
218            x (float): the world x-coord.
219            y (float): the world y-coord.
220            proj (str, osr.SpatialReference): the coordinate system of the input coordinates. Default (None) uses the same system as this image. Otherwise
221                   an osr.SpatialReference can be passed (HyImage.project), or an EPSG string (e.g. get_projection_EPSG(...)).
222
223        Returns:
224            the pixel coordinates based on the affine transform stored in self.affine.
225        """
226
227        try:
228            from osgeo import osr
229            import osgeo.gdal as gdal
230            from osgeo import ogr
231        except:
232            assert False, "Error - GDAL must be installed to work with spatial projections in hylite."
233
234        # parse project
235        if proj is None:
236            proj = self.projection
237        elif isinstance(proj, str) or isinstance(proj, int):
238            epsg = proj
239            if isinstance(epsg, str):
240                try:
241                    epsg = int(str.split(':')[1])
242                except:
243                    assert False, "Error - %s is an invalid EPSG code." % proj
244            proj = osr.SpatialReference()
245            proj.ImportFromEPSG(epsg)
246
247
248        # check we have all the required info
249        assert isinstance(proj, osr.SpatialReference), "Error - invalid spatial reference %s" % proj
250        assert (not self.affine is None) and (not self.projection is None), "Error - project information is undefined."
251
252        # project to this images CS (if different)
253        if not proj.IsSameGeogCS(self.projection):
254            P = ogr.Geometry(ogr.wkbPoint)
255            if proj.EPSGTreatsAsNorthingEasting():
256                P.AddPoint(x, y)
257            else:
258                P.AddPoint(y, x)
259            P.AssignSpatialReference(proj)  # tell the point what coordinates it's in
260            P.AddPoint(x, y)
261            P.TransformTo(self.projection)  # reproject it to the out spatial reference
262            x, y = P.GetX(), P.GetY()
263            if self.projection.EPSGTreatsAsLatLong(): # do we need to transpose?
264                x, y = y, x  # we want lon,lat not lat,lon
265
266        inv = gdal.InvGeoTransform(self.affine)
267        assert not inv is None, "Error - could not invert affine transform?"
268
269        #apply
270        return gdal.ApplyGeoTransform(inv, x, y)
271
272    def flip(self, axis='x'):
273        """
274        Flip the image on the x or y axis.
275
276        Args:
277            axis (str): 'x' or 'y' or both 'xy'.
278        """
279
280        if 'x' in axis.lower():
281            self.data = np.flip(self.data,axis=0)
282        if 'y' in axis.lower():
283            self.data = np.flip(self.data,axis=1)
284
285    def rot90(self):
286        """
287        Rotate this image by 90 degrees by transposing the underlying data array. Combine with flip('x') or flip('y')
288        to achieve positive/negative rotations.
289        """
290        self.data = np.transpose( self.data, (1,0,2) )
291        self.push_to_header()
292
293    #####################################
294    ##IMAGE FILTERING
295    #####################################
296    def fill_holes(self):
297        """
298        Replaces nan pixel with an average of their neighbours, thus removing 1-pixel large holes from an image. Note that
299        for performance reasons this assumes that holes line up across bands. Note that this is not vectorized so very slow...
300        """
301
302        # perform greyscale dilation
303        dilate = self.data.copy()
304        mask = np.logical_not(np.isfinite(dilate))
305        dilate[mask] = 0
306        for b in range(self.band_count()):
307            dilate[:, :, b] = sp.ndimage.grey_dilation(dilate[:, :, b], size=(3, 3))
308
309        # map back to holes in dataset
310        self.data[mask] = dilate[mask]
311        #self.data[self.data == 0] = np.nan  # replace remaining 0's with nans
312
313    def blur(self, n=3):
314        """
315        Applies a gaussian kernel of size n to the image using OpenCV.
316
317        Args:
318            n (int): the dimensions of the gaussian kernel to convolve. Default is 3. Increase for more blurry results.
319        """
320        import cv2 # import this here to avoid errors if opencv is not installed properly
321
322        nanmask = np.isnan(self.data)
323        assert isinstance(n, int) and n >= 3, "Error - invalid kernel. N must be an integer > 3. "
324        kernel = np.ones((n, n), np.float32) / (n ** 2)
325        self.data = cv2.filter2D(self.data, -1, kernel)
326        self.data[nanmask] = np.nan  # remove mask
327
328    def erode(self, size=3, iterations=1):
329        """
330        Apply an erode filter to this image to expand background (nan) pixels. Refer to open-cv's erode
331        function for more details.
332
333        Args:
334            size (int): the size of the erode filter. Default is a 3x3 kernel.
335            iterations (int): the number of erode iterations. Default is 1.
336        """
337        import cv2 # import this here to avoid errors if opencv is not installed properly
338
339        # erode
340        kernel = np.ones((size, size), np.uint8)
341        if self.is_float():
342            mask = np.isfinite(self.data).any(axis=-1)
343            mask = cv2.erode(mask.astype(np.uint8), kernel, iterations=iterations)
344            self.data[mask == 0, :] = np.nan
345        else:
346            mask = (self.data != 0).any( axis=-1 )
347            mask = cv2.erode(mask.astype(np.uint8), kernel, iterations=iterations)
348            self.data[mask == 0, :] = 0
349
350    def resize(self, newdims : tuple, interpolation : int = 1):
351        """
352        Resize this image with opencv.
353
354        Args:
355            newdims (tuple): the new image dimensions.
356            interpolation (int): opencv interpolation method. Default is cv2.INTER_LINEAR.
357        """
358        import cv2 # import this here to avoid errors if opencv is not installed properly
359        self.data = cv2.resize(self.data, (newdims[1],newdims[0]), interpolation=interpolation)
360
361    def despeckle(self, size=5):
362        """
363        Despeckle each band of this image (independently) using a median filter.
364
365        Args:
366            size (int): the size of the median filter kernel. Default is 5. Must be an odd number.
367        """
368
369        assert (size % 2) == 1, "Error - size must be an odd integer"
370        import cv2 # import this here to avoid errors if opencv is not installed properly
371        if self.is_float():
372            self.data = cv2.medianBlur( self.data.astype(np.float32), size )
373        else:
374            self.data = cv2.medianBlur( self.data, size )
375
376    #####################################
377    ##FEATURES AND FEATURE MATCHING
378    ######################################
379    def get_keypoints(self, band, eq=False, mask=True, method='sift', cfac=0.0,bfac=0.0, **kwds):
380        """
381        Get feature descriptors from the specified band.
382
383        Args:
384            band (int,float,str,tuple): the band index (int) or wavelength (float) to extract features from. Alternatively, a tuple can be passed
385                    containing a range of bands (min : max) to average before feature matching.
386            eq (bool): True if the image should be histogram equalized first. Default is False.
387            mask (bool): True if 0 value pixels should be masked. Default is True.
388            method (str): the feature detector to use. Options are 'SIFT' and 'ORB' (faster but less accurate). Default is 'SIFT'.
389            cfac (float): contrast adjustment to apply to hyperspectral bands before matching. Default is 0.0.
390            bfac (float): brightness adjustment to apply to hyperspectral bands before matching. Default is 0.0.
391            **kwds: keyword arguments are passed to the opencv feature detector. For SIFT these are:
392
393                - contrastThreshold: default is 0.01.
394                - edgeThreshold: default is 10.
395                - sigma: default is 1.0
396
397                For ORB these are:
398
399                - nfeatures = the number of features to detect. Default is 5000.
400
401            Returns:
402                Tuple containing
403
404                    - k (ndarray): the keypoints detected
405                    - d (ndarray): corresponding feature descriptors
406         """
407        import cv2 # import this here to avoid errors if opencv is not installed properly
408
409        # get image
410        if isinstance(band, int) or isinstance(band, float): #single band
411            image = self.data[:, :, self.get_band_index(band)]
412        elif isinstance(band,tuple): #range of bands (averaged)
413            idx0 = self.get_band_index(band[0])
414            idx1 = self.get_band_index(band[1])
415
416            #deal with out of range errors
417            if idx0 is None:
418                idx0 = 0
419            if idx1 is None:
420                idx1 = self.band_count()
421
422            #average bands
423            image = np.nanmean(self.data[:,:,idx0:idx1],axis=2)
424        else:
425            assert False, "Error, unrecognised band %s" % band
426
427        #normalise image to range 0 - 1
428        image -= np.nanmin(image)
429        image = image / np.nanmax(image)
430
431        #apply brightness/contrast adjustment
432        image = (1.0+cfac)*image + bfac
433        image[image > 1.0] = 1.0
434        image[image < 0.0] = 0.0
435
436        #convert image to uint8 for opencv
437        image = np.uint8(255 * image)
438        if eq:
439            image = cv2.equalizeHist(image)
440
441        if mask:
442            mask = np.zeros(image.shape, dtype=np.uint8)
443            mask[image != 0] = 255  # include only non-zero pixels
444        else:
445            mask = None
446
447        if 'sift' in method.lower():  # SIFT
448
449            # setup default keywords
450            kwds["contrastThreshold"] = kwds.get("contrastThreshold", 0.01)
451            kwds["edgeThreshold"] = kwds.get("edgeThreshold", 10)
452            kwds["sigma"] = kwds.get("sigma", 1.0)
453
454            # make feature detector
455            #alg = cv2.xfeatures2d.SIFT_create(**kwds)
456            alg = cv2.SIFT_create()
457        elif 'orb' in method.lower():  # orb
458            kwds['nfeatures'] = kwds.get('nfeatures', 5000)
459            alg = cv2.ORB_create(scoreType=cv2.ORB_FAST_SCORE, **kwds)
460        else:
461            assert False, "Error - %s is not a recognised feature detector." % method
462
463        # detect keypoints
464        kp = alg.detect(image, mask)
465
466        # extract and return feature vectors
467        return alg.compute(image, kp)
468
469    @classmethod
470    def match_keypoints(cls, kp1, kp2, d1, d2, method='SIFT', dist=0.7, tree = 5, check = 100, min_count=5):
471        """
472        Compares keypoint feature vectors from two images and returns matching pairs.
473
474        Args:
475            kp1 (ndarray): keypoints from the first image
476            kp2 (ndarray): keypoints from the second image
477            d1 (ndarray): descriptors for the keypoints from the first image
478            d2 (ndarray): descriptors for the keypoints from the second image
479            method (str): the method used to calculate the feature descriptors. Should be 'sift' or 'orb'. Default is 'sift'.
480            dist (float): minimum match distance (0 to 1), default is 0.7
481            tree (int): not sure what this does? Default is 5. See open-cv docs.
482            check (int): ditto. Default is 100.
483            min_count (int): the minimum number of matches to consider a valid matching operation. If fewer matches are found,
484                       then the function returns None, None. Default is 5.
485        """
486        import cv2 # import this here to avoid errors if opencv is not installed properly
487        if 'sift' in method.lower():
488            algorithm = cv2.NORM_INF
489        elif 'orb' in method.lower():
490            algorithm = cv2.NORM_HAMMING
491        else:
492            assert False, "Error - unknown matching algorithm %s" % method
493
494        #calculate flann matches
495        index_params = dict(algorithm=algorithm, trees=tree)
496        search_params = dict(checks=check)
497        flann = cv2.FlannBasedMatcher(index_params, search_params)
498        matches = flann.knnMatch(d1, d2, k=2)
499
500        # store all the good matches as per Lowe's ratio test.
501        good = []
502        for m, n in matches:
503            if m.distance < dist * n.distance:
504                good.append(m)
505
506        if len(good) < min_count:
507            return None, None
508        else:
509            src_pts = np.float32([kp1[m.queryIdx].pt for m in good]).reshape(-1, 1, 2)
510            dst_pts = np.float32([kp2[m.trainIdx].pt for m in good]).reshape(-1, 1, 2)
511            return src_pts, dst_pts
512
513    ############################
514    ## Visualisation methods
515    ############################
516    def quick_plot(self, band=0, ax=None, bfac=0.0, cfac=0.0, samples=False, tscale=False, invert=False, rot=False, flipX=False, flipY=False,
517                   **kwds):
518        """
519        Plot a band using matplotlib.imshow(...).
520
521        Args:
522            band (str,int,float,tuple): the band name (string), index (integer) or wavelength (float) to plot. Default is 0. If a tuple is passed then
523                  each band in the tuple (string or index) will be mapped to rgb. Bands with negative wavelengths or indices will be inverted before plotting.
524            ax: an axis object to plot to. If none, plt.imshow( ... ) is used.
525            bfac (float): a brightness adjustment to apply to RGB mappings (-1 to 1)
526            cfac (float): a contrast adjustment to apply to RGB mappings (-1 to 1)
527            samples (bool): True if sample points (defined in the header file) should be plotted. Default is False. Otherwise, a list of
528                     [ (x,y), ... ] points can be passed.
529            tscale (bool): True if each band (for ternary images) should be scaled independently. Default is False.
530                    When using scaling, vmin and vmax can be used to set the clipping percentiles (integers) or
531                    (constant) values (float).
532            invert (bool) : True if each band should be inverted before plotting. Only works for multiband (ternary) images.
533            rot (bool): if True, the x and y axis will be flipped (90 degree rotation) before plotting. Default is False.
534            flipX (bool): if True, the x axis will be flipped before plotting (after applying rotations).
535            flipY (bool): if True, the y axis will be flippe before plotting (after applying rotations).
536            **kwds: keywords are passed to matplotlib.imshow( ... ), except for the following:
537
538                 - mask = a 2D boolean mask containing true if pixels should be drawn and false otherwise.
539                 - path = a file path to save the image too (at matching resolution; use fig.savefig(..) if you want to save the figure).
540                 - ticks = True if x- and y- ticks should be plotted. Default is False.
541                 - ps, pc = the size and color of sample points to plot. Can be constant or list.
542                 - figsize = a figsize for the figure to create (if ax is None).
543
544        Returns:
545            Tuple containing
546
547            - fig: matplotlib figure object
548            - ax:  matplotlib axes object. If a colorbar is created, (band is an integer or a float), then this will be stored in ax.cbar.
549        """
550
551        #create new axes?
552        if ax is None:
553            fig, ax = plt.subplots(figsize=kwds.pop('figsize', (18,18*self.ydim()/self.xdim()) ))
554
555        # deal with ticks
556        if not kwds.pop('ticks', False ):
557            ax.set_xticks([])
558            ax.set_yticks([])
559
560        #map individual band using colourmap
561        if isinstance(band, str) or isinstance(band, int) or isinstance(band, float):
562            #get band
563            if isinstance(band, str):
564                data = self.data[:, :, self.get_band_index(band)]
565            else:
566                data = self.data[:, :, self.get_band_index(np.abs(band))]
567            if not isinstance(band, str) and band < 0:
568                data = np.nanmax(data) - data # flip
569
570            # convert integer vmin and vmax values to percentiles
571            if 'vmin' in kwds:
572                if isinstance(kwds['vmin'], int):
573                    kwds['vmin'] = np.nanpercentile( data, kwds['vmin'] )
574            if 'vmax' in kwds:
575                if isinstance(kwds['vmax'], int):
576                    kwds['vmax'] = np.nanpercentile( data, kwds['vmax'] )
577
578            #mask nans (and apply custom mask)
579            mask = np.isnan(data)
580            if not np.isnan(self.header.get_data_ignore_value()):
581                mask = mask + data == self.header.get_data_ignore_value()
582            if 'mask' in kwds:
583                mask = mask + kwds.get('mask')
584                del kwds['mask']
585            data = np.ma.array(data, mask = mask > 0 )
586
587            # apply rotations and flipping
588            if rot:
589                data = data.T
590            if flipX:
591                data = data[::-1, :]
592            if flipY:
593                data = data[:, ::-1]
594
595            # save?
596            if 'path' in kwds:
597                path = kwds.pop('path')
598                from matplotlib.pyplot import imsave
599                if not os.path.exists(os.path.dirname(path)):
600                    os.makedirs(os.path.dirname(path)) # ensure output directory exists
601                imsave(path, data.T, **kwds)  # save the image
602
603            ax.cbar = ax.imshow(data.T, interpolation=kwds.pop('interpolation', 'none'), **kwds) # change default interpolation to None
604
605        #map 3 bands to RGB
606        elif isinstance(band, tuple) or isinstance(band, list):
607            #get band indices and range
608            rgb = []
609            for b in band:
610                if isinstance(b, str):
611                    rgb.append(self.get_band_index(b))
612                else:
613                    rgb.append(self.get_band_index(np.abs(b)))
614
615            #slice image (as copy) and map to 0 - 1
616            img = np.array(self.data[:, :, rgb]).copy()
617            if np.isnan(img).all():
618                print("Warning - image contains no data.")
619                return ax.get_figure(), ax
620
621            # invert if needed
622            if invert:
623                band = [-b for b in band]
624            for i,b in enumerate(band):
625                if not isinstance(b, str) and (b < 0):
626                    img[..., i] = np.nanmax(img[..., i]) - img[..., i]
627
628            # do scaling
629            if tscale: # scale bands independently
630                for b in range(3):
631                    mn = kwds.get("vmin", float(np.nanmin(img)))
632                    mx = kwds.get("vmax", float(np.nanmax(img)))
633                    if isinstance (mn, int):
634                        assert mn >= 0 and mn <= 100, "Error - integer vmin values must be a percentile."
635                        mn = float(np.nanpercentile(img[...,b], mn ))
636                    if isinstance (mx, int):
637                        assert mx >= 0 and mx <= 100, "Error - integer vmax values must be a percentile."
638                        mx = float(np.nanpercentile(img[...,b], mx ))
639                    img[...,b] = (img[..., b] - mn) / (mx - mn)
640            else: # scale bands together
641                mn = kwds.get("vmin", float(np.nanmin(img)))
642                mx = kwds.get("vmax", float(np.nanmax(img)))
643                if isinstance(mn, int):
644                    assert mn >= 0 and mn <= 100, "Error - integer vmin values must be a percentile."
645                    mn = float(np.nanpercentile(img, mn))
646                if isinstance(mx, int):
647                    assert mx >= 0 and mx <= 100, "Error - integer vmax values must be a percentile."
648                    mx = float(np.nanpercentile(img, mx))
649                img = (img - mn) / (mx - mn)
650
651            #apply brightness/contrast mapping
652            img = np.clip((1.0 + cfac) * img + bfac, 0, 1.0 )
653
654            #apply masking so background is white
655            img[np.logical_not( np.isfinite( img ) )] = 1.0
656            if 'mask' in kwds:
657                img[kwds.pop("mask"),:] = 1.0
658
659            # apply rotations and flipping
660            if rot:
661                img = np.transpose( img, (1,0,2) )
662            if flipX:
663                img = img[::-1, :, :]
664            if flipY:
665                img = img[:, ::-1, :]
666
667            # save?
668            if 'path' in kwds:
669                path = kwds.pop('path')
670                from matplotlib.pyplot import imsave
671                if not os.path.exists(os.path.dirname(path)):
672                    os.makedirs(os.path.dirname(path)) # ensure output directory exists
673                imsave(path, np.transpose( np.clip( img*255, 0, 255).astype(np.uint8), (1, 0, 2)))  # save the image
674
675            # plot samples?
676            ps = kwds.pop('ps', 5)
677            pc = kwds.pop('pc', 'r')
678            if samples:
679                if isinstance(samples, list) or isinstance(samples, np.ndarray):
680                    ax.scatter([s[0] for s in samples], [s[1] for s in samples], s=ps, c=pc)
681                else:
682                    for n in self.header.get_class_names():
683                        points = np.array(self.header.get_sample_points(n))
684                        ax.scatter(points[:, 0], points[:, 1], s=ps, c=pc)
685
686            #plot
687            ax.imshow(np.transpose(img, (1,0,2)), interpolation=kwds.pop('interpolation', 'none'), **kwds)
688            ax.cbar = None  # no colorbar
689
690        return ax.get_figure(), ax
691
692    def createGIF(self, path, bands=None, figsize=(10,10), fps=10, **kwds):
693        """
694        Create and save an animated gif that loops through the bands of the image.
695
696        Args:
697            path (str): the path to save the .gif
698            bands (tuple): Tuple containing the range of band indices to draw. Default is the whole range.
699            figsize (tuple): the size of the image to draw. Default is (10,10).
700            fps (int): the framerate (frames per second) of the gif. Default is 10.
701            **kwds: keywords are passed directly to matplotlib.imshow. Use this to specify cmap etc.
702        """
703
704        frames = []
705        if bands is None:
706            bands = (0,self.band_count())
707        else:
708            assert 0 < bands[0] < self.band_count(), "Error - invalid range."
709            assert 0 < bands[1] < self.band_count(), "Error - invalid range."
710            assert bands[1] > bands[0], "Error - invalid range."
711
712        #plot frames
713        for i in range(bands[0],bands[1]):
714            fig, ax = plt.subplots(figsize=figsize)
715            ax.imshow(self.data[:, :, i], **kwds)
716            fig.canvas.draw()
717            frames.append(np.frombuffer(fig.canvas.tostring_rgb(), dtype='uint8'))
718            frames[-1] = np.reshape(frames[-1], (fig.canvas.get_width_height()[1], fig.canvas.get_width_height()[0], 3))
719            plt.close(fig)
720
721        #save gif
722        imageio.mimsave( os.path.splitext(path)[0] + ".gif", frames, fps=fps)
723
724    ## masking
725    def drop_bbl(self, drop=True):
726        """
727        Remove bad bands as stored in the 'bbl' key in the image header. Note that this operates in-place.
728
729        Args:
730            drop (bool): True if bad bands should be completely dropped. If False, these bands will be kept but replaced with nans.
731        """
732        assert 'bbl' in self.header, "Please specify a bad band list ('bbl') in the image header, as per the ENVI format definition."
733        mask = self.header.get_list('bbl') == 0
734        self.data[...,mask] = np.nan
735        if drop:
736            self.delete_nan_bands(inplace=True)
737    
738    def mask(self, mask=None, flag=np.nan, invert=False, crop=False, bands=None):
739        """
740         Apply a mask to an image, flagging masked pixels with the specified value. Note that this applies the mask to the
741         image in-situ.
742
743         Args:
744            flag (float): the value to use for masked pixels. Default is np.nan
745            mask (ndarray): a numpy array defining the mask polygon of the format [[x1,y1],[x2,y2],...]. If None is passed then
746                    pickPolygon( ... ) is used to interactively define a polygon. If a file path is passed then the polygon
747                    will be loaded using np.load( ... ). Alternatively if mask.shape == image.shape[0,1] then it is treated as a
748                    binary image mask (must be boolean) and True values will be masked across all bands. Default is None.
749            invert (bool): if True, pixels within the polygon will be masked. If False, pixels outside the polygon are masked. Default is False.
750            crop (bool): True if rows/columns containing only zeros should be removed. Default is False.
751            bands (tuple): the bands of the image to plot if no mask is specified. If None, the middle band is used.
752
753         Returns:
754            Tuple containing
755
756            - mask (ndarray): a boolean array with True where pixels are masked and False elsewhere.
757            - poly (ndarray): the mask polygon array in the format described above. Useful if the polygon was interactively defined.
758         """
759
760        if mask is None:  # pick mask interactively
761            if bands is None:
762                bands = int(self.band_count() / 2)
763
764            regions = self.pickPolygons(region_names=["mask"], bands=bands)
765
766            # the user bailed without picking a mask?
767            if len(regions) == 0:
768                print("Warning - no mask picked/applied.")
769                return
770
771            # extract polygon mask
772            mask = regions[0]
773
774        # convert polygon mask to binary mask
775        if mask.shape[1] == 2:
776
777            # build meshgrid with pixel coords
778            xx, yy = np.meshgrid(np.arange(self.xdim()), np.arange(self.ydim()))
779            xx = xx.flatten()
780            yy = yy.flatten()
781            points = np.vstack([xx, yy]).T  # coordinates of each pixel
782
783            # calculate per-pixel mask
784            mask = path.Path(mask).contains_points(points)
785            mask = mask.reshape((self.ydim(), self.xdim())).T
786
787            # flip as we want to mask (==True) outside points (unless invert is true)
788            if not invert:
789                mask = np.logical_not(mask)
790
791        # apply binary image mask
792        assert mask.shape[0] == self.data.shape[0] and mask.shape[1] == self.data.shape[1], \
793            "Error - mask shape %s does not match image shape %s" % (mask.shape, self.data.shape)
794        for b in range(self.band_count()):
795            self.data[:, :, b][mask] = flag
796
797        # crop image
798        if crop:
799            # calculate non-masked pixels
800            valid = np.logical_not(mask)
801
802            # integrate along axes
803            xdata = np.sum(valid, axis=1) > 0.0
804            ydata = np.sum(valid, axis=0) > 0.0
805
806            # calculate domain containing valid pixels
807            xmin = np.argmax(xdata)
808            xmax = xdata.shape[0] - np.argmax(xdata[::-1])
809            ymin = np.argmax(ydata)
810            ymax = ydata.shape[0] - np.argmax(ydata[::-1])
811
812            # crop
813            self.data = self.data[xmin:xmax, ymin:ymax, :]
814
815        return mask
816
817    def crop_to_data(self):
818        """
819        Remove padding of nan or zero pixels from image. Note that this is performed in place.
820        """
821
822        valid = np.isfinite(self.data).any(axis=-1) & (self.data != 0).any(axis=-1)
823        ymin, ymax = np.percentile(np.argwhere(np.sum(valid, axis=0) != 0), (0, 100))
824        xmin, xmax = np.percentile(np.argwhere(np.sum(valid, axis=1) != 0), (0, 100))
825        self.data = self.data[int(xmin):int(xmax), int(ymin):int(ymax), :]  # do clipping
826
827    ##################################################
828    ## Interactive tools for picking regions/pixels
829    ##################################################
830    def pickPolygons(self, region_names, bands=0):
831        """
832        Creates a matplotlib gui for selecting polygon regions in an image.
833
834        Args:
835            names (list, str): a list containing the names of the regions to pick. If a string is passed only one name is used.
836            bands (tuple): the bands of the image to plot.
837        """
838
839        if isinstance(region_names, str):
840            region_names = [region_names]
841
842        assert isinstance(region_names, list), "Error - names must be a list or a string."
843
844        # set matplotlib backend
845        backend = matplotlib.get_backend()
846        matplotlib.use('Qt5Agg')  # need this backend for ROIPoly to work
847
848        # plot image and extract roi's
849        fig, ax = self.quick_plot(bands)
850        roi = MultiRoi(roi_names=region_names)
851        plt.close(fig)  # close figure
852
853        # extract regions
854        regions = []
855        for name, r in roi.rois.items():
856            # store region
857            x = r.x
858            y = r.y
859            regions.append(np.vstack([x, y]).T)
860
861        # restore matplotlib backend (if possible)
862        try:
863            matplotlib.use(backend)
864        except:
865            print("Warning: could not reset matplotlib backend. Plots will remain interactive...")
866            pass
867
868        return regions
869
870    def pickPoints(self, n=-1, bands=hylite.RGB, integer=True, title="Pick Points", **kwds):
871        """
872        Creates a matplotlib gui for picking pixels from an image.
873
874        Args:
875            n (int): the number of pixels to pick, or -1 if the user can select as many as they wish. Default is -1.
876            bands (tuple): the bands of the image to plot. Default is HyImage.RGB
877            integer (bool): True if points coordinates should be cast to integers (for use as indices). Default is True.
878            title (str): The title of the point picking window.
879            **kwds: Keywords are passed to HyImage.quick_plot( ... ).
880
881        Returns:
882            A list containing the picked point coordinates [ (x1,y1), (x2,y2), ... ].
883        """
884
885        # set matplotlib backend
886        backend = matplotlib.get_backend()
887        matplotlib.use('Qt5Agg')  # need this backend for ROIPoly to work
888
889        # create figure
890        fig, ax = self.quick_plot( bands, **kwds )
891        ax.set_title(title)
892
893        # get points
894        points = fig.ginput( n )
895
896        if integer:
897            points = [ (int(p[0]), int(p[1])) for p in points ]
898
899        # restore matplotlib backend (if possible)
900        try:
901            matplotlib.use(backend)
902        except:
903            print("Warning: could not reset matplotlib backend. Plots will remain interactive...")
904            pass
905
906        return points
907
908    def pickSamples(self, names=None, store=True, **kwds):
909        """
910        Pick sample probe points and store these in the image header file.
911
912        Args:
913            names (str, list): the name of the sample to pick, or a list of names to pick multiple.
914            store (bool): True if sample should be stored in the image header file (for later access). Default is True.
915            **kwds: Keywords are passed to HyImage.quick_plot( ... )
916
917        Returns:
918            a list containing a list of points for each sample.
919        """
920
921        if isinstance(names, str):
922            names = [names]
923
924        # pick points
925        points = []
926        for s in names:
927            pnts = self.pickPoints(title="%s" % s, **kwds)
928            if store:
929                self.header['sample %s' % s] = pnts # store in header
930            points.append(pnts)
931        # add class to header file
932        if store:
933            cls_names = self.header.get_class_names()
934            if cls_names is None:
935                cls_names = []
936            self.header['class names'] = cls_names + names
937
938        return points
class HyImage(hylite.hydata.HyData):
 20class HyImage( HyData ):
 21    """
 22    A class for hyperspectral image data. These can be individual scenes or hyperspectral orthoimages.
 23    """
 24
 25    def __init__(self, data, **kwds):
 26        """
 27        Args:
 28            data (ndarray): a numpy array such that data[x][y][band] gives each pixel value.
 29            **kwds:
 30                wav = A numpy array containing band wavelengths for this image.
 31                affine = an affine transform of the format returned by GDAL.GetGeoTransform().
 32                projection = string defining the project. Default is None.
 33                sensor = sensor name. Default is "unknown".
 34                header = path to associated header file. Default is None.
 35        """
 36
 37        #call constructor for HyData
 38        super().__init__(data, **kwds)
 39
 40        # special case - if dataset only has oneband, slice it so it still has
 41        # the format data[x,y,b].
 42        if not self.data is None:
 43            if len(self.data.shape) == 1:
 44                self.data = self.data[None, None, :] # single pixel image
 45            if len(self.data.shape) == 2:
 46                self.data = self.data[:, :, None] # single band iamge
 47
 48        #load any additional project information (specific to images)
 49        self.set_projection(kwds.get("projection",None))
 50        self.affine = kwds.get("affine",[0,1,0,0,0,1])
 51
 52        # wavelengths
 53        if 'wav' in kwds:
 54            self.set_wavelengths(kwds['wav'])
 55
 56        #special header formatting
 57        self.header['file type'] = 'ENVI Standard'
 58
 59    def copy(self,data=True):
 60        """
 61        Make a deep copy of this image instance.
 62
 63        Args:
 64            data (bool): True if a copy of the data should be made, otherwise only copy header.
 65
 66        Returns:
 67            a new HyImage instance.
 68        """
 69        if not data:
 70            return HyImage(None, header=self.header.copy(), projection=self.projection, affine=self.affine)
 71        else:
 72            return HyImage( self.data.copy(), header=self.header.copy(), projection=self.projection, affine=self.affine)
 73
 74    def T(self):
 75        """
 76        Return a transposed view of the data matrix (corresponding with the [y,x] indexing used by matplotlib, opencv etc.
 77        """
 78        return np.transpose(self.data, (1,0,2))
 79
 80    def xdim(self):
 81        """
 82        Return number of pixels in x (first dimension of data array)
 83        """
 84        return self.data.shape[0]
 85
 86    def ydim(self):
 87        """
 88        Return number of pixels in y (second dimension of data array)
 89        """
 90        return self.data.shape[1]
 91
 92    def aspx(self):
 93        """
 94        Return the aspect ratio of this image (width/height).
 95        """
 96        return self.ydim() / self.xdim()
 97
 98    def get_extent(self):
 99        """
100        Returns the width and height of this image in world coordinates.
101
102        Returns:
103            tuple with (width, height).
104        """
105        return self.xdim * self.pixel_size[0], self.ydim * self.pixel_size[1]
106
107    def set_projection(self,proj):
108        """
109        Set this project to an existing osgeo.osr.SpatialReference or GDAL georeference string.
110
111        Args:
112            proj (str, osgeo.osr.SpatialReference): the project to use as osgeo.osr.SpatialReference or GDAL georeference string.
113        """
114        if proj is None:
115            self.projection = None
116        else:
117            try:
118                from osgeo.osr import SpatialReference
119            except:
120                assert False, "Error - GDAL must be installed to work with spatial projections in hylite."
121            if isinstance(proj, SpatialReference):
122                self.projection = proj
123            elif isinstance(proj, str):
124                self.projection = SpatialReference(proj)
125            else:
126                print("Invalid project %s" % proj)
127                raise
128
129    def set_projection_EPSG(self,EPSG):
130        """
131        Sets this image project using an EPSG code.
132
133        Args:
134            EPSG (str): string EPSG code that can be passed to SpatialReference.SetFromUserInput(...).
135        """
136
137        try:
138            from osgeo.osr import SpatialReference
139        except:
140            assert False, "Error - GDAL must be installed to work with spatial projections in hylite."
141
142        self.projection = SpatialReference()
143        self.projection.SetFromUserInput(EPSG)
144
145    def get_projection_EPSG(self):
146        """
147        Gets a string describing this projections EPSG code (if it is an EPSG project).
148
149        Returns:
150            an EPSG code string of the format "EPSG:XXXX".
151        """
152        if self.projection is None:
153            return None
154        else:
155            return "%s:%s" % (self.projection.GetAttrValue("AUTHORITY",0),self.projection.GetAttrValue("AUTHORITY",1))
156
157    def pix_to_world(self, px, py, proj=None):
158        """
159        Take pixel coordinates and return world coordinates
160
161        Args:
162            px (int): the pixel x-coord.
163            py (int): the pixel y-coord.
164            proj (str, osr.SpatialReference): the coordinate system to use. Default (None) uses the same system as this image. Otherwise
165                   an osr.SpatialReference can be passed (HyImage.project), or an EPSG string (e.g. get_projection_EPSG(...)).
166        Returns:
167            the world coordinates in the coordinate system defined by get_projection_EPSG(...).
168        """
169
170        try:
171            from osgeo import osr
172            import osgeo.gdal as gdal
173            from osgeo import ogr
174        except:
175            assert False, "Error - GDAL must be installed to work with spatial projections in hylite."
176
177        # parse project
178        if proj is None:
179            proj = self.projection
180        elif isinstance(proj, str) or isinstance(proj, int):
181            epsg = proj
182            if isinstance(epsg, str):
183                try:
184                    epsg = int(str.split(':')[1])
185                except:
186                    assert False, "Error - %s is an invalid EPSG code." % proj
187            proj = osr.SpatialReference()
188            proj.ImportFromEPSG(epsg)
189
190        # check we have all the required info
191        assert isinstance(proj, osr.SpatialReference), "Error - invalid spatial reference %s" % proj
192        assert (not self.affine is None) and (
193            not self.projection is None), "Error - project information is undefined."
194
195        #project to world coordinates in this images project/world coords
196        x,y = gdal.ApplyGeoTransform(self.affine, px, py)
197
198        #project to target coords (if different)
199        if not proj.IsSameGeogCS(self.projection):
200            P = ogr.Geometry(ogr.wkbPoint)
201            if proj.EPSGTreatsAsNorthingEasting():
202                P.AddPoint(x, y)
203            else:
204                P.AddPoint(y, x)
205            P.AssignSpatialReference(self.projection)  # tell the point what coordinates it's in
206            P.TransformTo(proj)  # reproject it to the out spatial reference
207            x, y = P.GetX(), P.GetY()
208
209            #do we need to transpose?
210            if proj.EPSGTreatsAsLatLong():
211                x,y=y,x #we want lon,lat not lat,lon
212        return x, y
213
214    def world_to_pix(self, x, y, proj = None):
215        """
216        Take world coordinates and return pixel coordinates
217
218        Args:
219            x (float): the world x-coord.
220            y (float): the world y-coord.
221            proj (str, osr.SpatialReference): the coordinate system of the input coordinates. Default (None) uses the same system as this image. Otherwise
222                   an osr.SpatialReference can be passed (HyImage.project), or an EPSG string (e.g. get_projection_EPSG(...)).
223
224        Returns:
225            the pixel coordinates based on the affine transform stored in self.affine.
226        """
227
228        try:
229            from osgeo import osr
230            import osgeo.gdal as gdal
231            from osgeo import ogr
232        except:
233            assert False, "Error - GDAL must be installed to work with spatial projections in hylite."
234
235        # parse project
236        if proj is None:
237            proj = self.projection
238        elif isinstance(proj, str) or isinstance(proj, int):
239            epsg = proj
240            if isinstance(epsg, str):
241                try:
242                    epsg = int(str.split(':')[1])
243                except:
244                    assert False, "Error - %s is an invalid EPSG code." % proj
245            proj = osr.SpatialReference()
246            proj.ImportFromEPSG(epsg)
247
248
249        # check we have all the required info
250        assert isinstance(proj, osr.SpatialReference), "Error - invalid spatial reference %s" % proj
251        assert (not self.affine is None) and (not self.projection is None), "Error - project information is undefined."
252
253        # project to this images CS (if different)
254        if not proj.IsSameGeogCS(self.projection):
255            P = ogr.Geometry(ogr.wkbPoint)
256            if proj.EPSGTreatsAsNorthingEasting():
257                P.AddPoint(x, y)
258            else:
259                P.AddPoint(y, x)
260            P.AssignSpatialReference(proj)  # tell the point what coordinates it's in
261            P.AddPoint(x, y)
262            P.TransformTo(self.projection)  # reproject it to the out spatial reference
263            x, y = P.GetX(), P.GetY()
264            if self.projection.EPSGTreatsAsLatLong(): # do we need to transpose?
265                x, y = y, x  # we want lon,lat not lat,lon
266
267        inv = gdal.InvGeoTransform(self.affine)
268        assert not inv is None, "Error - could not invert affine transform?"
269
270        #apply
271        return gdal.ApplyGeoTransform(inv, x, y)
272
273    def flip(self, axis='x'):
274        """
275        Flip the image on the x or y axis.
276
277        Args:
278            axis (str): 'x' or 'y' or both 'xy'.
279        """
280
281        if 'x' in axis.lower():
282            self.data = np.flip(self.data,axis=0)
283        if 'y' in axis.lower():
284            self.data = np.flip(self.data,axis=1)
285
286    def rot90(self):
287        """
288        Rotate this image by 90 degrees by transposing the underlying data array. Combine with flip('x') or flip('y')
289        to achieve positive/negative rotations.
290        """
291        self.data = np.transpose( self.data, (1,0,2) )
292        self.push_to_header()
293
294    #####################################
295    ##IMAGE FILTERING
296    #####################################
297    def fill_holes(self):
298        """
299        Replaces nan pixel with an average of their neighbours, thus removing 1-pixel large holes from an image. Note that
300        for performance reasons this assumes that holes line up across bands. Note that this is not vectorized so very slow...
301        """
302
303        # perform greyscale dilation
304        dilate = self.data.copy()
305        mask = np.logical_not(np.isfinite(dilate))
306        dilate[mask] = 0
307        for b in range(self.band_count()):
308            dilate[:, :, b] = sp.ndimage.grey_dilation(dilate[:, :, b], size=(3, 3))
309
310        # map back to holes in dataset
311        self.data[mask] = dilate[mask]
312        #self.data[self.data == 0] = np.nan  # replace remaining 0's with nans
313
314    def blur(self, n=3):
315        """
316        Applies a gaussian kernel of size n to the image using OpenCV.
317
318        Args:
319            n (int): the dimensions of the gaussian kernel to convolve. Default is 3. Increase for more blurry results.
320        """
321        import cv2 # import this here to avoid errors if opencv is not installed properly
322
323        nanmask = np.isnan(self.data)
324        assert isinstance(n, int) and n >= 3, "Error - invalid kernel. N must be an integer > 3. "
325        kernel = np.ones((n, n), np.float32) / (n ** 2)
326        self.data = cv2.filter2D(self.data, -1, kernel)
327        self.data[nanmask] = np.nan  # remove mask
328
329    def erode(self, size=3, iterations=1):
330        """
331        Apply an erode filter to this image to expand background (nan) pixels. Refer to open-cv's erode
332        function for more details.
333
334        Args:
335            size (int): the size of the erode filter. Default is a 3x3 kernel.
336            iterations (int): the number of erode iterations. Default is 1.
337        """
338        import cv2 # import this here to avoid errors if opencv is not installed properly
339
340        # erode
341        kernel = np.ones((size, size), np.uint8)
342        if self.is_float():
343            mask = np.isfinite(self.data).any(axis=-1)
344            mask = cv2.erode(mask.astype(np.uint8), kernel, iterations=iterations)
345            self.data[mask == 0, :] = np.nan
346        else:
347            mask = (self.data != 0).any( axis=-1 )
348            mask = cv2.erode(mask.astype(np.uint8), kernel, iterations=iterations)
349            self.data[mask == 0, :] = 0
350
351    def resize(self, newdims : tuple, interpolation : int = 1):
352        """
353        Resize this image with opencv.
354
355        Args:
356            newdims (tuple): the new image dimensions.
357            interpolation (int): opencv interpolation method. Default is cv2.INTER_LINEAR.
358        """
359        import cv2 # import this here to avoid errors if opencv is not installed properly
360        self.data = cv2.resize(self.data, (newdims[1],newdims[0]), interpolation=interpolation)
361
362    def despeckle(self, size=5):
363        """
364        Despeckle each band of this image (independently) using a median filter.
365
366        Args:
367            size (int): the size of the median filter kernel. Default is 5. Must be an odd number.
368        """
369
370        assert (size % 2) == 1, "Error - size must be an odd integer"
371        import cv2 # import this here to avoid errors if opencv is not installed properly
372        if self.is_float():
373            self.data = cv2.medianBlur( self.data.astype(np.float32), size )
374        else:
375            self.data = cv2.medianBlur( self.data, size )
376
377    #####################################
378    ##FEATURES AND FEATURE MATCHING
379    ######################################
380    def get_keypoints(self, band, eq=False, mask=True, method='sift', cfac=0.0,bfac=0.0, **kwds):
381        """
382        Get feature descriptors from the specified band.
383
384        Args:
385            band (int,float,str,tuple): the band index (int) or wavelength (float) to extract features from. Alternatively, a tuple can be passed
386                    containing a range of bands (min : max) to average before feature matching.
387            eq (bool): True if the image should be histogram equalized first. Default is False.
388            mask (bool): True if 0 value pixels should be masked. Default is True.
389            method (str): the feature detector to use. Options are 'SIFT' and 'ORB' (faster but less accurate). Default is 'SIFT'.
390            cfac (float): contrast adjustment to apply to hyperspectral bands before matching. Default is 0.0.
391            bfac (float): brightness adjustment to apply to hyperspectral bands before matching. Default is 0.0.
392            **kwds: keyword arguments are passed to the opencv feature detector. For SIFT these are:
393
394                - contrastThreshold: default is 0.01.
395                - edgeThreshold: default is 10.
396                - sigma: default is 1.0
397
398                For ORB these are:
399
400                - nfeatures = the number of features to detect. Default is 5000.
401
402            Returns:
403                Tuple containing
404
405                    - k (ndarray): the keypoints detected
406                    - d (ndarray): corresponding feature descriptors
407         """
408        import cv2 # import this here to avoid errors if opencv is not installed properly
409
410        # get image
411        if isinstance(band, int) or isinstance(band, float): #single band
412            image = self.data[:, :, self.get_band_index(band)]
413        elif isinstance(band,tuple): #range of bands (averaged)
414            idx0 = self.get_band_index(band[0])
415            idx1 = self.get_band_index(band[1])
416
417            #deal with out of range errors
418            if idx0 is None:
419                idx0 = 0
420            if idx1 is None:
421                idx1 = self.band_count()
422
423            #average bands
424            image = np.nanmean(self.data[:,:,idx0:idx1],axis=2)
425        else:
426            assert False, "Error, unrecognised band %s" % band
427
428        #normalise image to range 0 - 1
429        image -= np.nanmin(image)
430        image = image / np.nanmax(image)
431
432        #apply brightness/contrast adjustment
433        image = (1.0+cfac)*image + bfac
434        image[image > 1.0] = 1.0
435        image[image < 0.0] = 0.0
436
437        #convert image to uint8 for opencv
438        image = np.uint8(255 * image)
439        if eq:
440            image = cv2.equalizeHist(image)
441
442        if mask:
443            mask = np.zeros(image.shape, dtype=np.uint8)
444            mask[image != 0] = 255  # include only non-zero pixels
445        else:
446            mask = None
447
448        if 'sift' in method.lower():  # SIFT
449
450            # setup default keywords
451            kwds["contrastThreshold"] = kwds.get("contrastThreshold", 0.01)
452            kwds["edgeThreshold"] = kwds.get("edgeThreshold", 10)
453            kwds["sigma"] = kwds.get("sigma", 1.0)
454
455            # make feature detector
456            #alg = cv2.xfeatures2d.SIFT_create(**kwds)
457            alg = cv2.SIFT_create()
458        elif 'orb' in method.lower():  # orb
459            kwds['nfeatures'] = kwds.get('nfeatures', 5000)
460            alg = cv2.ORB_create(scoreType=cv2.ORB_FAST_SCORE, **kwds)
461        else:
462            assert False, "Error - %s is not a recognised feature detector." % method
463
464        # detect keypoints
465        kp = alg.detect(image, mask)
466
467        # extract and return feature vectors
468        return alg.compute(image, kp)
469
470    @classmethod
471    def match_keypoints(cls, kp1, kp2, d1, d2, method='SIFT', dist=0.7, tree = 5, check = 100, min_count=5):
472        """
473        Compares keypoint feature vectors from two images and returns matching pairs.
474
475        Args:
476            kp1 (ndarray): keypoints from the first image
477            kp2 (ndarray): keypoints from the second image
478            d1 (ndarray): descriptors for the keypoints from the first image
479            d2 (ndarray): descriptors for the keypoints from the second image
480            method (str): the method used to calculate the feature descriptors. Should be 'sift' or 'orb'. Default is 'sift'.
481            dist (float): minimum match distance (0 to 1), default is 0.7
482            tree (int): not sure what this does? Default is 5. See open-cv docs.
483            check (int): ditto. Default is 100.
484            min_count (int): the minimum number of matches to consider a valid matching operation. If fewer matches are found,
485                       then the function returns None, None. Default is 5.
486        """
487        import cv2 # import this here to avoid errors if opencv is not installed properly
488        if 'sift' in method.lower():
489            algorithm = cv2.NORM_INF
490        elif 'orb' in method.lower():
491            algorithm = cv2.NORM_HAMMING
492        else:
493            assert False, "Error - unknown matching algorithm %s" % method
494
495        #calculate flann matches
496        index_params = dict(algorithm=algorithm, trees=tree)
497        search_params = dict(checks=check)
498        flann = cv2.FlannBasedMatcher(index_params, search_params)
499        matches = flann.knnMatch(d1, d2, k=2)
500
501        # store all the good matches as per Lowe's ratio test.
502        good = []
503        for m, n in matches:
504            if m.distance < dist * n.distance:
505                good.append(m)
506
507        if len(good) < min_count:
508            return None, None
509        else:
510            src_pts = np.float32([kp1[m.queryIdx].pt for m in good]).reshape(-1, 1, 2)
511            dst_pts = np.float32([kp2[m.trainIdx].pt for m in good]).reshape(-1, 1, 2)
512            return src_pts, dst_pts
513
514    ############################
515    ## Visualisation methods
516    ############################
517    def quick_plot(self, band=0, ax=None, bfac=0.0, cfac=0.0, samples=False, tscale=False, invert=False, rot=False, flipX=False, flipY=False,
518                   **kwds):
519        """
520        Plot a band using matplotlib.imshow(...).
521
522        Args:
523            band (str,int,float,tuple): the band name (string), index (integer) or wavelength (float) to plot. Default is 0. If a tuple is passed then
524                  each band in the tuple (string or index) will be mapped to rgb. Bands with negative wavelengths or indices will be inverted before plotting.
525            ax: an axis object to plot to. If none, plt.imshow( ... ) is used.
526            bfac (float): a brightness adjustment to apply to RGB mappings (-1 to 1)
527            cfac (float): a contrast adjustment to apply to RGB mappings (-1 to 1)
528            samples (bool): True if sample points (defined in the header file) should be plotted. Default is False. Otherwise, a list of
529                     [ (x,y), ... ] points can be passed.
530            tscale (bool): True if each band (for ternary images) should be scaled independently. Default is False.
531                    When using scaling, vmin and vmax can be used to set the clipping percentiles (integers) or
532                    (constant) values (float).
533            invert (bool) : True if each band should be inverted before plotting. Only works for multiband (ternary) images.
534            rot (bool): if True, the x and y axis will be flipped (90 degree rotation) before plotting. Default is False.
535            flipX (bool): if True, the x axis will be flipped before plotting (after applying rotations).
536            flipY (bool): if True, the y axis will be flippe before plotting (after applying rotations).
537            **kwds: keywords are passed to matplotlib.imshow( ... ), except for the following:
538
539                 - mask = a 2D boolean mask containing true if pixels should be drawn and false otherwise.
540                 - path = a file path to save the image too (at matching resolution; use fig.savefig(..) if you want to save the figure).
541                 - ticks = True if x- and y- ticks should be plotted. Default is False.
542                 - ps, pc = the size and color of sample points to plot. Can be constant or list.
543                 - figsize = a figsize for the figure to create (if ax is None).
544
545        Returns:
546            Tuple containing
547
548            - fig: matplotlib figure object
549            - ax:  matplotlib axes object. If a colorbar is created, (band is an integer or a float), then this will be stored in ax.cbar.
550        """
551
552        #create new axes?
553        if ax is None:
554            fig, ax = plt.subplots(figsize=kwds.pop('figsize', (18,18*self.ydim()/self.xdim()) ))
555
556        # deal with ticks
557        if not kwds.pop('ticks', False ):
558            ax.set_xticks([])
559            ax.set_yticks([])
560
561        #map individual band using colourmap
562        if isinstance(band, str) or isinstance(band, int) or isinstance(band, float):
563            #get band
564            if isinstance(band, str):
565                data = self.data[:, :, self.get_band_index(band)]
566            else:
567                data = self.data[:, :, self.get_band_index(np.abs(band))]
568            if not isinstance(band, str) and band < 0:
569                data = np.nanmax(data) - data # flip
570
571            # convert integer vmin and vmax values to percentiles
572            if 'vmin' in kwds:
573                if isinstance(kwds['vmin'], int):
574                    kwds['vmin'] = np.nanpercentile( data, kwds['vmin'] )
575            if 'vmax' in kwds:
576                if isinstance(kwds['vmax'], int):
577                    kwds['vmax'] = np.nanpercentile( data, kwds['vmax'] )
578
579            #mask nans (and apply custom mask)
580            mask = np.isnan(data)
581            if not np.isnan(self.header.get_data_ignore_value()):
582                mask = mask + data == self.header.get_data_ignore_value()
583            if 'mask' in kwds:
584                mask = mask + kwds.get('mask')
585                del kwds['mask']
586            data = np.ma.array(data, mask = mask > 0 )
587
588            # apply rotations and flipping
589            if rot:
590                data = data.T
591            if flipX:
592                data = data[::-1, :]
593            if flipY:
594                data = data[:, ::-1]
595
596            # save?
597            if 'path' in kwds:
598                path = kwds.pop('path')
599                from matplotlib.pyplot import imsave
600                if not os.path.exists(os.path.dirname(path)):
601                    os.makedirs(os.path.dirname(path)) # ensure output directory exists
602                imsave(path, data.T, **kwds)  # save the image
603
604            ax.cbar = ax.imshow(data.T, interpolation=kwds.pop('interpolation', 'none'), **kwds) # change default interpolation to None
605
606        #map 3 bands to RGB
607        elif isinstance(band, tuple) or isinstance(band, list):
608            #get band indices and range
609            rgb = []
610            for b in band:
611                if isinstance(b, str):
612                    rgb.append(self.get_band_index(b))
613                else:
614                    rgb.append(self.get_band_index(np.abs(b)))
615
616            #slice image (as copy) and map to 0 - 1
617            img = np.array(self.data[:, :, rgb]).copy()
618            if np.isnan(img).all():
619                print("Warning - image contains no data.")
620                return ax.get_figure(), ax
621
622            # invert if needed
623            if invert:
624                band = [-b for b in band]
625            for i,b in enumerate(band):
626                if not isinstance(b, str) and (b < 0):
627                    img[..., i] = np.nanmax(img[..., i]) - img[..., i]
628
629            # do scaling
630            if tscale: # scale bands independently
631                for b in range(3):
632                    mn = kwds.get("vmin", float(np.nanmin(img)))
633                    mx = kwds.get("vmax", float(np.nanmax(img)))
634                    if isinstance (mn, int):
635                        assert mn >= 0 and mn <= 100, "Error - integer vmin values must be a percentile."
636                        mn = float(np.nanpercentile(img[...,b], mn ))
637                    if isinstance (mx, int):
638                        assert mx >= 0 and mx <= 100, "Error - integer vmax values must be a percentile."
639                        mx = float(np.nanpercentile(img[...,b], mx ))
640                    img[...,b] = (img[..., b] - mn) / (mx - mn)
641            else: # scale bands together
642                mn = kwds.get("vmin", float(np.nanmin(img)))
643                mx = kwds.get("vmax", float(np.nanmax(img)))
644                if isinstance(mn, int):
645                    assert mn >= 0 and mn <= 100, "Error - integer vmin values must be a percentile."
646                    mn = float(np.nanpercentile(img, mn))
647                if isinstance(mx, int):
648                    assert mx >= 0 and mx <= 100, "Error - integer vmax values must be a percentile."
649                    mx = float(np.nanpercentile(img, mx))
650                img = (img - mn) / (mx - mn)
651
652            #apply brightness/contrast mapping
653            img = np.clip((1.0 + cfac) * img + bfac, 0, 1.0 )
654
655            #apply masking so background is white
656            img[np.logical_not( np.isfinite( img ) )] = 1.0
657            if 'mask' in kwds:
658                img[kwds.pop("mask"),:] = 1.0
659
660            # apply rotations and flipping
661            if rot:
662                img = np.transpose( img, (1,0,2) )
663            if flipX:
664                img = img[::-1, :, :]
665            if flipY:
666                img = img[:, ::-1, :]
667
668            # save?
669            if 'path' in kwds:
670                path = kwds.pop('path')
671                from matplotlib.pyplot import imsave
672                if not os.path.exists(os.path.dirname(path)):
673                    os.makedirs(os.path.dirname(path)) # ensure output directory exists
674                imsave(path, np.transpose( np.clip( img*255, 0, 255).astype(np.uint8), (1, 0, 2)))  # save the image
675
676            # plot samples?
677            ps = kwds.pop('ps', 5)
678            pc = kwds.pop('pc', 'r')
679            if samples:
680                if isinstance(samples, list) or isinstance(samples, np.ndarray):
681                    ax.scatter([s[0] for s in samples], [s[1] for s in samples], s=ps, c=pc)
682                else:
683                    for n in self.header.get_class_names():
684                        points = np.array(self.header.get_sample_points(n))
685                        ax.scatter(points[:, 0], points[:, 1], s=ps, c=pc)
686
687            #plot
688            ax.imshow(np.transpose(img, (1,0,2)), interpolation=kwds.pop('interpolation', 'none'), **kwds)
689            ax.cbar = None  # no colorbar
690
691        return ax.get_figure(), ax
692
693    def createGIF(self, path, bands=None, figsize=(10,10), fps=10, **kwds):
694        """
695        Create and save an animated gif that loops through the bands of the image.
696
697        Args:
698            path (str): the path to save the .gif
699            bands (tuple): Tuple containing the range of band indices to draw. Default is the whole range.
700            figsize (tuple): the size of the image to draw. Default is (10,10).
701            fps (int): the framerate (frames per second) of the gif. Default is 10.
702            **kwds: keywords are passed directly to matplotlib.imshow. Use this to specify cmap etc.
703        """
704
705        frames = []
706        if bands is None:
707            bands = (0,self.band_count())
708        else:
709            assert 0 < bands[0] < self.band_count(), "Error - invalid range."
710            assert 0 < bands[1] < self.band_count(), "Error - invalid range."
711            assert bands[1] > bands[0], "Error - invalid range."
712
713        #plot frames
714        for i in range(bands[0],bands[1]):
715            fig, ax = plt.subplots(figsize=figsize)
716            ax.imshow(self.data[:, :, i], **kwds)
717            fig.canvas.draw()
718            frames.append(np.frombuffer(fig.canvas.tostring_rgb(), dtype='uint8'))
719            frames[-1] = np.reshape(frames[-1], (fig.canvas.get_width_height()[1], fig.canvas.get_width_height()[0], 3))
720            plt.close(fig)
721
722        #save gif
723        imageio.mimsave( os.path.splitext(path)[0] + ".gif", frames, fps=fps)
724
725    ## masking
726    def drop_bbl(self, drop=True):
727        """
728        Remove bad bands as stored in the 'bbl' key in the image header. Note that this operates in-place.
729
730        Args:
731            drop (bool): True if bad bands should be completely dropped. If False, these bands will be kept but replaced with nans.
732        """
733        assert 'bbl' in self.header, "Please specify a bad band list ('bbl') in the image header, as per the ENVI format definition."
734        mask = self.header.get_list('bbl') == 0
735        self.data[...,mask] = np.nan
736        if drop:
737            self.delete_nan_bands(inplace=True)
738    
739    def mask(self, mask=None, flag=np.nan, invert=False, crop=False, bands=None):
740        """
741         Apply a mask to an image, flagging masked pixels with the specified value. Note that this applies the mask to the
742         image in-situ.
743
744         Args:
745            flag (float): the value to use for masked pixels. Default is np.nan
746            mask (ndarray): a numpy array defining the mask polygon of the format [[x1,y1],[x2,y2],...]. If None is passed then
747                    pickPolygon( ... ) is used to interactively define a polygon. If a file path is passed then the polygon
748                    will be loaded using np.load( ... ). Alternatively if mask.shape == image.shape[0,1] then it is treated as a
749                    binary image mask (must be boolean) and True values will be masked across all bands. Default is None.
750            invert (bool): if True, pixels within the polygon will be masked. If False, pixels outside the polygon are masked. Default is False.
751            crop (bool): True if rows/columns containing only zeros should be removed. Default is False.
752            bands (tuple): the bands of the image to plot if no mask is specified. If None, the middle band is used.
753
754         Returns:
755            Tuple containing
756
757            - mask (ndarray): a boolean array with True where pixels are masked and False elsewhere.
758            - poly (ndarray): the mask polygon array in the format described above. Useful if the polygon was interactively defined.
759         """
760
761        if mask is None:  # pick mask interactively
762            if bands is None:
763                bands = int(self.band_count() / 2)
764
765            regions = self.pickPolygons(region_names=["mask"], bands=bands)
766
767            # the user bailed without picking a mask?
768            if len(regions) == 0:
769                print("Warning - no mask picked/applied.")
770                return
771
772            # extract polygon mask
773            mask = regions[0]
774
775        # convert polygon mask to binary mask
776        if mask.shape[1] == 2:
777
778            # build meshgrid with pixel coords
779            xx, yy = np.meshgrid(np.arange(self.xdim()), np.arange(self.ydim()))
780            xx = xx.flatten()
781            yy = yy.flatten()
782            points = np.vstack([xx, yy]).T  # coordinates of each pixel
783
784            # calculate per-pixel mask
785            mask = path.Path(mask).contains_points(points)
786            mask = mask.reshape((self.ydim(), self.xdim())).T
787
788            # flip as we want to mask (==True) outside points (unless invert is true)
789            if not invert:
790                mask = np.logical_not(mask)
791
792        # apply binary image mask
793        assert mask.shape[0] == self.data.shape[0] and mask.shape[1] == self.data.shape[1], \
794            "Error - mask shape %s does not match image shape %s" % (mask.shape, self.data.shape)
795        for b in range(self.band_count()):
796            self.data[:, :, b][mask] = flag
797
798        # crop image
799        if crop:
800            # calculate non-masked pixels
801            valid = np.logical_not(mask)
802
803            # integrate along axes
804            xdata = np.sum(valid, axis=1) > 0.0
805            ydata = np.sum(valid, axis=0) > 0.0
806
807            # calculate domain containing valid pixels
808            xmin = np.argmax(xdata)
809            xmax = xdata.shape[0] - np.argmax(xdata[::-1])
810            ymin = np.argmax(ydata)
811            ymax = ydata.shape[0] - np.argmax(ydata[::-1])
812
813            # crop
814            self.data = self.data[xmin:xmax, ymin:ymax, :]
815
816        return mask
817
818    def crop_to_data(self):
819        """
820        Remove padding of nan or zero pixels from image. Note that this is performed in place.
821        """
822
823        valid = np.isfinite(self.data).any(axis=-1) & (self.data != 0).any(axis=-1)
824        ymin, ymax = np.percentile(np.argwhere(np.sum(valid, axis=0) != 0), (0, 100))
825        xmin, xmax = np.percentile(np.argwhere(np.sum(valid, axis=1) != 0), (0, 100))
826        self.data = self.data[int(xmin):int(xmax), int(ymin):int(ymax), :]  # do clipping
827
828    ##################################################
829    ## Interactive tools for picking regions/pixels
830    ##################################################
831    def pickPolygons(self, region_names, bands=0):
832        """
833        Creates a matplotlib gui for selecting polygon regions in an image.
834
835        Args:
836            names (list, str): a list containing the names of the regions to pick. If a string is passed only one name is used.
837            bands (tuple): the bands of the image to plot.
838        """
839
840        if isinstance(region_names, str):
841            region_names = [region_names]
842
843        assert isinstance(region_names, list), "Error - names must be a list or a string."
844
845        # set matplotlib backend
846        backend = matplotlib.get_backend()
847        matplotlib.use('Qt5Agg')  # need this backend for ROIPoly to work
848
849        # plot image and extract roi's
850        fig, ax = self.quick_plot(bands)
851        roi = MultiRoi(roi_names=region_names)
852        plt.close(fig)  # close figure
853
854        # extract regions
855        regions = []
856        for name, r in roi.rois.items():
857            # store region
858            x = r.x
859            y = r.y
860            regions.append(np.vstack([x, y]).T)
861
862        # restore matplotlib backend (if possible)
863        try:
864            matplotlib.use(backend)
865        except:
866            print("Warning: could not reset matplotlib backend. Plots will remain interactive...")
867            pass
868
869        return regions
870
871    def pickPoints(self, n=-1, bands=hylite.RGB, integer=True, title="Pick Points", **kwds):
872        """
873        Creates a matplotlib gui for picking pixels from an image.
874
875        Args:
876            n (int): the number of pixels to pick, or -1 if the user can select as many as they wish. Default is -1.
877            bands (tuple): the bands of the image to plot. Default is HyImage.RGB
878            integer (bool): True if points coordinates should be cast to integers (for use as indices). Default is True.
879            title (str): The title of the point picking window.
880            **kwds: Keywords are passed to HyImage.quick_plot( ... ).
881
882        Returns:
883            A list containing the picked point coordinates [ (x1,y1), (x2,y2), ... ].
884        """
885
886        # set matplotlib backend
887        backend = matplotlib.get_backend()
888        matplotlib.use('Qt5Agg')  # need this backend for ROIPoly to work
889
890        # create figure
891        fig, ax = self.quick_plot( bands, **kwds )
892        ax.set_title(title)
893
894        # get points
895        points = fig.ginput( n )
896
897        if integer:
898            points = [ (int(p[0]), int(p[1])) for p in points ]
899
900        # restore matplotlib backend (if possible)
901        try:
902            matplotlib.use(backend)
903        except:
904            print("Warning: could not reset matplotlib backend. Plots will remain interactive...")
905            pass
906
907        return points
908
909    def pickSamples(self, names=None, store=True, **kwds):
910        """
911        Pick sample probe points and store these in the image header file.
912
913        Args:
914            names (str, list): the name of the sample to pick, or a list of names to pick multiple.
915            store (bool): True if sample should be stored in the image header file (for later access). Default is True.
916            **kwds: Keywords are passed to HyImage.quick_plot( ... )
917
918        Returns:
919            a list containing a list of points for each sample.
920        """
921
922        if isinstance(names, str):
923            names = [names]
924
925        # pick points
926        points = []
927        for s in names:
928            pnts = self.pickPoints(title="%s" % s, **kwds)
929            if store:
930                self.header['sample %s' % s] = pnts # store in header
931            points.append(pnts)
932        # add class to header file
933        if store:
934            cls_names = self.header.get_class_names()
935            if cls_names is None:
936                cls_names = []
937            self.header['class names'] = cls_names + names
938
939        return points

A class for hyperspectral image data. These can be individual scenes or hyperspectral orthoimages.

HyImage(data, **kwds)
25    def __init__(self, data, **kwds):
26        """
27        Args:
28            data (ndarray): a numpy array such that data[x][y][band] gives each pixel value.
29            **kwds:
30                wav = A numpy array containing band wavelengths for this image.
31                affine = an affine transform of the format returned by GDAL.GetGeoTransform().
32                projection = string defining the project. Default is None.
33                sensor = sensor name. Default is "unknown".
34                header = path to associated header file. Default is None.
35        """
36
37        #call constructor for HyData
38        super().__init__(data, **kwds)
39
40        # special case - if dataset only has oneband, slice it so it still has
41        # the format data[x,y,b].
42        if not self.data is None:
43            if len(self.data.shape) == 1:
44                self.data = self.data[None, None, :] # single pixel image
45            if len(self.data.shape) == 2:
46                self.data = self.data[:, :, None] # single band iamge
47
48        #load any additional project information (specific to images)
49        self.set_projection(kwds.get("projection",None))
50        self.affine = kwds.get("affine",[0,1,0,0,0,1])
51
52        # wavelengths
53        if 'wav' in kwds:
54            self.set_wavelengths(kwds['wav'])
55
56        #special header formatting
57        self.header['file type'] = 'ENVI Standard'
Arguments:
  • data (ndarray): a numpy array such that data[x][y][band] gives each pixel value.
  • **kwds: wav = A numpy array containing band wavelengths for this image. affine = an affine transform of the format returned by GDAL.GetGeoTransform(). projection = string defining the project. Default is None. sensor = sensor name. Default is "unknown". header = path to associated header file. Default is None.
def copy(self, data=True):
59    def copy(self,data=True):
60        """
61        Make a deep copy of this image instance.
62
63        Args:
64            data (bool): True if a copy of the data should be made, otherwise only copy header.
65
66        Returns:
67            a new HyImage instance.
68        """
69        if not data:
70            return HyImage(None, header=self.header.copy(), projection=self.projection, affine=self.affine)
71        else:
72            return HyImage( self.data.copy(), header=self.header.copy(), projection=self.projection, affine=self.affine)

Make a deep copy of this image instance.

Arguments:
  • data (bool): True if a copy of the data should be made, otherwise only copy header.
Returns:

a new HyImage instance.

def T(self):
74    def T(self):
75        """
76        Return a transposed view of the data matrix (corresponding with the [y,x] indexing used by matplotlib, opencv etc.
77        """
78        return np.transpose(self.data, (1,0,2))

Return a transposed view of the data matrix (corresponding with the [y,x] indexing used by matplotlib, opencv etc.

def xdim(self):
80    def xdim(self):
81        """
82        Return number of pixels in x (first dimension of data array)
83        """
84        return self.data.shape[0]

Return number of pixels in x (first dimension of data array)

def ydim(self):
86    def ydim(self):
87        """
88        Return number of pixels in y (second dimension of data array)
89        """
90        return self.data.shape[1]

Return number of pixels in y (second dimension of data array)

def aspx(self):
92    def aspx(self):
93        """
94        Return the aspect ratio of this image (width/height).
95        """
96        return self.ydim() / self.xdim()

Return the aspect ratio of this image (width/height).

def get_extent(self):
 98    def get_extent(self):
 99        """
100        Returns the width and height of this image in world coordinates.
101
102        Returns:
103            tuple with (width, height).
104        """
105        return self.xdim * self.pixel_size[0], self.ydim * self.pixel_size[1]

Returns the width and height of this image in world coordinates.

Returns:

tuple with (width, height).

def set_projection(self, proj):
107    def set_projection(self,proj):
108        """
109        Set this project to an existing osgeo.osr.SpatialReference or GDAL georeference string.
110
111        Args:
112            proj (str, osgeo.osr.SpatialReference): the project to use as osgeo.osr.SpatialReference or GDAL georeference string.
113        """
114        if proj is None:
115            self.projection = None
116        else:
117            try:
118                from osgeo.osr import SpatialReference
119            except:
120                assert False, "Error - GDAL must be installed to work with spatial projections in hylite."
121            if isinstance(proj, SpatialReference):
122                self.projection = proj
123            elif isinstance(proj, str):
124                self.projection = SpatialReference(proj)
125            else:
126                print("Invalid project %s" % proj)
127                raise

Set this project to an existing osgeo.osr.SpatialReference or GDAL georeference string.

Arguments:
  • proj (str, osgeo.osr.SpatialReference): the project to use as osgeo.osr.SpatialReference or GDAL georeference string.
def set_projection_EPSG(self, EPSG):
129    def set_projection_EPSG(self,EPSG):
130        """
131        Sets this image project using an EPSG code.
132
133        Args:
134            EPSG (str): string EPSG code that can be passed to SpatialReference.SetFromUserInput(...).
135        """
136
137        try:
138            from osgeo.osr import SpatialReference
139        except:
140            assert False, "Error - GDAL must be installed to work with spatial projections in hylite."
141
142        self.projection = SpatialReference()
143        self.projection.SetFromUserInput(EPSG)

Sets this image project using an EPSG code.

Arguments:
  • EPSG (str): string EPSG code that can be passed to SpatialReference.SetFromUserInput(...).
def get_projection_EPSG(self):
145    def get_projection_EPSG(self):
146        """
147        Gets a string describing this projections EPSG code (if it is an EPSG project).
148
149        Returns:
150            an EPSG code string of the format "EPSG:XXXX".
151        """
152        if self.projection is None:
153            return None
154        else:
155            return "%s:%s" % (self.projection.GetAttrValue("AUTHORITY",0),self.projection.GetAttrValue("AUTHORITY",1))

Gets a string describing this projections EPSG code (if it is an EPSG project).

Returns:

an EPSG code string of the format "EPSG:XXXX".

def pix_to_world(self, px, py, proj=None):
157    def pix_to_world(self, px, py, proj=None):
158        """
159        Take pixel coordinates and return world coordinates
160
161        Args:
162            px (int): the pixel x-coord.
163            py (int): the pixel y-coord.
164            proj (str, osr.SpatialReference): the coordinate system to use. Default (None) uses the same system as this image. Otherwise
165                   an osr.SpatialReference can be passed (HyImage.project), or an EPSG string (e.g. get_projection_EPSG(...)).
166        Returns:
167            the world coordinates in the coordinate system defined by get_projection_EPSG(...).
168        """
169
170        try:
171            from osgeo import osr
172            import osgeo.gdal as gdal
173            from osgeo import ogr
174        except:
175            assert False, "Error - GDAL must be installed to work with spatial projections in hylite."
176
177        # parse project
178        if proj is None:
179            proj = self.projection
180        elif isinstance(proj, str) or isinstance(proj, int):
181            epsg = proj
182            if isinstance(epsg, str):
183                try:
184                    epsg = int(str.split(':')[1])
185                except:
186                    assert False, "Error - %s is an invalid EPSG code." % proj
187            proj = osr.SpatialReference()
188            proj.ImportFromEPSG(epsg)
189
190        # check we have all the required info
191        assert isinstance(proj, osr.SpatialReference), "Error - invalid spatial reference %s" % proj
192        assert (not self.affine is None) and (
193            not self.projection is None), "Error - project information is undefined."
194
195        #project to world coordinates in this images project/world coords
196        x,y = gdal.ApplyGeoTransform(self.affine, px, py)
197
198        #project to target coords (if different)
199        if not proj.IsSameGeogCS(self.projection):
200            P = ogr.Geometry(ogr.wkbPoint)
201            if proj.EPSGTreatsAsNorthingEasting():
202                P.AddPoint(x, y)
203            else:
204                P.AddPoint(y, x)
205            P.AssignSpatialReference(self.projection)  # tell the point what coordinates it's in
206            P.TransformTo(proj)  # reproject it to the out spatial reference
207            x, y = P.GetX(), P.GetY()
208
209            #do we need to transpose?
210            if proj.EPSGTreatsAsLatLong():
211                x,y=y,x #we want lon,lat not lat,lon
212        return x, y

Take pixel coordinates and return world coordinates

Arguments:
  • px (int): the pixel x-coord.
  • py (int): the pixel y-coord.
  • proj (str, osr.SpatialReference): the coordinate system to use. Default (None) uses the same system as this image. Otherwise an osr.SpatialReference can be passed (HyImage.project), or an EPSG string (e.g. get_projection_EPSG(...)).
Returns:

the world coordinates in the coordinate system defined by get_projection_EPSG(...).

def world_to_pix(self, x, y, proj=None):
214    def world_to_pix(self, x, y, proj = None):
215        """
216        Take world coordinates and return pixel coordinates
217
218        Args:
219            x (float): the world x-coord.
220            y (float): the world y-coord.
221            proj (str, osr.SpatialReference): the coordinate system of the input coordinates. Default (None) uses the same system as this image. Otherwise
222                   an osr.SpatialReference can be passed (HyImage.project), or an EPSG string (e.g. get_projection_EPSG(...)).
223
224        Returns:
225            the pixel coordinates based on the affine transform stored in self.affine.
226        """
227
228        try:
229            from osgeo import osr
230            import osgeo.gdal as gdal
231            from osgeo import ogr
232        except:
233            assert False, "Error - GDAL must be installed to work with spatial projections in hylite."
234
235        # parse project
236        if proj is None:
237            proj = self.projection
238        elif isinstance(proj, str) or isinstance(proj, int):
239            epsg = proj
240            if isinstance(epsg, str):
241                try:
242                    epsg = int(str.split(':')[1])
243                except:
244                    assert False, "Error - %s is an invalid EPSG code." % proj
245            proj = osr.SpatialReference()
246            proj.ImportFromEPSG(epsg)
247
248
249        # check we have all the required info
250        assert isinstance(proj, osr.SpatialReference), "Error - invalid spatial reference %s" % proj
251        assert (not self.affine is None) and (not self.projection is None), "Error - project information is undefined."
252
253        # project to this images CS (if different)
254        if not proj.IsSameGeogCS(self.projection):
255            P = ogr.Geometry(ogr.wkbPoint)
256            if proj.EPSGTreatsAsNorthingEasting():
257                P.AddPoint(x, y)
258            else:
259                P.AddPoint(y, x)
260            P.AssignSpatialReference(proj)  # tell the point what coordinates it's in
261            P.AddPoint(x, y)
262            P.TransformTo(self.projection)  # reproject it to the out spatial reference
263            x, y = P.GetX(), P.GetY()
264            if self.projection.EPSGTreatsAsLatLong(): # do we need to transpose?
265                x, y = y, x  # we want lon,lat not lat,lon
266
267        inv = gdal.InvGeoTransform(self.affine)
268        assert not inv is None, "Error - could not invert affine transform?"
269
270        #apply
271        return gdal.ApplyGeoTransform(inv, x, y)

Take world coordinates and return pixel coordinates

Arguments:
  • x (float): the world x-coord.
  • y (float): the world y-coord.
  • proj (str, osr.SpatialReference): the coordinate system of the input coordinates. Default (None) uses the same system as this image. Otherwise an osr.SpatialReference can be passed (HyImage.project), or an EPSG string (e.g. get_projection_EPSG(...)).
Returns:

the pixel coordinates based on the affine transform stored in self.affine.

def flip(self, axis='x'):
273    def flip(self, axis='x'):
274        """
275        Flip the image on the x or y axis.
276
277        Args:
278            axis (str): 'x' or 'y' or both 'xy'.
279        """
280
281        if 'x' in axis.lower():
282            self.data = np.flip(self.data,axis=0)
283        if 'y' in axis.lower():
284            self.data = np.flip(self.data,axis=1)

Flip the image on the x or y axis.

Arguments:
  • axis (str): 'x' or 'y' or both 'xy'.
def rot90(self):
286    def rot90(self):
287        """
288        Rotate this image by 90 degrees by transposing the underlying data array. Combine with flip('x') or flip('y')
289        to achieve positive/negative rotations.
290        """
291        self.data = np.transpose( self.data, (1,0,2) )
292        self.push_to_header()

Rotate this image by 90 degrees by transposing the underlying data array. Combine with flip('x') or flip('y') to achieve positive/negative rotations.

def fill_holes(self):
297    def fill_holes(self):
298        """
299        Replaces nan pixel with an average of their neighbours, thus removing 1-pixel large holes from an image. Note that
300        for performance reasons this assumes that holes line up across bands. Note that this is not vectorized so very slow...
301        """
302
303        # perform greyscale dilation
304        dilate = self.data.copy()
305        mask = np.logical_not(np.isfinite(dilate))
306        dilate[mask] = 0
307        for b in range(self.band_count()):
308            dilate[:, :, b] = sp.ndimage.grey_dilation(dilate[:, :, b], size=(3, 3))
309
310        # map back to holes in dataset
311        self.data[mask] = dilate[mask]
312        #self.data[self.data == 0] = np.nan  # replace remaining 0's with nans

Replaces nan pixel with an average of their neighbours, thus removing 1-pixel large holes from an image. Note that for performance reasons this assumes that holes line up across bands. Note that this is not vectorized so very slow...

def blur(self, n=3):
314    def blur(self, n=3):
315        """
316        Applies a gaussian kernel of size n to the image using OpenCV.
317
318        Args:
319            n (int): the dimensions of the gaussian kernel to convolve. Default is 3. Increase for more blurry results.
320        """
321        import cv2 # import this here to avoid errors if opencv is not installed properly
322
323        nanmask = np.isnan(self.data)
324        assert isinstance(n, int) and n >= 3, "Error - invalid kernel. N must be an integer > 3. "
325        kernel = np.ones((n, n), np.float32) / (n ** 2)
326        self.data = cv2.filter2D(self.data, -1, kernel)
327        self.data[nanmask] = np.nan  # remove mask

Applies a gaussian kernel of size n to the image using OpenCV.

Arguments:
  • n (int): the dimensions of the gaussian kernel to convolve. Default is 3. Increase for more blurry results.
def erode(self, size=3, iterations=1):
329    def erode(self, size=3, iterations=1):
330        """
331        Apply an erode filter to this image to expand background (nan) pixels. Refer to open-cv's erode
332        function for more details.
333
334        Args:
335            size (int): the size of the erode filter. Default is a 3x3 kernel.
336            iterations (int): the number of erode iterations. Default is 1.
337        """
338        import cv2 # import this here to avoid errors if opencv is not installed properly
339
340        # erode
341        kernel = np.ones((size, size), np.uint8)
342        if self.is_float():
343            mask = np.isfinite(self.data).any(axis=-1)
344            mask = cv2.erode(mask.astype(np.uint8), kernel, iterations=iterations)
345            self.data[mask == 0, :] = np.nan
346        else:
347            mask = (self.data != 0).any( axis=-1 )
348            mask = cv2.erode(mask.astype(np.uint8), kernel, iterations=iterations)
349            self.data[mask == 0, :] = 0

Apply an erode filter to this image to expand background (nan) pixels. Refer to open-cv's erode function for more details.

Arguments:
  • size (int): the size of the erode filter. Default is a 3x3 kernel.
  • iterations (int): the number of erode iterations. Default is 1.
def resize(self, newdims: tuple, interpolation: int = 1):
351    def resize(self, newdims : tuple, interpolation : int = 1):
352        """
353        Resize this image with opencv.
354
355        Args:
356            newdims (tuple): the new image dimensions.
357            interpolation (int): opencv interpolation method. Default is cv2.INTER_LINEAR.
358        """
359        import cv2 # import this here to avoid errors if opencv is not installed properly
360        self.data = cv2.resize(self.data, (newdims[1],newdims[0]), interpolation=interpolation)

Resize this image with opencv.

Arguments:
  • newdims (tuple): the new image dimensions.
  • interpolation (int): opencv interpolation method. Default is cv2.INTER_LINEAR.
def despeckle(self, size=5):
362    def despeckle(self, size=5):
363        """
364        Despeckle each band of this image (independently) using a median filter.
365
366        Args:
367            size (int): the size of the median filter kernel. Default is 5. Must be an odd number.
368        """
369
370        assert (size % 2) == 1, "Error - size must be an odd integer"
371        import cv2 # import this here to avoid errors if opencv is not installed properly
372        if self.is_float():
373            self.data = cv2.medianBlur( self.data.astype(np.float32), size )
374        else:
375            self.data = cv2.medianBlur( self.data, size )

Despeckle each band of this image (independently) using a median filter.

Arguments:
  • size (int): the size of the median filter kernel. Default is 5. Must be an odd number.
def get_keypoints( self, band, eq=False, mask=True, method='sift', cfac=0.0, bfac=0.0, **kwds):
380    def get_keypoints(self, band, eq=False, mask=True, method='sift', cfac=0.0,bfac=0.0, **kwds):
381        """
382        Get feature descriptors from the specified band.
383
384        Args:
385            band (int,float,str,tuple): the band index (int) or wavelength (float) to extract features from. Alternatively, a tuple can be passed
386                    containing a range of bands (min : max) to average before feature matching.
387            eq (bool): True if the image should be histogram equalized first. Default is False.
388            mask (bool): True if 0 value pixels should be masked. Default is True.
389            method (str): the feature detector to use. Options are 'SIFT' and 'ORB' (faster but less accurate). Default is 'SIFT'.
390            cfac (float): contrast adjustment to apply to hyperspectral bands before matching. Default is 0.0.
391            bfac (float): brightness adjustment to apply to hyperspectral bands before matching. Default is 0.0.
392            **kwds: keyword arguments are passed to the opencv feature detector. For SIFT these are:
393
394                - contrastThreshold: default is 0.01.
395                - edgeThreshold: default is 10.
396                - sigma: default is 1.0
397
398                For ORB these are:
399
400                - nfeatures = the number of features to detect. Default is 5000.
401
402            Returns:
403                Tuple containing
404
405                    - k (ndarray): the keypoints detected
406                    - d (ndarray): corresponding feature descriptors
407         """
408        import cv2 # import this here to avoid errors if opencv is not installed properly
409
410        # get image
411        if isinstance(band, int) or isinstance(band, float): #single band
412            image = self.data[:, :, self.get_band_index(band)]
413        elif isinstance(band,tuple): #range of bands (averaged)
414            idx0 = self.get_band_index(band[0])
415            idx1 = self.get_band_index(band[1])
416
417            #deal with out of range errors
418            if idx0 is None:
419                idx0 = 0
420            if idx1 is None:
421                idx1 = self.band_count()
422
423            #average bands
424            image = np.nanmean(self.data[:,:,idx0:idx1],axis=2)
425        else:
426            assert False, "Error, unrecognised band %s" % band
427
428        #normalise image to range 0 - 1
429        image -= np.nanmin(image)
430        image = image / np.nanmax(image)
431
432        #apply brightness/contrast adjustment
433        image = (1.0+cfac)*image + bfac
434        image[image > 1.0] = 1.0
435        image[image < 0.0] = 0.0
436
437        #convert image to uint8 for opencv
438        image = np.uint8(255 * image)
439        if eq:
440            image = cv2.equalizeHist(image)
441
442        if mask:
443            mask = np.zeros(image.shape, dtype=np.uint8)
444            mask[image != 0] = 255  # include only non-zero pixels
445        else:
446            mask = None
447
448        if 'sift' in method.lower():  # SIFT
449
450            # setup default keywords
451            kwds["contrastThreshold"] = kwds.get("contrastThreshold", 0.01)
452            kwds["edgeThreshold"] = kwds.get("edgeThreshold", 10)
453            kwds["sigma"] = kwds.get("sigma", 1.0)
454
455            # make feature detector
456            #alg = cv2.xfeatures2d.SIFT_create(**kwds)
457            alg = cv2.SIFT_create()
458        elif 'orb' in method.lower():  # orb
459            kwds['nfeatures'] = kwds.get('nfeatures', 5000)
460            alg = cv2.ORB_create(scoreType=cv2.ORB_FAST_SCORE, **kwds)
461        else:
462            assert False, "Error - %s is not a recognised feature detector." % method
463
464        # detect keypoints
465        kp = alg.detect(image, mask)
466
467        # extract and return feature vectors
468        return alg.compute(image, kp)

Get feature descriptors from the specified band.

Arguments:
  • band (int,float,str,tuple): the band index (int) or wavelength (float) to extract features from. Alternatively, a tuple can be passed containing a range of bands (min : max) to average before feature matching.
  • eq (bool): True if the image should be histogram equalized first. Default is False.
  • mask (bool): True if 0 value pixels should be masked. Default is True.
  • method (str): the feature detector to use. Options are 'SIFT' and 'ORB' (faster but less accurate). Default is 'SIFT'.
  • cfac (float): contrast adjustment to apply to hyperspectral bands before matching. Default is 0.0.
  • bfac (float): brightness adjustment to apply to hyperspectral bands before matching. Default is 0.0.
  • **kwds: keyword arguments are passed to the opencv feature detector. For SIFT these are:

    • contrastThreshold: default is 0.01.
    • edgeThreshold: default is 10.
    • sigma: default is 1.0

    For ORB these are:

    • nfeatures = the number of features to detect. Default is 5000.
  • Returns: Tuple containing

    • k (ndarray): the keypoints detected
    • d (ndarray): corresponding feature descriptors
@classmethod
def match_keypoints( cls, kp1, kp2, d1, d2, method='SIFT', dist=0.7, tree=5, check=100, min_count=5):
470    @classmethod
471    def match_keypoints(cls, kp1, kp2, d1, d2, method='SIFT', dist=0.7, tree = 5, check = 100, min_count=5):
472        """
473        Compares keypoint feature vectors from two images and returns matching pairs.
474
475        Args:
476            kp1 (ndarray): keypoints from the first image
477            kp2 (ndarray): keypoints from the second image
478            d1 (ndarray): descriptors for the keypoints from the first image
479            d2 (ndarray): descriptors for the keypoints from the second image
480            method (str): the method used to calculate the feature descriptors. Should be 'sift' or 'orb'. Default is 'sift'.
481            dist (float): minimum match distance (0 to 1), default is 0.7
482            tree (int): not sure what this does? Default is 5. See open-cv docs.
483            check (int): ditto. Default is 100.
484            min_count (int): the minimum number of matches to consider a valid matching operation. If fewer matches are found,
485                       then the function returns None, None. Default is 5.
486        """
487        import cv2 # import this here to avoid errors if opencv is not installed properly
488        if 'sift' in method.lower():
489            algorithm = cv2.NORM_INF
490        elif 'orb' in method.lower():
491            algorithm = cv2.NORM_HAMMING
492        else:
493            assert False, "Error - unknown matching algorithm %s" % method
494
495        #calculate flann matches
496        index_params = dict(algorithm=algorithm, trees=tree)
497        search_params = dict(checks=check)
498        flann = cv2.FlannBasedMatcher(index_params, search_params)
499        matches = flann.knnMatch(d1, d2, k=2)
500
501        # store all the good matches as per Lowe's ratio test.
502        good = []
503        for m, n in matches:
504            if m.distance < dist * n.distance:
505                good.append(m)
506
507        if len(good) < min_count:
508            return None, None
509        else:
510            src_pts = np.float32([kp1[m.queryIdx].pt for m in good]).reshape(-1, 1, 2)
511            dst_pts = np.float32([kp2[m.trainIdx].pt for m in good]).reshape(-1, 1, 2)
512            return src_pts, dst_pts

Compares keypoint feature vectors from two images and returns matching pairs.

Arguments:
  • kp1 (ndarray): keypoints from the first image
  • kp2 (ndarray): keypoints from the second image
  • d1 (ndarray): descriptors for the keypoints from the first image
  • d2 (ndarray): descriptors for the keypoints from the second image
  • method (str): the method used to calculate the feature descriptors. Should be 'sift' or 'orb'. Default is 'sift'.
  • dist (float): minimum match distance (0 to 1), default is 0.7
  • tree (int): not sure what this does? Default is 5. See open-cv docs.
  • check (int): ditto. Default is 100.
  • min_count (int): the minimum number of matches to consider a valid matching operation. If fewer matches are found, then the function returns None, None. Default is 5.
def quick_plot( self, band=0, ax=None, bfac=0.0, cfac=0.0, samples=False, tscale=False, invert=False, rot=False, flipX=False, flipY=False, **kwds):
517    def quick_plot(self, band=0, ax=None, bfac=0.0, cfac=0.0, samples=False, tscale=False, invert=False, rot=False, flipX=False, flipY=False,
518                   **kwds):
519        """
520        Plot a band using matplotlib.imshow(...).
521
522        Args:
523            band (str,int,float,tuple): the band name (string), index (integer) or wavelength (float) to plot. Default is 0. If a tuple is passed then
524                  each band in the tuple (string or index) will be mapped to rgb. Bands with negative wavelengths or indices will be inverted before plotting.
525            ax: an axis object to plot to. If none, plt.imshow( ... ) is used.
526            bfac (float): a brightness adjustment to apply to RGB mappings (-1 to 1)
527            cfac (float): a contrast adjustment to apply to RGB mappings (-1 to 1)
528            samples (bool): True if sample points (defined in the header file) should be plotted. Default is False. Otherwise, a list of
529                     [ (x,y), ... ] points can be passed.
530            tscale (bool): True if each band (for ternary images) should be scaled independently. Default is False.
531                    When using scaling, vmin and vmax can be used to set the clipping percentiles (integers) or
532                    (constant) values (float).
533            invert (bool) : True if each band should be inverted before plotting. Only works for multiband (ternary) images.
534            rot (bool): if True, the x and y axis will be flipped (90 degree rotation) before plotting. Default is False.
535            flipX (bool): if True, the x axis will be flipped before plotting (after applying rotations).
536            flipY (bool): if True, the y axis will be flippe before plotting (after applying rotations).
537            **kwds: keywords are passed to matplotlib.imshow( ... ), except for the following:
538
539                 - mask = a 2D boolean mask containing true if pixels should be drawn and false otherwise.
540                 - path = a file path to save the image too (at matching resolution; use fig.savefig(..) if you want to save the figure).
541                 - ticks = True if x- and y- ticks should be plotted. Default is False.
542                 - ps, pc = the size and color of sample points to plot. Can be constant or list.
543                 - figsize = a figsize for the figure to create (if ax is None).
544
545        Returns:
546            Tuple containing
547
548            - fig: matplotlib figure object
549            - ax:  matplotlib axes object. If a colorbar is created, (band is an integer or a float), then this will be stored in ax.cbar.
550        """
551
552        #create new axes?
553        if ax is None:
554            fig, ax = plt.subplots(figsize=kwds.pop('figsize', (18,18*self.ydim()/self.xdim()) ))
555
556        # deal with ticks
557        if not kwds.pop('ticks', False ):
558            ax.set_xticks([])
559            ax.set_yticks([])
560
561        #map individual band using colourmap
562        if isinstance(band, str) or isinstance(band, int) or isinstance(band, float):
563            #get band
564            if isinstance(band, str):
565                data = self.data[:, :, self.get_band_index(band)]
566            else:
567                data = self.data[:, :, self.get_band_index(np.abs(band))]
568            if not isinstance(band, str) and band < 0:
569                data = np.nanmax(data) - data # flip
570
571            # convert integer vmin and vmax values to percentiles
572            if 'vmin' in kwds:
573                if isinstance(kwds['vmin'], int):
574                    kwds['vmin'] = np.nanpercentile( data, kwds['vmin'] )
575            if 'vmax' in kwds:
576                if isinstance(kwds['vmax'], int):
577                    kwds['vmax'] = np.nanpercentile( data, kwds['vmax'] )
578
579            #mask nans (and apply custom mask)
580            mask = np.isnan(data)
581            if not np.isnan(self.header.get_data_ignore_value()):
582                mask = mask + data == self.header.get_data_ignore_value()
583            if 'mask' in kwds:
584                mask = mask + kwds.get('mask')
585                del kwds['mask']
586            data = np.ma.array(data, mask = mask > 0 )
587
588            # apply rotations and flipping
589            if rot:
590                data = data.T
591            if flipX:
592                data = data[::-1, :]
593            if flipY:
594                data = data[:, ::-1]
595
596            # save?
597            if 'path' in kwds:
598                path = kwds.pop('path')
599                from matplotlib.pyplot import imsave
600                if not os.path.exists(os.path.dirname(path)):
601                    os.makedirs(os.path.dirname(path)) # ensure output directory exists
602                imsave(path, data.T, **kwds)  # save the image
603
604            ax.cbar = ax.imshow(data.T, interpolation=kwds.pop('interpolation', 'none'), **kwds) # change default interpolation to None
605
606        #map 3 bands to RGB
607        elif isinstance(band, tuple) or isinstance(band, list):
608            #get band indices and range
609            rgb = []
610            for b in band:
611                if isinstance(b, str):
612                    rgb.append(self.get_band_index(b))
613                else:
614                    rgb.append(self.get_band_index(np.abs(b)))
615
616            #slice image (as copy) and map to 0 - 1
617            img = np.array(self.data[:, :, rgb]).copy()
618            if np.isnan(img).all():
619                print("Warning - image contains no data.")
620                return ax.get_figure(), ax
621
622            # invert if needed
623            if invert:
624                band = [-b for b in band]
625            for i,b in enumerate(band):
626                if not isinstance(b, str) and (b < 0):
627                    img[..., i] = np.nanmax(img[..., i]) - img[..., i]
628
629            # do scaling
630            if tscale: # scale bands independently
631                for b in range(3):
632                    mn = kwds.get("vmin", float(np.nanmin(img)))
633                    mx = kwds.get("vmax", float(np.nanmax(img)))
634                    if isinstance (mn, int):
635                        assert mn >= 0 and mn <= 100, "Error - integer vmin values must be a percentile."
636                        mn = float(np.nanpercentile(img[...,b], mn ))
637                    if isinstance (mx, int):
638                        assert mx >= 0 and mx <= 100, "Error - integer vmax values must be a percentile."
639                        mx = float(np.nanpercentile(img[...,b], mx ))
640                    img[...,b] = (img[..., b] - mn) / (mx - mn)
641            else: # scale bands together
642                mn = kwds.get("vmin", float(np.nanmin(img)))
643                mx = kwds.get("vmax", float(np.nanmax(img)))
644                if isinstance(mn, int):
645                    assert mn >= 0 and mn <= 100, "Error - integer vmin values must be a percentile."
646                    mn = float(np.nanpercentile(img, mn))
647                if isinstance(mx, int):
648                    assert mx >= 0 and mx <= 100, "Error - integer vmax values must be a percentile."
649                    mx = float(np.nanpercentile(img, mx))
650                img = (img - mn) / (mx - mn)
651
652            #apply brightness/contrast mapping
653            img = np.clip((1.0 + cfac) * img + bfac, 0, 1.0 )
654
655            #apply masking so background is white
656            img[np.logical_not( np.isfinite( img ) )] = 1.0
657            if 'mask' in kwds:
658                img[kwds.pop("mask"),:] = 1.0
659
660            # apply rotations and flipping
661            if rot:
662                img = np.transpose( img, (1,0,2) )
663            if flipX:
664                img = img[::-1, :, :]
665            if flipY:
666                img = img[:, ::-1, :]
667
668            # save?
669            if 'path' in kwds:
670                path = kwds.pop('path')
671                from matplotlib.pyplot import imsave
672                if not os.path.exists(os.path.dirname(path)):
673                    os.makedirs(os.path.dirname(path)) # ensure output directory exists
674                imsave(path, np.transpose( np.clip( img*255, 0, 255).astype(np.uint8), (1, 0, 2)))  # save the image
675
676            # plot samples?
677            ps = kwds.pop('ps', 5)
678            pc = kwds.pop('pc', 'r')
679            if samples:
680                if isinstance(samples, list) or isinstance(samples, np.ndarray):
681                    ax.scatter([s[0] for s in samples], [s[1] for s in samples], s=ps, c=pc)
682                else:
683                    for n in self.header.get_class_names():
684                        points = np.array(self.header.get_sample_points(n))
685                        ax.scatter(points[:, 0], points[:, 1], s=ps, c=pc)
686
687            #plot
688            ax.imshow(np.transpose(img, (1,0,2)), interpolation=kwds.pop('interpolation', 'none'), **kwds)
689            ax.cbar = None  # no colorbar
690
691        return ax.get_figure(), ax

Plot a band using matplotlib.imshow(...).

Arguments:
  • band (str,int,float,tuple): the band name (string), index (integer) or wavelength (float) to plot. Default is 0. If a tuple is passed then each band in the tuple (string or index) will be mapped to rgb. Bands with negative wavelengths or indices will be inverted before plotting.
  • ax: an axis object to plot to. If none, plt.imshow( ... ) is used.
  • bfac (float): a brightness adjustment to apply to RGB mappings (-1 to 1)
  • cfac (float): a contrast adjustment to apply to RGB mappings (-1 to 1)
  • samples (bool): True if sample points (defined in the header file) should be plotted. Default is False. Otherwise, a list of [ (x,y), ... ] points can be passed.
  • tscale (bool): True if each band (for ternary images) should be scaled independently. Default is False. When using scaling, vmin and vmax can be used to set the clipping percentiles (integers) or (constant) values (float).
  • invert (bool) : True if each band should be inverted before plotting. Only works for multiband (ternary) images.
  • rot (bool): if True, the x and y axis will be flipped (90 degree rotation) before plotting. Default is False.
  • flipX (bool): if True, the x axis will be flipped before plotting (after applying rotations).
  • flipY (bool): if True, the y axis will be flippe before plotting (after applying rotations).
  • **kwds: keywords are passed to matplotlib.imshow( ... ), except for the following:

    • mask = a 2D boolean mask containing true if pixels should be drawn and false otherwise.
    • path = a file path to save the image too (at matching resolution; use fig.savefig(..) if you want to save the figure).
    • ticks = True if x- and y- ticks should be plotted. Default is False.
    • ps, pc = the size and color of sample points to plot. Can be constant or list.
    • figsize = a figsize for the figure to create (if ax is None).
Returns:

Tuple containing

  • fig: matplotlib figure object
  • ax: matplotlib axes object. If a colorbar is created, (band is an integer or a float), then this will be stored in ax.cbar.
def createGIF(self, path, bands=None, figsize=(10, 10), fps=10, **kwds):
693    def createGIF(self, path, bands=None, figsize=(10,10), fps=10, **kwds):
694        """
695        Create and save an animated gif that loops through the bands of the image.
696
697        Args:
698            path (str): the path to save the .gif
699            bands (tuple): Tuple containing the range of band indices to draw. Default is the whole range.
700            figsize (tuple): the size of the image to draw. Default is (10,10).
701            fps (int): the framerate (frames per second) of the gif. Default is 10.
702            **kwds: keywords are passed directly to matplotlib.imshow. Use this to specify cmap etc.
703        """
704
705        frames = []
706        if bands is None:
707            bands = (0,self.band_count())
708        else:
709            assert 0 < bands[0] < self.band_count(), "Error - invalid range."
710            assert 0 < bands[1] < self.band_count(), "Error - invalid range."
711            assert bands[1] > bands[0], "Error - invalid range."
712
713        #plot frames
714        for i in range(bands[0],bands[1]):
715            fig, ax = plt.subplots(figsize=figsize)
716            ax.imshow(self.data[:, :, i], **kwds)
717            fig.canvas.draw()
718            frames.append(np.frombuffer(fig.canvas.tostring_rgb(), dtype='uint8'))
719            frames[-1] = np.reshape(frames[-1], (fig.canvas.get_width_height()[1], fig.canvas.get_width_height()[0], 3))
720            plt.close(fig)
721
722        #save gif
723        imageio.mimsave( os.path.splitext(path)[0] + ".gif", frames, fps=fps)

Create and save an animated gif that loops through the bands of the image.

Arguments:
  • path (str): the path to save the .gif
  • bands (tuple): Tuple containing the range of band indices to draw. Default is the whole range.
  • figsize (tuple): the size of the image to draw. Default is (10,10).
  • fps (int): the framerate (frames per second) of the gif. Default is 10.
  • **kwds: keywords are passed directly to matplotlib.imshow. Use this to specify cmap etc.
def drop_bbl(self, drop=True):
726    def drop_bbl(self, drop=True):
727        """
728        Remove bad bands as stored in the 'bbl' key in the image header. Note that this operates in-place.
729
730        Args:
731            drop (bool): True if bad bands should be completely dropped. If False, these bands will be kept but replaced with nans.
732        """
733        assert 'bbl' in self.header, "Please specify a bad band list ('bbl') in the image header, as per the ENVI format definition."
734        mask = self.header.get_list('bbl') == 0
735        self.data[...,mask] = np.nan
736        if drop:
737            self.delete_nan_bands(inplace=True)

Remove bad bands as stored in the 'bbl' key in the image header. Note that this operates in-place.

Arguments:
  • drop (bool): True if bad bands should be completely dropped. If False, these bands will be kept but replaced with nans.
def mask(self, mask=None, flag=nan, invert=False, crop=False, bands=None):
739    def mask(self, mask=None, flag=np.nan, invert=False, crop=False, bands=None):
740        """
741         Apply a mask to an image, flagging masked pixels with the specified value. Note that this applies the mask to the
742         image in-situ.
743
744         Args:
745            flag (float): the value to use for masked pixels. Default is np.nan
746            mask (ndarray): a numpy array defining the mask polygon of the format [[x1,y1],[x2,y2],...]. If None is passed then
747                    pickPolygon( ... ) is used to interactively define a polygon. If a file path is passed then the polygon
748                    will be loaded using np.load( ... ). Alternatively if mask.shape == image.shape[0,1] then it is treated as a
749                    binary image mask (must be boolean) and True values will be masked across all bands. Default is None.
750            invert (bool): if True, pixels within the polygon will be masked. If False, pixels outside the polygon are masked. Default is False.
751            crop (bool): True if rows/columns containing only zeros should be removed. Default is False.
752            bands (tuple): the bands of the image to plot if no mask is specified. If None, the middle band is used.
753
754         Returns:
755            Tuple containing
756
757            - mask (ndarray): a boolean array with True where pixels are masked and False elsewhere.
758            - poly (ndarray): the mask polygon array in the format described above. Useful if the polygon was interactively defined.
759         """
760
761        if mask is None:  # pick mask interactively
762            if bands is None:
763                bands = int(self.band_count() / 2)
764
765            regions = self.pickPolygons(region_names=["mask"], bands=bands)
766
767            # the user bailed without picking a mask?
768            if len(regions) == 0:
769                print("Warning - no mask picked/applied.")
770                return
771
772            # extract polygon mask
773            mask = regions[0]
774
775        # convert polygon mask to binary mask
776        if mask.shape[1] == 2:
777
778            # build meshgrid with pixel coords
779            xx, yy = np.meshgrid(np.arange(self.xdim()), np.arange(self.ydim()))
780            xx = xx.flatten()
781            yy = yy.flatten()
782            points = np.vstack([xx, yy]).T  # coordinates of each pixel
783
784            # calculate per-pixel mask
785            mask = path.Path(mask).contains_points(points)
786            mask = mask.reshape((self.ydim(), self.xdim())).T
787
788            # flip as we want to mask (==True) outside points (unless invert is true)
789            if not invert:
790                mask = np.logical_not(mask)
791
792        # apply binary image mask
793        assert mask.shape[0] == self.data.shape[0] and mask.shape[1] == self.data.shape[1], \
794            "Error - mask shape %s does not match image shape %s" % (mask.shape, self.data.shape)
795        for b in range(self.band_count()):
796            self.data[:, :, b][mask] = flag
797
798        # crop image
799        if crop:
800            # calculate non-masked pixels
801            valid = np.logical_not(mask)
802
803            # integrate along axes
804            xdata = np.sum(valid, axis=1) > 0.0
805            ydata = np.sum(valid, axis=0) > 0.0
806
807            # calculate domain containing valid pixels
808            xmin = np.argmax(xdata)
809            xmax = xdata.shape[0] - np.argmax(xdata[::-1])
810            ymin = np.argmax(ydata)
811            ymax = ydata.shape[0] - np.argmax(ydata[::-1])
812
813            # crop
814            self.data = self.data[xmin:xmax, ymin:ymax, :]
815
816        return mask

Apply a mask to an image, flagging masked pixels with the specified value. Note that this applies the mask to the image in-situ.

Arguments:
  • flag (float): the value to use for masked pixels. Default is np.nan
  • mask (ndarray): a numpy array defining the mask polygon of the format [[x1,y1],[x2,y2],...]. If None is passed then pickPolygon( ... ) is used to interactively define a polygon. If a file path is passed then the polygon will be loaded using np.load( ... ). Alternatively if mask.shape == image.shape[0,1] then it is treated as a binary image mask (must be boolean) and True values will be masked across all bands. Default is None.
  • invert (bool): if True, pixels within the polygon will be masked. If False, pixels outside the polygon are masked. Default is False.
  • crop (bool): True if rows/columns containing only zeros should be removed. Default is False.
  • bands (tuple): the bands of the image to plot if no mask is specified. If None, the middle band is used.
Returns:

Tuple containing

  • mask (ndarray): a boolean array with True where pixels are masked and False elsewhere.
  • poly (ndarray): the mask polygon array in the format described above. Useful if the polygon was interactively defined.
def crop_to_data(self):
818    def crop_to_data(self):
819        """
820        Remove padding of nan or zero pixels from image. Note that this is performed in place.
821        """
822
823        valid = np.isfinite(self.data).any(axis=-1) & (self.data != 0).any(axis=-1)
824        ymin, ymax = np.percentile(np.argwhere(np.sum(valid, axis=0) != 0), (0, 100))
825        xmin, xmax = np.percentile(np.argwhere(np.sum(valid, axis=1) != 0), (0, 100))
826        self.data = self.data[int(xmin):int(xmax), int(ymin):int(ymax), :]  # do clipping

Remove padding of nan or zero pixels from image. Note that this is performed in place.

def pickPolygons(self, region_names, bands=0):
831    def pickPolygons(self, region_names, bands=0):
832        """
833        Creates a matplotlib gui for selecting polygon regions in an image.
834
835        Args:
836            names (list, str): a list containing the names of the regions to pick. If a string is passed only one name is used.
837            bands (tuple): the bands of the image to plot.
838        """
839
840        if isinstance(region_names, str):
841            region_names = [region_names]
842
843        assert isinstance(region_names, list), "Error - names must be a list or a string."
844
845        # set matplotlib backend
846        backend = matplotlib.get_backend()
847        matplotlib.use('Qt5Agg')  # need this backend for ROIPoly to work
848
849        # plot image and extract roi's
850        fig, ax = self.quick_plot(bands)
851        roi = MultiRoi(roi_names=region_names)
852        plt.close(fig)  # close figure
853
854        # extract regions
855        regions = []
856        for name, r in roi.rois.items():
857            # store region
858            x = r.x
859            y = r.y
860            regions.append(np.vstack([x, y]).T)
861
862        # restore matplotlib backend (if possible)
863        try:
864            matplotlib.use(backend)
865        except:
866            print("Warning: could not reset matplotlib backend. Plots will remain interactive...")
867            pass
868
869        return regions

Creates a matplotlib gui for selecting polygon regions in an image.

Arguments:
  • names (list, str): a list containing the names of the regions to pick. If a string is passed only one name is used.
  • bands (tuple): the bands of the image to plot.
def pickPoints( self, n=-1, bands=(680.0, 550.0, 505.0), integer=True, title='Pick Points', **kwds):
871    def pickPoints(self, n=-1, bands=hylite.RGB, integer=True, title="Pick Points", **kwds):
872        """
873        Creates a matplotlib gui for picking pixels from an image.
874
875        Args:
876            n (int): the number of pixels to pick, or -1 if the user can select as many as they wish. Default is -1.
877            bands (tuple): the bands of the image to plot. Default is HyImage.RGB
878            integer (bool): True if points coordinates should be cast to integers (for use as indices). Default is True.
879            title (str): The title of the point picking window.
880            **kwds: Keywords are passed to HyImage.quick_plot( ... ).
881
882        Returns:
883            A list containing the picked point coordinates [ (x1,y1), (x2,y2), ... ].
884        """
885
886        # set matplotlib backend
887        backend = matplotlib.get_backend()
888        matplotlib.use('Qt5Agg')  # need this backend for ROIPoly to work
889
890        # create figure
891        fig, ax = self.quick_plot( bands, **kwds )
892        ax.set_title(title)
893
894        # get points
895        points = fig.ginput( n )
896
897        if integer:
898            points = [ (int(p[0]), int(p[1])) for p in points ]
899
900        # restore matplotlib backend (if possible)
901        try:
902            matplotlib.use(backend)
903        except:
904            print("Warning: could not reset matplotlib backend. Plots will remain interactive...")
905            pass
906
907        return points

Creates a matplotlib gui for picking pixels from an image.

Arguments:
  • n (int): the number of pixels to pick, or -1 if the user can select as many as they wish. Default is -1.
  • bands (tuple): the bands of the image to plot. Default is HyImage.RGB
  • integer (bool): True if points coordinates should be cast to integers (for use as indices). Default is True.
  • title (str): The title of the point picking window.
  • **kwds: Keywords are passed to HyImage.quick_plot( ... ).
Returns:

A list containing the picked point coordinates [ (x1,y1), (x2,y2), ... ].

def pickSamples(self, names=None, store=True, **kwds):
909    def pickSamples(self, names=None, store=True, **kwds):
910        """
911        Pick sample probe points and store these in the image header file.
912
913        Args:
914            names (str, list): the name of the sample to pick, or a list of names to pick multiple.
915            store (bool): True if sample should be stored in the image header file (for later access). Default is True.
916            **kwds: Keywords are passed to HyImage.quick_plot( ... )
917
918        Returns:
919            a list containing a list of points for each sample.
920        """
921
922        if isinstance(names, str):
923            names = [names]
924
925        # pick points
926        points = []
927        for s in names:
928            pnts = self.pickPoints(title="%s" % s, **kwds)
929            if store:
930                self.header['sample %s' % s] = pnts # store in header
931            points.append(pnts)
932        # add class to header file
933        if store:
934            cls_names = self.header.get_class_names()
935            if cls_names is None:
936                cls_names = []
937            self.header['class names'] = cls_names + names
938
939        return points

Pick sample probe points and store these in the image header file.

Arguments:
  • names (str, list): the name of the sample to pick, or a list of names to pick multiple.
  • store (bool): True if sample should be stored in the image header file (for later access). Default is True.
  • **kwds: Keywords are passed to HyImage.quick_plot( ... )
Returns:

a list containing a list of points for each sample.