hylite.hyimage

Store and manipulate hyperspectral image data.

  1"""
  2Store and manipulate hyperspectral image data.
  3"""
  4
  5import os
  6import numpy as np
  7import matplotlib
  8import matplotlib.pyplot as plt
  9from matplotlib import path
 10from roipoly import MultiRoi
 11import imageio
 12import scipy as sp
 13import hylite
 14from hylite.hydata import HyData
 15from hylite.hylibrary import HyLibrary
 16
 17
 18
 19class HyImage( HyData ):
 20    """
 21    A class for hyperspectral image data. These can be individual scenes or hyperspectral orthoimages.
 22    """
 23
 24    def __init__(self, data, **kwds):
 25        """
 26        Args:
 27            data (ndarray): a numpy array such that data[x][y][band] gives each pixel value.
 28            **kwds:
 29                wav = A numpy array containing band wavelengths for this image.
 30                affine = an affine transform of the format returned by GDAL.GetGeoTransform().
 31                projection = string defining the project. Default is None.
 32                sensor = sensor name. Default is "unknown".
 33                header = path to associated header file. Default is None.
 34        """
 35
 36        #call constructor for HyData
 37        super().__init__(data, **kwds)
 38
 39        # special case - if dataset only has oneband, slice it so it still has
 40        # the format data[x,y,b].
 41        if not self.data is None:
 42            if len(self.data.shape) == 1:
 43                self.data = self.data[None, None, :] # single pixel image
 44            if len(self.data.shape) == 2:
 45                self.data = self.data[:, :, None] # single band iamge
 46
 47        #load any additional project information (specific to images)
 48        self.set_projection(kwds.get("projection",None))
 49        self.affine = kwds.get("affine",[0,1,0,0,0,1])
 50
 51        # wavelengths
 52        if 'wav' in kwds:
 53            self.set_wavelengths(kwds['wav'])
 54
 55        #special header formatting
 56        self.header['file type'] = 'ENVI Standard'
 57
 58    def copy(self,data=True):
 59        """
 60        Make a deep copy of this image instance.
 61
 62        Args:
 63            data (bool): True if a copy of the data should be made, otherwise only copy header.
 64
 65        Returns:
 66            a new HyImage instance.
 67        """
 68        if not data:
 69            return HyImage(None, header=self.header.copy(), projection=self.projection, affine=self.affine)
 70        else:
 71            return HyImage( self.data.copy(), header=self.header.copy(), projection=self.projection, affine=self.affine)
 72
 73    def T(self):
 74        """
 75        Return a transposed view of the data matrix (corresponding with the [y,x] indexing used by matplotlib, opencv etc.
 76        """
 77        return np.transpose(self.data, (1,0,2))
 78
 79    def xdim(self):
 80        """
 81        Return number of pixels in x (first dimension of data array)
 82        """
 83        return self.data.shape[0]
 84
 85    def ydim(self):
 86        """
 87        Return number of pixels in y (second dimension of data array)
 88        """
 89        return self.data.shape[1]
 90
 91    def aspx(self):
 92        """
 93        Return the aspect ratio of this image (width/height).
 94        """
 95        return self.ydim() / self.xdim()
 96
 97    def get_extent(self):
 98        """
 99        Returns the width and height of this image in world coordinates.
100
101        Returns:
102            tuple with (width, height).
103        """
104        return self.xdim * self.pixel_size[0], self.ydim * self.pixel_size[1]
105
106    def set_projection(self,proj):
107        """
108        Set this project to an existing osgeo.osr.SpatialReference or GDAL georeference string.
109
110        Args:
111            proj (str, osgeo.osr.SpatialReference): the project to use as osgeo.osr.SpatialReference or GDAL georeference string.
112        """
113        if proj is None:
114            self.projection = None
115        else:
116            try:
117                from osgeo.osr import SpatialReference
118            except:
119                assert False, "Error - GDAL must be installed to work with spatial projections in hylite."
120            if isinstance(proj, SpatialReference):
121                self.projection = proj
122            elif isinstance(proj, str):
123                self.projection = SpatialReference(proj)
124            else:
125                print("Invalid project %s" % proj)
126                raise
127
128    def set_projection_EPSG(self,EPSG):
129        """
130        Sets this image project using an EPSG code.
131
132        Args:
133            EPSG (str): string EPSG code that can be passed to SpatialReference.SetFromUserInput(...).
134        """
135
136        try:
137            from osgeo.osr import SpatialReference
138        except:
139            assert False, "Error - GDAL must be installed to work with spatial projections in hylite."
140
141        self.projection = SpatialReference()
142        self.projection.SetFromUserInput(EPSG)
143
144    def get_projection_EPSG(self):
145        """
146        Gets a string describing this projections EPSG code (if it is an EPSG project).
147
148        Returns:
149            an EPSG code string of the format "EPSG:XXXX".
150        """
151        if self.projection is None:
152            return None
153        else:
154            return "%s:%s" % (self.projection.GetAttrValue("AUTHORITY",0),self.projection.GetAttrValue("AUTHORITY",1))
155
156    def pix_to_world(self, px, py, proj=None):
157        """
158        Take pixel coordinates and return world coordinates
159
160        Args:
161            px (int): the pixel x-coord.
162            py (int): the pixel y-coord.
163            proj (str, osr.SpatialReference): the coordinate system to use. Default (None) uses the same system as this image. Otherwise
164                   an osr.SpatialReference can be passed (HyImage.project), or an EPSG string (e.g. get_projection_EPSG(...)).
165        Returns:
166            the world coordinates in the coordinate system defined by get_projection_EPSG(...).
167        """
168
169        try:
170            from osgeo import osr
171            import osgeo.gdal as gdal
172            from osgeo import ogr
173        except:
174            assert False, "Error - GDAL must be installed to work with spatial projections in hylite."
175
176        # parse project
177        if proj is None:
178            proj = self.projection
179        elif isinstance(proj, str) or isinstance(proj, int):
180            epsg = proj
181            if isinstance(epsg, str):
182                try:
183                    epsg = int(str.split(':')[1])
184                except:
185                    assert False, "Error - %s is an invalid EPSG code." % proj
186            proj = osr.SpatialReference()
187            proj.ImportFromEPSG(epsg)
188
189        # check we have all the required info
190        assert isinstance(proj, osr.SpatialReference), "Error - invalid spatial reference %s" % proj
191        assert (not self.affine is None) and (
192            not self.projection is None), "Error - project information is undefined."
193
194        #project to world coordinates in this images project/world coords
195        x,y = gdal.ApplyGeoTransform(self.affine, px, py)
196
197        #project to target coords (if different)
198        if not proj.IsSameGeogCS(self.projection):
199            P = ogr.Geometry(ogr.wkbPoint)
200            if proj.EPSGTreatsAsNorthingEasting():
201                P.AddPoint(x, y)
202            else:
203                P.AddPoint(y, x)
204            P.AssignSpatialReference(self.projection)  # tell the point what coordinates it's in
205            P.TransformTo(proj)  # reproject it to the out spatial reference
206            x, y = P.GetX(), P.GetY()
207
208            #do we need to transpose?
209            if proj.EPSGTreatsAsLatLong():
210                x,y=y,x #we want lon,lat not lat,lon
211        return x, y
212
213    def world_to_pix(self, x, y, proj = None):
214        """
215        Take world coordinates and return pixel coordinates
216
217        Args:
218            x (float): the world x-coord.
219            y (float): the world y-coord.
220            proj (str, osr.SpatialReference): the coordinate system of the input coordinates. Default (None) uses the same system as this image. Otherwise
221                   an osr.SpatialReference can be passed (HyImage.project), or an EPSG string (e.g. get_projection_EPSG(...)).
222
223        Returns:
224            the pixel coordinates based on the affine transform stored in self.affine.
225        """
226
227        try:
228            from osgeo import osr
229            import osgeo.gdal as gdal
230            from osgeo import ogr
231        except:
232            assert False, "Error - GDAL must be installed to work with spatial projections in hylite."
233
234        # parse project
235        if proj is None:
236            proj = self.projection
237        elif isinstance(proj, str) or isinstance(proj, int):
238            epsg = proj
239            if isinstance(epsg, str):
240                try:
241                    epsg = int(str.split(':')[1])
242                except:
243                    assert False, "Error - %s is an invalid EPSG code." % proj
244            proj = osr.SpatialReference()
245            proj.ImportFromEPSG(epsg)
246
247
248        # check we have all the required info
249        assert isinstance(proj, osr.SpatialReference), "Error - invalid spatial reference %s" % proj
250        assert (not self.affine is None) and (not self.projection is None), "Error - project information is undefined."
251
252        # project to this images CS (if different)
253        if not proj.IsSameGeogCS(self.projection):
254            P = ogr.Geometry(ogr.wkbPoint)
255            if proj.EPSGTreatsAsNorthingEasting():
256                P.AddPoint(x, y)
257            else:
258                P.AddPoint(y, x)
259            P.AssignSpatialReference(proj)  # tell the point what coordinates it's in
260            P.AddPoint(x, y)
261            P.TransformTo(self.projection)  # reproject it to the out spatial reference
262            x, y = P.GetX(), P.GetY()
263            if self.projection.EPSGTreatsAsLatLong(): # do we need to transpose?
264                x, y = y, x  # we want lon,lat not lat,lon
265
266        inv = gdal.InvGeoTransform(self.affine)
267        assert not inv is None, "Error - could not invert affine transform?"
268
269        #apply
270        return gdal.ApplyGeoTransform(inv, x, y)
271
272    def flip(self, axis='x'):
273        """
274        Flip the image on the x or y axis.
275
276        Args:
277            axis (str): 'x' or 'y' or both 'xy'.
278        """
279
280        if 'x' in axis.lower():
281            self.data = np.flip(self.data,axis=0)
282        if 'y' in axis.lower():
283            self.data = np.flip(self.data,axis=1)
284
285    def rot90(self):
286        """
287        Rotate this image by 90 degrees by transposing the underlying data array. Combine with flip('x') or flip('y')
288        to achieve positive/negative rotations.
289        """
290        self.data = np.transpose( self.data, (1,0,2) )
291        self.push_to_header()
292
293    #####################################
294    ##IMAGE FILTERING
295    #####################################
296    def fill_holes(self):
297        """
298        Replaces nan pixel with an average of their neighbours, thus removing 1-pixel large holes from an image. Note that
299        for performance reasons this assumes that holes line up across bands. Note that this is not vectorized so very slow...
300        """
301
302        # perform greyscale dilation
303        dilate = self.data.copy()
304        mask = np.logical_not(np.isfinite(dilate))
305        dilate[mask] = 0
306        for b in range(self.band_count()):
307            dilate[:, :, b] = sp.ndimage.grey_dilation(dilate[:, :, b], size=(3, 3))
308
309        # map back to holes in dataset
310        self.data[mask] = dilate[mask]
311        #self.data[self.data == 0] = np.nan  # replace remaining 0's with nans
312
313    def blur(self, n=3):
314        """
315        Applies a gaussian kernel of size n to the image using OpenCV.
316
317        Args:
318            n (int): the dimensions of the gaussian kernel to convolve. Default is 3. Increase for more blurry results.
319        """
320        import cv2 # import this here to avoid errors if opencv is not installed properly
321
322        nanmask = np.isnan(self.data)
323        assert isinstance(n, int) and n >= 3, "Error - invalid kernel. N must be an integer > 3. "
324        kernel = np.ones((n, n), np.float32) / (n ** 2)
325        self.data = cv2.filter2D(self.data, -1, kernel)
326        self.data[nanmask] = np.nan  # remove mask
327
328    def erode(self, size=3, iterations=1):
329        """
330        Apply an erode filter to this image to expand background (nan) pixels. Refer to open-cv's erode
331        function for more details.
332
333        Args:
334            size (int): the size of the erode filter. Default is a 3x3 kernel.
335            iterations (int): the number of erode iterations. Default is 1.
336        """
337        import cv2 # import this here to avoid errors if opencv is not installed properly
338
339        # erode
340        kernel = np.ones((size, size), np.uint8)
341        if self.is_float():
342            mask = np.isfinite(self.data).any(axis=-1)
343            mask = cv2.erode(mask.astype(np.uint8), kernel, iterations=iterations)
344            self.data[mask == 0, :] = np.nan
345        else:
346            mask = (self.data != 0).any( axis=-1 )
347            mask = cv2.erode(mask.astype(np.uint8), kernel, iterations=iterations)
348            self.data[mask == 0, :] = 0
349
350    def resize(self, newdims : tuple, interpolation : int = 1):
351        """
352        Resize this image with opencv.
353
354        Args:
355            newdims (tuple): the new image dimensions.
356            interpolation (int): opencv interpolation method. Default is cv2.INTER_LINEAR.
357        """
358        import cv2 # import this here to avoid errors if opencv is not installed properly
359        self.data = cv2.resize(self.data, (newdims[1],newdims[0]), interpolation=interpolation)
360
361    def despeckle(self, size=5):
362        """
363        Despeckle each band of this image (independently) using a median filter.
364
365        Args:
366            size (int): the size of the median filter kernel. Default is 5. Must be an odd number.
367        """
368
369        assert (size % 2) == 1, "Error - size must be an odd integer"
370        import cv2 # import this here to avoid errors if opencv is not installed properly
371        if self.is_float():
372            self.data = cv2.medianBlur( self.data.astype(np.float32), size )
373        else:
374            self.data = cv2.medianBlur( self.data, size )
375
376    #####################################
377    ##FEATURES AND FEATURE MATCHING
378    ######################################
379    def get_keypoints(self, band, eq=False, mask=True, method='sift', cfac=0.0,bfac=0.0, **kwds):
380        """
381        Get feature descriptors from the specified band.
382
383        Args:
384            band (int,float,str,tuple): the band index (int) or wavelength (float) to extract features from. Alternatively, a tuple can be passed
385                    containing a range of bands (min : max) to average before feature matching.
386            eq (bool): True if the image should be histogram equalized first. Default is False.
387            mask (bool): True if 0 value pixels should be masked. Default is True.
388            method (str): the feature detector to use. Options are 'SIFT' and 'ORB' (faster but less accurate). Default is 'SIFT'.
389            cfac (float): contrast adjustment to apply to hyperspectral bands before matching. Default is 0.0.
390            bfac (float): brightness adjustment to apply to hyperspectral bands before matching. Default is 0.0.
391            **kwds: keyword arguments are passed to the opencv feature detector. For SIFT these are:
392
393                - contrastThreshold: default is 0.01.
394                - edgeThreshold: default is 10.
395                - sigma: default is 1.0
396
397                For ORB these are:
398
399                - nfeatures = the number of features to detect. Default is 5000.
400
401            Returns:
402                Tuple containing
403
404                    - k (ndarray): the keypoints detected
405                    - d (ndarray): corresponding feature descriptors
406         """
407        import cv2 # import this here to avoid errors if opencv is not installed properly
408
409        # get image
410        if isinstance(band, int) or isinstance(band, float): #single band
411            image = self.data[:, :, self.get_band_index(band)]
412        elif isinstance(band,tuple): #range of bands (averaged)
413            idx0 = self.get_band_index(band[0])
414            idx1 = self.get_band_index(band[1])
415
416            #deal with out of range errors
417            if idx0 is None:
418                idx0 = 0
419            if idx1 is None:
420                idx1 = self.band_count()
421
422            #average bands
423            image = np.nanmean(self.data[:,:,idx0:idx1],axis=2)
424        else:
425            assert False, "Error, unrecognised band %s" % band
426
427        #normalise image to range 0 - 1
428        image -= np.nanmin(image)
429        image = image / np.nanmax(image)
430
431        #apply brightness/contrast adjustment
432        image = (1.0+cfac)*image + bfac
433        image[image > 1.0] = 1.0
434        image[image < 0.0] = 0.0
435
436        #convert image to uint8 for opencv
437        image = np.uint8(255 * image)
438        if eq:
439            image = cv2.equalizeHist(image)
440
441        if mask:
442            mask = np.zeros(image.shape, dtype=np.uint8)
443            mask[image != 0] = 255  # include only non-zero pixels
444        else:
445            mask = None
446
447        if 'sift' in method.lower():  # SIFT
448
449            # setup default keywords
450            kwds["contrastThreshold"] = kwds.get("contrastThreshold", 0.01)
451            kwds["edgeThreshold"] = kwds.get("edgeThreshold", 10)
452            kwds["sigma"] = kwds.get("sigma", 1.0)
453
454            # make feature detector
455            #alg = cv2.xfeatures2d.SIFT_create(**kwds)
456            alg = cv2.SIFT_create()
457        elif 'orb' in method.lower():  # orb
458            kwds['nfeatures'] = kwds.get('nfeatures', 5000)
459            alg = cv2.ORB_create(scoreType=cv2.ORB_FAST_SCORE, **kwds)
460        else:
461            assert False, "Error - %s is not a recognised feature detector." % method
462
463        # detect keypoints
464        kp = alg.detect(image, mask)
465
466        # extract and return feature vectors
467        return alg.compute(image, kp)
468
469    @classmethod
470    def match_keypoints(cls, kp1, kp2, d1, d2, method='SIFT', dist=0.7, tree = 5, check = 100, min_count=5):
471        """
472        Compares keypoint feature vectors from two images and returns matching pairs.
473
474        Args:
475            kp1 (ndarray): keypoints from the first image
476            kp2 (ndarray): keypoints from the second image
477            d1 (ndarray): descriptors for the keypoints from the first image
478            d2 (ndarray): descriptors for the keypoints from the second image
479            method (str): the method used to calculate the feature descriptors. Should be 'sift' or 'orb'. Default is 'sift'.
480            dist (float): minimum match distance (0 to 1), default is 0.7
481            tree (int): not sure what this does? Default is 5. See open-cv docs.
482            check (int): ditto. Default is 100.
483            min_count (int): the minimum number of matches to consider a valid matching operation. If fewer matches are found,
484                       then the function returns None, None. Default is 5.
485        """
486        import cv2 # import this here to avoid errors if opencv is not installed properly
487        if 'sift' in method.lower():
488            algorithm = cv2.NORM_INF
489        elif 'orb' in method.lower():
490            algorithm = cv2.NORM_HAMMING
491        else:
492            assert False, "Error - unknown matching algorithm %s" % method
493
494        #calculate flann matches
495        index_params = dict(algorithm=algorithm, trees=tree)
496        search_params = dict(checks=check)
497        flann = cv2.FlannBasedMatcher(index_params, search_params)
498        matches = flann.knnMatch(d1, d2, k=2)
499
500        # store all the good matches as per Lowe's ratio test.
501        good = []
502        for m, n in matches:
503            if m.distance < dist * n.distance:
504                good.append(m)
505
506        if len(good) < min_count:
507            return None, None
508        else:
509            src_pts = np.float32([kp1[m.queryIdx].pt for m in good]).reshape(-1, 1, 2)
510            dst_pts = np.float32([kp2[m.trainIdx].pt for m in good]).reshape(-1, 1, 2)
511            return src_pts, dst_pts
512
513    ############################
514    ## Visualisation methods
515    ############################
516    def quick_plot(self, band=0, ax=None, bfac=0.0, cfac=0.0, samples=False, tscale=False, rot=False, flipX=False, flipY=False,
517                   **kwds):
518        """
519        Plot a band using matplotlib.imshow(...).
520
521        Args:
522            band (str,int,float,tuple): the band name (string), index (integer) or wavelength (float) to plot. Default is 0. If a tuple is passed then
523                  each band in the tuple (string or index) will be mapped to rgb. Bands with negative wavelengths or indices will be inverted before plotting.
524            ax: an axis object to plot to. If none, plt.imshow( ... ) is used.
525            bfac (float): a brightness adjustment to apply to RGB mappings (-1 to 1)
526            cfac (float): a contrast adjustment to apply to RGB mappings (-1 to 1)
527            samples (bool): True if sample points (defined in the header file) should be plotted. Default is False. Otherwise, a list of
528                     [ (x,y), ... ] points can be passed.
529            tscale (bool): True if each band (for ternary images) should be scaled independently. Default is False.
530                    When using scaling, vmin and vmax can be used to set the clipping percentiles (integers) or
531                    (constant) values (float).
532            rot (bool): if True, the x and y axis will be flipped (90 degree rotation) before plotting. Default is False.
533            flipX (bool): if True, the x axis will be flipped before plotting (after applying rotations).
534            flipY (bool): if True, the y axis will be flippe before plotting (after applying rotations).
535            **kwds: keywords are passed to matplotlib.imshow( ... ), except for the following:
536
537                 - mask = a 2D boolean mask containing true if pixels should be drawn and false otherwise.
538                 - path = a file path to save the image too (at matching resolution; use fig.savefig(..) if you want to save the figure).
539                 - ticks = True if x- and y- ticks should be plotted. Default is False.
540                 - ps, pc = the size and color of sample points to plot. Can be constant or list.
541                 - figsize = a figsize for the figure to create (if ax is None).
542
543        Returns:
544            Tuple containing
545
546            - fig: matplotlib figure object
547            - ax:  matplotlib axes object. If a colorbar is created, (band is an integer or a float), then this will be stored in ax.cbar.
548        """
549
550        #create new axes?
551        if ax is None:
552            fig, ax = plt.subplots(figsize=kwds.pop('figsize', (18,18*self.ydim()/self.xdim()) ))
553
554        # deal with ticks
555        if not kwds.pop('ticks', False ):
556            ax.set_xticks([])
557            ax.set_yticks([])
558
559        #map individual band using colourmap
560        if isinstance(band, str) or isinstance(band, int) or isinstance(band, float):
561            #get band
562            if isinstance(band, str):
563                data = self.data[:, :, self.get_band_index(band)]
564            else:
565                data = self.data[:, :, self.get_band_index(np.abs(band))]
566            if not isinstance(band, str) and band < 0:
567                data = np.nanmax(data) - data # flip
568
569            # convert integer vmin and vmax values to percentiles
570            if 'vmin' in kwds:
571                if isinstance(kwds['vmin'], int):
572                    kwds['vmin'] = np.nanpercentile( data, kwds['vmin'] )
573            if 'vmax' in kwds:
574                if isinstance(kwds['vmax'], int):
575                    kwds['vmax'] = np.nanpercentile( data, kwds['vmax'] )
576
577            #mask nans (and apply custom mask)
578            mask = np.isnan(data)
579            if not np.isnan(self.header.get_data_ignore_value()):
580                mask = mask + data == self.header.get_data_ignore_value()
581            if 'mask' in kwds:
582                mask = mask + kwds.get('mask')
583                del kwds['mask']
584            data = np.ma.array(data, mask = mask > 0 )
585
586            # apply rotations and flipping
587            if rot:
588                data = data.T
589            if flipX:
590                data = data[::-1, :]
591            if flipY:
592                data = data[:, ::-1]
593
594            # save?
595            if 'path' in kwds:
596                path = kwds.pop('path')
597                from matplotlib.pyplot import imsave
598                if not os.path.exists(os.path.dirname(path)):
599                    os.makedirs(os.path.dirname(path)) # ensure output directory exists
600                imsave(path, data.T, **kwds)  # save the image
601
602            ax.cbar = ax.imshow(data.T, interpolation=kwds.pop('interpolation', 'none'), **kwds) # change default interpolation to None
603
604        #map 3 bands to RGB
605        elif isinstance(band, tuple) or isinstance(band, list):
606            #get band indices and range
607            rgb = []
608            for b in band:
609                if isinstance(b, str):
610                    rgb.append(self.get_band_index(b))
611                else:
612                    rgb.append(self.get_band_index(np.abs(b)))
613
614            #slice image (as copy) and map to 0 - 1
615            img = np.array(self.data[:, :, rgb]).copy()
616            if np.isnan(img).all():
617                print("Warning - image contains no data.")
618                return ax.get_figure(), ax
619
620            # invert if needed
621            for i,b in enumerate(band):
622                if not isinstance(b, str) and (b < 0):
623                    img[..., i] = np.nanmax(img[..., i]) - img[..., i]
624
625            # do scaling
626            if tscale: # scale bands independently
627                for b in range(3):
628                    mn = kwds.get("vmin", float(np.nanmin(img)))
629                    mx = kwds.get("vmax", float(np.nanmax(img)))
630                    if isinstance (mn, int):
631                        assert mn >= 0 and mn <= 100, "Error - integer vmin values must be a percentile."
632                        mn = float(np.nanpercentile(img[...,b], mn ))
633                    if isinstance (mx, int):
634                        assert mx >= 0 and mx <= 100, "Error - integer vmax values must be a percentile."
635                        mx = float(np.nanpercentile(img[...,b], mx ))
636                    img[...,b] = (img[..., b] - mn) / (mx - mn)
637            else: # scale bands together
638                mn = kwds.get("vmin", float(np.nanmin(img)))
639                mx = kwds.get("vmax", float(np.nanmax(img)))
640                if isinstance(mn, int):
641                    assert mn >= 0 and mn <= 100, "Error - integer vmin values must be a percentile."
642                    mn = float(np.nanpercentile(img, mn))
643                if isinstance(mx, int):
644                    assert mx >= 0 and mx <= 100, "Error - integer vmax values must be a percentile."
645                    mx = float(np.nanpercentile(img, mx))
646                img = (img - mn) / (mx - mn)
647
648            #apply brightness/contrast mapping
649            img = np.clip((1.0 + cfac) * img + bfac, 0, 1.0 )
650
651            #apply masking so background is white
652            img[np.logical_not( np.isfinite( img ) )] = 1.0
653            if 'mask' in kwds:
654                img[kwds.pop("mask"),:] = 1.0
655
656            # apply rotations and flipping
657            if rot:
658                img = np.transpose( img, (1,0,2) )
659            if flipX:
660                img = img[::-1, :, :]
661            if flipY:
662                img = img[:, ::-1, :]
663
664            # save?
665            if 'path' in kwds:
666                path = kwds.pop('path')
667                from matplotlib.pyplot import imsave
668                if not os.path.exists(os.path.dirname(path)):
669                    os.makedirs(os.path.dirname(path)) # ensure output directory exists
670                imsave(path, np.transpose( np.clip( img*255, 0, 255).astype(np.uint8), (1, 0, 2)))  # save the image
671
672            # plot samples?
673            ps = kwds.pop('ps', 5)
674            pc = kwds.pop('pc', 'r')
675            if samples:
676                if isinstance(samples, list) or isinstance(samples, np.ndarray):
677                    ax.scatter([s[0] for s in samples], [s[1] for s in samples], s=ps, c=pc)
678                else:
679                    for n in self.header.get_class_names():
680                        points = np.array(self.header.get_sample_points(n))
681                        ax.scatter(points[:, 0], points[:, 1], s=ps, c=pc)
682
683            #plot
684            ax.imshow(np.transpose(img, (1,0,2)), interpolation=kwds.pop('interpolation', 'none'), **kwds)
685            ax.cbar = None  # no colorbar
686
687        return ax.get_figure(), ax
688
689    def createGIF(self, path, bands=None, figsize=(10,10), fps=10, **kwds):
690        """
691        Create and save an animated gif that loops through the bands of the image.
692
693        Args:
694            path (str): the path to save the .gif
695            bands (tuple): Tuple containing the range of band indices to draw. Default is the whole range.
696            figsize (tuple): the size of the image to draw. Default is (10,10).
697            fps (int): the framerate (frames per second) of the gif. Default is 10.
698            **kwds: keywords are passed directly to matplotlib.imshow. Use this to specify cmap etc.
699        """
700
701        frames = []
702        if bands is None:
703            bands = (0,self.band_count())
704        else:
705            assert 0 < bands[0] < self.band_count(), "Error - invalid range."
706            assert 0 < bands[1] < self.band_count(), "Error - invalid range."
707            assert bands[1] > bands[0], "Error - invalid range."
708
709        #plot frames
710        for i in range(bands[0],bands[1]):
711            fig, ax = plt.subplots(figsize=figsize)
712            ax.imshow(self.data[:, :, i], **kwds)
713            fig.canvas.draw()
714            frames.append(np.frombuffer(fig.canvas.tostring_rgb(), dtype='uint8'))
715            frames[-1] = np.reshape(frames[-1], (fig.canvas.get_width_height()[1], fig.canvas.get_width_height()[0], 3))
716            plt.close(fig)
717
718        #save gif
719        imageio.mimsave( os.path.splitext(path)[0] + ".gif", frames, fps=fps)
720
721    ## masking
722    def mask(self, mask=None, flag=np.nan, invert=False, crop=False, bands=None):
723        """
724         Apply a mask to an image, flagging masked pixels with the specified value. Note that this applies the mask to the
725         image in-situ.
726
727         Args:
728            flag (float): the value to use for masked pixels. Default is np.nan
729            mask (ndarray): a numpy array defining the mask polygon of the format [[x1,y1],[x2,y2],...]. If None is passed then
730                    pickPolygon( ... ) is used to interactively define a polygon. If a file path is passed then the polygon
731                    will be loaded using np.load( ... ). Alternatively if mask.shape == image.shape[0,1] then it is treated as a
732                    binary image mask (must be boolean) and True values will be masked across all bands. Default is None.
733            invert (bool): if True, pixels within the polygon will be masked. If False, pixels outside the polygon are masked. Default is False.
734            crop (bool): True if rows/columns containing only zeros should be removed. Default is False.
735            bands (tuple): the bands of the image to plot if no mask is specified. If None, the middle band is used.
736
737         Returns:
738            Tuple containing
739
740            - mask (ndarray): a boolean array with True where pixels are masked and False elsewhere.
741            - poly (ndarray): the mask polygon array in the format described above. Useful if the polygon was interactively defined.
742         """
743
744        if mask is None:  # pick mask interactively
745            if bands is None:
746                bands = int(self.band_count() / 2)
747
748            regions = self.pickPolygons(region_names=["mask"], bands=bands)
749
750            # the user bailed without picking a mask?
751            if len(regions) == 0:
752                print("Warning - no mask picked/applied.")
753                return
754
755            # extract polygon mask
756            mask = regions[0]
757
758        # convert polygon mask to binary mask
759        if mask.shape[1] == 2:
760
761            # build meshgrid with pixel coords
762            xx, yy = np.meshgrid(np.arange(self.xdim()), np.arange(self.ydim()))
763            xx = xx.flatten()
764            yy = yy.flatten()
765            points = np.vstack([xx, yy]).T  # coordinates of each pixel
766
767            # calculate per-pixel mask
768            mask = path.Path(mask).contains_points(points)
769            mask = mask.reshape((self.ydim(), self.xdim())).T
770
771            # flip as we want to mask (==True) outside points (unless invert is true)
772            if not invert:
773                mask = np.logical_not(mask)
774
775        # apply binary image mask
776        assert mask.shape[0] == self.data.shape[0] and mask.shape[1] == self.data.shape[1], \
777            "Error - mask shape %s does not match image shape %s" % (mask.shape, self.data.shape)
778        for b in range(self.band_count()):
779            self.data[:, :, b][mask] = flag
780
781        # crop image
782        if crop:
783            # calculate non-masked pixels
784            valid = np.logical_not(mask)
785
786            # integrate along axes
787            xdata = np.sum(valid, axis=1) > 0.0
788            ydata = np.sum(valid, axis=0) > 0.0
789
790            # calculate domain containing valid pixels
791            xmin = np.argmax(xdata)
792            xmax = xdata.shape[0] - np.argmax(xdata[::-1])
793            ymin = np.argmax(ydata)
794            ymax = ydata.shape[0] - np.argmax(ydata[::-1])
795
796            # crop
797            self.data = self.data[xmin:xmax, ymin:ymax, :]
798
799        return mask
800
801    def crop_to_data(self):
802        """
803        Remove padding of nan or zero pixels from image. Note that this is performed in place.
804        """
805
806        valid = np.isfinite(self.data).any(axis=-1) & (self.data != 0).any(axis=-1)
807        ymin, ymax = np.percentile(np.argwhere(np.sum(valid, axis=0) != 0), (0, 100))
808        xmin, xmax = np.percentile(np.argwhere(np.sum(valid, axis=1) != 0), (0, 100))
809        self.data = self.data[int(xmin):int(xmax), int(ymin):int(ymax), :]  # do clipping
810
811    ##################################################
812    ## Interactive tools for picking regions/pixels
813    ##################################################
814    def pickPolygons(self, region_names, bands=0):
815        """
816        Creates a matplotlib gui for selecting polygon regions in an image.
817
818        Args:
819            names (list, str): a list containing the names of the regions to pick. If a string is passed only one name is used.
820            bands (tuple): the bands of the image to plot.
821        """
822
823        if isinstance(region_names, str):
824            region_names = [region_names]
825
826        assert isinstance(region_names, list), "Error - names must be a list or a string."
827
828        # set matplotlib backend
829        backend = matplotlib.get_backend()
830        matplotlib.use('Qt5Agg')  # need this backend for ROIPoly to work
831
832        # plot image and extract roi's
833        fig, ax = self.quick_plot(bands)
834        roi = MultiRoi(roi_names=region_names)
835        plt.close(fig)  # close figure
836
837        # extract regions
838        regions = []
839        for name, r in roi.rois.items():
840            # store region
841            x = r.x
842            y = r.y
843            regions.append(np.vstack([x, y]).T)
844
845        # restore matplotlib backend (if possible)
846        try:
847            matplotlib.use(backend)
848        except:
849            print("Warning: could not reset matplotlib backend. Plots will remain interactive...")
850            pass
851
852        return regions
853
854    def pickPoints(self, n=-1, bands=hylite.RGB, integer=True, title="Pick Points", **kwds):
855        """
856        Creates a matplotlib gui for picking pixels from an image.
857
858        Args:
859            n (int): the number of pixels to pick, or -1 if the user can select as many as they wish. Default is -1.
860            bands (tuple): the bands of the image to plot. Default is HyImage.RGB
861            integer (bool): True if points coordinates should be cast to integers (for use as indices). Default is True.
862            title (str): The title of the point picking window.
863            **kwds: Keywords are passed to HyImage.quick_plot( ... ).
864
865        Returns:
866            A list containing the picked point coordinates [ (x1,y1), (x2,y2), ... ].
867        """
868
869        # set matplotlib backend
870        backend = matplotlib.get_backend()
871        matplotlib.use('Qt5Agg')  # need this backend for ROIPoly to work
872
873        # create figure
874        fig, ax = self.quick_plot( bands, **kwds )
875        ax.set_title(title)
876
877        # get points
878        points = fig.ginput( n )
879
880        if integer:
881            points = [ (int(p[0]), int(p[1])) for p in points ]
882
883        # restore matplotlib backend (if possible)
884        try:
885            matplotlib.use(backend)
886        except:
887            print("Warning: could not reset matplotlib backend. Plots will remain interactive...")
888            pass
889
890        return points
891
892    def pickSamples(self, names=None, store=True, **kwds):
893        """
894        Pick sample probe points and store these in the image header file.
895
896        Args:
897            names (str, list): the name of the sample to pick, or a list of names to pick multiple.
898            store (bool): True if sample should be stored in the image header file (for later access). Default is True.
899            **kwds: Keywords are passed to HyImage.quick_plot( ... )
900
901        Returns:
902            a list containing a list of points for each sample.
903        """
904
905        if isinstance(names, str):
906            names = [names]
907
908        # pick points
909        points = []
910        for s in names:
911            pnts = self.pickPoints(title="%s" % s, **kwds)
912            if store:
913                self.header['sample %s' % s] = pnts # store in header
914            points.append(pnts)
915        # add class to header file
916        if store:
917            cls_names = self.header.get_class_names()
918            if cls_names is None:
919                cls_names = []
920            self.header['class names'] = cls_names + names
921
922        return points
class HyImage(hylite.hydata.HyData):
 20class HyImage( HyData ):
 21    """
 22    A class for hyperspectral image data. These can be individual scenes or hyperspectral orthoimages.
 23    """
 24
 25    def __init__(self, data, **kwds):
 26        """
 27        Args:
 28            data (ndarray): a numpy array such that data[x][y][band] gives each pixel value.
 29            **kwds:
 30                wav = A numpy array containing band wavelengths for this image.
 31                affine = an affine transform of the format returned by GDAL.GetGeoTransform().
 32                projection = string defining the project. Default is None.
 33                sensor = sensor name. Default is "unknown".
 34                header = path to associated header file. Default is None.
 35        """
 36
 37        #call constructor for HyData
 38        super().__init__(data, **kwds)
 39
 40        # special case - if dataset only has oneband, slice it so it still has
 41        # the format data[x,y,b].
 42        if not self.data is None:
 43            if len(self.data.shape) == 1:
 44                self.data = self.data[None, None, :] # single pixel image
 45            if len(self.data.shape) == 2:
 46                self.data = self.data[:, :, None] # single band iamge
 47
 48        #load any additional project information (specific to images)
 49        self.set_projection(kwds.get("projection",None))
 50        self.affine = kwds.get("affine",[0,1,0,0,0,1])
 51
 52        # wavelengths
 53        if 'wav' in kwds:
 54            self.set_wavelengths(kwds['wav'])
 55
 56        #special header formatting
 57        self.header['file type'] = 'ENVI Standard'
 58
 59    def copy(self,data=True):
 60        """
 61        Make a deep copy of this image instance.
 62
 63        Args:
 64            data (bool): True if a copy of the data should be made, otherwise only copy header.
 65
 66        Returns:
 67            a new HyImage instance.
 68        """
 69        if not data:
 70            return HyImage(None, header=self.header.copy(), projection=self.projection, affine=self.affine)
 71        else:
 72            return HyImage( self.data.copy(), header=self.header.copy(), projection=self.projection, affine=self.affine)
 73
 74    def T(self):
 75        """
 76        Return a transposed view of the data matrix (corresponding with the [y,x] indexing used by matplotlib, opencv etc.
 77        """
 78        return np.transpose(self.data, (1,0,2))
 79
 80    def xdim(self):
 81        """
 82        Return number of pixels in x (first dimension of data array)
 83        """
 84        return self.data.shape[0]
 85
 86    def ydim(self):
 87        """
 88        Return number of pixels in y (second dimension of data array)
 89        """
 90        return self.data.shape[1]
 91
 92    def aspx(self):
 93        """
 94        Return the aspect ratio of this image (width/height).
 95        """
 96        return self.ydim() / self.xdim()
 97
 98    def get_extent(self):
 99        """
100        Returns the width and height of this image in world coordinates.
101
102        Returns:
103            tuple with (width, height).
104        """
105        return self.xdim * self.pixel_size[0], self.ydim * self.pixel_size[1]
106
107    def set_projection(self,proj):
108        """
109        Set this project to an existing osgeo.osr.SpatialReference or GDAL georeference string.
110
111        Args:
112            proj (str, osgeo.osr.SpatialReference): the project to use as osgeo.osr.SpatialReference or GDAL georeference string.
113        """
114        if proj is None:
115            self.projection = None
116        else:
117            try:
118                from osgeo.osr import SpatialReference
119            except:
120                assert False, "Error - GDAL must be installed to work with spatial projections in hylite."
121            if isinstance(proj, SpatialReference):
122                self.projection = proj
123            elif isinstance(proj, str):
124                self.projection = SpatialReference(proj)
125            else:
126                print("Invalid project %s" % proj)
127                raise
128
129    def set_projection_EPSG(self,EPSG):
130        """
131        Sets this image project using an EPSG code.
132
133        Args:
134            EPSG (str): string EPSG code that can be passed to SpatialReference.SetFromUserInput(...).
135        """
136
137        try:
138            from osgeo.osr import SpatialReference
139        except:
140            assert False, "Error - GDAL must be installed to work with spatial projections in hylite."
141
142        self.projection = SpatialReference()
143        self.projection.SetFromUserInput(EPSG)
144
145    def get_projection_EPSG(self):
146        """
147        Gets a string describing this projections EPSG code (if it is an EPSG project).
148
149        Returns:
150            an EPSG code string of the format "EPSG:XXXX".
151        """
152        if self.projection is None:
153            return None
154        else:
155            return "%s:%s" % (self.projection.GetAttrValue("AUTHORITY",0),self.projection.GetAttrValue("AUTHORITY",1))
156
157    def pix_to_world(self, px, py, proj=None):
158        """
159        Take pixel coordinates and return world coordinates
160
161        Args:
162            px (int): the pixel x-coord.
163            py (int): the pixel y-coord.
164            proj (str, osr.SpatialReference): the coordinate system to use. Default (None) uses the same system as this image. Otherwise
165                   an osr.SpatialReference can be passed (HyImage.project), or an EPSG string (e.g. get_projection_EPSG(...)).
166        Returns:
167            the world coordinates in the coordinate system defined by get_projection_EPSG(...).
168        """
169
170        try:
171            from osgeo import osr
172            import osgeo.gdal as gdal
173            from osgeo import ogr
174        except:
175            assert False, "Error - GDAL must be installed to work with spatial projections in hylite."
176
177        # parse project
178        if proj is None:
179            proj = self.projection
180        elif isinstance(proj, str) or isinstance(proj, int):
181            epsg = proj
182            if isinstance(epsg, str):
183                try:
184                    epsg = int(str.split(':')[1])
185                except:
186                    assert False, "Error - %s is an invalid EPSG code." % proj
187            proj = osr.SpatialReference()
188            proj.ImportFromEPSG(epsg)
189
190        # check we have all the required info
191        assert isinstance(proj, osr.SpatialReference), "Error - invalid spatial reference %s" % proj
192        assert (not self.affine is None) and (
193            not self.projection is None), "Error - project information is undefined."
194
195        #project to world coordinates in this images project/world coords
196        x,y = gdal.ApplyGeoTransform(self.affine, px, py)
197
198        #project to target coords (if different)
199        if not proj.IsSameGeogCS(self.projection):
200            P = ogr.Geometry(ogr.wkbPoint)
201            if proj.EPSGTreatsAsNorthingEasting():
202                P.AddPoint(x, y)
203            else:
204                P.AddPoint(y, x)
205            P.AssignSpatialReference(self.projection)  # tell the point what coordinates it's in
206            P.TransformTo(proj)  # reproject it to the out spatial reference
207            x, y = P.GetX(), P.GetY()
208
209            #do we need to transpose?
210            if proj.EPSGTreatsAsLatLong():
211                x,y=y,x #we want lon,lat not lat,lon
212        return x, y
213
214    def world_to_pix(self, x, y, proj = None):
215        """
216        Take world coordinates and return pixel coordinates
217
218        Args:
219            x (float): the world x-coord.
220            y (float): the world y-coord.
221            proj (str, osr.SpatialReference): the coordinate system of the input coordinates. Default (None) uses the same system as this image. Otherwise
222                   an osr.SpatialReference can be passed (HyImage.project), or an EPSG string (e.g. get_projection_EPSG(...)).
223
224        Returns:
225            the pixel coordinates based on the affine transform stored in self.affine.
226        """
227
228        try:
229            from osgeo import osr
230            import osgeo.gdal as gdal
231            from osgeo import ogr
232        except:
233            assert False, "Error - GDAL must be installed to work with spatial projections in hylite."
234
235        # parse project
236        if proj is None:
237            proj = self.projection
238        elif isinstance(proj, str) or isinstance(proj, int):
239            epsg = proj
240            if isinstance(epsg, str):
241                try:
242                    epsg = int(str.split(':')[1])
243                except:
244                    assert False, "Error - %s is an invalid EPSG code." % proj
245            proj = osr.SpatialReference()
246            proj.ImportFromEPSG(epsg)
247
248
249        # check we have all the required info
250        assert isinstance(proj, osr.SpatialReference), "Error - invalid spatial reference %s" % proj
251        assert (not self.affine is None) and (not self.projection is None), "Error - project information is undefined."
252
253        # project to this images CS (if different)
254        if not proj.IsSameGeogCS(self.projection):
255            P = ogr.Geometry(ogr.wkbPoint)
256            if proj.EPSGTreatsAsNorthingEasting():
257                P.AddPoint(x, y)
258            else:
259                P.AddPoint(y, x)
260            P.AssignSpatialReference(proj)  # tell the point what coordinates it's in
261            P.AddPoint(x, y)
262            P.TransformTo(self.projection)  # reproject it to the out spatial reference
263            x, y = P.GetX(), P.GetY()
264            if self.projection.EPSGTreatsAsLatLong(): # do we need to transpose?
265                x, y = y, x  # we want lon,lat not lat,lon
266
267        inv = gdal.InvGeoTransform(self.affine)
268        assert not inv is None, "Error - could not invert affine transform?"
269
270        #apply
271        return gdal.ApplyGeoTransform(inv, x, y)
272
273    def flip(self, axis='x'):
274        """
275        Flip the image on the x or y axis.
276
277        Args:
278            axis (str): 'x' or 'y' or both 'xy'.
279        """
280
281        if 'x' in axis.lower():
282            self.data = np.flip(self.data,axis=0)
283        if 'y' in axis.lower():
284            self.data = np.flip(self.data,axis=1)
285
286    def rot90(self):
287        """
288        Rotate this image by 90 degrees by transposing the underlying data array. Combine with flip('x') or flip('y')
289        to achieve positive/negative rotations.
290        """
291        self.data = np.transpose( self.data, (1,0,2) )
292        self.push_to_header()
293
294    #####################################
295    ##IMAGE FILTERING
296    #####################################
297    def fill_holes(self):
298        """
299        Replaces nan pixel with an average of their neighbours, thus removing 1-pixel large holes from an image. Note that
300        for performance reasons this assumes that holes line up across bands. Note that this is not vectorized so very slow...
301        """
302
303        # perform greyscale dilation
304        dilate = self.data.copy()
305        mask = np.logical_not(np.isfinite(dilate))
306        dilate[mask] = 0
307        for b in range(self.band_count()):
308            dilate[:, :, b] = sp.ndimage.grey_dilation(dilate[:, :, b], size=(3, 3))
309
310        # map back to holes in dataset
311        self.data[mask] = dilate[mask]
312        #self.data[self.data == 0] = np.nan  # replace remaining 0's with nans
313
314    def blur(self, n=3):
315        """
316        Applies a gaussian kernel of size n to the image using OpenCV.
317
318        Args:
319            n (int): the dimensions of the gaussian kernel to convolve. Default is 3. Increase for more blurry results.
320        """
321        import cv2 # import this here to avoid errors if opencv is not installed properly
322
323        nanmask = np.isnan(self.data)
324        assert isinstance(n, int) and n >= 3, "Error - invalid kernel. N must be an integer > 3. "
325        kernel = np.ones((n, n), np.float32) / (n ** 2)
326        self.data = cv2.filter2D(self.data, -1, kernel)
327        self.data[nanmask] = np.nan  # remove mask
328
329    def erode(self, size=3, iterations=1):
330        """
331        Apply an erode filter to this image to expand background (nan) pixels. Refer to open-cv's erode
332        function for more details.
333
334        Args:
335            size (int): the size of the erode filter. Default is a 3x3 kernel.
336            iterations (int): the number of erode iterations. Default is 1.
337        """
338        import cv2 # import this here to avoid errors if opencv is not installed properly
339
340        # erode
341        kernel = np.ones((size, size), np.uint8)
342        if self.is_float():
343            mask = np.isfinite(self.data).any(axis=-1)
344            mask = cv2.erode(mask.astype(np.uint8), kernel, iterations=iterations)
345            self.data[mask == 0, :] = np.nan
346        else:
347            mask = (self.data != 0).any( axis=-1 )
348            mask = cv2.erode(mask.astype(np.uint8), kernel, iterations=iterations)
349            self.data[mask == 0, :] = 0
350
351    def resize(self, newdims : tuple, interpolation : int = 1):
352        """
353        Resize this image with opencv.
354
355        Args:
356            newdims (tuple): the new image dimensions.
357            interpolation (int): opencv interpolation method. Default is cv2.INTER_LINEAR.
358        """
359        import cv2 # import this here to avoid errors if opencv is not installed properly
360        self.data = cv2.resize(self.data, (newdims[1],newdims[0]), interpolation=interpolation)
361
362    def despeckle(self, size=5):
363        """
364        Despeckle each band of this image (independently) using a median filter.
365
366        Args:
367            size (int): the size of the median filter kernel. Default is 5. Must be an odd number.
368        """
369
370        assert (size % 2) == 1, "Error - size must be an odd integer"
371        import cv2 # import this here to avoid errors if opencv is not installed properly
372        if self.is_float():
373            self.data = cv2.medianBlur( self.data.astype(np.float32), size )
374        else:
375            self.data = cv2.medianBlur( self.data, size )
376
377    #####################################
378    ##FEATURES AND FEATURE MATCHING
379    ######################################
380    def get_keypoints(self, band, eq=False, mask=True, method='sift', cfac=0.0,bfac=0.0, **kwds):
381        """
382        Get feature descriptors from the specified band.
383
384        Args:
385            band (int,float,str,tuple): the band index (int) or wavelength (float) to extract features from. Alternatively, a tuple can be passed
386                    containing a range of bands (min : max) to average before feature matching.
387            eq (bool): True if the image should be histogram equalized first. Default is False.
388            mask (bool): True if 0 value pixels should be masked. Default is True.
389            method (str): the feature detector to use. Options are 'SIFT' and 'ORB' (faster but less accurate). Default is 'SIFT'.
390            cfac (float): contrast adjustment to apply to hyperspectral bands before matching. Default is 0.0.
391            bfac (float): brightness adjustment to apply to hyperspectral bands before matching. Default is 0.0.
392            **kwds: keyword arguments are passed to the opencv feature detector. For SIFT these are:
393
394                - contrastThreshold: default is 0.01.
395                - edgeThreshold: default is 10.
396                - sigma: default is 1.0
397
398                For ORB these are:
399
400                - nfeatures = the number of features to detect. Default is 5000.
401
402            Returns:
403                Tuple containing
404
405                    - k (ndarray): the keypoints detected
406                    - d (ndarray): corresponding feature descriptors
407         """
408        import cv2 # import this here to avoid errors if opencv is not installed properly
409
410        # get image
411        if isinstance(band, int) or isinstance(band, float): #single band
412            image = self.data[:, :, self.get_band_index(band)]
413        elif isinstance(band,tuple): #range of bands (averaged)
414            idx0 = self.get_band_index(band[0])
415            idx1 = self.get_band_index(band[1])
416
417            #deal with out of range errors
418            if idx0 is None:
419                idx0 = 0
420            if idx1 is None:
421                idx1 = self.band_count()
422
423            #average bands
424            image = np.nanmean(self.data[:,:,idx0:idx1],axis=2)
425        else:
426            assert False, "Error, unrecognised band %s" % band
427
428        #normalise image to range 0 - 1
429        image -= np.nanmin(image)
430        image = image / np.nanmax(image)
431
432        #apply brightness/contrast adjustment
433        image = (1.0+cfac)*image + bfac
434        image[image > 1.0] = 1.0
435        image[image < 0.0] = 0.0
436
437        #convert image to uint8 for opencv
438        image = np.uint8(255 * image)
439        if eq:
440            image = cv2.equalizeHist(image)
441
442        if mask:
443            mask = np.zeros(image.shape, dtype=np.uint8)
444            mask[image != 0] = 255  # include only non-zero pixels
445        else:
446            mask = None
447
448        if 'sift' in method.lower():  # SIFT
449
450            # setup default keywords
451            kwds["contrastThreshold"] = kwds.get("contrastThreshold", 0.01)
452            kwds["edgeThreshold"] = kwds.get("edgeThreshold", 10)
453            kwds["sigma"] = kwds.get("sigma", 1.0)
454
455            # make feature detector
456            #alg = cv2.xfeatures2d.SIFT_create(**kwds)
457            alg = cv2.SIFT_create()
458        elif 'orb' in method.lower():  # orb
459            kwds['nfeatures'] = kwds.get('nfeatures', 5000)
460            alg = cv2.ORB_create(scoreType=cv2.ORB_FAST_SCORE, **kwds)
461        else:
462            assert False, "Error - %s is not a recognised feature detector." % method
463
464        # detect keypoints
465        kp = alg.detect(image, mask)
466
467        # extract and return feature vectors
468        return alg.compute(image, kp)
469
470    @classmethod
471    def match_keypoints(cls, kp1, kp2, d1, d2, method='SIFT', dist=0.7, tree = 5, check = 100, min_count=5):
472        """
473        Compares keypoint feature vectors from two images and returns matching pairs.
474
475        Args:
476            kp1 (ndarray): keypoints from the first image
477            kp2 (ndarray): keypoints from the second image
478            d1 (ndarray): descriptors for the keypoints from the first image
479            d2 (ndarray): descriptors for the keypoints from the second image
480            method (str): the method used to calculate the feature descriptors. Should be 'sift' or 'orb'. Default is 'sift'.
481            dist (float): minimum match distance (0 to 1), default is 0.7
482            tree (int): not sure what this does? Default is 5. See open-cv docs.
483            check (int): ditto. Default is 100.
484            min_count (int): the minimum number of matches to consider a valid matching operation. If fewer matches are found,
485                       then the function returns None, None. Default is 5.
486        """
487        import cv2 # import this here to avoid errors if opencv is not installed properly
488        if 'sift' in method.lower():
489            algorithm = cv2.NORM_INF
490        elif 'orb' in method.lower():
491            algorithm = cv2.NORM_HAMMING
492        else:
493            assert False, "Error - unknown matching algorithm %s" % method
494
495        #calculate flann matches
496        index_params = dict(algorithm=algorithm, trees=tree)
497        search_params = dict(checks=check)
498        flann = cv2.FlannBasedMatcher(index_params, search_params)
499        matches = flann.knnMatch(d1, d2, k=2)
500
501        # store all the good matches as per Lowe's ratio test.
502        good = []
503        for m, n in matches:
504            if m.distance < dist * n.distance:
505                good.append(m)
506
507        if len(good) < min_count:
508            return None, None
509        else:
510            src_pts = np.float32([kp1[m.queryIdx].pt for m in good]).reshape(-1, 1, 2)
511            dst_pts = np.float32([kp2[m.trainIdx].pt for m in good]).reshape(-1, 1, 2)
512            return src_pts, dst_pts
513
514    ############################
515    ## Visualisation methods
516    ############################
517    def quick_plot(self, band=0, ax=None, bfac=0.0, cfac=0.0, samples=False, tscale=False, rot=False, flipX=False, flipY=False,
518                   **kwds):
519        """
520        Plot a band using matplotlib.imshow(...).
521
522        Args:
523            band (str,int,float,tuple): the band name (string), index (integer) or wavelength (float) to plot. Default is 0. If a tuple is passed then
524                  each band in the tuple (string or index) will be mapped to rgb. Bands with negative wavelengths or indices will be inverted before plotting.
525            ax: an axis object to plot to. If none, plt.imshow( ... ) is used.
526            bfac (float): a brightness adjustment to apply to RGB mappings (-1 to 1)
527            cfac (float): a contrast adjustment to apply to RGB mappings (-1 to 1)
528            samples (bool): True if sample points (defined in the header file) should be plotted. Default is False. Otherwise, a list of
529                     [ (x,y), ... ] points can be passed.
530            tscale (bool): True if each band (for ternary images) should be scaled independently. Default is False.
531                    When using scaling, vmin and vmax can be used to set the clipping percentiles (integers) or
532                    (constant) values (float).
533            rot (bool): if True, the x and y axis will be flipped (90 degree rotation) before plotting. Default is False.
534            flipX (bool): if True, the x axis will be flipped before plotting (after applying rotations).
535            flipY (bool): if True, the y axis will be flippe before plotting (after applying rotations).
536            **kwds: keywords are passed to matplotlib.imshow( ... ), except for the following:
537
538                 - mask = a 2D boolean mask containing true if pixels should be drawn and false otherwise.
539                 - path = a file path to save the image too (at matching resolution; use fig.savefig(..) if you want to save the figure).
540                 - ticks = True if x- and y- ticks should be plotted. Default is False.
541                 - ps, pc = the size and color of sample points to plot. Can be constant or list.
542                 - figsize = a figsize for the figure to create (if ax is None).
543
544        Returns:
545            Tuple containing
546
547            - fig: matplotlib figure object
548            - ax:  matplotlib axes object. If a colorbar is created, (band is an integer or a float), then this will be stored in ax.cbar.
549        """
550
551        #create new axes?
552        if ax is None:
553            fig, ax = plt.subplots(figsize=kwds.pop('figsize', (18,18*self.ydim()/self.xdim()) ))
554
555        # deal with ticks
556        if not kwds.pop('ticks', False ):
557            ax.set_xticks([])
558            ax.set_yticks([])
559
560        #map individual band using colourmap
561        if isinstance(band, str) or isinstance(band, int) or isinstance(band, float):
562            #get band
563            if isinstance(band, str):
564                data = self.data[:, :, self.get_band_index(band)]
565            else:
566                data = self.data[:, :, self.get_band_index(np.abs(band))]
567            if not isinstance(band, str) and band < 0:
568                data = np.nanmax(data) - data # flip
569
570            # convert integer vmin and vmax values to percentiles
571            if 'vmin' in kwds:
572                if isinstance(kwds['vmin'], int):
573                    kwds['vmin'] = np.nanpercentile( data, kwds['vmin'] )
574            if 'vmax' in kwds:
575                if isinstance(kwds['vmax'], int):
576                    kwds['vmax'] = np.nanpercentile( data, kwds['vmax'] )
577
578            #mask nans (and apply custom mask)
579            mask = np.isnan(data)
580            if not np.isnan(self.header.get_data_ignore_value()):
581                mask = mask + data == self.header.get_data_ignore_value()
582            if 'mask' in kwds:
583                mask = mask + kwds.get('mask')
584                del kwds['mask']
585            data = np.ma.array(data, mask = mask > 0 )
586
587            # apply rotations and flipping
588            if rot:
589                data = data.T
590            if flipX:
591                data = data[::-1, :]
592            if flipY:
593                data = data[:, ::-1]
594
595            # save?
596            if 'path' in kwds:
597                path = kwds.pop('path')
598                from matplotlib.pyplot import imsave
599                if not os.path.exists(os.path.dirname(path)):
600                    os.makedirs(os.path.dirname(path)) # ensure output directory exists
601                imsave(path, data.T, **kwds)  # save the image
602
603            ax.cbar = ax.imshow(data.T, interpolation=kwds.pop('interpolation', 'none'), **kwds) # change default interpolation to None
604
605        #map 3 bands to RGB
606        elif isinstance(band, tuple) or isinstance(band, list):
607            #get band indices and range
608            rgb = []
609            for b in band:
610                if isinstance(b, str):
611                    rgb.append(self.get_band_index(b))
612                else:
613                    rgb.append(self.get_band_index(np.abs(b)))
614
615            #slice image (as copy) and map to 0 - 1
616            img = np.array(self.data[:, :, rgb]).copy()
617            if np.isnan(img).all():
618                print("Warning - image contains no data.")
619                return ax.get_figure(), ax
620
621            # invert if needed
622            for i,b in enumerate(band):
623                if not isinstance(b, str) and (b < 0):
624                    img[..., i] = np.nanmax(img[..., i]) - img[..., i]
625
626            # do scaling
627            if tscale: # scale bands independently
628                for b in range(3):
629                    mn = kwds.get("vmin", float(np.nanmin(img)))
630                    mx = kwds.get("vmax", float(np.nanmax(img)))
631                    if isinstance (mn, int):
632                        assert mn >= 0 and mn <= 100, "Error - integer vmin values must be a percentile."
633                        mn = float(np.nanpercentile(img[...,b], mn ))
634                    if isinstance (mx, int):
635                        assert mx >= 0 and mx <= 100, "Error - integer vmax values must be a percentile."
636                        mx = float(np.nanpercentile(img[...,b], mx ))
637                    img[...,b] = (img[..., b] - mn) / (mx - mn)
638            else: # scale bands together
639                mn = kwds.get("vmin", float(np.nanmin(img)))
640                mx = kwds.get("vmax", float(np.nanmax(img)))
641                if isinstance(mn, int):
642                    assert mn >= 0 and mn <= 100, "Error - integer vmin values must be a percentile."
643                    mn = float(np.nanpercentile(img, mn))
644                if isinstance(mx, int):
645                    assert mx >= 0 and mx <= 100, "Error - integer vmax values must be a percentile."
646                    mx = float(np.nanpercentile(img, mx))
647                img = (img - mn) / (mx - mn)
648
649            #apply brightness/contrast mapping
650            img = np.clip((1.0 + cfac) * img + bfac, 0, 1.0 )
651
652            #apply masking so background is white
653            img[np.logical_not( np.isfinite( img ) )] = 1.0
654            if 'mask' in kwds:
655                img[kwds.pop("mask"),:] = 1.0
656
657            # apply rotations and flipping
658            if rot:
659                img = np.transpose( img, (1,0,2) )
660            if flipX:
661                img = img[::-1, :, :]
662            if flipY:
663                img = img[:, ::-1, :]
664
665            # save?
666            if 'path' in kwds:
667                path = kwds.pop('path')
668                from matplotlib.pyplot import imsave
669                if not os.path.exists(os.path.dirname(path)):
670                    os.makedirs(os.path.dirname(path)) # ensure output directory exists
671                imsave(path, np.transpose( np.clip( img*255, 0, 255).astype(np.uint8), (1, 0, 2)))  # save the image
672
673            # plot samples?
674            ps = kwds.pop('ps', 5)
675            pc = kwds.pop('pc', 'r')
676            if samples:
677                if isinstance(samples, list) or isinstance(samples, np.ndarray):
678                    ax.scatter([s[0] for s in samples], [s[1] for s in samples], s=ps, c=pc)
679                else:
680                    for n in self.header.get_class_names():
681                        points = np.array(self.header.get_sample_points(n))
682                        ax.scatter(points[:, 0], points[:, 1], s=ps, c=pc)
683
684            #plot
685            ax.imshow(np.transpose(img, (1,0,2)), interpolation=kwds.pop('interpolation', 'none'), **kwds)
686            ax.cbar = None  # no colorbar
687
688        return ax.get_figure(), ax
689
690    def createGIF(self, path, bands=None, figsize=(10,10), fps=10, **kwds):
691        """
692        Create and save an animated gif that loops through the bands of the image.
693
694        Args:
695            path (str): the path to save the .gif
696            bands (tuple): Tuple containing the range of band indices to draw. Default is the whole range.
697            figsize (tuple): the size of the image to draw. Default is (10,10).
698            fps (int): the framerate (frames per second) of the gif. Default is 10.
699            **kwds: keywords are passed directly to matplotlib.imshow. Use this to specify cmap etc.
700        """
701
702        frames = []
703        if bands is None:
704            bands = (0,self.band_count())
705        else:
706            assert 0 < bands[0] < self.band_count(), "Error - invalid range."
707            assert 0 < bands[1] < self.band_count(), "Error - invalid range."
708            assert bands[1] > bands[0], "Error - invalid range."
709
710        #plot frames
711        for i in range(bands[0],bands[1]):
712            fig, ax = plt.subplots(figsize=figsize)
713            ax.imshow(self.data[:, :, i], **kwds)
714            fig.canvas.draw()
715            frames.append(np.frombuffer(fig.canvas.tostring_rgb(), dtype='uint8'))
716            frames[-1] = np.reshape(frames[-1], (fig.canvas.get_width_height()[1], fig.canvas.get_width_height()[0], 3))
717            plt.close(fig)
718
719        #save gif
720        imageio.mimsave( os.path.splitext(path)[0] + ".gif", frames, fps=fps)
721
722    ## masking
723    def mask(self, mask=None, flag=np.nan, invert=False, crop=False, bands=None):
724        """
725         Apply a mask to an image, flagging masked pixels with the specified value. Note that this applies the mask to the
726         image in-situ.
727
728         Args:
729            flag (float): the value to use for masked pixels. Default is np.nan
730            mask (ndarray): a numpy array defining the mask polygon of the format [[x1,y1],[x2,y2],...]. If None is passed then
731                    pickPolygon( ... ) is used to interactively define a polygon. If a file path is passed then the polygon
732                    will be loaded using np.load( ... ). Alternatively if mask.shape == image.shape[0,1] then it is treated as a
733                    binary image mask (must be boolean) and True values will be masked across all bands. Default is None.
734            invert (bool): if True, pixels within the polygon will be masked. If False, pixels outside the polygon are masked. Default is False.
735            crop (bool): True if rows/columns containing only zeros should be removed. Default is False.
736            bands (tuple): the bands of the image to plot if no mask is specified. If None, the middle band is used.
737
738         Returns:
739            Tuple containing
740
741            - mask (ndarray): a boolean array with True where pixels are masked and False elsewhere.
742            - poly (ndarray): the mask polygon array in the format described above. Useful if the polygon was interactively defined.
743         """
744
745        if mask is None:  # pick mask interactively
746            if bands is None:
747                bands = int(self.band_count() / 2)
748
749            regions = self.pickPolygons(region_names=["mask"], bands=bands)
750
751            # the user bailed without picking a mask?
752            if len(regions) == 0:
753                print("Warning - no mask picked/applied.")
754                return
755
756            # extract polygon mask
757            mask = regions[0]
758
759        # convert polygon mask to binary mask
760        if mask.shape[1] == 2:
761
762            # build meshgrid with pixel coords
763            xx, yy = np.meshgrid(np.arange(self.xdim()), np.arange(self.ydim()))
764            xx = xx.flatten()
765            yy = yy.flatten()
766            points = np.vstack([xx, yy]).T  # coordinates of each pixel
767
768            # calculate per-pixel mask
769            mask = path.Path(mask).contains_points(points)
770            mask = mask.reshape((self.ydim(), self.xdim())).T
771
772            # flip as we want to mask (==True) outside points (unless invert is true)
773            if not invert:
774                mask = np.logical_not(mask)
775
776        # apply binary image mask
777        assert mask.shape[0] == self.data.shape[0] and mask.shape[1] == self.data.shape[1], \
778            "Error - mask shape %s does not match image shape %s" % (mask.shape, self.data.shape)
779        for b in range(self.band_count()):
780            self.data[:, :, b][mask] = flag
781
782        # crop image
783        if crop:
784            # calculate non-masked pixels
785            valid = np.logical_not(mask)
786
787            # integrate along axes
788            xdata = np.sum(valid, axis=1) > 0.0
789            ydata = np.sum(valid, axis=0) > 0.0
790
791            # calculate domain containing valid pixels
792            xmin = np.argmax(xdata)
793            xmax = xdata.shape[0] - np.argmax(xdata[::-1])
794            ymin = np.argmax(ydata)
795            ymax = ydata.shape[0] - np.argmax(ydata[::-1])
796
797            # crop
798            self.data = self.data[xmin:xmax, ymin:ymax, :]
799
800        return mask
801
802    def crop_to_data(self):
803        """
804        Remove padding of nan or zero pixels from image. Note that this is performed in place.
805        """
806
807        valid = np.isfinite(self.data).any(axis=-1) & (self.data != 0).any(axis=-1)
808        ymin, ymax = np.percentile(np.argwhere(np.sum(valid, axis=0) != 0), (0, 100))
809        xmin, xmax = np.percentile(np.argwhere(np.sum(valid, axis=1) != 0), (0, 100))
810        self.data = self.data[int(xmin):int(xmax), int(ymin):int(ymax), :]  # do clipping
811
812    ##################################################
813    ## Interactive tools for picking regions/pixels
814    ##################################################
815    def pickPolygons(self, region_names, bands=0):
816        """
817        Creates a matplotlib gui for selecting polygon regions in an image.
818
819        Args:
820            names (list, str): a list containing the names of the regions to pick. If a string is passed only one name is used.
821            bands (tuple): the bands of the image to plot.
822        """
823
824        if isinstance(region_names, str):
825            region_names = [region_names]
826
827        assert isinstance(region_names, list), "Error - names must be a list or a string."
828
829        # set matplotlib backend
830        backend = matplotlib.get_backend()
831        matplotlib.use('Qt5Agg')  # need this backend for ROIPoly to work
832
833        # plot image and extract roi's
834        fig, ax = self.quick_plot(bands)
835        roi = MultiRoi(roi_names=region_names)
836        plt.close(fig)  # close figure
837
838        # extract regions
839        regions = []
840        for name, r in roi.rois.items():
841            # store region
842            x = r.x
843            y = r.y
844            regions.append(np.vstack([x, y]).T)
845
846        # restore matplotlib backend (if possible)
847        try:
848            matplotlib.use(backend)
849        except:
850            print("Warning: could not reset matplotlib backend. Plots will remain interactive...")
851            pass
852
853        return regions
854
855    def pickPoints(self, n=-1, bands=hylite.RGB, integer=True, title="Pick Points", **kwds):
856        """
857        Creates a matplotlib gui for picking pixels from an image.
858
859        Args:
860            n (int): the number of pixels to pick, or -1 if the user can select as many as they wish. Default is -1.
861            bands (tuple): the bands of the image to plot. Default is HyImage.RGB
862            integer (bool): True if points coordinates should be cast to integers (for use as indices). Default is True.
863            title (str): The title of the point picking window.
864            **kwds: Keywords are passed to HyImage.quick_plot( ... ).
865
866        Returns:
867            A list containing the picked point coordinates [ (x1,y1), (x2,y2), ... ].
868        """
869
870        # set matplotlib backend
871        backend = matplotlib.get_backend()
872        matplotlib.use('Qt5Agg')  # need this backend for ROIPoly to work
873
874        # create figure
875        fig, ax = self.quick_plot( bands, **kwds )
876        ax.set_title(title)
877
878        # get points
879        points = fig.ginput( n )
880
881        if integer:
882            points = [ (int(p[0]), int(p[1])) for p in points ]
883
884        # restore matplotlib backend (if possible)
885        try:
886            matplotlib.use(backend)
887        except:
888            print("Warning: could not reset matplotlib backend. Plots will remain interactive...")
889            pass
890
891        return points
892
893    def pickSamples(self, names=None, store=True, **kwds):
894        """
895        Pick sample probe points and store these in the image header file.
896
897        Args:
898            names (str, list): the name of the sample to pick, or a list of names to pick multiple.
899            store (bool): True if sample should be stored in the image header file (for later access). Default is True.
900            **kwds: Keywords are passed to HyImage.quick_plot( ... )
901
902        Returns:
903            a list containing a list of points for each sample.
904        """
905
906        if isinstance(names, str):
907            names = [names]
908
909        # pick points
910        points = []
911        for s in names:
912            pnts = self.pickPoints(title="%s" % s, **kwds)
913            if store:
914                self.header['sample %s' % s] = pnts # store in header
915            points.append(pnts)
916        # add class to header file
917        if store:
918            cls_names = self.header.get_class_names()
919            if cls_names is None:
920                cls_names = []
921            self.header['class names'] = cls_names + names
922
923        return points

A class for hyperspectral image data. These can be individual scenes or hyperspectral orthoimages.

HyImage(data, **kwds)
25    def __init__(self, data, **kwds):
26        """
27        Args:
28            data (ndarray): a numpy array such that data[x][y][band] gives each pixel value.
29            **kwds:
30                wav = A numpy array containing band wavelengths for this image.
31                affine = an affine transform of the format returned by GDAL.GetGeoTransform().
32                projection = string defining the project. Default is None.
33                sensor = sensor name. Default is "unknown".
34                header = path to associated header file. Default is None.
35        """
36
37        #call constructor for HyData
38        super().__init__(data, **kwds)
39
40        # special case - if dataset only has oneband, slice it so it still has
41        # the format data[x,y,b].
42        if not self.data is None:
43            if len(self.data.shape) == 1:
44                self.data = self.data[None, None, :] # single pixel image
45            if len(self.data.shape) == 2:
46                self.data = self.data[:, :, None] # single band iamge
47
48        #load any additional project information (specific to images)
49        self.set_projection(kwds.get("projection",None))
50        self.affine = kwds.get("affine",[0,1,0,0,0,1])
51
52        # wavelengths
53        if 'wav' in kwds:
54            self.set_wavelengths(kwds['wav'])
55
56        #special header formatting
57        self.header['file type'] = 'ENVI Standard'
Arguments:
  • data (ndarray): a numpy array such that data[x][y][band] gives each pixel value.
  • **kwds: wav = A numpy array containing band wavelengths for this image. affine = an affine transform of the format returned by GDAL.GetGeoTransform(). projection = string defining the project. Default is None. sensor = sensor name. Default is "unknown". header = path to associated header file. Default is None.
def copy(self, data=True):
59    def copy(self,data=True):
60        """
61        Make a deep copy of this image instance.
62
63        Args:
64            data (bool): True if a copy of the data should be made, otherwise only copy header.
65
66        Returns:
67            a new HyImage instance.
68        """
69        if not data:
70            return HyImage(None, header=self.header.copy(), projection=self.projection, affine=self.affine)
71        else:
72            return HyImage( self.data.copy(), header=self.header.copy(), projection=self.projection, affine=self.affine)

Make a deep copy of this image instance.

Arguments:
  • data (bool): True if a copy of the data should be made, otherwise only copy header.
Returns:

a new HyImage instance.

def T(self):
74    def T(self):
75        """
76        Return a transposed view of the data matrix (corresponding with the [y,x] indexing used by matplotlib, opencv etc.
77        """
78        return np.transpose(self.data, (1,0,2))

Return a transposed view of the data matrix (corresponding with the [y,x] indexing used by matplotlib, opencv etc.

def xdim(self):
80    def xdim(self):
81        """
82        Return number of pixels in x (first dimension of data array)
83        """
84        return self.data.shape[0]

Return number of pixels in x (first dimension of data array)

def ydim(self):
86    def ydim(self):
87        """
88        Return number of pixels in y (second dimension of data array)
89        """
90        return self.data.shape[1]

Return number of pixels in y (second dimension of data array)

def aspx(self):
92    def aspx(self):
93        """
94        Return the aspect ratio of this image (width/height).
95        """
96        return self.ydim() / self.xdim()

Return the aspect ratio of this image (width/height).

def get_extent(self):
 98    def get_extent(self):
 99        """
100        Returns the width and height of this image in world coordinates.
101
102        Returns:
103            tuple with (width, height).
104        """
105        return self.xdim * self.pixel_size[0], self.ydim * self.pixel_size[1]

Returns the width and height of this image in world coordinates.

Returns:

tuple with (width, height).

def set_projection(self, proj):
107    def set_projection(self,proj):
108        """
109        Set this project to an existing osgeo.osr.SpatialReference or GDAL georeference string.
110
111        Args:
112            proj (str, osgeo.osr.SpatialReference): the project to use as osgeo.osr.SpatialReference or GDAL georeference string.
113        """
114        if proj is None:
115            self.projection = None
116        else:
117            try:
118                from osgeo.osr import SpatialReference
119            except:
120                assert False, "Error - GDAL must be installed to work with spatial projections in hylite."
121            if isinstance(proj, SpatialReference):
122                self.projection = proj
123            elif isinstance(proj, str):
124                self.projection = SpatialReference(proj)
125            else:
126                print("Invalid project %s" % proj)
127                raise

Set this project to an existing osgeo.osr.SpatialReference or GDAL georeference string.

Arguments:
  • proj (str, osgeo.osr.SpatialReference): the project to use as osgeo.osr.SpatialReference or GDAL georeference string.
def set_projection_EPSG(self, EPSG):
129    def set_projection_EPSG(self,EPSG):
130        """
131        Sets this image project using an EPSG code.
132
133        Args:
134            EPSG (str): string EPSG code that can be passed to SpatialReference.SetFromUserInput(...).
135        """
136
137        try:
138            from osgeo.osr import SpatialReference
139        except:
140            assert False, "Error - GDAL must be installed to work with spatial projections in hylite."
141
142        self.projection = SpatialReference()
143        self.projection.SetFromUserInput(EPSG)

Sets this image project using an EPSG code.

Arguments:
  • EPSG (str): string EPSG code that can be passed to SpatialReference.SetFromUserInput(...).
def get_projection_EPSG(self):
145    def get_projection_EPSG(self):
146        """
147        Gets a string describing this projections EPSG code (if it is an EPSG project).
148
149        Returns:
150            an EPSG code string of the format "EPSG:XXXX".
151        """
152        if self.projection is None:
153            return None
154        else:
155            return "%s:%s" % (self.projection.GetAttrValue("AUTHORITY",0),self.projection.GetAttrValue("AUTHORITY",1))

Gets a string describing this projections EPSG code (if it is an EPSG project).

Returns:

an EPSG code string of the format "EPSG:XXXX".

def pix_to_world(self, px, py, proj=None):
157    def pix_to_world(self, px, py, proj=None):
158        """
159        Take pixel coordinates and return world coordinates
160
161        Args:
162            px (int): the pixel x-coord.
163            py (int): the pixel y-coord.
164            proj (str, osr.SpatialReference): the coordinate system to use. Default (None) uses the same system as this image. Otherwise
165                   an osr.SpatialReference can be passed (HyImage.project), or an EPSG string (e.g. get_projection_EPSG(...)).
166        Returns:
167            the world coordinates in the coordinate system defined by get_projection_EPSG(...).
168        """
169
170        try:
171            from osgeo import osr
172            import osgeo.gdal as gdal
173            from osgeo import ogr
174        except:
175            assert False, "Error - GDAL must be installed to work with spatial projections in hylite."
176
177        # parse project
178        if proj is None:
179            proj = self.projection
180        elif isinstance(proj, str) or isinstance(proj, int):
181            epsg = proj
182            if isinstance(epsg, str):
183                try:
184                    epsg = int(str.split(':')[1])
185                except:
186                    assert False, "Error - %s is an invalid EPSG code." % proj
187            proj = osr.SpatialReference()
188            proj.ImportFromEPSG(epsg)
189
190        # check we have all the required info
191        assert isinstance(proj, osr.SpatialReference), "Error - invalid spatial reference %s" % proj
192        assert (not self.affine is None) and (
193            not self.projection is None), "Error - project information is undefined."
194
195        #project to world coordinates in this images project/world coords
196        x,y = gdal.ApplyGeoTransform(self.affine, px, py)
197
198        #project to target coords (if different)
199        if not proj.IsSameGeogCS(self.projection):
200            P = ogr.Geometry(ogr.wkbPoint)
201            if proj.EPSGTreatsAsNorthingEasting():
202                P.AddPoint(x, y)
203            else:
204                P.AddPoint(y, x)
205            P.AssignSpatialReference(self.projection)  # tell the point what coordinates it's in
206            P.TransformTo(proj)  # reproject it to the out spatial reference
207            x, y = P.GetX(), P.GetY()
208
209            #do we need to transpose?
210            if proj.EPSGTreatsAsLatLong():
211                x,y=y,x #we want lon,lat not lat,lon
212        return x, y

Take pixel coordinates and return world coordinates

Arguments:
  • px (int): the pixel x-coord.
  • py (int): the pixel y-coord.
  • proj (str, osr.SpatialReference): the coordinate system to use. Default (None) uses the same system as this image. Otherwise an osr.SpatialReference can be passed (HyImage.project), or an EPSG string (e.g. get_projection_EPSG(...)).
Returns:

the world coordinates in the coordinate system defined by get_projection_EPSG(...).

def world_to_pix(self, x, y, proj=None):
214    def world_to_pix(self, x, y, proj = None):
215        """
216        Take world coordinates and return pixel coordinates
217
218        Args:
219            x (float): the world x-coord.
220            y (float): the world y-coord.
221            proj (str, osr.SpatialReference): the coordinate system of the input coordinates. Default (None) uses the same system as this image. Otherwise
222                   an osr.SpatialReference can be passed (HyImage.project), or an EPSG string (e.g. get_projection_EPSG(...)).
223
224        Returns:
225            the pixel coordinates based on the affine transform stored in self.affine.
226        """
227
228        try:
229            from osgeo import osr
230            import osgeo.gdal as gdal
231            from osgeo import ogr
232        except:
233            assert False, "Error - GDAL must be installed to work with spatial projections in hylite."
234
235        # parse project
236        if proj is None:
237            proj = self.projection
238        elif isinstance(proj, str) or isinstance(proj, int):
239            epsg = proj
240            if isinstance(epsg, str):
241                try:
242                    epsg = int(str.split(':')[1])
243                except:
244                    assert False, "Error - %s is an invalid EPSG code." % proj
245            proj = osr.SpatialReference()
246            proj.ImportFromEPSG(epsg)
247
248
249        # check we have all the required info
250        assert isinstance(proj, osr.SpatialReference), "Error - invalid spatial reference %s" % proj
251        assert (not self.affine is None) and (not self.projection is None), "Error - project information is undefined."
252
253        # project to this images CS (if different)
254        if not proj.IsSameGeogCS(self.projection):
255            P = ogr.Geometry(ogr.wkbPoint)
256            if proj.EPSGTreatsAsNorthingEasting():
257                P.AddPoint(x, y)
258            else:
259                P.AddPoint(y, x)
260            P.AssignSpatialReference(proj)  # tell the point what coordinates it's in
261            P.AddPoint(x, y)
262            P.TransformTo(self.projection)  # reproject it to the out spatial reference
263            x, y = P.GetX(), P.GetY()
264            if self.projection.EPSGTreatsAsLatLong(): # do we need to transpose?
265                x, y = y, x  # we want lon,lat not lat,lon
266
267        inv = gdal.InvGeoTransform(self.affine)
268        assert not inv is None, "Error - could not invert affine transform?"
269
270        #apply
271        return gdal.ApplyGeoTransform(inv, x, y)

Take world coordinates and return pixel coordinates

Arguments:
  • x (float): the world x-coord.
  • y (float): the world y-coord.
  • proj (str, osr.SpatialReference): the coordinate system of the input coordinates. Default (None) uses the same system as this image. Otherwise an osr.SpatialReference can be passed (HyImage.project), or an EPSG string (e.g. get_projection_EPSG(...)).
Returns:

the pixel coordinates based on the affine transform stored in self.affine.

def flip(self, axis='x'):
273    def flip(self, axis='x'):
274        """
275        Flip the image on the x or y axis.
276
277        Args:
278            axis (str): 'x' or 'y' or both 'xy'.
279        """
280
281        if 'x' in axis.lower():
282            self.data = np.flip(self.data,axis=0)
283        if 'y' in axis.lower():
284            self.data = np.flip(self.data,axis=1)

Flip the image on the x or y axis.

Arguments:
  • axis (str): 'x' or 'y' or both 'xy'.
def rot90(self):
286    def rot90(self):
287        """
288        Rotate this image by 90 degrees by transposing the underlying data array. Combine with flip('x') or flip('y')
289        to achieve positive/negative rotations.
290        """
291        self.data = np.transpose( self.data, (1,0,2) )
292        self.push_to_header()

Rotate this image by 90 degrees by transposing the underlying data array. Combine with flip('x') or flip('y') to achieve positive/negative rotations.

def fill_holes(self):
297    def fill_holes(self):
298        """
299        Replaces nan pixel with an average of their neighbours, thus removing 1-pixel large holes from an image. Note that
300        for performance reasons this assumes that holes line up across bands. Note that this is not vectorized so very slow...
301        """
302
303        # perform greyscale dilation
304        dilate = self.data.copy()
305        mask = np.logical_not(np.isfinite(dilate))
306        dilate[mask] = 0
307        for b in range(self.band_count()):
308            dilate[:, :, b] = sp.ndimage.grey_dilation(dilate[:, :, b], size=(3, 3))
309
310        # map back to holes in dataset
311        self.data[mask] = dilate[mask]
312        #self.data[self.data == 0] = np.nan  # replace remaining 0's with nans

Replaces nan pixel with an average of their neighbours, thus removing 1-pixel large holes from an image. Note that for performance reasons this assumes that holes line up across bands. Note that this is not vectorized so very slow...

def blur(self, n=3):
314    def blur(self, n=3):
315        """
316        Applies a gaussian kernel of size n to the image using OpenCV.
317
318        Args:
319            n (int): the dimensions of the gaussian kernel to convolve. Default is 3. Increase for more blurry results.
320        """
321        import cv2 # import this here to avoid errors if opencv is not installed properly
322
323        nanmask = np.isnan(self.data)
324        assert isinstance(n, int) and n >= 3, "Error - invalid kernel. N must be an integer > 3. "
325        kernel = np.ones((n, n), np.float32) / (n ** 2)
326        self.data = cv2.filter2D(self.data, -1, kernel)
327        self.data[nanmask] = np.nan  # remove mask

Applies a gaussian kernel of size n to the image using OpenCV.

Arguments:
  • n (int): the dimensions of the gaussian kernel to convolve. Default is 3. Increase for more blurry results.
def erode(self, size=3, iterations=1):
329    def erode(self, size=3, iterations=1):
330        """
331        Apply an erode filter to this image to expand background (nan) pixels. Refer to open-cv's erode
332        function for more details.
333
334        Args:
335            size (int): the size of the erode filter. Default is a 3x3 kernel.
336            iterations (int): the number of erode iterations. Default is 1.
337        """
338        import cv2 # import this here to avoid errors if opencv is not installed properly
339
340        # erode
341        kernel = np.ones((size, size), np.uint8)
342        if self.is_float():
343            mask = np.isfinite(self.data).any(axis=-1)
344            mask = cv2.erode(mask.astype(np.uint8), kernel, iterations=iterations)
345            self.data[mask == 0, :] = np.nan
346        else:
347            mask = (self.data != 0).any( axis=-1 )
348            mask = cv2.erode(mask.astype(np.uint8), kernel, iterations=iterations)
349            self.data[mask == 0, :] = 0

Apply an erode filter to this image to expand background (nan) pixels. Refer to open-cv's erode function for more details.

Arguments:
  • size (int): the size of the erode filter. Default is a 3x3 kernel.
  • iterations (int): the number of erode iterations. Default is 1.
def resize(self, newdims: tuple, interpolation: int = 1):
351    def resize(self, newdims : tuple, interpolation : int = 1):
352        """
353        Resize this image with opencv.
354
355        Args:
356            newdims (tuple): the new image dimensions.
357            interpolation (int): opencv interpolation method. Default is cv2.INTER_LINEAR.
358        """
359        import cv2 # import this here to avoid errors if opencv is not installed properly
360        self.data = cv2.resize(self.data, (newdims[1],newdims[0]), interpolation=interpolation)

Resize this image with opencv.

Arguments:
  • newdims (tuple): the new image dimensions.
  • interpolation (int): opencv interpolation method. Default is cv2.INTER_LINEAR.
def despeckle(self, size=5):
362    def despeckle(self, size=5):
363        """
364        Despeckle each band of this image (independently) using a median filter.
365
366        Args:
367            size (int): the size of the median filter kernel. Default is 5. Must be an odd number.
368        """
369
370        assert (size % 2) == 1, "Error - size must be an odd integer"
371        import cv2 # import this here to avoid errors if opencv is not installed properly
372        if self.is_float():
373            self.data = cv2.medianBlur( self.data.astype(np.float32), size )
374        else:
375            self.data = cv2.medianBlur( self.data, size )

Despeckle each band of this image (independently) using a median filter.

Arguments:
  • size (int): the size of the median filter kernel. Default is 5. Must be an odd number.
def get_keypoints( self, band, eq=False, mask=True, method='sift', cfac=0.0, bfac=0.0, **kwds):
380    def get_keypoints(self, band, eq=False, mask=True, method='sift', cfac=0.0,bfac=0.0, **kwds):
381        """
382        Get feature descriptors from the specified band.
383
384        Args:
385            band (int,float,str,tuple): the band index (int) or wavelength (float) to extract features from. Alternatively, a tuple can be passed
386                    containing a range of bands (min : max) to average before feature matching.
387            eq (bool): True if the image should be histogram equalized first. Default is False.
388            mask (bool): True if 0 value pixels should be masked. Default is True.
389            method (str): the feature detector to use. Options are 'SIFT' and 'ORB' (faster but less accurate). Default is 'SIFT'.
390            cfac (float): contrast adjustment to apply to hyperspectral bands before matching. Default is 0.0.
391            bfac (float): brightness adjustment to apply to hyperspectral bands before matching. Default is 0.0.
392            **kwds: keyword arguments are passed to the opencv feature detector. For SIFT these are:
393
394                - contrastThreshold: default is 0.01.
395                - edgeThreshold: default is 10.
396                - sigma: default is 1.0
397
398                For ORB these are:
399
400                - nfeatures = the number of features to detect. Default is 5000.
401
402            Returns:
403                Tuple containing
404
405                    - k (ndarray): the keypoints detected
406                    - d (ndarray): corresponding feature descriptors
407         """
408        import cv2 # import this here to avoid errors if opencv is not installed properly
409
410        # get image
411        if isinstance(band, int) or isinstance(band, float): #single band
412            image = self.data[:, :, self.get_band_index(band)]
413        elif isinstance(band,tuple): #range of bands (averaged)
414            idx0 = self.get_band_index(band[0])
415            idx1 = self.get_band_index(band[1])
416
417            #deal with out of range errors
418            if idx0 is None:
419                idx0 = 0
420            if idx1 is None:
421                idx1 = self.band_count()
422
423            #average bands
424            image = np.nanmean(self.data[:,:,idx0:idx1],axis=2)
425        else:
426            assert False, "Error, unrecognised band %s" % band
427
428        #normalise image to range 0 - 1
429        image -= np.nanmin(image)
430        image = image / np.nanmax(image)
431
432        #apply brightness/contrast adjustment
433        image = (1.0+cfac)*image + bfac
434        image[image > 1.0] = 1.0
435        image[image < 0.0] = 0.0
436
437        #convert image to uint8 for opencv
438        image = np.uint8(255 * image)
439        if eq:
440            image = cv2.equalizeHist(image)
441
442        if mask:
443            mask = np.zeros(image.shape, dtype=np.uint8)
444            mask[image != 0] = 255  # include only non-zero pixels
445        else:
446            mask = None
447
448        if 'sift' in method.lower():  # SIFT
449
450            # setup default keywords
451            kwds["contrastThreshold"] = kwds.get("contrastThreshold", 0.01)
452            kwds["edgeThreshold"] = kwds.get("edgeThreshold", 10)
453            kwds["sigma"] = kwds.get("sigma", 1.0)
454
455            # make feature detector
456            #alg = cv2.xfeatures2d.SIFT_create(**kwds)
457            alg = cv2.SIFT_create()
458        elif 'orb' in method.lower():  # orb
459            kwds['nfeatures'] = kwds.get('nfeatures', 5000)
460            alg = cv2.ORB_create(scoreType=cv2.ORB_FAST_SCORE, **kwds)
461        else:
462            assert False, "Error - %s is not a recognised feature detector." % method
463
464        # detect keypoints
465        kp = alg.detect(image, mask)
466
467        # extract and return feature vectors
468        return alg.compute(image, kp)

Get feature descriptors from the specified band.

Arguments:
  • band (int,float,str,tuple): the band index (int) or wavelength (float) to extract features from. Alternatively, a tuple can be passed containing a range of bands (min : max) to average before feature matching.
  • eq (bool): True if the image should be histogram equalized first. Default is False.
  • mask (bool): True if 0 value pixels should be masked. Default is True.
  • method (str): the feature detector to use. Options are 'SIFT' and 'ORB' (faster but less accurate). Default is 'SIFT'.
  • cfac (float): contrast adjustment to apply to hyperspectral bands before matching. Default is 0.0.
  • bfac (float): brightness adjustment to apply to hyperspectral bands before matching. Default is 0.0.
  • **kwds: keyword arguments are passed to the opencv feature detector. For SIFT these are:

    • contrastThreshold: default is 0.01.
    • edgeThreshold: default is 10.
    • sigma: default is 1.0

    For ORB these are:

    • nfeatures = the number of features to detect. Default is 5000.
  • Returns: Tuple containing

    • k (ndarray): the keypoints detected
    • d (ndarray): corresponding feature descriptors
@classmethod
def match_keypoints( cls, kp1, kp2, d1, d2, method='SIFT', dist=0.7, tree=5, check=100, min_count=5):
470    @classmethod
471    def match_keypoints(cls, kp1, kp2, d1, d2, method='SIFT', dist=0.7, tree = 5, check = 100, min_count=5):
472        """
473        Compares keypoint feature vectors from two images and returns matching pairs.
474
475        Args:
476            kp1 (ndarray): keypoints from the first image
477            kp2 (ndarray): keypoints from the second image
478            d1 (ndarray): descriptors for the keypoints from the first image
479            d2 (ndarray): descriptors for the keypoints from the second image
480            method (str): the method used to calculate the feature descriptors. Should be 'sift' or 'orb'. Default is 'sift'.
481            dist (float): minimum match distance (0 to 1), default is 0.7
482            tree (int): not sure what this does? Default is 5. See open-cv docs.
483            check (int): ditto. Default is 100.
484            min_count (int): the minimum number of matches to consider a valid matching operation. If fewer matches are found,
485                       then the function returns None, None. Default is 5.
486        """
487        import cv2 # import this here to avoid errors if opencv is not installed properly
488        if 'sift' in method.lower():
489            algorithm = cv2.NORM_INF
490        elif 'orb' in method.lower():
491            algorithm = cv2.NORM_HAMMING
492        else:
493            assert False, "Error - unknown matching algorithm %s" % method
494
495        #calculate flann matches
496        index_params = dict(algorithm=algorithm, trees=tree)
497        search_params = dict(checks=check)
498        flann = cv2.FlannBasedMatcher(index_params, search_params)
499        matches = flann.knnMatch(d1, d2, k=2)
500
501        # store all the good matches as per Lowe's ratio test.
502        good = []
503        for m, n in matches:
504            if m.distance < dist * n.distance:
505                good.append(m)
506
507        if len(good) < min_count:
508            return None, None
509        else:
510            src_pts = np.float32([kp1[m.queryIdx].pt for m in good]).reshape(-1, 1, 2)
511            dst_pts = np.float32([kp2[m.trainIdx].pt for m in good]).reshape(-1, 1, 2)
512            return src_pts, dst_pts

Compares keypoint feature vectors from two images and returns matching pairs.

Arguments:
  • kp1 (ndarray): keypoints from the first image
  • kp2 (ndarray): keypoints from the second image
  • d1 (ndarray): descriptors for the keypoints from the first image
  • d2 (ndarray): descriptors for the keypoints from the second image
  • method (str): the method used to calculate the feature descriptors. Should be 'sift' or 'orb'. Default is 'sift'.
  • dist (float): minimum match distance (0 to 1), default is 0.7
  • tree (int): not sure what this does? Default is 5. See open-cv docs.
  • check (int): ditto. Default is 100.
  • min_count (int): the minimum number of matches to consider a valid matching operation. If fewer matches are found, then the function returns None, None. Default is 5.
def quick_plot( self, band=0, ax=None, bfac=0.0, cfac=0.0, samples=False, tscale=False, rot=False, flipX=False, flipY=False, **kwds):
517    def quick_plot(self, band=0, ax=None, bfac=0.0, cfac=0.0, samples=False, tscale=False, rot=False, flipX=False, flipY=False,
518                   **kwds):
519        """
520        Plot a band using matplotlib.imshow(...).
521
522        Args:
523            band (str,int,float,tuple): the band name (string), index (integer) or wavelength (float) to plot. Default is 0. If a tuple is passed then
524                  each band in the tuple (string or index) will be mapped to rgb. Bands with negative wavelengths or indices will be inverted before plotting.
525            ax: an axis object to plot to. If none, plt.imshow( ... ) is used.
526            bfac (float): a brightness adjustment to apply to RGB mappings (-1 to 1)
527            cfac (float): a contrast adjustment to apply to RGB mappings (-1 to 1)
528            samples (bool): True if sample points (defined in the header file) should be plotted. Default is False. Otherwise, a list of
529                     [ (x,y), ... ] points can be passed.
530            tscale (bool): True if each band (for ternary images) should be scaled independently. Default is False.
531                    When using scaling, vmin and vmax can be used to set the clipping percentiles (integers) or
532                    (constant) values (float).
533            rot (bool): if True, the x and y axis will be flipped (90 degree rotation) before plotting. Default is False.
534            flipX (bool): if True, the x axis will be flipped before plotting (after applying rotations).
535            flipY (bool): if True, the y axis will be flippe before plotting (after applying rotations).
536            **kwds: keywords are passed to matplotlib.imshow( ... ), except for the following:
537
538                 - mask = a 2D boolean mask containing true if pixels should be drawn and false otherwise.
539                 - path = a file path to save the image too (at matching resolution; use fig.savefig(..) if you want to save the figure).
540                 - ticks = True if x- and y- ticks should be plotted. Default is False.
541                 - ps, pc = the size and color of sample points to plot. Can be constant or list.
542                 - figsize = a figsize for the figure to create (if ax is None).
543
544        Returns:
545            Tuple containing
546
547            - fig: matplotlib figure object
548            - ax:  matplotlib axes object. If a colorbar is created, (band is an integer or a float), then this will be stored in ax.cbar.
549        """
550
551        #create new axes?
552        if ax is None:
553            fig, ax = plt.subplots(figsize=kwds.pop('figsize', (18,18*self.ydim()/self.xdim()) ))
554
555        # deal with ticks
556        if not kwds.pop('ticks', False ):
557            ax.set_xticks([])
558            ax.set_yticks([])
559
560        #map individual band using colourmap
561        if isinstance(band, str) or isinstance(band, int) or isinstance(band, float):
562            #get band
563            if isinstance(band, str):
564                data = self.data[:, :, self.get_band_index(band)]
565            else:
566                data = self.data[:, :, self.get_band_index(np.abs(band))]
567            if not isinstance(band, str) and band < 0:
568                data = np.nanmax(data) - data # flip
569
570            # convert integer vmin and vmax values to percentiles
571            if 'vmin' in kwds:
572                if isinstance(kwds['vmin'], int):
573                    kwds['vmin'] = np.nanpercentile( data, kwds['vmin'] )
574            if 'vmax' in kwds:
575                if isinstance(kwds['vmax'], int):
576                    kwds['vmax'] = np.nanpercentile( data, kwds['vmax'] )
577
578            #mask nans (and apply custom mask)
579            mask = np.isnan(data)
580            if not np.isnan(self.header.get_data_ignore_value()):
581                mask = mask + data == self.header.get_data_ignore_value()
582            if 'mask' in kwds:
583                mask = mask + kwds.get('mask')
584                del kwds['mask']
585            data = np.ma.array(data, mask = mask > 0 )
586
587            # apply rotations and flipping
588            if rot:
589                data = data.T
590            if flipX:
591                data = data[::-1, :]
592            if flipY:
593                data = data[:, ::-1]
594
595            # save?
596            if 'path' in kwds:
597                path = kwds.pop('path')
598                from matplotlib.pyplot import imsave
599                if not os.path.exists(os.path.dirname(path)):
600                    os.makedirs(os.path.dirname(path)) # ensure output directory exists
601                imsave(path, data.T, **kwds)  # save the image
602
603            ax.cbar = ax.imshow(data.T, interpolation=kwds.pop('interpolation', 'none'), **kwds) # change default interpolation to None
604
605        #map 3 bands to RGB
606        elif isinstance(band, tuple) or isinstance(band, list):
607            #get band indices and range
608            rgb = []
609            for b in band:
610                if isinstance(b, str):
611                    rgb.append(self.get_band_index(b))
612                else:
613                    rgb.append(self.get_band_index(np.abs(b)))
614
615            #slice image (as copy) and map to 0 - 1
616            img = np.array(self.data[:, :, rgb]).copy()
617            if np.isnan(img).all():
618                print("Warning - image contains no data.")
619                return ax.get_figure(), ax
620
621            # invert if needed
622            for i,b in enumerate(band):
623                if not isinstance(b, str) and (b < 0):
624                    img[..., i] = np.nanmax(img[..., i]) - img[..., i]
625
626            # do scaling
627            if tscale: # scale bands independently
628                for b in range(3):
629                    mn = kwds.get("vmin", float(np.nanmin(img)))
630                    mx = kwds.get("vmax", float(np.nanmax(img)))
631                    if isinstance (mn, int):
632                        assert mn >= 0 and mn <= 100, "Error - integer vmin values must be a percentile."
633                        mn = float(np.nanpercentile(img[...,b], mn ))
634                    if isinstance (mx, int):
635                        assert mx >= 0 and mx <= 100, "Error - integer vmax values must be a percentile."
636                        mx = float(np.nanpercentile(img[...,b], mx ))
637                    img[...,b] = (img[..., b] - mn) / (mx - mn)
638            else: # scale bands together
639                mn = kwds.get("vmin", float(np.nanmin(img)))
640                mx = kwds.get("vmax", float(np.nanmax(img)))
641                if isinstance(mn, int):
642                    assert mn >= 0 and mn <= 100, "Error - integer vmin values must be a percentile."
643                    mn = float(np.nanpercentile(img, mn))
644                if isinstance(mx, int):
645                    assert mx >= 0 and mx <= 100, "Error - integer vmax values must be a percentile."
646                    mx = float(np.nanpercentile(img, mx))
647                img = (img - mn) / (mx - mn)
648
649            #apply brightness/contrast mapping
650            img = np.clip((1.0 + cfac) * img + bfac, 0, 1.0 )
651
652            #apply masking so background is white
653            img[np.logical_not( np.isfinite( img ) )] = 1.0
654            if 'mask' in kwds:
655                img[kwds.pop("mask"),:] = 1.0
656
657            # apply rotations and flipping
658            if rot:
659                img = np.transpose( img, (1,0,2) )
660            if flipX:
661                img = img[::-1, :, :]
662            if flipY:
663                img = img[:, ::-1, :]
664
665            # save?
666            if 'path' in kwds:
667                path = kwds.pop('path')
668                from matplotlib.pyplot import imsave
669                if not os.path.exists(os.path.dirname(path)):
670                    os.makedirs(os.path.dirname(path)) # ensure output directory exists
671                imsave(path, np.transpose( np.clip( img*255, 0, 255).astype(np.uint8), (1, 0, 2)))  # save the image
672
673            # plot samples?
674            ps = kwds.pop('ps', 5)
675            pc = kwds.pop('pc', 'r')
676            if samples:
677                if isinstance(samples, list) or isinstance(samples, np.ndarray):
678                    ax.scatter([s[0] for s in samples], [s[1] for s in samples], s=ps, c=pc)
679                else:
680                    for n in self.header.get_class_names():
681                        points = np.array(self.header.get_sample_points(n))
682                        ax.scatter(points[:, 0], points[:, 1], s=ps, c=pc)
683
684            #plot
685            ax.imshow(np.transpose(img, (1,0,2)), interpolation=kwds.pop('interpolation', 'none'), **kwds)
686            ax.cbar = None  # no colorbar
687
688        return ax.get_figure(), ax

Plot a band using matplotlib.imshow(...).

Arguments:
  • band (str,int,float,tuple): the band name (string), index (integer) or wavelength (float) to plot. Default is 0. If a tuple is passed then each band in the tuple (string or index) will be mapped to rgb. Bands with negative wavelengths or indices will be inverted before plotting.
  • ax: an axis object to plot to. If none, plt.imshow( ... ) is used.
  • bfac (float): a brightness adjustment to apply to RGB mappings (-1 to 1)
  • cfac (float): a contrast adjustment to apply to RGB mappings (-1 to 1)
  • samples (bool): True if sample points (defined in the header file) should be plotted. Default is False. Otherwise, a list of [ (x,y), ... ] points can be passed.
  • tscale (bool): True if each band (for ternary images) should be scaled independently. Default is False. When using scaling, vmin and vmax can be used to set the clipping percentiles (integers) or (constant) values (float).
  • rot (bool): if True, the x and y axis will be flipped (90 degree rotation) before plotting. Default is False.
  • flipX (bool): if True, the x axis will be flipped before plotting (after applying rotations).
  • flipY (bool): if True, the y axis will be flippe before plotting (after applying rotations).
  • **kwds: keywords are passed to matplotlib.imshow( ... ), except for the following:

    • mask = a 2D boolean mask containing true if pixels should be drawn and false otherwise.
    • path = a file path to save the image too (at matching resolution; use fig.savefig(..) if you want to save the figure).
    • ticks = True if x- and y- ticks should be plotted. Default is False.
    • ps, pc = the size and color of sample points to plot. Can be constant or list.
    • figsize = a figsize for the figure to create (if ax is None).
Returns:

Tuple containing

  • fig: matplotlib figure object
  • ax: matplotlib axes object. If a colorbar is created, (band is an integer or a float), then this will be stored in ax.cbar.
def createGIF(self, path, bands=None, figsize=(10, 10), fps=10, **kwds):
690    def createGIF(self, path, bands=None, figsize=(10,10), fps=10, **kwds):
691        """
692        Create and save an animated gif that loops through the bands of the image.
693
694        Args:
695            path (str): the path to save the .gif
696            bands (tuple): Tuple containing the range of band indices to draw. Default is the whole range.
697            figsize (tuple): the size of the image to draw. Default is (10,10).
698            fps (int): the framerate (frames per second) of the gif. Default is 10.
699            **kwds: keywords are passed directly to matplotlib.imshow. Use this to specify cmap etc.
700        """
701
702        frames = []
703        if bands is None:
704            bands = (0,self.band_count())
705        else:
706            assert 0 < bands[0] < self.band_count(), "Error - invalid range."
707            assert 0 < bands[1] < self.band_count(), "Error - invalid range."
708            assert bands[1] > bands[0], "Error - invalid range."
709
710        #plot frames
711        for i in range(bands[0],bands[1]):
712            fig, ax = plt.subplots(figsize=figsize)
713            ax.imshow(self.data[:, :, i], **kwds)
714            fig.canvas.draw()
715            frames.append(np.frombuffer(fig.canvas.tostring_rgb(), dtype='uint8'))
716            frames[-1] = np.reshape(frames[-1], (fig.canvas.get_width_height()[1], fig.canvas.get_width_height()[0], 3))
717            plt.close(fig)
718
719        #save gif
720        imageio.mimsave( os.path.splitext(path)[0] + ".gif", frames, fps=fps)

Create and save an animated gif that loops through the bands of the image.

Arguments:
  • path (str): the path to save the .gif
  • bands (tuple): Tuple containing the range of band indices to draw. Default is the whole range.
  • figsize (tuple): the size of the image to draw. Default is (10,10).
  • fps (int): the framerate (frames per second) of the gif. Default is 10.
  • **kwds: keywords are passed directly to matplotlib.imshow. Use this to specify cmap etc.
def mask(self, mask=None, flag=nan, invert=False, crop=False, bands=None):
723    def mask(self, mask=None, flag=np.nan, invert=False, crop=False, bands=None):
724        """
725         Apply a mask to an image, flagging masked pixels with the specified value. Note that this applies the mask to the
726         image in-situ.
727
728         Args:
729            flag (float): the value to use for masked pixels. Default is np.nan
730            mask (ndarray): a numpy array defining the mask polygon of the format [[x1,y1],[x2,y2],...]. If None is passed then
731                    pickPolygon( ... ) is used to interactively define a polygon. If a file path is passed then the polygon
732                    will be loaded using np.load( ... ). Alternatively if mask.shape == image.shape[0,1] then it is treated as a
733                    binary image mask (must be boolean) and True values will be masked across all bands. Default is None.
734            invert (bool): if True, pixels within the polygon will be masked. If False, pixels outside the polygon are masked. Default is False.
735            crop (bool): True if rows/columns containing only zeros should be removed. Default is False.
736            bands (tuple): the bands of the image to plot if no mask is specified. If None, the middle band is used.
737
738         Returns:
739            Tuple containing
740
741            - mask (ndarray): a boolean array with True where pixels are masked and False elsewhere.
742            - poly (ndarray): the mask polygon array in the format described above. Useful if the polygon was interactively defined.
743         """
744
745        if mask is None:  # pick mask interactively
746            if bands is None:
747                bands = int(self.band_count() / 2)
748
749            regions = self.pickPolygons(region_names=["mask"], bands=bands)
750
751            # the user bailed without picking a mask?
752            if len(regions) == 0:
753                print("Warning - no mask picked/applied.")
754                return
755
756            # extract polygon mask
757            mask = regions[0]
758
759        # convert polygon mask to binary mask
760        if mask.shape[1] == 2:
761
762            # build meshgrid with pixel coords
763            xx, yy = np.meshgrid(np.arange(self.xdim()), np.arange(self.ydim()))
764            xx = xx.flatten()
765            yy = yy.flatten()
766            points = np.vstack([xx, yy]).T  # coordinates of each pixel
767
768            # calculate per-pixel mask
769            mask = path.Path(mask).contains_points(points)
770            mask = mask.reshape((self.ydim(), self.xdim())).T
771
772            # flip as we want to mask (==True) outside points (unless invert is true)
773            if not invert:
774                mask = np.logical_not(mask)
775
776        # apply binary image mask
777        assert mask.shape[0] == self.data.shape[0] and mask.shape[1] == self.data.shape[1], \
778            "Error - mask shape %s does not match image shape %s" % (mask.shape, self.data.shape)
779        for b in range(self.band_count()):
780            self.data[:, :, b][mask] = flag
781
782        # crop image
783        if crop:
784            # calculate non-masked pixels
785            valid = np.logical_not(mask)
786
787            # integrate along axes
788            xdata = np.sum(valid, axis=1) > 0.0
789            ydata = np.sum(valid, axis=0) > 0.0
790
791            # calculate domain containing valid pixels
792            xmin = np.argmax(xdata)
793            xmax = xdata.shape[0] - np.argmax(xdata[::-1])
794            ymin = np.argmax(ydata)
795            ymax = ydata.shape[0] - np.argmax(ydata[::-1])
796
797            # crop
798            self.data = self.data[xmin:xmax, ymin:ymax, :]
799
800        return mask

Apply a mask to an image, flagging masked pixels with the specified value. Note that this applies the mask to the image in-situ.

Arguments:
  • flag (float): the value to use for masked pixels. Default is np.nan
  • mask (ndarray): a numpy array defining the mask polygon of the format [[x1,y1],[x2,y2],...]. If None is passed then pickPolygon( ... ) is used to interactively define a polygon. If a file path is passed then the polygon will be loaded using np.load( ... ). Alternatively if mask.shape == image.shape[0,1] then it is treated as a binary image mask (must be boolean) and True values will be masked across all bands. Default is None.
  • invert (bool): if True, pixels within the polygon will be masked. If False, pixels outside the polygon are masked. Default is False.
  • crop (bool): True if rows/columns containing only zeros should be removed. Default is False.
  • bands (tuple): the bands of the image to plot if no mask is specified. If None, the middle band is used.
Returns:

Tuple containing

  • mask (ndarray): a boolean array with True where pixels are masked and False elsewhere.
  • poly (ndarray): the mask polygon array in the format described above. Useful if the polygon was interactively defined.
def crop_to_data(self):
802    def crop_to_data(self):
803        """
804        Remove padding of nan or zero pixels from image. Note that this is performed in place.
805        """
806
807        valid = np.isfinite(self.data).any(axis=-1) & (self.data != 0).any(axis=-1)
808        ymin, ymax = np.percentile(np.argwhere(np.sum(valid, axis=0) != 0), (0, 100))
809        xmin, xmax = np.percentile(np.argwhere(np.sum(valid, axis=1) != 0), (0, 100))
810        self.data = self.data[int(xmin):int(xmax), int(ymin):int(ymax), :]  # do clipping

Remove padding of nan or zero pixels from image. Note that this is performed in place.

def pickPolygons(self, region_names, bands=0):
815    def pickPolygons(self, region_names, bands=0):
816        """
817        Creates a matplotlib gui for selecting polygon regions in an image.
818
819        Args:
820            names (list, str): a list containing the names of the regions to pick. If a string is passed only one name is used.
821            bands (tuple): the bands of the image to plot.
822        """
823
824        if isinstance(region_names, str):
825            region_names = [region_names]
826
827        assert isinstance(region_names, list), "Error - names must be a list or a string."
828
829        # set matplotlib backend
830        backend = matplotlib.get_backend()
831        matplotlib.use('Qt5Agg')  # need this backend for ROIPoly to work
832
833        # plot image and extract roi's
834        fig, ax = self.quick_plot(bands)
835        roi = MultiRoi(roi_names=region_names)
836        plt.close(fig)  # close figure
837
838        # extract regions
839        regions = []
840        for name, r in roi.rois.items():
841            # store region
842            x = r.x
843            y = r.y
844            regions.append(np.vstack([x, y]).T)
845
846        # restore matplotlib backend (if possible)
847        try:
848            matplotlib.use(backend)
849        except:
850            print("Warning: could not reset matplotlib backend. Plots will remain interactive...")
851            pass
852
853        return regions

Creates a matplotlib gui for selecting polygon regions in an image.

Arguments:
  • names (list, str): a list containing the names of the regions to pick. If a string is passed only one name is used.
  • bands (tuple): the bands of the image to plot.
def pickPoints( self, n=-1, bands=(680.0, 550.0, 505.0), integer=True, title='Pick Points', **kwds):
855    def pickPoints(self, n=-1, bands=hylite.RGB, integer=True, title="Pick Points", **kwds):
856        """
857        Creates a matplotlib gui for picking pixels from an image.
858
859        Args:
860            n (int): the number of pixels to pick, or -1 if the user can select as many as they wish. Default is -1.
861            bands (tuple): the bands of the image to plot. Default is HyImage.RGB
862            integer (bool): True if points coordinates should be cast to integers (for use as indices). Default is True.
863            title (str): The title of the point picking window.
864            **kwds: Keywords are passed to HyImage.quick_plot( ... ).
865
866        Returns:
867            A list containing the picked point coordinates [ (x1,y1), (x2,y2), ... ].
868        """
869
870        # set matplotlib backend
871        backend = matplotlib.get_backend()
872        matplotlib.use('Qt5Agg')  # need this backend for ROIPoly to work
873
874        # create figure
875        fig, ax = self.quick_plot( bands, **kwds )
876        ax.set_title(title)
877
878        # get points
879        points = fig.ginput( n )
880
881        if integer:
882            points = [ (int(p[0]), int(p[1])) for p in points ]
883
884        # restore matplotlib backend (if possible)
885        try:
886            matplotlib.use(backend)
887        except:
888            print("Warning: could not reset matplotlib backend. Plots will remain interactive...")
889            pass
890
891        return points

Creates a matplotlib gui for picking pixels from an image.

Arguments:
  • n (int): the number of pixels to pick, or -1 if the user can select as many as they wish. Default is -1.
  • bands (tuple): the bands of the image to plot. Default is HyImage.RGB
  • integer (bool): True if points coordinates should be cast to integers (for use as indices). Default is True.
  • title (str): The title of the point picking window.
  • **kwds: Keywords are passed to HyImage.quick_plot( ... ).
Returns:

A list containing the picked point coordinates [ (x1,y1), (x2,y2), ... ].

def pickSamples(self, names=None, store=True, **kwds):
893    def pickSamples(self, names=None, store=True, **kwds):
894        """
895        Pick sample probe points and store these in the image header file.
896
897        Args:
898            names (str, list): the name of the sample to pick, or a list of names to pick multiple.
899            store (bool): True if sample should be stored in the image header file (for later access). Default is True.
900            **kwds: Keywords are passed to HyImage.quick_plot( ... )
901
902        Returns:
903            a list containing a list of points for each sample.
904        """
905
906        if isinstance(names, str):
907            names = [names]
908
909        # pick points
910        points = []
911        for s in names:
912            pnts = self.pickPoints(title="%s" % s, **kwds)
913            if store:
914                self.header['sample %s' % s] = pnts # store in header
915            points.append(pnts)
916        # add class to header file
917        if store:
918            cls_names = self.header.get_class_names()
919            if cls_names is None:
920                cls_names = []
921            self.header['class names'] = cls_names + names
922
923        return points

Pick sample probe points and store these in the image header file.

Arguments:
  • names (str, list): the name of the sample to pick, or a list of names to pick multiple.
  • store (bool): True if sample should be stored in the image header file (for later access). Default is True.
  • **kwds: Keywords are passed to HyImage.quick_plot( ... )
Returns:

a list containing a list of points for each sample.