hylite documentation|Stable (master)Development (dev)

hylite.hyimage

Store and manipulate hyperspectral image data.

   1"""
   2Store and manipulate hyperspectral image data.
   3"""
   4
   5import os
   6import numpy as np
   7import matplotlib
   8import matplotlib.pyplot as plt
   9from matplotlib import path
  10from roipoly import MultiRoi
  11import imageio
  12import scipy as sp
  13import hylite
  14from hylite.hydata import HyData
  15from hylite.hylibrary import HyLibrary
  16
  17
  18
  19class HyImage( HyData ):
  20    """
  21    A class for hyperspectral image data. These can be individual scenes or hyperspectral orthoimages.
  22    """
  23
  24    def __init__(self, data, **kwds):
  25        """
  26        Args:
  27            data (ndarray): a numpy array such that data[x][y][band] gives each pixel value.
  28            **kwds:
  29                wav = A numpy array containing band wavelengths for this image.
  30                affine = an affine transform of the format returned by GDAL.GetGeoTransform().
  31                projection = string defining the project. Default is None.
  32                sensor = sensor name. Default is "unknown".
  33                header = path to associated header file. Default is None.
  34        """
  35
  36        #call constructor for HyData
  37        super().__init__(data, **kwds)
  38
  39        # special case - if dataset only has oneband, slice it so it still has
  40        # the format data[x,y,b].
  41        if not self.data is None:
  42            if len(self.data.shape) == 1:
  43                self.data = self.data[None, None, :] # single pixel image
  44            if len(self.data.shape) == 2:
  45                self.data = self.data[:, :, None] # single band iamge
  46
  47        #load any additional project information (specific to images)
  48        self.set_projection(kwds.get("projection",None))
  49        self.affine = np.array( kwds.get("affine",[0,1,0,0,0,1]) )
  50        self.header['affine'] = kwds.get("affine",[0,1,0,0,0,1]) # also store this here
  51
  52        # wavelengths
  53        if 'wav' in kwds:
  54            self.set_wavelengths(kwds['wav'])
  55
  56        #special header formatting
  57        self.header['file type'] = 'ENVI Standard'
  58
  59    def copy(self,data=True):
  60        """
  61        Make a deep copy of this image instance.
  62
  63        Args:
  64            data (bool): True if a copy of the data should be made, otherwise only copy header.
  65
  66        Returns:
  67            a new HyImage instance.
  68        """
  69        if not data:
  70            return HyImage(None, header=self.header.copy(), projection=self.projection, affine=self.affine)
  71        else:
  72            return HyImage( self.data.copy(), header=self.header.copy(), projection=self.projection, affine=self.affine)
  73
  74    def T(self):
  75        """
  76        Return a transposed view of the data matrix (corresponding with the [y,x] indexing used by matplotlib, opencv etc.
  77        """
  78        return np.transpose(self.data, (1,0,2))
  79
  80    def xdim(self):
  81        """
  82        Return number of pixels in x (first dimension of data array)
  83        """
  84        return self.data.shape[0]
  85
  86    def ydim(self):
  87        """
  88        Return number of pixels in y (second dimension of data array)
  89        """
  90        return self.data.shape[1]
  91
  92    def aspx(self):
  93        """
  94        Return the aspect ratio of this image (width/height).
  95        """
  96        return self.ydim() / self.xdim()
  97    
  98    #####################################
  99    ## GEOREFERENCING METHODS
 100    #####################################
 101
 102    def get_extent(self):
 103        """
 104        Returns the width and height of this image in world coordinates.
 105
 106        Returns:
 107            tuple with (width, height).
 108        """
 109        return self.xdim * self.pixel_size[0], self.ydim * self.pixel_size[1]
 110
 111    def set_projection(self,proj):
 112        """
 113        Set this project to an existing osgeo.osr.SpatialReference or GDAL georeference string.
 114
 115        Args:
 116            proj (str, osgeo.osr.SpatialReference): the project to use as osgeo.osr.SpatialReference or GDAL georeference string.
 117        """
 118        if proj is None:
 119            self.projection = None
 120        else:
 121            try:
 122                from osgeo.osr import SpatialReference
 123            except:
 124                assert False, "Error - GDAL must be installed to work with spatial projections in hylite."
 125            if isinstance(proj, SpatialReference):
 126                self.projection = proj
 127            elif isinstance(proj, str):
 128                self.projection = SpatialReference(proj)
 129            else:
 130                print("Invalid project %s" % proj)
 131                raise
 132
 133    def set_projection_EPSG(self,EPSG):
 134        """
 135        Sets this image project using an EPSG code.
 136
 137        Args:
 138            EPSG (str): string EPSG code that can be passed to SpatialReference.SetFromUserInput(...).
 139        """
 140
 141        try:
 142            from osgeo.osr import SpatialReference
 143        except:
 144            assert False, "Error - GDAL must be installed to work with spatial projections in hylite."
 145
 146        self.projection = SpatialReference()
 147        self.projection.SetFromUserInput(EPSG)
 148
 149    def get_projection_EPSG(self):
 150        """
 151        Gets a string describing this projections EPSG code (if it is an EPSG project).
 152
 153        Returns:
 154            an EPSG code string of the format "EPSG:XXXX".
 155        """
 156        if self.projection is None:
 157            return None
 158        else:
 159            return "%s:%s" % (self.projection.GetAttrValue("AUTHORITY",0),self.projection.GetAttrValue("AUTHORITY",1))
 160
 161    def pix_to_world(self, px, py, proj=None):
 162        """
 163        Take pixel coordinates and return world coordinates
 164
 165        Args:
 166            px (int): the pixel x-coord.
 167            py (int): the pixel y-coord.
 168            proj (str, osr.SpatialReference): the coordinate system to use. Default (None) uses the same system as this image. Otherwise
 169                   an osr.SpatialReference can be passed (HyImage.project), or an EPSG string (e.g. get_projection_EPSG(...)).
 170        Returns:
 171            the world coordinates in the coordinate system defined by get_projection_EPSG(...).
 172        """
 173
 174        try:
 175            from osgeo import osr
 176            import osgeo.gdal as gdal
 177            from osgeo import ogr
 178        except:
 179            assert False, "Error - GDAL must be installed to work with spatial projections in hylite."
 180
 181        # parse project
 182        if proj is None:
 183            proj = self.projection
 184        elif isinstance(proj, str) or isinstance(proj, int):
 185            epsg = proj
 186            if isinstance(epsg, str):
 187                try:
 188                    epsg = int(str.split(':')[1])
 189                except:
 190                    assert False, "Error - %s is an invalid EPSG code." % proj
 191            proj = osr.SpatialReference()
 192            proj.ImportFromEPSG(epsg)
 193
 194        # check we have all the required info
 195        assert isinstance(proj, osr.SpatialReference), "Error - invalid spatial reference %s" % proj
 196        assert (not self.affine is None) and (
 197            not self.projection is None), "Error - project information is undefined."
 198
 199        #project to world coordinates in this images project/world coords
 200        x,y = gdal.ApplyGeoTransform(self.affine, px, py)
 201
 202        #project to target coords (if different)
 203        if not proj.IsSameGeogCS(self.projection):
 204            P = ogr.Geometry(ogr.wkbPoint)
 205            if proj.EPSGTreatsAsNorthingEasting():
 206                P.AddPoint(x, y)
 207            else:
 208                P.AddPoint(y, x)
 209            P.AssignSpatialReference(self.projection)  # tell the point what coordinates it's in
 210            P.TransformTo(proj)  # reproject it to the out spatial reference
 211            x, y = P.GetX(), P.GetY()
 212
 213            #do we need to transpose?
 214            if proj.EPSGTreatsAsLatLong():
 215                x,y=y,x #we want lon,lat not lat,lon
 216        return x, y
 217
 218    def world_to_pix(self, x, y, proj = None):
 219        """
 220        Take world coordinates and return pixel coordinates
 221
 222        Args:
 223            x (float): the world x-coord.
 224            y (float): the world y-coord.
 225            proj (str, osr.SpatialReference): the coordinate system of the input coordinates. Default (None) uses the same system as this image. Otherwise
 226                   an osr.SpatialReference can be passed (HyImage.project), or an EPSG string (e.g. get_projection_EPSG(...)).
 227
 228        Returns:
 229            the pixel coordinates based on the affine transform stored in self.affine.
 230        """
 231
 232        try:
 233            from osgeo import osr
 234            import osgeo.gdal as gdal
 235            from osgeo import ogr
 236        except:
 237            assert False, "Error - GDAL must be installed to work with spatial projections in hylite."
 238
 239        # parse project
 240        if proj is None:
 241            proj = self.projection
 242        elif isinstance(proj, str) or isinstance(proj, int):
 243            epsg = proj
 244            if isinstance(epsg, str):
 245                try:
 246                    epsg = int(str.split(':')[1])
 247                except:
 248                    assert False, "Error - %s is an invalid EPSG code." % proj
 249            proj = osr.SpatialReference()
 250            proj.ImportFromEPSG(epsg)
 251
 252
 253        # check we have all the required info
 254        assert isinstance(proj, osr.SpatialReference), "Error - invalid spatial reference %s" % proj
 255        assert (not self.affine is None) and (not self.projection is None), "Error - project information is undefined."
 256
 257        # project to this images CS (if different)
 258        if not proj.IsSameGeogCS(self.projection):
 259            P = ogr.Geometry(ogr.wkbPoint)
 260            if proj.EPSGTreatsAsNorthingEasting():
 261                P.AddPoint(x, y)
 262            else:
 263                P.AddPoint(y, x)
 264            P.AssignSpatialReference(proj)  # tell the point what coordinates it's in
 265            P.AddPoint(x, y)
 266            P.TransformTo(self.projection)  # reproject it to the out spatial reference
 267            x, y = P.GetX(), P.GetY()
 268            if self.projection.EPSGTreatsAsLatLong(): # do we need to transpose?
 269                x, y = y, x  # we want lon,lat not lat,lon
 270
 271        inv = gdal.InvGeoTransform(self.affine)
 272        assert not inv is None, "Error - could not invert affine transform?"
 273
 274        #apply
 275        return gdal.ApplyGeoTransform(inv, x, y)
 276
 277    def crop(self, xmin, xmax, ymin, ymax, bands=None):
 278        """
 279        Return a cropped copy of this image.
 280
 281        Args:
 282            xmin, xmax (int): pixel bounds in x (rows)
 283            ymin, ymax (int): pixel bounds in y (columns)
 284            bands (None, list, tuple): optional band indices or (min,max) range
 285
 286        Returns:
 287            HyImage: cropped image with updated affine transform
 288        """
 289
 290        # ---- validate bounds ----
 291        xmin = int(max(0, xmin))
 292        ymin = int(max(0, ymin))
 293        xmax = int(min(self.xdim(), xmax))
 294        ymax = int(min(self.ydim(), ymax))
 295
 296        assert xmin < xmax and ymin < ymax, "Invalid crop extent."
 297
 298        # ---- crop data ----
 299        if bands is None:
 300            data = self.data[xmin:xmax, ymin:ymax, :].copy()
 301            wav = self.get_wavelengths()
 302        else: # band selection
 303            if isinstance(bands, tuple):
 304                b0 = self.get_band_index(bands[0])
 305                b1 = self.get_band_index(bands[1])
 306                data = self.data[xmin:xmax, ymin:ymax, b0:b1].copy()
 307                wav = self.get_wavelengths()[b0:b1]
 308            else:
 309                idx = [self.get_band_index(b) for b in bands]
 310                data = self.data[xmin:xmax, ymin:ymax, idx].copy()
 311                wav = self.get_wavelengths()[idx]
 312
 313        # ---- update affine transform ----
 314        if self.affine is not None:
 315            a = list(self.affine)
 316            new_affine = a.copy()
 317
 318            # shift origin to new top-left pixel
 319            new_affine[0] = a[0] + xmin*a[1] + ymin*a[2]
 320            new_affine[3] = a[3] + xmin*a[4] + ymin*a[5]
 321        else:
 322            new_affine = None
 323
 324        # ---- construct output image ----
 325        out = HyImage(
 326            data,
 327            header=self.header.copy(),
 328            projection=self.projection,
 329            affine=new_affine,
 330            wav=wav
 331        )
 332
 333        return out
 334
 335    def resize(self, newdims: tuple, interpolation: int = 1):
 336        """
 337        Resize this image with opencv and update affine transform accordingly.
 338
 339        Args:
 340            newdims (tuple): the new image dimensions (xdim, ydim)
 341            interpolation (int): opencv interpolation method. Default is cv2.INTER_LINEAR.
 342        """
 343        import cv2  # avoid import issues if opencv is missing
 344
 345        old_x, old_y = self.xdim(), self.ydim()
 346        new_x, new_y = int(newdims[0]), int(newdims[1])
 347
 348        assert new_x > 0 and new_y > 0, "Invalid resize dimensions."
 349
 350        # resize data (opencv uses width, height = y, x)
 351        self.data = cv2.resize(
 352            self.data,
 353            (new_y, new_x),
 354            interpolation=interpolation
 355        )
 356
 357        # update affine transform
 358        if self.affine is not None:
 359            a = list(self.affine)
 360
 361            sx = old_x / new_x
 362            sy = old_y / new_y
 363
 364            self.affine = [
 365                a[0],          # x origin unchanged
 366                a[1] * sx,     # pixel width
 367                a[2] * sy,     # row rotation
 368                a[3],          # y origin unchanged
 369                a[4] * sx,     # column rotation
 370                a[5] * sy      # pixel height
 371            ]
 372
 373    def tile(self, tile_size):
 374        """
 375        Break image into tiles of given size and return a list of HyImage tiles.
 376        Each tile has an updated affine transform reflecting its position in the original image.
 377
 378        Args:
 379            tile_size (tuple): (tile_x, tile_y) in pixels
 380        Returns:
 381            list of HyImage
 382        """
 383        tiles = []
 384        tx, ty = tile_size
 385        nx, ny = self.xdim(), self.ydim()
 386        for i in range(0, nx, tx):
 387            for j in range(0, ny, ty):
 388                data_tile = self.data[i:min(i+tx, nx), j:min(j+ty, ny), :].copy() # slice data
 389
 390                # compute new affine
 391                # affine = [x0, dx, rotx, y0, roty, dy]
 392                x0, dx, rotx, y0, roty, dy = self.affine
 393                new_x0 = x0 + i*dx + j*rotx
 394                new_y0 = y0 + i*roty + j*dy
 395                new_affine = [new_x0, dx, rotx, new_y0, roty, dy]
 396
 397                # create tile
 398                tile_img = HyImage(
 399                    data_tile,
 400                    affine=new_affine,
 401                    projection=self.projection,
 402                    wav=self.get_wavelengths(),
 403                    header=self.header.copy()
 404                )
 405                tile_img.header['xleft'] = i
 406                tile_img.header['ytop'] = j
 407                tiles.append(tile_img)
 408
 409        return tiles
 410
 411    @staticmethod
 412    def mosaic(
 413        tiles,
 414        blend="mean",
 415        resampling="nearest",
 416        out_affine=None,
 417        out_shape=None,
 418    ):
 419        """
 420        Mosaic georeferenced HyImage tiles using GDAL. Note that this assumes all tiles are in the same coordinate system.
 421
 422        Args:
 423            tiles (list[HyImage])
 424            blend (str): 'first', 'min', 'max', 'mean', 'median'
 425            resampling (str): 'nearest', 'bilinear', 'cubic'
 426            out_affine (list): optional 6-element affine to define output grid. If None, the affine of the first tile is used.
 427            out_shape (tuple): optional (xdim, ydim) shape of the output grid. If None, the extent of all tiles will be used.
 428        Returns:
 429            HyImage
 430        """
 431        import numpy as np
 432        from hylite.project.align import resample_raster
 433        from osgeo import gdal, osr
 434
 435        assert len(tiles) > 0
 436        assert blend in ("first", "min", "max", "mean", "median")
 437
 438        # compute bounds in world coordinates
 439        # N.B. THIS ASSUMES ALL DATA ARE IN THE SAME CRS
 440        points = []
 441        for t in tiles:
 442            points.append( t.pix_to_world(0,0) )
 443            points.append( t.pix_to_world(t.xdim()+1,t.ydim()+1) )
 444        min_x, min_y = np.min(points, axis=0)
 445        max_x, max_y = np.max(points, axis=0)
 446        if out_shape is None:
 447            out_shape = (np.array(tiles[0].world_to_pix(max_x, max_y)) + np.array(tiles[0].world_to_pix(min_x, min_y))).round().astype(int)
 448        if out_affine is None:
 449            out_affine = list(tiles[0].affine)
 450            out_affine[0] = min_x
 451            out_affine[3] = max_y
 452
 453        if blend == "first": # fill output, first come, first served.
 454            out = np.full( tuple(out_shape) + (tiles[0].band_count(),), np.nan, dtype=np.float32)
 455            for t in tiles:
 456                r = resample_raster( t.data, t.affine, out_affine, out_shape )
 457                mask = np.isnan(out)
 458                out[mask] = r[mask]
 459        elif blend == "min": # keep minimum value in case of overlap
 460            out = np.nanmin( np.stack([ resample_raster( t.data, t.affine, out_affine, out_shape ) for t in tiles ], 
 461                                      axis=0), axis=0 )
 462        elif blend == "max": # keep maximum value in case of overlap
 463            out = np.nanmax( np.stack([ resample_raster( t.data, t.affine, out_affine, out_shape ) for t in tiles ], 
 464                                      axis=0), axis=0 )
 465        elif blend == "mean": # use average in case of overlap
 466            out = np.nanmean( np.stack([ resample_raster( t.data, t.affine, out_affine, out_shape ) for t in tiles ], 
 467                                       axis=0), axis=0 )
 468        elif blend == "median": # use mean in case of overlap
 469            out = np.nanmedian( np.stack([ resample_raster( t.data, t.affine, out_affine, out_shape ) for t in tiles ], 
 470                                         axis=0), axis=0 )
 471
 472        # Return HyImage
 473        out =  HyImage(
 474            out,
 475            affine=out_affine,
 476            projection=tiles[0].projection,
 477            wav=tiles[0].get_wavelengths(),
 478            header=tiles[0].header.copy()
 479        )
 480        if 'xleft' in out.header: del out.header['xleft'] # stored in tiles, but not meaningful here
 481        if 'ytop' in out.header: del out.header['ytop'] # stored in tiles, but not meaningful here
 482
 483        return out
 484
 485    #####################################
 486    ## BASIC TRANSFORMS
 487    #####################################
 488
 489    def flip(self, axis='x'):
 490        """
 491        Flip the image on the x or y axis. Note that this will remove any defined affine transform.
 492
 493        Args:
 494            axis (str): 'x' or 'y' or both 'xy'.
 495        """
 496
 497        if 'x' in axis.lower():
 498            self.data = np.flip(self.data,axis=0)
 499        if 'y' in axis.lower():
 500            self.data = np.flip(self.data,axis=1)
 501        self.affine = None
 502        if 'affine' in self.header: del self.header['affine']
 503        self.push_to_header() # update width and height info
 504
 505    def rot90(self):
 506        """
 507        Rotate this image by 90 degrees by transposing the underlying data array. Combine with flip('x') or flip('y')
 508        to achieve positive/negative rotations.
 509        """
 510        self.data = np.transpose( self.data, (1,0,2) )
 511        self.affine = None
 512        if 'affine' in self.header: del self.header['affine']
 513        self.push_to_header() # update width and height info
 514
 515    #####################################
 516    ##IMAGE FILTERING
 517    #####################################
 518    def fill_holes(self):
 519        """
 520        Replaces nan pixel with an average of their neighbours, thus removing 1-pixel large holes from an image. Note that
 521        for performance reasons this assumes that holes line up across bands. Note that this is not vectorized so very slow...
 522        """
 523
 524        # perform greyscale dilation
 525        dilate = self.data.copy()
 526        mask = np.logical_not(np.isfinite(dilate))
 527        dilate[mask] = 0
 528        for b in range(self.band_count()):
 529            dilate[:, :, b] = sp.ndimage.grey_dilation(dilate[:, :, b], size=(3, 3))
 530
 531        # map back to holes in dataset
 532        self.data[mask] = dilate[mask]
 533        #self.data[self.data == 0] = np.nan  # replace remaining 0's with nans
 534
 535    def blur(self, n=3):
 536        """
 537        Applies a gaussian kernel of size n to the image using OpenCV.
 538
 539        Args:
 540            n (int): the dimensions of the gaussian kernel to convolve. Default is 3. Increase for more blurry results.
 541        """
 542        import cv2 # import this here to avoid errors if opencv is not installed properly
 543
 544        nanmask = np.isnan(self.data)
 545        assert isinstance(n, int) and n >= 3, "Error - invalid kernel. N must be an integer > 3. "
 546        kernel = np.ones((n, n), np.float32) / (n ** 2)
 547        self.data = cv2.filter2D(self.data, -1, kernel)
 548        self.data[nanmask] = np.nan  # remove mask
 549
 550    def erode(self, size=3, iterations=1):
 551        """
 552        Apply an erode filter to this image to expand background (nan) pixels. Refer to open-cv's erode
 553        function for more details.
 554
 555        Args:
 556            size (int): the size of the erode filter. Default is a 3x3 kernel.
 557            iterations (int): the number of erode iterations. Default is 1.
 558        """
 559        import cv2 # import this here to avoid errors if opencv is not installed properly
 560
 561        # erode
 562        kernel = np.ones((size, size), np.uint8)
 563        if self.is_float():
 564            mask = np.isfinite(self.data).any(axis=-1)
 565            mask = cv2.erode(mask.astype(np.uint8), kernel, iterations=iterations)
 566            self.data[mask == 0, :] = np.nan
 567        else:
 568            mask = (self.data != 0).any( axis=-1 )
 569            mask = cv2.erode(mask.astype(np.uint8), kernel, iterations=iterations)
 570            self.data[mask == 0, :] = 0
 571
 572    def despeckle(self, size=5):
 573        """
 574        Despeckle each band of this image (independently) using a median filter.
 575
 576        Args:
 577            size (int): the size of the median filter kernel. Default is 5. Must be an odd number.
 578        """
 579
 580        assert (size % 2) == 1, "Error - size must be an odd integer"
 581        import cv2 # import this here to avoid errors if opencv is not installed properly
 582        if self.is_float():
 583            self.data = cv2.medianBlur( self.data.astype(np.float32), size )
 584        else:
 585            self.data = cv2.medianBlur( self.data, size )
 586
 587    #####################################
 588    ##FEATURES AND FEATURE MATCHING
 589    ######################################
 590    def get_keypoints(self, band, eq=False, mask=True, method='sift', cfac=0.0,bfac=0.0, **kwds):
 591        """
 592        Get feature descriptors from the specified band.
 593
 594        Args:
 595            band (int,float,str,tuple): the band index (int) or wavelength (float) to extract features from. Alternatively, a tuple can be passed
 596                    containing a range of bands (min : max) to average before feature matching.
 597            eq (bool): True if the image should be histogram equalized first. Default is False.
 598            mask (bool): True if 0 value pixels should be masked. Default is True.
 599            method (str): the feature detector to use. Options are 'SIFT' and 'ORB' (faster but less accurate). Default is 'SIFT'.
 600            cfac (float): contrast adjustment to apply to hyperspectral bands before matching. Default is 0.0.
 601            bfac (float): brightness adjustment to apply to hyperspectral bands before matching. Default is 0.0.
 602            **kwds: keyword arguments are passed to the opencv feature detector. For SIFT these are:
 603
 604                - contrastThreshold: default is 0.01.
 605                - edgeThreshold: default is 10.
 606                - sigma: default is 1.0
 607
 608                For ORB these are:
 609
 610                - nfeatures = the number of features to detect. Default is 5000.
 611
 612            Returns:
 613                Tuple containing
 614
 615                    - k (ndarray): the keypoints detected
 616                    - d (ndarray): corresponding feature descriptors
 617         """
 618        import cv2 # import this here to avoid errors if opencv is not installed properly
 619
 620        # get image
 621        if isinstance(band, int) or isinstance(band, float): #single band
 622            image = self.data[:, :, self.get_band_index(band)]
 623        elif isinstance(band,tuple): #range of bands (averaged)
 624            idx0 = self.get_band_index(band[0])
 625            idx1 = self.get_band_index(band[1])
 626
 627            #deal with out of range errors
 628            if idx0 is None:
 629                idx0 = 0
 630            if idx1 is None:
 631                idx1 = self.band_count()
 632
 633            #average bands
 634            image = np.nanmean(self.data[:,:,idx0:idx1],axis=2)
 635        else:
 636            assert False, "Error, unrecognised band %s" % band
 637
 638        #normalise image to range 0 - 1
 639        image -= np.nanmin(image)
 640        image = image / np.nanmax(image)
 641
 642        #apply brightness/contrast adjustment
 643        image = (1.0+cfac)*image + bfac
 644        image[image > 1.0] = 1.0
 645        image[image < 0.0] = 0.0
 646
 647        #convert image to uint8 for opencv
 648        image = np.uint8(255 * image)
 649        if eq:
 650            image = cv2.equalizeHist(image)
 651
 652        if mask:
 653            mask = np.zeros(image.shape, dtype=np.uint8)
 654            mask[image != 0] = 255  # include only non-zero pixels
 655        else:
 656            mask = None
 657
 658        if 'sift' in method.lower():  # SIFT
 659
 660            # setup default keywords
 661            kwds["contrastThreshold"] = kwds.get("contrastThreshold", 0.01)
 662            kwds["edgeThreshold"] = kwds.get("edgeThreshold", 10)
 663            kwds["sigma"] = kwds.get("sigma", 1.0)
 664
 665            # make feature detector
 666            #alg = cv2.xfeatures2d.SIFT_create(**kwds)
 667            alg = cv2.SIFT_create()
 668        elif 'orb' in method.lower():  # orb
 669            kwds['nfeatures'] = kwds.get('nfeatures', 5000)
 670            alg = cv2.ORB_create(scoreType=cv2.ORB_FAST_SCORE, **kwds)
 671        else:
 672            assert False, "Error - %s is not a recognised feature detector." % method
 673
 674        # detect keypoints
 675        kp = alg.detect(image, mask)
 676
 677        # extract and return feature vectors
 678        return alg.compute(image, kp)
 679
 680    @classmethod
 681    def match_keypoints(cls, kp1, kp2, d1, d2, method='SIFT', dist=0.7, tree = 5, check = 100, min_count=5):
 682        """
 683        Compares keypoint feature vectors from two images and returns matching pairs.
 684
 685        Args:
 686            kp1 (ndarray): keypoints from the first image
 687            kp2 (ndarray): keypoints from the second image
 688            d1 (ndarray): descriptors for the keypoints from the first image
 689            d2 (ndarray): descriptors for the keypoints from the second image
 690            method (str): the method used to calculate the feature descriptors. Should be 'sift' or 'orb'. Default is 'sift'.
 691            dist (float): minimum match distance (0 to 1), default is 0.7
 692            tree (int): not sure what this does? Default is 5. See open-cv docs.
 693            check (int): ditto. Default is 100.
 694            min_count (int): the minimum number of matches to consider a valid matching operation. If fewer matches are found,
 695                       then the function returns None, None. Default is 5.
 696        """
 697        import cv2 # import this here to avoid errors if opencv is not installed properly
 698        if 'sift' in method.lower():
 699            algorithm = cv2.NORM_INF
 700        elif 'orb' in method.lower():
 701            algorithm = cv2.NORM_HAMMING
 702        else:
 703            assert False, "Error - unknown matching algorithm %s" % method
 704
 705        #calculate flann matches
 706        index_params = dict(algorithm=algorithm, trees=tree)
 707        search_params = dict(checks=check)
 708        flann = cv2.FlannBasedMatcher(index_params, search_params)
 709        matches = flann.knnMatch(d1, d2, k=2)
 710
 711        # store all the good matches as per Lowe's ratio test.
 712        good = []
 713        for m, n in matches:
 714            if m.distance < dist * n.distance:
 715                good.append(m)
 716
 717        if len(good) < min_count:
 718            return None, None
 719        else:
 720            src_pts = np.float32([kp1[m.queryIdx].pt for m in good]).reshape(-1, 1, 2)
 721            dst_pts = np.float32([kp2[m.trainIdx].pt for m in good]).reshape(-1, 1, 2)
 722            return src_pts, dst_pts
 723
 724    ############################
 725    ## Visualisation methods
 726    ############################
 727    def quick_plot(self, bands=0, ax=None, bfac=0.0, cfac=0.0, samples=False, tscale=False, invert=False, rot=False, flipX=False, flipY=False,
 728                   **kwds):
 729        """
 730        Plot a band using matplotlib.imshow(...).
 731
 732        Args:
 733            bands (str,int,float,tuple): the band name (string), index (integer) or wavelength (float) to plot. Default is 0. If a tuple is passed then
 734                  each band in the tuple (string or index) will be mapped to rgb. Bands with negative wavelengths or indices will be inverted before plotting.
 735            ax: an axis object to plot to. If none, plt.imshow( ... ) is used.
 736            bfac (float): a brightness adjustment to apply to RGB mappings (-1 to 1)
 737            cfac (float): a contrast adjustment to apply to RGB mappings (-1 to 1)
 738            samples (bool): True if sample points (defined in the header file) should be plotted. Default is False. Otherwise, a list of
 739                     [ (x,y), ... ] points can be passed.
 740            tscale (bool): True if each band (for ternary images) should be scaled independently. Default is False.
 741                    When using scaling, vmin and vmax can be used to set the clipping percentiles (integers) or
 742                    (constant) values (float).
 743            invert (bool) : True if each band should be inverted before plotting. Only works for multiband (ternary) images.
 744            rot (bool): if True, the x and y axis will be flipped (90 degree rotation) before plotting. Default is False.
 745            flipX (bool): if True, the x axis will be flipped before plotting (after applying rotations).
 746            flipY (bool): if True, the y axis will be flippe before plotting (after applying rotations).
 747            **kwds: keywords are passed to matplotlib.imshow( ... ), except for the following:
 748
 749                 - mask = a 2D boolean mask containing true if pixels should be drawn and false otherwise.
 750                 - path = a file path to save the image too (at matching resolution; use fig.savefig(..) if you want to save the figure).
 751                 - ticks = True if x- and y- ticks should be plotted. Default is False.
 752                 - ps, pc = the size and color of sample points to plot. Can be constant or list.
 753                 - figsize = a figsize for the figure to create (if ax is None).
 754
 755        Returns:
 756            Tuple containing
 757
 758            - fig: matplotlib figure object
 759            - ax:  matplotlib axes object. If a colorbar is created, (band is an integer or a float), then this will be stored in ax.cbar.
 760        """
 761
 762        #create new axes?
 763        if ax is None:
 764            fig, ax = plt.subplots(figsize=kwds.pop('figsize', (18,18*self.ydim()/self.xdim()) ))
 765
 766        # deal with ticks
 767        if not kwds.pop('ticks', False ):
 768            ax.set_xticks([])
 769            ax.set_yticks([])
 770
 771        #map individual band using colourmap
 772        if isinstance(bands, str) or isinstance(bands, int) or isinstance(bands, float):
 773            #get band
 774            if isinstance(bands, str):
 775                data = self.data[:, :, self.get_band_index(bands)]
 776            else:
 777                data = self.data[:, :, self.get_band_index(np.abs(bands))]
 778            if not isinstance(bands, str) and bands < 0:
 779                data = np.nanmax(data) - data # flip
 780
 781            # convert integer vmin and vmax values to percentiles
 782            if 'vmin' in kwds:
 783                if isinstance(kwds['vmin'], int):
 784                    kwds['vmin'] = np.nanpercentile( data, kwds['vmin'] )
 785            if 'vmax' in kwds:
 786                if isinstance(kwds['vmax'], int):
 787                    kwds['vmax'] = np.nanpercentile( data, kwds['vmax'] )
 788
 789            #mask nans (and apply custom mask)
 790            mask = np.isnan(data)
 791            if not np.isnan(self.header.get_data_ignore_value()):
 792                mask = mask + data == self.header.get_data_ignore_value()
 793            if 'mask' in kwds:
 794                mask = mask + kwds.get('mask')
 795                del kwds['mask']
 796            data = np.ma.array(data, mask = mask > 0 )
 797
 798            # apply rotations and flipping
 799            if rot:
 800                data = data.T
 801            if flipX:
 802                data = data[::-1, :]
 803            if flipY:
 804                data = data[:, ::-1]
 805
 806            # save?
 807            if 'path' in kwds:
 808                path = kwds.pop('path')
 809                from matplotlib.pyplot import imsave
 810                if not os.path.exists(os.path.dirname(path)):
 811                    os.makedirs(os.path.dirname(path)) # ensure output directory exists
 812                imsave(path, data.T, **kwds)  # save the image
 813
 814            ax.cbar = ax.imshow(data.T, interpolation=kwds.pop('interpolation', 'none'), **kwds) # change default interpolation to None
 815
 816        #map 3 bands to RGB
 817        elif isinstance(bands, tuple) or isinstance(bands, list):
 818            #get band indices and range
 819            rgb = []
 820            for b in bands:
 821                if isinstance(b, str):
 822                    rgb.append(self.get_band_index(b))
 823                else:
 824                    rgb.append(self.get_band_index(np.abs(b)))
 825
 826            #slice image (as copy) and map to 0 - 1
 827            img = np.array(self.data[:, :, rgb]).copy()
 828            if np.isnan(img).all():
 829                print("Warning - image contains no data.")
 830                return ax.get_figure(), ax
 831
 832            # invert if needed
 833            if invert:
 834                bands = [-b for b in bands]
 835            for i,b in enumerate(bands):
 836                if not isinstance(b, str) and (b < 0):
 837                    img[..., i] = np.nanmax(img[..., i]) - img[..., i]
 838
 839            # do scaling
 840            if tscale: # scale bands independently
 841                for b in range(3):
 842                    mn = kwds.get("vmin", float(np.nanmin(img)))
 843                    mx = kwds.get("vmax", float(np.nanmax(img)))
 844                    if isinstance (mn, int):
 845                        assert mn >= 0 and mn <= 100, "Error - integer vmin values must be a percentile."
 846                        mn = float(np.nanpercentile(img[...,b], mn ))
 847                    if isinstance (mx, int):
 848                        assert mx >= 0 and mx <= 100, "Error - integer vmax values must be a percentile."
 849                        mx = float(np.nanpercentile(img[...,b], mx ))
 850                    img[...,b] = (img[..., b] - mn) / (mx - mn)
 851            else: # scale bands together
 852                mn = kwds.get("vmin", float(np.nanmin(img)))
 853                mx = kwds.get("vmax", float(np.nanmax(img)))
 854                if isinstance(mn, int):
 855                    assert mn >= 0 and mn <= 100, "Error - integer vmin values must be a percentile."
 856                    mn = float(np.nanpercentile(img, mn))
 857                if isinstance(mx, int):
 858                    assert mx >= 0 and mx <= 100, "Error - integer vmax values must be a percentile."
 859                    mx = float(np.nanpercentile(img, mx))
 860                img = (img - mn) / (mx - mn)
 861
 862            #apply brightness/contrast mapping
 863            img = np.clip((1.0 + cfac) * img + bfac, 0, 1.0 )
 864
 865            #apply masking so background is white
 866            img[np.logical_not( np.isfinite( img ) )] = 1.0
 867            if 'mask' in kwds:
 868                img[kwds.pop("mask"),:] = 1.0
 869
 870            # apply rotations and flipping
 871            if rot:
 872                img = np.transpose( img, (1,0,2) )
 873            if flipX:
 874                img = img[::-1, :, :]
 875            if flipY:
 876                img = img[:, ::-1, :]
 877
 878            # save?
 879            if 'path' in kwds:
 880                path = kwds.pop('path')
 881                from matplotlib.pyplot import imsave
 882                if not os.path.exists(os.path.dirname(path)):
 883                    os.makedirs(os.path.dirname(path)) # ensure output directory exists
 884                imsave(path, np.transpose( np.clip( img*255, 0, 255).astype(np.uint8), (1, 0, 2)))  # save the image
 885
 886            # plot samples?
 887            ps = kwds.pop('ps', 5)
 888            pc = kwds.pop('pc', 'r')
 889            if samples:
 890                if isinstance(samples, list) or isinstance(samples, np.ndarray):
 891                    ax.scatter([s[0] for s in samples], [s[1] for s in samples], s=ps, c=pc)
 892                else:
 893                    for n in self.header.get_class_names():
 894                        points = np.array(self.header.get_sample_points(n))
 895                        ax.scatter(points[:, 0], points[:, 1], s=ps, c=pc)
 896
 897            #plot
 898            ax.imshow(np.transpose(img, (1,0,2)), interpolation=kwds.pop('interpolation', 'none'), **kwds)
 899            ax.cbar = None  # no colorbar
 900
 901        return ax.get_figure(), ax
 902
 903    def createGIF(self, path, bands=None, figsize=(10,10), fps=10, **kwds):
 904        """
 905        Create and save an animated gif that loops through the bands of the image.
 906
 907        Args:
 908            path (str): the path to save the .gif
 909            bands (tuple): Tuple containing the range of band indices to draw. Default is the whole range.
 910            figsize (tuple): the size of the image to draw. Default is (10,10).
 911            fps (int): the framerate (frames per second) of the gif. Default is 10.
 912            **kwds: keywords are passed directly to matplotlib.imshow. Use this to specify cmap etc.
 913        """
 914
 915        frames = []
 916        if bands is None:
 917            bands = (0,self.band_count())
 918        else:
 919            assert 0 < bands[0] < self.band_count(), "Error - invalid range."
 920            assert 0 < bands[1] < self.band_count(), "Error - invalid range."
 921            assert bands[1] > bands[0], "Error - invalid range."
 922
 923        #plot frames
 924        for i in range(bands[0],bands[1]):
 925            fig, ax = plt.subplots(figsize=figsize)
 926            ax.imshow(self.data[:, :, i], **kwds)
 927            fig.canvas.draw()
 928            frames.append(np.frombuffer(fig.canvas.tostring_rgb(), dtype='uint8'))
 929            frames[-1] = np.reshape(frames[-1], (fig.canvas.get_width_height()[1], fig.canvas.get_width_height()[0], 3))
 930            plt.close(fig)
 931
 932        #save gif
 933        imageio.mimsave( os.path.splitext(path)[0] + ".gif", frames, fps=fps)
 934
 935    ## masking
 936    def drop_bbl(self, drop=True):
 937        """
 938        Remove bad bands as stored in the 'bbl' key in the image header. Note that this operates in-place.
 939
 940        Args:
 941            drop (bool): True if bad bands should be completely dropped. If False, these bands will be kept but replaced with nans.
 942        """
 943        assert 'bbl' in self.header, "Please specify a bad band list ('bbl') in the image header, as per the ENVI format definition."
 944        mask = self.header.get_list('bbl') == 0
 945        self.data[...,mask] = np.nan
 946        if drop:
 947            self.delete_nan_bands(inplace=True)
 948    
 949    def mask(self, mask=None, flag=np.nan, invert=False, crop=False, bands=None):
 950        """
 951         Apply a mask to an image, flagging masked pixels with the specified value. Note that this applies the mask to the
 952         image in-situ.
 953
 954         Args:
 955            flag (float): the value to use for masked pixels. Default is np.nan
 956            mask (ndarray): a numpy array defining the mask polygon of the format [[x1,y1],[x2,y2],...]. If None is passed then
 957                    pickPolygon( ... ) is used to interactively define a polygon. If a file path is passed then the polygon
 958                    will be loaded using np.load( ... ). Alternatively if mask.shape == image.shape[0,1] then it is treated as a
 959                    binary image mask (must be boolean) and True values will be masked across all bands. Default is None.
 960            invert (bool): if True, pixels within the polygon will be masked. If False, pixels outside the polygon are masked. Default is False.
 961            crop (bool): True if rows/columns containing only zeros should be removed. Default is False.
 962            bands (tuple): the bands of the image to plot if no mask is specified. If None, the middle band is used.
 963
 964         Returns:
 965            Tuple containing
 966
 967            - mask (ndarray): a boolean array with True where pixels are masked and False elsewhere.
 968            - poly (ndarray): the mask polygon array in the format described above. Useful if the polygon was interactively defined.
 969         """
 970
 971        if mask is None:  # pick mask interactively
 972            if bands is None:
 973                bands = int(self.band_count() / 2)
 974
 975            regions = self.pickPolygons(region_names=["mask"], bands=bands)
 976
 977            # the user bailed without picking a mask?
 978            if len(regions) == 0:
 979                print("Warning - no mask picked/applied.")
 980                return
 981
 982            # extract polygon mask
 983            mask = regions[0]
 984
 985        # convert polygon mask to binary mask
 986        if mask.shape[1] == 2:
 987
 988            # build meshgrid with pixel coords
 989            xx, yy = np.meshgrid(np.arange(self.xdim()), np.arange(self.ydim()))
 990            xx = xx.flatten()
 991            yy = yy.flatten()
 992            points = np.vstack([xx, yy]).T  # coordinates of each pixel
 993
 994            # calculate per-pixel mask
 995            mask = path.Path(mask).contains_points(points)
 996            mask = mask.reshape((self.ydim(), self.xdim())).T
 997
 998            # flip as we want to mask (==True) outside points (unless invert is true)
 999            if not invert:
1000                mask = np.logical_not(mask)
1001
1002        # apply binary image mask
1003        assert mask.shape[0] == self.data.shape[0] and mask.shape[1] == self.data.shape[1], \
1004            "Error - mask shape %s does not match image shape %s" % (mask.shape, self.data.shape)
1005        for b in range(self.band_count()):
1006            self.data[:, :, b][mask] = flag
1007
1008        # crop image
1009        if crop:
1010            self.crop_to_data()
1011
1012        return mask
1013
1014    def crop_to_data(self):
1015        """
1016        Remove padding of nan or zero pixels from image. Note that this is performed in place.
1017        """
1018        valid = np.isfinite(self.data).any(axis=-1) & (self.data != 0).any(axis=-1)
1019
1020        # integrate along axes
1021        xdata = np.sum(valid, axis=1) > 0.0
1022        ydata = np.sum(valid, axis=0) > 0.0
1023
1024        # calculate domain containing valid pixels
1025        xmin = np.argmax(xdata)
1026        xmax = xdata.shape[0] - np.argmax(xdata[::-1])
1027        ymin = np.argmax(ydata)
1028        ymax = ydata.shape[0] - np.argmax(ydata[::-1])
1029
1030        # crop
1031        self.data = self.data[xmin:xmax, ymin:ymax, :]
1032        
1033        # shift affine origin to new top-left pixel
1034        if self.affine is not None: 
1035            a = self.affine # shorthand for affine
1036            new_affine = list(self.affine)
1037            new_affine[0] = a[0] + xmin*a[1] + ymin*a[2]
1038            new_affine[3] = a[3] + xmin*a[4] + ymin*a[5]
1039            self.affine = np.array(new_affine)
1040            self.header['affine'] = self.affine
1041            
1042    ##################################################
1043    ## Interactive tools for picking regions/pixels
1044    ##################################################
1045    def pickPolygons(self, region_names, bands=0):
1046        """
1047        Creates a matplotlib gui for selecting polygon regions in an image.
1048
1049        Args:
1050            names (list, str): a list containing the names of the regions to pick. If a string is passed only one name is used.
1051            bands (tuple): the bands of the image to plot.
1052        """
1053
1054        if isinstance(region_names, str):
1055            region_names = [region_names]
1056
1057        assert isinstance(region_names, list), "Error - names must be a list or a string."
1058
1059        # set matplotlib backend
1060        backend = matplotlib.get_backend()
1061        matplotlib.use('Qt5Agg')  # need this backend for ROIPoly to work
1062
1063        # plot image and extract roi's
1064        fig, ax = self.quick_plot(bands)
1065        roi = MultiRoi(roi_names=region_names)
1066        plt.close(fig)  # close figure
1067
1068        # extract regions
1069        regions = []
1070        for name, r in roi.rois.items():
1071            # store region
1072            x = r.x
1073            y = r.y
1074            regions.append(np.vstack([x, y]).T)
1075
1076        # restore matplotlib backend (if possible)
1077        try:
1078            matplotlib.use(backend)
1079        except:
1080            print("Warning: could not reset matplotlib backend. Plots will remain interactive...")
1081            pass
1082
1083        return regions
1084
1085    def pickPoints(self, n=-1, bands=hylite.RGB, integer=True, title="Pick Points", **kwds):
1086        """
1087        Creates a matplotlib gui for picking pixels from an image.
1088
1089        Args:
1090            n (int): the number of pixels to pick, or -1 if the user can select as many as they wish. Default is -1.
1091            bands (tuple): the bands of the image to plot. Default is HyImage.RGB
1092            integer (bool): True if points coordinates should be cast to integers (for use as indices). Default is True.
1093            title (str): The title of the point picking window.
1094            **kwds: Keywords are passed to HyImage.quick_plot( ... ).
1095
1096        Returns:
1097            A list containing the picked point coordinates [ (x1,y1), (x2,y2), ... ].
1098        """
1099
1100        # set matplotlib backend
1101        backend = matplotlib.get_backend()
1102        matplotlib.use('Qt5Agg')  # need this backend for ROIPoly to work
1103
1104        # create figure
1105        fig, ax = self.quick_plot( bands, **kwds )
1106        ax.set_title(title)
1107
1108        # get points
1109        points = fig.ginput( n )
1110
1111        if integer:
1112            points = [ (int(p[0]), int(p[1])) for p in points ]
1113
1114        # restore matplotlib backend (if possible)
1115        try:
1116            matplotlib.use(backend)
1117        except:
1118            print("Warning: could not reset matplotlib backend. Plots will remain interactive...")
1119            pass
1120
1121        return points
1122
1123    def pickSamples(self, names=None, store=True, **kwds):
1124        """
1125        Pick sample probe points and store these in the image header file.
1126
1127        Args:
1128            names (str, list): the name of the sample to pick, or a list of names to pick multiple.
1129            store (bool): True if sample should be stored in the image header file (for later access). Default is True.
1130            **kwds: Keywords are passed to HyImage.quick_plot( ... )
1131
1132        Returns:
1133            a list containing a list of points for each sample.
1134        """
1135
1136        if isinstance(names, str):
1137            names = [names]
1138
1139        # pick points
1140        points = []
1141        for s in names:
1142            pnts = self.pickPoints(title="%s" % s, **kwds)
1143            if store:
1144                self.header['sample %s' % s] = pnts # store in header
1145            points.append(pnts)
1146        # add class to header file
1147        if store:
1148            cls_names = self.header.get_class_names()
1149            if cls_names is None:
1150                cls_names = []
1151            self.header['class names'] = cls_names + names
1152
1153        return points
class HyImage(hylite.hydata.HyData):
  20class HyImage( HyData ):
  21    """
  22    A class for hyperspectral image data. These can be individual scenes or hyperspectral orthoimages.
  23    """
  24
  25    def __init__(self, data, **kwds):
  26        """
  27        Args:
  28            data (ndarray): a numpy array such that data[x][y][band] gives each pixel value.
  29            **kwds:
  30                wav = A numpy array containing band wavelengths for this image.
  31                affine = an affine transform of the format returned by GDAL.GetGeoTransform().
  32                projection = string defining the project. Default is None.
  33                sensor = sensor name. Default is "unknown".
  34                header = path to associated header file. Default is None.
  35        """
  36
  37        #call constructor for HyData
  38        super().__init__(data, **kwds)
  39
  40        # special case - if dataset only has oneband, slice it so it still has
  41        # the format data[x,y,b].
  42        if not self.data is None:
  43            if len(self.data.shape) == 1:
  44                self.data = self.data[None, None, :] # single pixel image
  45            if len(self.data.shape) == 2:
  46                self.data = self.data[:, :, None] # single band iamge
  47
  48        #load any additional project information (specific to images)
  49        self.set_projection(kwds.get("projection",None))
  50        self.affine = np.array( kwds.get("affine",[0,1,0,0,0,1]) )
  51        self.header['affine'] = kwds.get("affine",[0,1,0,0,0,1]) # also store this here
  52
  53        # wavelengths
  54        if 'wav' in kwds:
  55            self.set_wavelengths(kwds['wav'])
  56
  57        #special header formatting
  58        self.header['file type'] = 'ENVI Standard'
  59
  60    def copy(self,data=True):
  61        """
  62        Make a deep copy of this image instance.
  63
  64        Args:
  65            data (bool): True if a copy of the data should be made, otherwise only copy header.
  66
  67        Returns:
  68            a new HyImage instance.
  69        """
  70        if not data:
  71            return HyImage(None, header=self.header.copy(), projection=self.projection, affine=self.affine)
  72        else:
  73            return HyImage( self.data.copy(), header=self.header.copy(), projection=self.projection, affine=self.affine)
  74
  75    def T(self):
  76        """
  77        Return a transposed view of the data matrix (corresponding with the [y,x] indexing used by matplotlib, opencv etc.
  78        """
  79        return np.transpose(self.data, (1,0,2))
  80
  81    def xdim(self):
  82        """
  83        Return number of pixels in x (first dimension of data array)
  84        """
  85        return self.data.shape[0]
  86
  87    def ydim(self):
  88        """
  89        Return number of pixels in y (second dimension of data array)
  90        """
  91        return self.data.shape[1]
  92
  93    def aspx(self):
  94        """
  95        Return the aspect ratio of this image (width/height).
  96        """
  97        return self.ydim() / self.xdim()
  98    
  99    #####################################
 100    ## GEOREFERENCING METHODS
 101    #####################################
 102
 103    def get_extent(self):
 104        """
 105        Returns the width and height of this image in world coordinates.
 106
 107        Returns:
 108            tuple with (width, height).
 109        """
 110        return self.xdim * self.pixel_size[0], self.ydim * self.pixel_size[1]
 111
 112    def set_projection(self,proj):
 113        """
 114        Set this project to an existing osgeo.osr.SpatialReference or GDAL georeference string.
 115
 116        Args:
 117            proj (str, osgeo.osr.SpatialReference): the project to use as osgeo.osr.SpatialReference or GDAL georeference string.
 118        """
 119        if proj is None:
 120            self.projection = None
 121        else:
 122            try:
 123                from osgeo.osr import SpatialReference
 124            except:
 125                assert False, "Error - GDAL must be installed to work with spatial projections in hylite."
 126            if isinstance(proj, SpatialReference):
 127                self.projection = proj
 128            elif isinstance(proj, str):
 129                self.projection = SpatialReference(proj)
 130            else:
 131                print("Invalid project %s" % proj)
 132                raise
 133
 134    def set_projection_EPSG(self,EPSG):
 135        """
 136        Sets this image project using an EPSG code.
 137
 138        Args:
 139            EPSG (str): string EPSG code that can be passed to SpatialReference.SetFromUserInput(...).
 140        """
 141
 142        try:
 143            from osgeo.osr import SpatialReference
 144        except:
 145            assert False, "Error - GDAL must be installed to work with spatial projections in hylite."
 146
 147        self.projection = SpatialReference()
 148        self.projection.SetFromUserInput(EPSG)
 149
 150    def get_projection_EPSG(self):
 151        """
 152        Gets a string describing this projections EPSG code (if it is an EPSG project).
 153
 154        Returns:
 155            an EPSG code string of the format "EPSG:XXXX".
 156        """
 157        if self.projection is None:
 158            return None
 159        else:
 160            return "%s:%s" % (self.projection.GetAttrValue("AUTHORITY",0),self.projection.GetAttrValue("AUTHORITY",1))
 161
 162    def pix_to_world(self, px, py, proj=None):
 163        """
 164        Take pixel coordinates and return world coordinates
 165
 166        Args:
 167            px (int): the pixel x-coord.
 168            py (int): the pixel y-coord.
 169            proj (str, osr.SpatialReference): the coordinate system to use. Default (None) uses the same system as this image. Otherwise
 170                   an osr.SpatialReference can be passed (HyImage.project), or an EPSG string (e.g. get_projection_EPSG(...)).
 171        Returns:
 172            the world coordinates in the coordinate system defined by get_projection_EPSG(...).
 173        """
 174
 175        try:
 176            from osgeo import osr
 177            import osgeo.gdal as gdal
 178            from osgeo import ogr
 179        except:
 180            assert False, "Error - GDAL must be installed to work with spatial projections in hylite."
 181
 182        # parse project
 183        if proj is None:
 184            proj = self.projection
 185        elif isinstance(proj, str) or isinstance(proj, int):
 186            epsg = proj
 187            if isinstance(epsg, str):
 188                try:
 189                    epsg = int(str.split(':')[1])
 190                except:
 191                    assert False, "Error - %s is an invalid EPSG code." % proj
 192            proj = osr.SpatialReference()
 193            proj.ImportFromEPSG(epsg)
 194
 195        # check we have all the required info
 196        assert isinstance(proj, osr.SpatialReference), "Error - invalid spatial reference %s" % proj
 197        assert (not self.affine is None) and (
 198            not self.projection is None), "Error - project information is undefined."
 199
 200        #project to world coordinates in this images project/world coords
 201        x,y = gdal.ApplyGeoTransform(self.affine, px, py)
 202
 203        #project to target coords (if different)
 204        if not proj.IsSameGeogCS(self.projection):
 205            P = ogr.Geometry(ogr.wkbPoint)
 206            if proj.EPSGTreatsAsNorthingEasting():
 207                P.AddPoint(x, y)
 208            else:
 209                P.AddPoint(y, x)
 210            P.AssignSpatialReference(self.projection)  # tell the point what coordinates it's in
 211            P.TransformTo(proj)  # reproject it to the out spatial reference
 212            x, y = P.GetX(), P.GetY()
 213
 214            #do we need to transpose?
 215            if proj.EPSGTreatsAsLatLong():
 216                x,y=y,x #we want lon,lat not lat,lon
 217        return x, y
 218
 219    def world_to_pix(self, x, y, proj = None):
 220        """
 221        Take world coordinates and return pixel coordinates
 222
 223        Args:
 224            x (float): the world x-coord.
 225            y (float): the world y-coord.
 226            proj (str, osr.SpatialReference): the coordinate system of the input coordinates. Default (None) uses the same system as this image. Otherwise
 227                   an osr.SpatialReference can be passed (HyImage.project), or an EPSG string (e.g. get_projection_EPSG(...)).
 228
 229        Returns:
 230            the pixel coordinates based on the affine transform stored in self.affine.
 231        """
 232
 233        try:
 234            from osgeo import osr
 235            import osgeo.gdal as gdal
 236            from osgeo import ogr
 237        except:
 238            assert False, "Error - GDAL must be installed to work with spatial projections in hylite."
 239
 240        # parse project
 241        if proj is None:
 242            proj = self.projection
 243        elif isinstance(proj, str) or isinstance(proj, int):
 244            epsg = proj
 245            if isinstance(epsg, str):
 246                try:
 247                    epsg = int(str.split(':')[1])
 248                except:
 249                    assert False, "Error - %s is an invalid EPSG code." % proj
 250            proj = osr.SpatialReference()
 251            proj.ImportFromEPSG(epsg)
 252
 253
 254        # check we have all the required info
 255        assert isinstance(proj, osr.SpatialReference), "Error - invalid spatial reference %s" % proj
 256        assert (not self.affine is None) and (not self.projection is None), "Error - project information is undefined."
 257
 258        # project to this images CS (if different)
 259        if not proj.IsSameGeogCS(self.projection):
 260            P = ogr.Geometry(ogr.wkbPoint)
 261            if proj.EPSGTreatsAsNorthingEasting():
 262                P.AddPoint(x, y)
 263            else:
 264                P.AddPoint(y, x)
 265            P.AssignSpatialReference(proj)  # tell the point what coordinates it's in
 266            P.AddPoint(x, y)
 267            P.TransformTo(self.projection)  # reproject it to the out spatial reference
 268            x, y = P.GetX(), P.GetY()
 269            if self.projection.EPSGTreatsAsLatLong(): # do we need to transpose?
 270                x, y = y, x  # we want lon,lat not lat,lon
 271
 272        inv = gdal.InvGeoTransform(self.affine)
 273        assert not inv is None, "Error - could not invert affine transform?"
 274
 275        #apply
 276        return gdal.ApplyGeoTransform(inv, x, y)
 277
 278    def crop(self, xmin, xmax, ymin, ymax, bands=None):
 279        """
 280        Return a cropped copy of this image.
 281
 282        Args:
 283            xmin, xmax (int): pixel bounds in x (rows)
 284            ymin, ymax (int): pixel bounds in y (columns)
 285            bands (None, list, tuple): optional band indices or (min,max) range
 286
 287        Returns:
 288            HyImage: cropped image with updated affine transform
 289        """
 290
 291        # ---- validate bounds ----
 292        xmin = int(max(0, xmin))
 293        ymin = int(max(0, ymin))
 294        xmax = int(min(self.xdim(), xmax))
 295        ymax = int(min(self.ydim(), ymax))
 296
 297        assert xmin < xmax and ymin < ymax, "Invalid crop extent."
 298
 299        # ---- crop data ----
 300        if bands is None:
 301            data = self.data[xmin:xmax, ymin:ymax, :].copy()
 302            wav = self.get_wavelengths()
 303        else: # band selection
 304            if isinstance(bands, tuple):
 305                b0 = self.get_band_index(bands[0])
 306                b1 = self.get_band_index(bands[1])
 307                data = self.data[xmin:xmax, ymin:ymax, b0:b1].copy()
 308                wav = self.get_wavelengths()[b0:b1]
 309            else:
 310                idx = [self.get_band_index(b) for b in bands]
 311                data = self.data[xmin:xmax, ymin:ymax, idx].copy()
 312                wav = self.get_wavelengths()[idx]
 313
 314        # ---- update affine transform ----
 315        if self.affine is not None:
 316            a = list(self.affine)
 317            new_affine = a.copy()
 318
 319            # shift origin to new top-left pixel
 320            new_affine[0] = a[0] + xmin*a[1] + ymin*a[2]
 321            new_affine[3] = a[3] + xmin*a[4] + ymin*a[5]
 322        else:
 323            new_affine = None
 324
 325        # ---- construct output image ----
 326        out = HyImage(
 327            data,
 328            header=self.header.copy(),
 329            projection=self.projection,
 330            affine=new_affine,
 331            wav=wav
 332        )
 333
 334        return out
 335
 336    def resize(self, newdims: tuple, interpolation: int = 1):
 337        """
 338        Resize this image with opencv and update affine transform accordingly.
 339
 340        Args:
 341            newdims (tuple): the new image dimensions (xdim, ydim)
 342            interpolation (int): opencv interpolation method. Default is cv2.INTER_LINEAR.
 343        """
 344        import cv2  # avoid import issues if opencv is missing
 345
 346        old_x, old_y = self.xdim(), self.ydim()
 347        new_x, new_y = int(newdims[0]), int(newdims[1])
 348
 349        assert new_x > 0 and new_y > 0, "Invalid resize dimensions."
 350
 351        # resize data (opencv uses width, height = y, x)
 352        self.data = cv2.resize(
 353            self.data,
 354            (new_y, new_x),
 355            interpolation=interpolation
 356        )
 357
 358        # update affine transform
 359        if self.affine is not None:
 360            a = list(self.affine)
 361
 362            sx = old_x / new_x
 363            sy = old_y / new_y
 364
 365            self.affine = [
 366                a[0],          # x origin unchanged
 367                a[1] * sx,     # pixel width
 368                a[2] * sy,     # row rotation
 369                a[3],          # y origin unchanged
 370                a[4] * sx,     # column rotation
 371                a[5] * sy      # pixel height
 372            ]
 373
 374    def tile(self, tile_size):
 375        """
 376        Break image into tiles of given size and return a list of HyImage tiles.
 377        Each tile has an updated affine transform reflecting its position in the original image.
 378
 379        Args:
 380            tile_size (tuple): (tile_x, tile_y) in pixels
 381        Returns:
 382            list of HyImage
 383        """
 384        tiles = []
 385        tx, ty = tile_size
 386        nx, ny = self.xdim(), self.ydim()
 387        for i in range(0, nx, tx):
 388            for j in range(0, ny, ty):
 389                data_tile = self.data[i:min(i+tx, nx), j:min(j+ty, ny), :].copy() # slice data
 390
 391                # compute new affine
 392                # affine = [x0, dx, rotx, y0, roty, dy]
 393                x0, dx, rotx, y0, roty, dy = self.affine
 394                new_x0 = x0 + i*dx + j*rotx
 395                new_y0 = y0 + i*roty + j*dy
 396                new_affine = [new_x0, dx, rotx, new_y0, roty, dy]
 397
 398                # create tile
 399                tile_img = HyImage(
 400                    data_tile,
 401                    affine=new_affine,
 402                    projection=self.projection,
 403                    wav=self.get_wavelengths(),
 404                    header=self.header.copy()
 405                )
 406                tile_img.header['xleft'] = i
 407                tile_img.header['ytop'] = j
 408                tiles.append(tile_img)
 409
 410        return tiles
 411
 412    @staticmethod
 413    def mosaic(
 414        tiles,
 415        blend="mean",
 416        resampling="nearest",
 417        out_affine=None,
 418        out_shape=None,
 419    ):
 420        """
 421        Mosaic georeferenced HyImage tiles using GDAL. Note that this assumes all tiles are in the same coordinate system.
 422
 423        Args:
 424            tiles (list[HyImage])
 425            blend (str): 'first', 'min', 'max', 'mean', 'median'
 426            resampling (str): 'nearest', 'bilinear', 'cubic'
 427            out_affine (list): optional 6-element affine to define output grid. If None, the affine of the first tile is used.
 428            out_shape (tuple): optional (xdim, ydim) shape of the output grid. If None, the extent of all tiles will be used.
 429        Returns:
 430            HyImage
 431        """
 432        import numpy as np
 433        from hylite.project.align import resample_raster
 434        from osgeo import gdal, osr
 435
 436        assert len(tiles) > 0
 437        assert blend in ("first", "min", "max", "mean", "median")
 438
 439        # compute bounds in world coordinates
 440        # N.B. THIS ASSUMES ALL DATA ARE IN THE SAME CRS
 441        points = []
 442        for t in tiles:
 443            points.append( t.pix_to_world(0,0) )
 444            points.append( t.pix_to_world(t.xdim()+1,t.ydim()+1) )
 445        min_x, min_y = np.min(points, axis=0)
 446        max_x, max_y = np.max(points, axis=0)
 447        if out_shape is None:
 448            out_shape = (np.array(tiles[0].world_to_pix(max_x, max_y)) + np.array(tiles[0].world_to_pix(min_x, min_y))).round().astype(int)
 449        if out_affine is None:
 450            out_affine = list(tiles[0].affine)
 451            out_affine[0] = min_x
 452            out_affine[3] = max_y
 453
 454        if blend == "first": # fill output, first come, first served.
 455            out = np.full( tuple(out_shape) + (tiles[0].band_count(),), np.nan, dtype=np.float32)
 456            for t in tiles:
 457                r = resample_raster( t.data, t.affine, out_affine, out_shape )
 458                mask = np.isnan(out)
 459                out[mask] = r[mask]
 460        elif blend == "min": # keep minimum value in case of overlap
 461            out = np.nanmin( np.stack([ resample_raster( t.data, t.affine, out_affine, out_shape ) for t in tiles ], 
 462                                      axis=0), axis=0 )
 463        elif blend == "max": # keep maximum value in case of overlap
 464            out = np.nanmax( np.stack([ resample_raster( t.data, t.affine, out_affine, out_shape ) for t in tiles ], 
 465                                      axis=0), axis=0 )
 466        elif blend == "mean": # use average in case of overlap
 467            out = np.nanmean( np.stack([ resample_raster( t.data, t.affine, out_affine, out_shape ) for t in tiles ], 
 468                                       axis=0), axis=0 )
 469        elif blend == "median": # use mean in case of overlap
 470            out = np.nanmedian( np.stack([ resample_raster( t.data, t.affine, out_affine, out_shape ) for t in tiles ], 
 471                                         axis=0), axis=0 )
 472
 473        # Return HyImage
 474        out =  HyImage(
 475            out,
 476            affine=out_affine,
 477            projection=tiles[0].projection,
 478            wav=tiles[0].get_wavelengths(),
 479            header=tiles[0].header.copy()
 480        )
 481        if 'xleft' in out.header: del out.header['xleft'] # stored in tiles, but not meaningful here
 482        if 'ytop' in out.header: del out.header['ytop'] # stored in tiles, but not meaningful here
 483
 484        return out
 485
 486    #####################################
 487    ## BASIC TRANSFORMS
 488    #####################################
 489
 490    def flip(self, axis='x'):
 491        """
 492        Flip the image on the x or y axis. Note that this will remove any defined affine transform.
 493
 494        Args:
 495            axis (str): 'x' or 'y' or both 'xy'.
 496        """
 497
 498        if 'x' in axis.lower():
 499            self.data = np.flip(self.data,axis=0)
 500        if 'y' in axis.lower():
 501            self.data = np.flip(self.data,axis=1)
 502        self.affine = None
 503        if 'affine' in self.header: del self.header['affine']
 504        self.push_to_header() # update width and height info
 505
 506    def rot90(self):
 507        """
 508        Rotate this image by 90 degrees by transposing the underlying data array. Combine with flip('x') or flip('y')
 509        to achieve positive/negative rotations.
 510        """
 511        self.data = np.transpose( self.data, (1,0,2) )
 512        self.affine = None
 513        if 'affine' in self.header: del self.header['affine']
 514        self.push_to_header() # update width and height info
 515
 516    #####################################
 517    ##IMAGE FILTERING
 518    #####################################
 519    def fill_holes(self):
 520        """
 521        Replaces nan pixel with an average of their neighbours, thus removing 1-pixel large holes from an image. Note that
 522        for performance reasons this assumes that holes line up across bands. Note that this is not vectorized so very slow...
 523        """
 524
 525        # perform greyscale dilation
 526        dilate = self.data.copy()
 527        mask = np.logical_not(np.isfinite(dilate))
 528        dilate[mask] = 0
 529        for b in range(self.band_count()):
 530            dilate[:, :, b] = sp.ndimage.grey_dilation(dilate[:, :, b], size=(3, 3))
 531
 532        # map back to holes in dataset
 533        self.data[mask] = dilate[mask]
 534        #self.data[self.data == 0] = np.nan  # replace remaining 0's with nans
 535
 536    def blur(self, n=3):
 537        """
 538        Applies a gaussian kernel of size n to the image using OpenCV.
 539
 540        Args:
 541            n (int): the dimensions of the gaussian kernel to convolve. Default is 3. Increase for more blurry results.
 542        """
 543        import cv2 # import this here to avoid errors if opencv is not installed properly
 544
 545        nanmask = np.isnan(self.data)
 546        assert isinstance(n, int) and n >= 3, "Error - invalid kernel. N must be an integer > 3. "
 547        kernel = np.ones((n, n), np.float32) / (n ** 2)
 548        self.data = cv2.filter2D(self.data, -1, kernel)
 549        self.data[nanmask] = np.nan  # remove mask
 550
 551    def erode(self, size=3, iterations=1):
 552        """
 553        Apply an erode filter to this image to expand background (nan) pixels. Refer to open-cv's erode
 554        function for more details.
 555
 556        Args:
 557            size (int): the size of the erode filter. Default is a 3x3 kernel.
 558            iterations (int): the number of erode iterations. Default is 1.
 559        """
 560        import cv2 # import this here to avoid errors if opencv is not installed properly
 561
 562        # erode
 563        kernel = np.ones((size, size), np.uint8)
 564        if self.is_float():
 565            mask = np.isfinite(self.data).any(axis=-1)
 566            mask = cv2.erode(mask.astype(np.uint8), kernel, iterations=iterations)
 567            self.data[mask == 0, :] = np.nan
 568        else:
 569            mask = (self.data != 0).any( axis=-1 )
 570            mask = cv2.erode(mask.astype(np.uint8), kernel, iterations=iterations)
 571            self.data[mask == 0, :] = 0
 572
 573    def despeckle(self, size=5):
 574        """
 575        Despeckle each band of this image (independently) using a median filter.
 576
 577        Args:
 578            size (int): the size of the median filter kernel. Default is 5. Must be an odd number.
 579        """
 580
 581        assert (size % 2) == 1, "Error - size must be an odd integer"
 582        import cv2 # import this here to avoid errors if opencv is not installed properly
 583        if self.is_float():
 584            self.data = cv2.medianBlur( self.data.astype(np.float32), size )
 585        else:
 586            self.data = cv2.medianBlur( self.data, size )
 587
 588    #####################################
 589    ##FEATURES AND FEATURE MATCHING
 590    ######################################
 591    def get_keypoints(self, band, eq=False, mask=True, method='sift', cfac=0.0,bfac=0.0, **kwds):
 592        """
 593        Get feature descriptors from the specified band.
 594
 595        Args:
 596            band (int,float,str,tuple): the band index (int) or wavelength (float) to extract features from. Alternatively, a tuple can be passed
 597                    containing a range of bands (min : max) to average before feature matching.
 598            eq (bool): True if the image should be histogram equalized first. Default is False.
 599            mask (bool): True if 0 value pixels should be masked. Default is True.
 600            method (str): the feature detector to use. Options are 'SIFT' and 'ORB' (faster but less accurate). Default is 'SIFT'.
 601            cfac (float): contrast adjustment to apply to hyperspectral bands before matching. Default is 0.0.
 602            bfac (float): brightness adjustment to apply to hyperspectral bands before matching. Default is 0.0.
 603            **kwds: keyword arguments are passed to the opencv feature detector. For SIFT these are:
 604
 605                - contrastThreshold: default is 0.01.
 606                - edgeThreshold: default is 10.
 607                - sigma: default is 1.0
 608
 609                For ORB these are:
 610
 611                - nfeatures = the number of features to detect. Default is 5000.
 612
 613            Returns:
 614                Tuple containing
 615
 616                    - k (ndarray): the keypoints detected
 617                    - d (ndarray): corresponding feature descriptors
 618         """
 619        import cv2 # import this here to avoid errors if opencv is not installed properly
 620
 621        # get image
 622        if isinstance(band, int) or isinstance(band, float): #single band
 623            image = self.data[:, :, self.get_band_index(band)]
 624        elif isinstance(band,tuple): #range of bands (averaged)
 625            idx0 = self.get_band_index(band[0])
 626            idx1 = self.get_band_index(band[1])
 627
 628            #deal with out of range errors
 629            if idx0 is None:
 630                idx0 = 0
 631            if idx1 is None:
 632                idx1 = self.band_count()
 633
 634            #average bands
 635            image = np.nanmean(self.data[:,:,idx0:idx1],axis=2)
 636        else:
 637            assert False, "Error, unrecognised band %s" % band
 638
 639        #normalise image to range 0 - 1
 640        image -= np.nanmin(image)
 641        image = image / np.nanmax(image)
 642
 643        #apply brightness/contrast adjustment
 644        image = (1.0+cfac)*image + bfac
 645        image[image > 1.0] = 1.0
 646        image[image < 0.0] = 0.0
 647
 648        #convert image to uint8 for opencv
 649        image = np.uint8(255 * image)
 650        if eq:
 651            image = cv2.equalizeHist(image)
 652
 653        if mask:
 654            mask = np.zeros(image.shape, dtype=np.uint8)
 655            mask[image != 0] = 255  # include only non-zero pixels
 656        else:
 657            mask = None
 658
 659        if 'sift' in method.lower():  # SIFT
 660
 661            # setup default keywords
 662            kwds["contrastThreshold"] = kwds.get("contrastThreshold", 0.01)
 663            kwds["edgeThreshold"] = kwds.get("edgeThreshold", 10)
 664            kwds["sigma"] = kwds.get("sigma", 1.0)
 665
 666            # make feature detector
 667            #alg = cv2.xfeatures2d.SIFT_create(**kwds)
 668            alg = cv2.SIFT_create()
 669        elif 'orb' in method.lower():  # orb
 670            kwds['nfeatures'] = kwds.get('nfeatures', 5000)
 671            alg = cv2.ORB_create(scoreType=cv2.ORB_FAST_SCORE, **kwds)
 672        else:
 673            assert False, "Error - %s is not a recognised feature detector." % method
 674
 675        # detect keypoints
 676        kp = alg.detect(image, mask)
 677
 678        # extract and return feature vectors
 679        return alg.compute(image, kp)
 680
 681    @classmethod
 682    def match_keypoints(cls, kp1, kp2, d1, d2, method='SIFT', dist=0.7, tree = 5, check = 100, min_count=5):
 683        """
 684        Compares keypoint feature vectors from two images and returns matching pairs.
 685
 686        Args:
 687            kp1 (ndarray): keypoints from the first image
 688            kp2 (ndarray): keypoints from the second image
 689            d1 (ndarray): descriptors for the keypoints from the first image
 690            d2 (ndarray): descriptors for the keypoints from the second image
 691            method (str): the method used to calculate the feature descriptors. Should be 'sift' or 'orb'. Default is 'sift'.
 692            dist (float): minimum match distance (0 to 1), default is 0.7
 693            tree (int): not sure what this does? Default is 5. See open-cv docs.
 694            check (int): ditto. Default is 100.
 695            min_count (int): the minimum number of matches to consider a valid matching operation. If fewer matches are found,
 696                       then the function returns None, None. Default is 5.
 697        """
 698        import cv2 # import this here to avoid errors if opencv is not installed properly
 699        if 'sift' in method.lower():
 700            algorithm = cv2.NORM_INF
 701        elif 'orb' in method.lower():
 702            algorithm = cv2.NORM_HAMMING
 703        else:
 704            assert False, "Error - unknown matching algorithm %s" % method
 705
 706        #calculate flann matches
 707        index_params = dict(algorithm=algorithm, trees=tree)
 708        search_params = dict(checks=check)
 709        flann = cv2.FlannBasedMatcher(index_params, search_params)
 710        matches = flann.knnMatch(d1, d2, k=2)
 711
 712        # store all the good matches as per Lowe's ratio test.
 713        good = []
 714        for m, n in matches:
 715            if m.distance < dist * n.distance:
 716                good.append(m)
 717
 718        if len(good) < min_count:
 719            return None, None
 720        else:
 721            src_pts = np.float32([kp1[m.queryIdx].pt for m in good]).reshape(-1, 1, 2)
 722            dst_pts = np.float32([kp2[m.trainIdx].pt for m in good]).reshape(-1, 1, 2)
 723            return src_pts, dst_pts
 724
 725    ############################
 726    ## Visualisation methods
 727    ############################
 728    def quick_plot(self, bands=0, ax=None, bfac=0.0, cfac=0.0, samples=False, tscale=False, invert=False, rot=False, flipX=False, flipY=False,
 729                   **kwds):
 730        """
 731        Plot a band using matplotlib.imshow(...).
 732
 733        Args:
 734            bands (str,int,float,tuple): the band name (string), index (integer) or wavelength (float) to plot. Default is 0. If a tuple is passed then
 735                  each band in the tuple (string or index) will be mapped to rgb. Bands with negative wavelengths or indices will be inverted before plotting.
 736            ax: an axis object to plot to. If none, plt.imshow( ... ) is used.
 737            bfac (float): a brightness adjustment to apply to RGB mappings (-1 to 1)
 738            cfac (float): a contrast adjustment to apply to RGB mappings (-1 to 1)
 739            samples (bool): True if sample points (defined in the header file) should be plotted. Default is False. Otherwise, a list of
 740                     [ (x,y), ... ] points can be passed.
 741            tscale (bool): True if each band (for ternary images) should be scaled independently. Default is False.
 742                    When using scaling, vmin and vmax can be used to set the clipping percentiles (integers) or
 743                    (constant) values (float).
 744            invert (bool) : True if each band should be inverted before plotting. Only works for multiband (ternary) images.
 745            rot (bool): if True, the x and y axis will be flipped (90 degree rotation) before plotting. Default is False.
 746            flipX (bool): if True, the x axis will be flipped before plotting (after applying rotations).
 747            flipY (bool): if True, the y axis will be flippe before plotting (after applying rotations).
 748            **kwds: keywords are passed to matplotlib.imshow( ... ), except for the following:
 749
 750                 - mask = a 2D boolean mask containing true if pixels should be drawn and false otherwise.
 751                 - path = a file path to save the image too (at matching resolution; use fig.savefig(..) if you want to save the figure).
 752                 - ticks = True if x- and y- ticks should be plotted. Default is False.
 753                 - ps, pc = the size and color of sample points to plot. Can be constant or list.
 754                 - figsize = a figsize for the figure to create (if ax is None).
 755
 756        Returns:
 757            Tuple containing
 758
 759            - fig: matplotlib figure object
 760            - ax:  matplotlib axes object. If a colorbar is created, (band is an integer or a float), then this will be stored in ax.cbar.
 761        """
 762
 763        #create new axes?
 764        if ax is None:
 765            fig, ax = plt.subplots(figsize=kwds.pop('figsize', (18,18*self.ydim()/self.xdim()) ))
 766
 767        # deal with ticks
 768        if not kwds.pop('ticks', False ):
 769            ax.set_xticks([])
 770            ax.set_yticks([])
 771
 772        #map individual band using colourmap
 773        if isinstance(bands, str) or isinstance(bands, int) or isinstance(bands, float):
 774            #get band
 775            if isinstance(bands, str):
 776                data = self.data[:, :, self.get_band_index(bands)]
 777            else:
 778                data = self.data[:, :, self.get_band_index(np.abs(bands))]
 779            if not isinstance(bands, str) and bands < 0:
 780                data = np.nanmax(data) - data # flip
 781
 782            # convert integer vmin and vmax values to percentiles
 783            if 'vmin' in kwds:
 784                if isinstance(kwds['vmin'], int):
 785                    kwds['vmin'] = np.nanpercentile( data, kwds['vmin'] )
 786            if 'vmax' in kwds:
 787                if isinstance(kwds['vmax'], int):
 788                    kwds['vmax'] = np.nanpercentile( data, kwds['vmax'] )
 789
 790            #mask nans (and apply custom mask)
 791            mask = np.isnan(data)
 792            if not np.isnan(self.header.get_data_ignore_value()):
 793                mask = mask + data == self.header.get_data_ignore_value()
 794            if 'mask' in kwds:
 795                mask = mask + kwds.get('mask')
 796                del kwds['mask']
 797            data = np.ma.array(data, mask = mask > 0 )
 798
 799            # apply rotations and flipping
 800            if rot:
 801                data = data.T
 802            if flipX:
 803                data = data[::-1, :]
 804            if flipY:
 805                data = data[:, ::-1]
 806
 807            # save?
 808            if 'path' in kwds:
 809                path = kwds.pop('path')
 810                from matplotlib.pyplot import imsave
 811                if not os.path.exists(os.path.dirname(path)):
 812                    os.makedirs(os.path.dirname(path)) # ensure output directory exists
 813                imsave(path, data.T, **kwds)  # save the image
 814
 815            ax.cbar = ax.imshow(data.T, interpolation=kwds.pop('interpolation', 'none'), **kwds) # change default interpolation to None
 816
 817        #map 3 bands to RGB
 818        elif isinstance(bands, tuple) or isinstance(bands, list):
 819            #get band indices and range
 820            rgb = []
 821            for b in bands:
 822                if isinstance(b, str):
 823                    rgb.append(self.get_band_index(b))
 824                else:
 825                    rgb.append(self.get_band_index(np.abs(b)))
 826
 827            #slice image (as copy) and map to 0 - 1
 828            img = np.array(self.data[:, :, rgb]).copy()
 829            if np.isnan(img).all():
 830                print("Warning - image contains no data.")
 831                return ax.get_figure(), ax
 832
 833            # invert if needed
 834            if invert:
 835                bands = [-b for b in bands]
 836            for i,b in enumerate(bands):
 837                if not isinstance(b, str) and (b < 0):
 838                    img[..., i] = np.nanmax(img[..., i]) - img[..., i]
 839
 840            # do scaling
 841            if tscale: # scale bands independently
 842                for b in range(3):
 843                    mn = kwds.get("vmin", float(np.nanmin(img)))
 844                    mx = kwds.get("vmax", float(np.nanmax(img)))
 845                    if isinstance (mn, int):
 846                        assert mn >= 0 and mn <= 100, "Error - integer vmin values must be a percentile."
 847                        mn = float(np.nanpercentile(img[...,b], mn ))
 848                    if isinstance (mx, int):
 849                        assert mx >= 0 and mx <= 100, "Error - integer vmax values must be a percentile."
 850                        mx = float(np.nanpercentile(img[...,b], mx ))
 851                    img[...,b] = (img[..., b] - mn) / (mx - mn)
 852            else: # scale bands together
 853                mn = kwds.get("vmin", float(np.nanmin(img)))
 854                mx = kwds.get("vmax", float(np.nanmax(img)))
 855                if isinstance(mn, int):
 856                    assert mn >= 0 and mn <= 100, "Error - integer vmin values must be a percentile."
 857                    mn = float(np.nanpercentile(img, mn))
 858                if isinstance(mx, int):
 859                    assert mx >= 0 and mx <= 100, "Error - integer vmax values must be a percentile."
 860                    mx = float(np.nanpercentile(img, mx))
 861                img = (img - mn) / (mx - mn)
 862
 863            #apply brightness/contrast mapping
 864            img = np.clip((1.0 + cfac) * img + bfac, 0, 1.0 )
 865
 866            #apply masking so background is white
 867            img[np.logical_not( np.isfinite( img ) )] = 1.0
 868            if 'mask' in kwds:
 869                img[kwds.pop("mask"),:] = 1.0
 870
 871            # apply rotations and flipping
 872            if rot:
 873                img = np.transpose( img, (1,0,2) )
 874            if flipX:
 875                img = img[::-1, :, :]
 876            if flipY:
 877                img = img[:, ::-1, :]
 878
 879            # save?
 880            if 'path' in kwds:
 881                path = kwds.pop('path')
 882                from matplotlib.pyplot import imsave
 883                if not os.path.exists(os.path.dirname(path)):
 884                    os.makedirs(os.path.dirname(path)) # ensure output directory exists
 885                imsave(path, np.transpose( np.clip( img*255, 0, 255).astype(np.uint8), (1, 0, 2)))  # save the image
 886
 887            # plot samples?
 888            ps = kwds.pop('ps', 5)
 889            pc = kwds.pop('pc', 'r')
 890            if samples:
 891                if isinstance(samples, list) or isinstance(samples, np.ndarray):
 892                    ax.scatter([s[0] for s in samples], [s[1] for s in samples], s=ps, c=pc)
 893                else:
 894                    for n in self.header.get_class_names():
 895                        points = np.array(self.header.get_sample_points(n))
 896                        ax.scatter(points[:, 0], points[:, 1], s=ps, c=pc)
 897
 898            #plot
 899            ax.imshow(np.transpose(img, (1,0,2)), interpolation=kwds.pop('interpolation', 'none'), **kwds)
 900            ax.cbar = None  # no colorbar
 901
 902        return ax.get_figure(), ax
 903
 904    def createGIF(self, path, bands=None, figsize=(10,10), fps=10, **kwds):
 905        """
 906        Create and save an animated gif that loops through the bands of the image.
 907
 908        Args:
 909            path (str): the path to save the .gif
 910            bands (tuple): Tuple containing the range of band indices to draw. Default is the whole range.
 911            figsize (tuple): the size of the image to draw. Default is (10,10).
 912            fps (int): the framerate (frames per second) of the gif. Default is 10.
 913            **kwds: keywords are passed directly to matplotlib.imshow. Use this to specify cmap etc.
 914        """
 915
 916        frames = []
 917        if bands is None:
 918            bands = (0,self.band_count())
 919        else:
 920            assert 0 < bands[0] < self.band_count(), "Error - invalid range."
 921            assert 0 < bands[1] < self.band_count(), "Error - invalid range."
 922            assert bands[1] > bands[0], "Error - invalid range."
 923
 924        #plot frames
 925        for i in range(bands[0],bands[1]):
 926            fig, ax = plt.subplots(figsize=figsize)
 927            ax.imshow(self.data[:, :, i], **kwds)
 928            fig.canvas.draw()
 929            frames.append(np.frombuffer(fig.canvas.tostring_rgb(), dtype='uint8'))
 930            frames[-1] = np.reshape(frames[-1], (fig.canvas.get_width_height()[1], fig.canvas.get_width_height()[0], 3))
 931            plt.close(fig)
 932
 933        #save gif
 934        imageio.mimsave( os.path.splitext(path)[0] + ".gif", frames, fps=fps)
 935
 936    ## masking
 937    def drop_bbl(self, drop=True):
 938        """
 939        Remove bad bands as stored in the 'bbl' key in the image header. Note that this operates in-place.
 940
 941        Args:
 942            drop (bool): True if bad bands should be completely dropped. If False, these bands will be kept but replaced with nans.
 943        """
 944        assert 'bbl' in self.header, "Please specify a bad band list ('bbl') in the image header, as per the ENVI format definition."
 945        mask = self.header.get_list('bbl') == 0
 946        self.data[...,mask] = np.nan
 947        if drop:
 948            self.delete_nan_bands(inplace=True)
 949    
 950    def mask(self, mask=None, flag=np.nan, invert=False, crop=False, bands=None):
 951        """
 952         Apply a mask to an image, flagging masked pixels with the specified value. Note that this applies the mask to the
 953         image in-situ.
 954
 955         Args:
 956            flag (float): the value to use for masked pixels. Default is np.nan
 957            mask (ndarray): a numpy array defining the mask polygon of the format [[x1,y1],[x2,y2],...]. If None is passed then
 958                    pickPolygon( ... ) is used to interactively define a polygon. If a file path is passed then the polygon
 959                    will be loaded using np.load( ... ). Alternatively if mask.shape == image.shape[0,1] then it is treated as a
 960                    binary image mask (must be boolean) and True values will be masked across all bands. Default is None.
 961            invert (bool): if True, pixels within the polygon will be masked. If False, pixels outside the polygon are masked. Default is False.
 962            crop (bool): True if rows/columns containing only zeros should be removed. Default is False.
 963            bands (tuple): the bands of the image to plot if no mask is specified. If None, the middle band is used.
 964
 965         Returns:
 966            Tuple containing
 967
 968            - mask (ndarray): a boolean array with True where pixels are masked and False elsewhere.
 969            - poly (ndarray): the mask polygon array in the format described above. Useful if the polygon was interactively defined.
 970         """
 971
 972        if mask is None:  # pick mask interactively
 973            if bands is None:
 974                bands = int(self.band_count() / 2)
 975
 976            regions = self.pickPolygons(region_names=["mask"], bands=bands)
 977
 978            # the user bailed without picking a mask?
 979            if len(regions) == 0:
 980                print("Warning - no mask picked/applied.")
 981                return
 982
 983            # extract polygon mask
 984            mask = regions[0]
 985
 986        # convert polygon mask to binary mask
 987        if mask.shape[1] == 2:
 988
 989            # build meshgrid with pixel coords
 990            xx, yy = np.meshgrid(np.arange(self.xdim()), np.arange(self.ydim()))
 991            xx = xx.flatten()
 992            yy = yy.flatten()
 993            points = np.vstack([xx, yy]).T  # coordinates of each pixel
 994
 995            # calculate per-pixel mask
 996            mask = path.Path(mask).contains_points(points)
 997            mask = mask.reshape((self.ydim(), self.xdim())).T
 998
 999            # flip as we want to mask (==True) outside points (unless invert is true)
1000            if not invert:
1001                mask = np.logical_not(mask)
1002
1003        # apply binary image mask
1004        assert mask.shape[0] == self.data.shape[0] and mask.shape[1] == self.data.shape[1], \
1005            "Error - mask shape %s does not match image shape %s" % (mask.shape, self.data.shape)
1006        for b in range(self.band_count()):
1007            self.data[:, :, b][mask] = flag
1008
1009        # crop image
1010        if crop:
1011            self.crop_to_data()
1012
1013        return mask
1014
1015    def crop_to_data(self):
1016        """
1017        Remove padding of nan or zero pixels from image. Note that this is performed in place.
1018        """
1019        valid = np.isfinite(self.data).any(axis=-1) & (self.data != 0).any(axis=-1)
1020
1021        # integrate along axes
1022        xdata = np.sum(valid, axis=1) > 0.0
1023        ydata = np.sum(valid, axis=0) > 0.0
1024
1025        # calculate domain containing valid pixels
1026        xmin = np.argmax(xdata)
1027        xmax = xdata.shape[0] - np.argmax(xdata[::-1])
1028        ymin = np.argmax(ydata)
1029        ymax = ydata.shape[0] - np.argmax(ydata[::-1])
1030
1031        # crop
1032        self.data = self.data[xmin:xmax, ymin:ymax, :]
1033        
1034        # shift affine origin to new top-left pixel
1035        if self.affine is not None: 
1036            a = self.affine # shorthand for affine
1037            new_affine = list(self.affine)
1038            new_affine[0] = a[0] + xmin*a[1] + ymin*a[2]
1039            new_affine[3] = a[3] + xmin*a[4] + ymin*a[5]
1040            self.affine = np.array(new_affine)
1041            self.header['affine'] = self.affine
1042            
1043    ##################################################
1044    ## Interactive tools for picking regions/pixels
1045    ##################################################
1046    def pickPolygons(self, region_names, bands=0):
1047        """
1048        Creates a matplotlib gui for selecting polygon regions in an image.
1049
1050        Args:
1051            names (list, str): a list containing the names of the regions to pick. If a string is passed only one name is used.
1052            bands (tuple): the bands of the image to plot.
1053        """
1054
1055        if isinstance(region_names, str):
1056            region_names = [region_names]
1057
1058        assert isinstance(region_names, list), "Error - names must be a list or a string."
1059
1060        # set matplotlib backend
1061        backend = matplotlib.get_backend()
1062        matplotlib.use('Qt5Agg')  # need this backend for ROIPoly to work
1063
1064        # plot image and extract roi's
1065        fig, ax = self.quick_plot(bands)
1066        roi = MultiRoi(roi_names=region_names)
1067        plt.close(fig)  # close figure
1068
1069        # extract regions
1070        regions = []
1071        for name, r in roi.rois.items():
1072            # store region
1073            x = r.x
1074            y = r.y
1075            regions.append(np.vstack([x, y]).T)
1076
1077        # restore matplotlib backend (if possible)
1078        try:
1079            matplotlib.use(backend)
1080        except:
1081            print("Warning: could not reset matplotlib backend. Plots will remain interactive...")
1082            pass
1083
1084        return regions
1085
1086    def pickPoints(self, n=-1, bands=hylite.RGB, integer=True, title="Pick Points", **kwds):
1087        """
1088        Creates a matplotlib gui for picking pixels from an image.
1089
1090        Args:
1091            n (int): the number of pixels to pick, or -1 if the user can select as many as they wish. Default is -1.
1092            bands (tuple): the bands of the image to plot. Default is HyImage.RGB
1093            integer (bool): True if points coordinates should be cast to integers (for use as indices). Default is True.
1094            title (str): The title of the point picking window.
1095            **kwds: Keywords are passed to HyImage.quick_plot( ... ).
1096
1097        Returns:
1098            A list containing the picked point coordinates [ (x1,y1), (x2,y2), ... ].
1099        """
1100
1101        # set matplotlib backend
1102        backend = matplotlib.get_backend()
1103        matplotlib.use('Qt5Agg')  # need this backend for ROIPoly to work
1104
1105        # create figure
1106        fig, ax = self.quick_plot( bands, **kwds )
1107        ax.set_title(title)
1108
1109        # get points
1110        points = fig.ginput( n )
1111
1112        if integer:
1113            points = [ (int(p[0]), int(p[1])) for p in points ]
1114
1115        # restore matplotlib backend (if possible)
1116        try:
1117            matplotlib.use(backend)
1118        except:
1119            print("Warning: could not reset matplotlib backend. Plots will remain interactive...")
1120            pass
1121
1122        return points
1123
1124    def pickSamples(self, names=None, store=True, **kwds):
1125        """
1126        Pick sample probe points and store these in the image header file.
1127
1128        Args:
1129            names (str, list): the name of the sample to pick, or a list of names to pick multiple.
1130            store (bool): True if sample should be stored in the image header file (for later access). Default is True.
1131            **kwds: Keywords are passed to HyImage.quick_plot( ... )
1132
1133        Returns:
1134            a list containing a list of points for each sample.
1135        """
1136
1137        if isinstance(names, str):
1138            names = [names]
1139
1140        # pick points
1141        points = []
1142        for s in names:
1143            pnts = self.pickPoints(title="%s" % s, **kwds)
1144            if store:
1145                self.header['sample %s' % s] = pnts # store in header
1146            points.append(pnts)
1147        # add class to header file
1148        if store:
1149            cls_names = self.header.get_class_names()
1150            if cls_names is None:
1151                cls_names = []
1152            self.header['class names'] = cls_names + names
1153
1154        return points

A class for hyperspectral image data. These can be individual scenes or hyperspectral orthoimages.

HyImage(data, **kwds)
25    def __init__(self, data, **kwds):
26        """
27        Args:
28            data (ndarray): a numpy array such that data[x][y][band] gives each pixel value.
29            **kwds:
30                wav = A numpy array containing band wavelengths for this image.
31                affine = an affine transform of the format returned by GDAL.GetGeoTransform().
32                projection = string defining the project. Default is None.
33                sensor = sensor name. Default is "unknown".
34                header = path to associated header file. Default is None.
35        """
36
37        #call constructor for HyData
38        super().__init__(data, **kwds)
39
40        # special case - if dataset only has oneband, slice it so it still has
41        # the format data[x,y,b].
42        if not self.data is None:
43            if len(self.data.shape) == 1:
44                self.data = self.data[None, None, :] # single pixel image
45            if len(self.data.shape) == 2:
46                self.data = self.data[:, :, None] # single band iamge
47
48        #load any additional project information (specific to images)
49        self.set_projection(kwds.get("projection",None))
50        self.affine = np.array( kwds.get("affine",[0,1,0,0,0,1]) )
51        self.header['affine'] = kwds.get("affine",[0,1,0,0,0,1]) # also store this here
52
53        # wavelengths
54        if 'wav' in kwds:
55            self.set_wavelengths(kwds['wav'])
56
57        #special header formatting
58        self.header['file type'] = 'ENVI Standard'
Arguments:
  • data (ndarray): a numpy array such that data[x][y][band] gives each pixel value.
  • **kwds: wav = A numpy array containing band wavelengths for this image. affine = an affine transform of the format returned by GDAL.GetGeoTransform(). projection = string defining the project. Default is None. sensor = sensor name. Default is "unknown". header = path to associated header file. Default is None.
affine
def copy(self, data=True):
60    def copy(self,data=True):
61        """
62        Make a deep copy of this image instance.
63
64        Args:
65            data (bool): True if a copy of the data should be made, otherwise only copy header.
66
67        Returns:
68            a new HyImage instance.
69        """
70        if not data:
71            return HyImage(None, header=self.header.copy(), projection=self.projection, affine=self.affine)
72        else:
73            return HyImage( self.data.copy(), header=self.header.copy(), projection=self.projection, affine=self.affine)

Make a deep copy of this image instance.

Arguments:
  • data (bool): True if a copy of the data should be made, otherwise only copy header.
Returns:

a new HyImage instance.

def T(self):
75    def T(self):
76        """
77        Return a transposed view of the data matrix (corresponding with the [y,x] indexing used by matplotlib, opencv etc.
78        """
79        return np.transpose(self.data, (1,0,2))

Return a transposed view of the data matrix (corresponding with the [y,x] indexing used by matplotlib, opencv etc.

def xdim(self):
81    def xdim(self):
82        """
83        Return number of pixels in x (first dimension of data array)
84        """
85        return self.data.shape[0]

Return number of pixels in x (first dimension of data array)

def ydim(self):
87    def ydim(self):
88        """
89        Return number of pixels in y (second dimension of data array)
90        """
91        return self.data.shape[1]

Return number of pixels in y (second dimension of data array)

def aspx(self):
93    def aspx(self):
94        """
95        Return the aspect ratio of this image (width/height).
96        """
97        return self.ydim() / self.xdim()

Return the aspect ratio of this image (width/height).

def get_extent(self):
103    def get_extent(self):
104        """
105        Returns the width and height of this image in world coordinates.
106
107        Returns:
108            tuple with (width, height).
109        """
110        return self.xdim * self.pixel_size[0], self.ydim * self.pixel_size[1]

Returns the width and height of this image in world coordinates.

Returns:

tuple with (width, height).

def set_projection(self, proj):
112    def set_projection(self,proj):
113        """
114        Set this project to an existing osgeo.osr.SpatialReference or GDAL georeference string.
115
116        Args:
117            proj (str, osgeo.osr.SpatialReference): the project to use as osgeo.osr.SpatialReference or GDAL georeference string.
118        """
119        if proj is None:
120            self.projection = None
121        else:
122            try:
123                from osgeo.osr import SpatialReference
124            except:
125                assert False, "Error - GDAL must be installed to work with spatial projections in hylite."
126            if isinstance(proj, SpatialReference):
127                self.projection = proj
128            elif isinstance(proj, str):
129                self.projection = SpatialReference(proj)
130            else:
131                print("Invalid project %s" % proj)
132                raise

Set this project to an existing osgeo.osr.SpatialReference or GDAL georeference string.

Arguments:
  • proj (str, osgeo.osr.SpatialReference): the project to use as osgeo.osr.SpatialReference or GDAL georeference string.
def set_projection_EPSG(self, EPSG):
134    def set_projection_EPSG(self,EPSG):
135        """
136        Sets this image project using an EPSG code.
137
138        Args:
139            EPSG (str): string EPSG code that can be passed to SpatialReference.SetFromUserInput(...).
140        """
141
142        try:
143            from osgeo.osr import SpatialReference
144        except:
145            assert False, "Error - GDAL must be installed to work with spatial projections in hylite."
146
147        self.projection = SpatialReference()
148        self.projection.SetFromUserInput(EPSG)

Sets this image project using an EPSG code.

Arguments:
  • EPSG (str): string EPSG code that can be passed to SpatialReference.SetFromUserInput(...).
def get_projection_EPSG(self):
150    def get_projection_EPSG(self):
151        """
152        Gets a string describing this projections EPSG code (if it is an EPSG project).
153
154        Returns:
155            an EPSG code string of the format "EPSG:XXXX".
156        """
157        if self.projection is None:
158            return None
159        else:
160            return "%s:%s" % (self.projection.GetAttrValue("AUTHORITY",0),self.projection.GetAttrValue("AUTHORITY",1))

Gets a string describing this projections EPSG code (if it is an EPSG project).

Returns:

an EPSG code string of the format "EPSG:XXXX".

def pix_to_world(self, px, py, proj=None):
162    def pix_to_world(self, px, py, proj=None):
163        """
164        Take pixel coordinates and return world coordinates
165
166        Args:
167            px (int): the pixel x-coord.
168            py (int): the pixel y-coord.
169            proj (str, osr.SpatialReference): the coordinate system to use. Default (None) uses the same system as this image. Otherwise
170                   an osr.SpatialReference can be passed (HyImage.project), or an EPSG string (e.g. get_projection_EPSG(...)).
171        Returns:
172            the world coordinates in the coordinate system defined by get_projection_EPSG(...).
173        """
174
175        try:
176            from osgeo import osr
177            import osgeo.gdal as gdal
178            from osgeo import ogr
179        except:
180            assert False, "Error - GDAL must be installed to work with spatial projections in hylite."
181
182        # parse project
183        if proj is None:
184            proj = self.projection
185        elif isinstance(proj, str) or isinstance(proj, int):
186            epsg = proj
187            if isinstance(epsg, str):
188                try:
189                    epsg = int(str.split(':')[1])
190                except:
191                    assert False, "Error - %s is an invalid EPSG code." % proj
192            proj = osr.SpatialReference()
193            proj.ImportFromEPSG(epsg)
194
195        # check we have all the required info
196        assert isinstance(proj, osr.SpatialReference), "Error - invalid spatial reference %s" % proj
197        assert (not self.affine is None) and (
198            not self.projection is None), "Error - project information is undefined."
199
200        #project to world coordinates in this images project/world coords
201        x,y = gdal.ApplyGeoTransform(self.affine, px, py)
202
203        #project to target coords (if different)
204        if not proj.IsSameGeogCS(self.projection):
205            P = ogr.Geometry(ogr.wkbPoint)
206            if proj.EPSGTreatsAsNorthingEasting():
207                P.AddPoint(x, y)
208            else:
209                P.AddPoint(y, x)
210            P.AssignSpatialReference(self.projection)  # tell the point what coordinates it's in
211            P.TransformTo(proj)  # reproject it to the out spatial reference
212            x, y = P.GetX(), P.GetY()
213
214            #do we need to transpose?
215            if proj.EPSGTreatsAsLatLong():
216                x,y=y,x #we want lon,lat not lat,lon
217        return x, y

Take pixel coordinates and return world coordinates

Arguments:
  • px (int): the pixel x-coord.
  • py (int): the pixel y-coord.
  • proj (str, osr.SpatialReference): the coordinate system to use. Default (None) uses the same system as this image. Otherwise an osr.SpatialReference can be passed (HyImage.project), or an EPSG string (e.g. get_projection_EPSG(...)).
Returns:

the world coordinates in the coordinate system defined by get_projection_EPSG(...).

def world_to_pix(self, x, y, proj=None):
219    def world_to_pix(self, x, y, proj = None):
220        """
221        Take world coordinates and return pixel coordinates
222
223        Args:
224            x (float): the world x-coord.
225            y (float): the world y-coord.
226            proj (str, osr.SpatialReference): the coordinate system of the input coordinates. Default (None) uses the same system as this image. Otherwise
227                   an osr.SpatialReference can be passed (HyImage.project), or an EPSG string (e.g. get_projection_EPSG(...)).
228
229        Returns:
230            the pixel coordinates based on the affine transform stored in self.affine.
231        """
232
233        try:
234            from osgeo import osr
235            import osgeo.gdal as gdal
236            from osgeo import ogr
237        except:
238            assert False, "Error - GDAL must be installed to work with spatial projections in hylite."
239
240        # parse project
241        if proj is None:
242            proj = self.projection
243        elif isinstance(proj, str) or isinstance(proj, int):
244            epsg = proj
245            if isinstance(epsg, str):
246                try:
247                    epsg = int(str.split(':')[1])
248                except:
249                    assert False, "Error - %s is an invalid EPSG code." % proj
250            proj = osr.SpatialReference()
251            proj.ImportFromEPSG(epsg)
252
253
254        # check we have all the required info
255        assert isinstance(proj, osr.SpatialReference), "Error - invalid spatial reference %s" % proj
256        assert (not self.affine is None) and (not self.projection is None), "Error - project information is undefined."
257
258        # project to this images CS (if different)
259        if not proj.IsSameGeogCS(self.projection):
260            P = ogr.Geometry(ogr.wkbPoint)
261            if proj.EPSGTreatsAsNorthingEasting():
262                P.AddPoint(x, y)
263            else:
264                P.AddPoint(y, x)
265            P.AssignSpatialReference(proj)  # tell the point what coordinates it's in
266            P.AddPoint(x, y)
267            P.TransformTo(self.projection)  # reproject it to the out spatial reference
268            x, y = P.GetX(), P.GetY()
269            if self.projection.EPSGTreatsAsLatLong(): # do we need to transpose?
270                x, y = y, x  # we want lon,lat not lat,lon
271
272        inv = gdal.InvGeoTransform(self.affine)
273        assert not inv is None, "Error - could not invert affine transform?"
274
275        #apply
276        return gdal.ApplyGeoTransform(inv, x, y)

Take world coordinates and return pixel coordinates

Arguments:
  • x (float): the world x-coord.
  • y (float): the world y-coord.
  • proj (str, osr.SpatialReference): the coordinate system of the input coordinates. Default (None) uses the same system as this image. Otherwise an osr.SpatialReference can be passed (HyImage.project), or an EPSG string (e.g. get_projection_EPSG(...)).
Returns:

the pixel coordinates based on the affine transform stored in self.affine.

def crop(self, xmin, xmax, ymin, ymax, bands=None):
278    def crop(self, xmin, xmax, ymin, ymax, bands=None):
279        """
280        Return a cropped copy of this image.
281
282        Args:
283            xmin, xmax (int): pixel bounds in x (rows)
284            ymin, ymax (int): pixel bounds in y (columns)
285            bands (None, list, tuple): optional band indices or (min,max) range
286
287        Returns:
288            HyImage: cropped image with updated affine transform
289        """
290
291        # ---- validate bounds ----
292        xmin = int(max(0, xmin))
293        ymin = int(max(0, ymin))
294        xmax = int(min(self.xdim(), xmax))
295        ymax = int(min(self.ydim(), ymax))
296
297        assert xmin < xmax and ymin < ymax, "Invalid crop extent."
298
299        # ---- crop data ----
300        if bands is None:
301            data = self.data[xmin:xmax, ymin:ymax, :].copy()
302            wav = self.get_wavelengths()
303        else: # band selection
304            if isinstance(bands, tuple):
305                b0 = self.get_band_index(bands[0])
306                b1 = self.get_band_index(bands[1])
307                data = self.data[xmin:xmax, ymin:ymax, b0:b1].copy()
308                wav = self.get_wavelengths()[b0:b1]
309            else:
310                idx = [self.get_band_index(b) for b in bands]
311                data = self.data[xmin:xmax, ymin:ymax, idx].copy()
312                wav = self.get_wavelengths()[idx]
313
314        # ---- update affine transform ----
315        if self.affine is not None:
316            a = list(self.affine)
317            new_affine = a.copy()
318
319            # shift origin to new top-left pixel
320            new_affine[0] = a[0] + xmin*a[1] + ymin*a[2]
321            new_affine[3] = a[3] + xmin*a[4] + ymin*a[5]
322        else:
323            new_affine = None
324
325        # ---- construct output image ----
326        out = HyImage(
327            data,
328            header=self.header.copy(),
329            projection=self.projection,
330            affine=new_affine,
331            wav=wav
332        )
333
334        return out

Return a cropped copy of this image.

Arguments:
  • xmin, xmax (int): pixel bounds in x (rows)
  • ymin, ymax (int): pixel bounds in y (columns)
  • bands (None, list, tuple): optional band indices or (min,max) range
Returns:

HyImage: cropped image with updated affine transform

def resize(self, newdims: tuple, interpolation: int = 1):
336    def resize(self, newdims: tuple, interpolation: int = 1):
337        """
338        Resize this image with opencv and update affine transform accordingly.
339
340        Args:
341            newdims (tuple): the new image dimensions (xdim, ydim)
342            interpolation (int): opencv interpolation method. Default is cv2.INTER_LINEAR.
343        """
344        import cv2  # avoid import issues if opencv is missing
345
346        old_x, old_y = self.xdim(), self.ydim()
347        new_x, new_y = int(newdims[0]), int(newdims[1])
348
349        assert new_x > 0 and new_y > 0, "Invalid resize dimensions."
350
351        # resize data (opencv uses width, height = y, x)
352        self.data = cv2.resize(
353            self.data,
354            (new_y, new_x),
355            interpolation=interpolation
356        )
357
358        # update affine transform
359        if self.affine is not None:
360            a = list(self.affine)
361
362            sx = old_x / new_x
363            sy = old_y / new_y
364
365            self.affine = [
366                a[0],          # x origin unchanged
367                a[1] * sx,     # pixel width
368                a[2] * sy,     # row rotation
369                a[3],          # y origin unchanged
370                a[4] * sx,     # column rotation
371                a[5] * sy      # pixel height
372            ]

Resize this image with opencv and update affine transform accordingly.

Arguments:
  • newdims (tuple): the new image dimensions (xdim, ydim)
  • interpolation (int): opencv interpolation method. Default is cv2.INTER_LINEAR.
def tile(self, tile_size):
374    def tile(self, tile_size):
375        """
376        Break image into tiles of given size and return a list of HyImage tiles.
377        Each tile has an updated affine transform reflecting its position in the original image.
378
379        Args:
380            tile_size (tuple): (tile_x, tile_y) in pixels
381        Returns:
382            list of HyImage
383        """
384        tiles = []
385        tx, ty = tile_size
386        nx, ny = self.xdim(), self.ydim()
387        for i in range(0, nx, tx):
388            for j in range(0, ny, ty):
389                data_tile = self.data[i:min(i+tx, nx), j:min(j+ty, ny), :].copy() # slice data
390
391                # compute new affine
392                # affine = [x0, dx, rotx, y0, roty, dy]
393                x0, dx, rotx, y0, roty, dy = self.affine
394                new_x0 = x0 + i*dx + j*rotx
395                new_y0 = y0 + i*roty + j*dy
396                new_affine = [new_x0, dx, rotx, new_y0, roty, dy]
397
398                # create tile
399                tile_img = HyImage(
400                    data_tile,
401                    affine=new_affine,
402                    projection=self.projection,
403                    wav=self.get_wavelengths(),
404                    header=self.header.copy()
405                )
406                tile_img.header['xleft'] = i
407                tile_img.header['ytop'] = j
408                tiles.append(tile_img)
409
410        return tiles

Break image into tiles of given size and return a list of HyImage tiles. Each tile has an updated affine transform reflecting its position in the original image.

Arguments:
  • tile_size (tuple): (tile_x, tile_y) in pixels
Returns:

list of HyImage

@staticmethod
def mosaic( tiles, blend='mean', resampling='nearest', out_affine=None, out_shape=None):
412    @staticmethod
413    def mosaic(
414        tiles,
415        blend="mean",
416        resampling="nearest",
417        out_affine=None,
418        out_shape=None,
419    ):
420        """
421        Mosaic georeferenced HyImage tiles using GDAL. Note that this assumes all tiles are in the same coordinate system.
422
423        Args:
424            tiles (list[HyImage])
425            blend (str): 'first', 'min', 'max', 'mean', 'median'
426            resampling (str): 'nearest', 'bilinear', 'cubic'
427            out_affine (list): optional 6-element affine to define output grid. If None, the affine of the first tile is used.
428            out_shape (tuple): optional (xdim, ydim) shape of the output grid. If None, the extent of all tiles will be used.
429        Returns:
430            HyImage
431        """
432        import numpy as np
433        from hylite.project.align import resample_raster
434        from osgeo import gdal, osr
435
436        assert len(tiles) > 0
437        assert blend in ("first", "min", "max", "mean", "median")
438
439        # compute bounds in world coordinates
440        # N.B. THIS ASSUMES ALL DATA ARE IN THE SAME CRS
441        points = []
442        for t in tiles:
443            points.append( t.pix_to_world(0,0) )
444            points.append( t.pix_to_world(t.xdim()+1,t.ydim()+1) )
445        min_x, min_y = np.min(points, axis=0)
446        max_x, max_y = np.max(points, axis=0)
447        if out_shape is None:
448            out_shape = (np.array(tiles[0].world_to_pix(max_x, max_y)) + np.array(tiles[0].world_to_pix(min_x, min_y))).round().astype(int)
449        if out_affine is None:
450            out_affine = list(tiles[0].affine)
451            out_affine[0] = min_x
452            out_affine[3] = max_y
453
454        if blend == "first": # fill output, first come, first served.
455            out = np.full( tuple(out_shape) + (tiles[0].band_count(),), np.nan, dtype=np.float32)
456            for t in tiles:
457                r = resample_raster( t.data, t.affine, out_affine, out_shape )
458                mask = np.isnan(out)
459                out[mask] = r[mask]
460        elif blend == "min": # keep minimum value in case of overlap
461            out = np.nanmin( np.stack([ resample_raster( t.data, t.affine, out_affine, out_shape ) for t in tiles ], 
462                                      axis=0), axis=0 )
463        elif blend == "max": # keep maximum value in case of overlap
464            out = np.nanmax( np.stack([ resample_raster( t.data, t.affine, out_affine, out_shape ) for t in tiles ], 
465                                      axis=0), axis=0 )
466        elif blend == "mean": # use average in case of overlap
467            out = np.nanmean( np.stack([ resample_raster( t.data, t.affine, out_affine, out_shape ) for t in tiles ], 
468                                       axis=0), axis=0 )
469        elif blend == "median": # use mean in case of overlap
470            out = np.nanmedian( np.stack([ resample_raster( t.data, t.affine, out_affine, out_shape ) for t in tiles ], 
471                                         axis=0), axis=0 )
472
473        # Return HyImage
474        out =  HyImage(
475            out,
476            affine=out_affine,
477            projection=tiles[0].projection,
478            wav=tiles[0].get_wavelengths(),
479            header=tiles[0].header.copy()
480        )
481        if 'xleft' in out.header: del out.header['xleft'] # stored in tiles, but not meaningful here
482        if 'ytop' in out.header: del out.header['ytop'] # stored in tiles, but not meaningful here
483
484        return out

Mosaic georeferenced HyImage tiles using GDAL. Note that this assumes all tiles are in the same coordinate system.

Arguments:
  • tiles (list[HyImage])
  • blend (str): 'first', 'min', 'max', 'mean', 'median'
  • resampling (str): 'nearest', 'bilinear', 'cubic'
  • out_affine (list): optional 6-element affine to define output grid. If None, the affine of the first tile is used.
  • out_shape (tuple): optional (xdim, ydim) shape of the output grid. If None, the extent of all tiles will be used.
Returns:

HyImage

def flip(self, axis='x'):
490    def flip(self, axis='x'):
491        """
492        Flip the image on the x or y axis. Note that this will remove any defined affine transform.
493
494        Args:
495            axis (str): 'x' or 'y' or both 'xy'.
496        """
497
498        if 'x' in axis.lower():
499            self.data = np.flip(self.data,axis=0)
500        if 'y' in axis.lower():
501            self.data = np.flip(self.data,axis=1)
502        self.affine = None
503        if 'affine' in self.header: del self.header['affine']
504        self.push_to_header() # update width and height info

Flip the image on the x or y axis. Note that this will remove any defined affine transform.

Arguments:
  • axis (str): 'x' or 'y' or both 'xy'.
def rot90(self):
506    def rot90(self):
507        """
508        Rotate this image by 90 degrees by transposing the underlying data array. Combine with flip('x') or flip('y')
509        to achieve positive/negative rotations.
510        """
511        self.data = np.transpose( self.data, (1,0,2) )
512        self.affine = None
513        if 'affine' in self.header: del self.header['affine']
514        self.push_to_header() # update width and height info

Rotate this image by 90 degrees by transposing the underlying data array. Combine with flip('x') or flip('y') to achieve positive/negative rotations.

def fill_holes(self):
519    def fill_holes(self):
520        """
521        Replaces nan pixel with an average of their neighbours, thus removing 1-pixel large holes from an image. Note that
522        for performance reasons this assumes that holes line up across bands. Note that this is not vectorized so very slow...
523        """
524
525        # perform greyscale dilation
526        dilate = self.data.copy()
527        mask = np.logical_not(np.isfinite(dilate))
528        dilate[mask] = 0
529        for b in range(self.band_count()):
530            dilate[:, :, b] = sp.ndimage.grey_dilation(dilate[:, :, b], size=(3, 3))
531
532        # map back to holes in dataset
533        self.data[mask] = dilate[mask]
534        #self.data[self.data == 0] = np.nan  # replace remaining 0's with nans

Replaces nan pixel with an average of their neighbours, thus removing 1-pixel large holes from an image. Note that for performance reasons this assumes that holes line up across bands. Note that this is not vectorized so very slow...

def blur(self, n=3):
536    def blur(self, n=3):
537        """
538        Applies a gaussian kernel of size n to the image using OpenCV.
539
540        Args:
541            n (int): the dimensions of the gaussian kernel to convolve. Default is 3. Increase for more blurry results.
542        """
543        import cv2 # import this here to avoid errors if opencv is not installed properly
544
545        nanmask = np.isnan(self.data)
546        assert isinstance(n, int) and n >= 3, "Error - invalid kernel. N must be an integer > 3. "
547        kernel = np.ones((n, n), np.float32) / (n ** 2)
548        self.data = cv2.filter2D(self.data, -1, kernel)
549        self.data[nanmask] = np.nan  # remove mask

Applies a gaussian kernel of size n to the image using OpenCV.

Arguments:
  • n (int): the dimensions of the gaussian kernel to convolve. Default is 3. Increase for more blurry results.
def erode(self, size=3, iterations=1):
551    def erode(self, size=3, iterations=1):
552        """
553        Apply an erode filter to this image to expand background (nan) pixels. Refer to open-cv's erode
554        function for more details.
555
556        Args:
557            size (int): the size of the erode filter. Default is a 3x3 kernel.
558            iterations (int): the number of erode iterations. Default is 1.
559        """
560        import cv2 # import this here to avoid errors if opencv is not installed properly
561
562        # erode
563        kernel = np.ones((size, size), np.uint8)
564        if self.is_float():
565            mask = np.isfinite(self.data).any(axis=-1)
566            mask = cv2.erode(mask.astype(np.uint8), kernel, iterations=iterations)
567            self.data[mask == 0, :] = np.nan
568        else:
569            mask = (self.data != 0).any( axis=-1 )
570            mask = cv2.erode(mask.astype(np.uint8), kernel, iterations=iterations)
571            self.data[mask == 0, :] = 0

Apply an erode filter to this image to expand background (nan) pixels. Refer to open-cv's erode function for more details.

Arguments:
  • size (int): the size of the erode filter. Default is a 3x3 kernel.
  • iterations (int): the number of erode iterations. Default is 1.
def despeckle(self, size=5):
573    def despeckle(self, size=5):
574        """
575        Despeckle each band of this image (independently) using a median filter.
576
577        Args:
578            size (int): the size of the median filter kernel. Default is 5. Must be an odd number.
579        """
580
581        assert (size % 2) == 1, "Error - size must be an odd integer"
582        import cv2 # import this here to avoid errors if opencv is not installed properly
583        if self.is_float():
584            self.data = cv2.medianBlur( self.data.astype(np.float32), size )
585        else:
586            self.data = cv2.medianBlur( self.data, size )

Despeckle each band of this image (independently) using a median filter.

Arguments:
  • size (int): the size of the median filter kernel. Default is 5. Must be an odd number.
def get_keypoints( self, band, eq=False, mask=True, method='sift', cfac=0.0, bfac=0.0, **kwds):
591    def get_keypoints(self, band, eq=False, mask=True, method='sift', cfac=0.0,bfac=0.0, **kwds):
592        """
593        Get feature descriptors from the specified band.
594
595        Args:
596            band (int,float,str,tuple): the band index (int) or wavelength (float) to extract features from. Alternatively, a tuple can be passed
597                    containing a range of bands (min : max) to average before feature matching.
598            eq (bool): True if the image should be histogram equalized first. Default is False.
599            mask (bool): True if 0 value pixels should be masked. Default is True.
600            method (str): the feature detector to use. Options are 'SIFT' and 'ORB' (faster but less accurate). Default is 'SIFT'.
601            cfac (float): contrast adjustment to apply to hyperspectral bands before matching. Default is 0.0.
602            bfac (float): brightness adjustment to apply to hyperspectral bands before matching. Default is 0.0.
603            **kwds: keyword arguments are passed to the opencv feature detector. For SIFT these are:
604
605                - contrastThreshold: default is 0.01.
606                - edgeThreshold: default is 10.
607                - sigma: default is 1.0
608
609                For ORB these are:
610
611                - nfeatures = the number of features to detect. Default is 5000.
612
613            Returns:
614                Tuple containing
615
616                    - k (ndarray): the keypoints detected
617                    - d (ndarray): corresponding feature descriptors
618         """
619        import cv2 # import this here to avoid errors if opencv is not installed properly
620
621        # get image
622        if isinstance(band, int) or isinstance(band, float): #single band
623            image = self.data[:, :, self.get_band_index(band)]
624        elif isinstance(band,tuple): #range of bands (averaged)
625            idx0 = self.get_band_index(band[0])
626            idx1 = self.get_band_index(band[1])
627
628            #deal with out of range errors
629            if idx0 is None:
630                idx0 = 0
631            if idx1 is None:
632                idx1 = self.band_count()
633
634            #average bands
635            image = np.nanmean(self.data[:,:,idx0:idx1],axis=2)
636        else:
637            assert False, "Error, unrecognised band %s" % band
638
639        #normalise image to range 0 - 1
640        image -= np.nanmin(image)
641        image = image / np.nanmax(image)
642
643        #apply brightness/contrast adjustment
644        image = (1.0+cfac)*image + bfac
645        image[image > 1.0] = 1.0
646        image[image < 0.0] = 0.0
647
648        #convert image to uint8 for opencv
649        image = np.uint8(255 * image)
650        if eq:
651            image = cv2.equalizeHist(image)
652
653        if mask:
654            mask = np.zeros(image.shape, dtype=np.uint8)
655            mask[image != 0] = 255  # include only non-zero pixels
656        else:
657            mask = None
658
659        if 'sift' in method.lower():  # SIFT
660
661            # setup default keywords
662            kwds["contrastThreshold"] = kwds.get("contrastThreshold", 0.01)
663            kwds["edgeThreshold"] = kwds.get("edgeThreshold", 10)
664            kwds["sigma"] = kwds.get("sigma", 1.0)
665
666            # make feature detector
667            #alg = cv2.xfeatures2d.SIFT_create(**kwds)
668            alg = cv2.SIFT_create()
669        elif 'orb' in method.lower():  # orb
670            kwds['nfeatures'] = kwds.get('nfeatures', 5000)
671            alg = cv2.ORB_create(scoreType=cv2.ORB_FAST_SCORE, **kwds)
672        else:
673            assert False, "Error - %s is not a recognised feature detector." % method
674
675        # detect keypoints
676        kp = alg.detect(image, mask)
677
678        # extract and return feature vectors
679        return alg.compute(image, kp)

Get feature descriptors from the specified band.

Arguments:
  • band (int,float,str,tuple): the band index (int) or wavelength (float) to extract features from. Alternatively, a tuple can be passed containing a range of bands (min : max) to average before feature matching.
  • eq (bool): True if the image should be histogram equalized first. Default is False.
  • mask (bool): True if 0 value pixels should be masked. Default is True.
  • method (str): the feature detector to use. Options are 'SIFT' and 'ORB' (faster but less accurate). Default is 'SIFT'.
  • cfac (float): contrast adjustment to apply to hyperspectral bands before matching. Default is 0.0.
  • bfac (float): brightness adjustment to apply to hyperspectral bands before matching. Default is 0.0.
  • **kwds: keyword arguments are passed to the opencv feature detector. For SIFT these are:

    • contrastThreshold: default is 0.01.
    • edgeThreshold: default is 10.
    • sigma: default is 1.0

    For ORB these are:

    • nfeatures = the number of features to detect. Default is 5000.
  • Returns: Tuple containing

    • k (ndarray): the keypoints detected
    • d (ndarray): corresponding feature descriptors
@classmethod
def match_keypoints( cls, kp1, kp2, d1, d2, method='SIFT', dist=0.7, tree=5, check=100, min_count=5):
681    @classmethod
682    def match_keypoints(cls, kp1, kp2, d1, d2, method='SIFT', dist=0.7, tree = 5, check = 100, min_count=5):
683        """
684        Compares keypoint feature vectors from two images and returns matching pairs.
685
686        Args:
687            kp1 (ndarray): keypoints from the first image
688            kp2 (ndarray): keypoints from the second image
689            d1 (ndarray): descriptors for the keypoints from the first image
690            d2 (ndarray): descriptors for the keypoints from the second image
691            method (str): the method used to calculate the feature descriptors. Should be 'sift' or 'orb'. Default is 'sift'.
692            dist (float): minimum match distance (0 to 1), default is 0.7
693            tree (int): not sure what this does? Default is 5. See open-cv docs.
694            check (int): ditto. Default is 100.
695            min_count (int): the minimum number of matches to consider a valid matching operation. If fewer matches are found,
696                       then the function returns None, None. Default is 5.
697        """
698        import cv2 # import this here to avoid errors if opencv is not installed properly
699        if 'sift' in method.lower():
700            algorithm = cv2.NORM_INF
701        elif 'orb' in method.lower():
702            algorithm = cv2.NORM_HAMMING
703        else:
704            assert False, "Error - unknown matching algorithm %s" % method
705
706        #calculate flann matches
707        index_params = dict(algorithm=algorithm, trees=tree)
708        search_params = dict(checks=check)
709        flann = cv2.FlannBasedMatcher(index_params, search_params)
710        matches = flann.knnMatch(d1, d2, k=2)
711
712        # store all the good matches as per Lowe's ratio test.
713        good = []
714        for m, n in matches:
715            if m.distance < dist * n.distance:
716                good.append(m)
717
718        if len(good) < min_count:
719            return None, None
720        else:
721            src_pts = np.float32([kp1[m.queryIdx].pt for m in good]).reshape(-1, 1, 2)
722            dst_pts = np.float32([kp2[m.trainIdx].pt for m in good]).reshape(-1, 1, 2)
723            return src_pts, dst_pts

Compares keypoint feature vectors from two images and returns matching pairs.

Arguments:
  • kp1 (ndarray): keypoints from the first image
  • kp2 (ndarray): keypoints from the second image
  • d1 (ndarray): descriptors for the keypoints from the first image
  • d2 (ndarray): descriptors for the keypoints from the second image
  • method (str): the method used to calculate the feature descriptors. Should be 'sift' or 'orb'. Default is 'sift'.
  • dist (float): minimum match distance (0 to 1), default is 0.7
  • tree (int): not sure what this does? Default is 5. See open-cv docs.
  • check (int): ditto. Default is 100.
  • min_count (int): the minimum number of matches to consider a valid matching operation. If fewer matches are found, then the function returns None, None. Default is 5.
def quick_plot( self, bands=0, ax=None, bfac=0.0, cfac=0.0, samples=False, tscale=False, invert=False, rot=False, flipX=False, flipY=False, **kwds):
728    def quick_plot(self, bands=0, ax=None, bfac=0.0, cfac=0.0, samples=False, tscale=False, invert=False, rot=False, flipX=False, flipY=False,
729                   **kwds):
730        """
731        Plot a band using matplotlib.imshow(...).
732
733        Args:
734            bands (str,int,float,tuple): the band name (string), index (integer) or wavelength (float) to plot. Default is 0. If a tuple is passed then
735                  each band in the tuple (string or index) will be mapped to rgb. Bands with negative wavelengths or indices will be inverted before plotting.
736            ax: an axis object to plot to. If none, plt.imshow( ... ) is used.
737            bfac (float): a brightness adjustment to apply to RGB mappings (-1 to 1)
738            cfac (float): a contrast adjustment to apply to RGB mappings (-1 to 1)
739            samples (bool): True if sample points (defined in the header file) should be plotted. Default is False. Otherwise, a list of
740                     [ (x,y), ... ] points can be passed.
741            tscale (bool): True if each band (for ternary images) should be scaled independently. Default is False.
742                    When using scaling, vmin and vmax can be used to set the clipping percentiles (integers) or
743                    (constant) values (float).
744            invert (bool) : True if each band should be inverted before plotting. Only works for multiband (ternary) images.
745            rot (bool): if True, the x and y axis will be flipped (90 degree rotation) before plotting. Default is False.
746            flipX (bool): if True, the x axis will be flipped before plotting (after applying rotations).
747            flipY (bool): if True, the y axis will be flippe before plotting (after applying rotations).
748            **kwds: keywords are passed to matplotlib.imshow( ... ), except for the following:
749
750                 - mask = a 2D boolean mask containing true if pixels should be drawn and false otherwise.
751                 - path = a file path to save the image too (at matching resolution; use fig.savefig(..) if you want to save the figure).
752                 - ticks = True if x- and y- ticks should be plotted. Default is False.
753                 - ps, pc = the size and color of sample points to plot. Can be constant or list.
754                 - figsize = a figsize for the figure to create (if ax is None).
755
756        Returns:
757            Tuple containing
758
759            - fig: matplotlib figure object
760            - ax:  matplotlib axes object. If a colorbar is created, (band is an integer or a float), then this will be stored in ax.cbar.
761        """
762
763        #create new axes?
764        if ax is None:
765            fig, ax = plt.subplots(figsize=kwds.pop('figsize', (18,18*self.ydim()/self.xdim()) ))
766
767        # deal with ticks
768        if not kwds.pop('ticks', False ):
769            ax.set_xticks([])
770            ax.set_yticks([])
771
772        #map individual band using colourmap
773        if isinstance(bands, str) or isinstance(bands, int) or isinstance(bands, float):
774            #get band
775            if isinstance(bands, str):
776                data = self.data[:, :, self.get_band_index(bands)]
777            else:
778                data = self.data[:, :, self.get_band_index(np.abs(bands))]
779            if not isinstance(bands, str) and bands < 0:
780                data = np.nanmax(data) - data # flip
781
782            # convert integer vmin and vmax values to percentiles
783            if 'vmin' in kwds:
784                if isinstance(kwds['vmin'], int):
785                    kwds['vmin'] = np.nanpercentile( data, kwds['vmin'] )
786            if 'vmax' in kwds:
787                if isinstance(kwds['vmax'], int):
788                    kwds['vmax'] = np.nanpercentile( data, kwds['vmax'] )
789
790            #mask nans (and apply custom mask)
791            mask = np.isnan(data)
792            if not np.isnan(self.header.get_data_ignore_value()):
793                mask = mask + data == self.header.get_data_ignore_value()
794            if 'mask' in kwds:
795                mask = mask + kwds.get('mask')
796                del kwds['mask']
797            data = np.ma.array(data, mask = mask > 0 )
798
799            # apply rotations and flipping
800            if rot:
801                data = data.T
802            if flipX:
803                data = data[::-1, :]
804            if flipY:
805                data = data[:, ::-1]
806
807            # save?
808            if 'path' in kwds:
809                path = kwds.pop('path')
810                from matplotlib.pyplot import imsave
811                if not os.path.exists(os.path.dirname(path)):
812                    os.makedirs(os.path.dirname(path)) # ensure output directory exists
813                imsave(path, data.T, **kwds)  # save the image
814
815            ax.cbar = ax.imshow(data.T, interpolation=kwds.pop('interpolation', 'none'), **kwds) # change default interpolation to None
816
817        #map 3 bands to RGB
818        elif isinstance(bands, tuple) or isinstance(bands, list):
819            #get band indices and range
820            rgb = []
821            for b in bands:
822                if isinstance(b, str):
823                    rgb.append(self.get_band_index(b))
824                else:
825                    rgb.append(self.get_band_index(np.abs(b)))
826
827            #slice image (as copy) and map to 0 - 1
828            img = np.array(self.data[:, :, rgb]).copy()
829            if np.isnan(img).all():
830                print("Warning - image contains no data.")
831                return ax.get_figure(), ax
832
833            # invert if needed
834            if invert:
835                bands = [-b for b in bands]
836            for i,b in enumerate(bands):
837                if not isinstance(b, str) and (b < 0):
838                    img[..., i] = np.nanmax(img[..., i]) - img[..., i]
839
840            # do scaling
841            if tscale: # scale bands independently
842                for b in range(3):
843                    mn = kwds.get("vmin", float(np.nanmin(img)))
844                    mx = kwds.get("vmax", float(np.nanmax(img)))
845                    if isinstance (mn, int):
846                        assert mn >= 0 and mn <= 100, "Error - integer vmin values must be a percentile."
847                        mn = float(np.nanpercentile(img[...,b], mn ))
848                    if isinstance (mx, int):
849                        assert mx >= 0 and mx <= 100, "Error - integer vmax values must be a percentile."
850                        mx = float(np.nanpercentile(img[...,b], mx ))
851                    img[...,b] = (img[..., b] - mn) / (mx - mn)
852            else: # scale bands together
853                mn = kwds.get("vmin", float(np.nanmin(img)))
854                mx = kwds.get("vmax", float(np.nanmax(img)))
855                if isinstance(mn, int):
856                    assert mn >= 0 and mn <= 100, "Error - integer vmin values must be a percentile."
857                    mn = float(np.nanpercentile(img, mn))
858                if isinstance(mx, int):
859                    assert mx >= 0 and mx <= 100, "Error - integer vmax values must be a percentile."
860                    mx = float(np.nanpercentile(img, mx))
861                img = (img - mn) / (mx - mn)
862
863            #apply brightness/contrast mapping
864            img = np.clip((1.0 + cfac) * img + bfac, 0, 1.0 )
865
866            #apply masking so background is white
867            img[np.logical_not( np.isfinite( img ) )] = 1.0
868            if 'mask' in kwds:
869                img[kwds.pop("mask"),:] = 1.0
870
871            # apply rotations and flipping
872            if rot:
873                img = np.transpose( img, (1,0,2) )
874            if flipX:
875                img = img[::-1, :, :]
876            if flipY:
877                img = img[:, ::-1, :]
878
879            # save?
880            if 'path' in kwds:
881                path = kwds.pop('path')
882                from matplotlib.pyplot import imsave
883                if not os.path.exists(os.path.dirname(path)):
884                    os.makedirs(os.path.dirname(path)) # ensure output directory exists
885                imsave(path, np.transpose( np.clip( img*255, 0, 255).astype(np.uint8), (1, 0, 2)))  # save the image
886
887            # plot samples?
888            ps = kwds.pop('ps', 5)
889            pc = kwds.pop('pc', 'r')
890            if samples:
891                if isinstance(samples, list) or isinstance(samples, np.ndarray):
892                    ax.scatter([s[0] for s in samples], [s[1] for s in samples], s=ps, c=pc)
893                else:
894                    for n in self.header.get_class_names():
895                        points = np.array(self.header.get_sample_points(n))
896                        ax.scatter(points[:, 0], points[:, 1], s=ps, c=pc)
897
898            #plot
899            ax.imshow(np.transpose(img, (1,0,2)), interpolation=kwds.pop('interpolation', 'none'), **kwds)
900            ax.cbar = None  # no colorbar
901
902        return ax.get_figure(), ax

Plot a band using matplotlib.imshow(...).

Arguments:
  • bands (str,int,float,tuple): the band name (string), index (integer) or wavelength (float) to plot. Default is 0. If a tuple is passed then each band in the tuple (string or index) will be mapped to rgb. Bands with negative wavelengths or indices will be inverted before plotting.
  • ax: an axis object to plot to. If none, plt.imshow( ... ) is used.
  • bfac (float): a brightness adjustment to apply to RGB mappings (-1 to 1)
  • cfac (float): a contrast adjustment to apply to RGB mappings (-1 to 1)
  • samples (bool): True if sample points (defined in the header file) should be plotted. Default is False. Otherwise, a list of [ (x,y), ... ] points can be passed.
  • tscale (bool): True if each band (for ternary images) should be scaled independently. Default is False. When using scaling, vmin and vmax can be used to set the clipping percentiles (integers) or (constant) values (float).
  • invert (bool) : True if each band should be inverted before plotting. Only works for multiband (ternary) images.
  • rot (bool): if True, the x and y axis will be flipped (90 degree rotation) before plotting. Default is False.
  • flipX (bool): if True, the x axis will be flipped before plotting (after applying rotations).
  • flipY (bool): if True, the y axis will be flippe before plotting (after applying rotations).
  • **kwds: keywords are passed to matplotlib.imshow( ... ), except for the following:

    • mask = a 2D boolean mask containing true if pixels should be drawn and false otherwise.
    • path = a file path to save the image too (at matching resolution; use fig.savefig(..) if you want to save the figure).
    • ticks = True if x- and y- ticks should be plotted. Default is False.
    • ps, pc = the size and color of sample points to plot. Can be constant or list.
    • figsize = a figsize for the figure to create (if ax is None).
Returns:

Tuple containing

  • fig: matplotlib figure object
  • ax: matplotlib axes object. If a colorbar is created, (band is an integer or a float), then this will be stored in ax.cbar.
def createGIF(self, path, bands=None, figsize=(10, 10), fps=10, **kwds):
904    def createGIF(self, path, bands=None, figsize=(10,10), fps=10, **kwds):
905        """
906        Create and save an animated gif that loops through the bands of the image.
907
908        Args:
909            path (str): the path to save the .gif
910            bands (tuple): Tuple containing the range of band indices to draw. Default is the whole range.
911            figsize (tuple): the size of the image to draw. Default is (10,10).
912            fps (int): the framerate (frames per second) of the gif. Default is 10.
913            **kwds: keywords are passed directly to matplotlib.imshow. Use this to specify cmap etc.
914        """
915
916        frames = []
917        if bands is None:
918            bands = (0,self.band_count())
919        else:
920            assert 0 < bands[0] < self.band_count(), "Error - invalid range."
921            assert 0 < bands[1] < self.band_count(), "Error - invalid range."
922            assert bands[1] > bands[0], "Error - invalid range."
923
924        #plot frames
925        for i in range(bands[0],bands[1]):
926            fig, ax = plt.subplots(figsize=figsize)
927            ax.imshow(self.data[:, :, i], **kwds)
928            fig.canvas.draw()
929            frames.append(np.frombuffer(fig.canvas.tostring_rgb(), dtype='uint8'))
930            frames[-1] = np.reshape(frames[-1], (fig.canvas.get_width_height()[1], fig.canvas.get_width_height()[0], 3))
931            plt.close(fig)
932
933        #save gif
934        imageio.mimsave( os.path.splitext(path)[0] + ".gif", frames, fps=fps)

Create and save an animated gif that loops through the bands of the image.

Arguments:
  • path (str): the path to save the .gif
  • bands (tuple): Tuple containing the range of band indices to draw. Default is the whole range.
  • figsize (tuple): the size of the image to draw. Default is (10,10).
  • fps (int): the framerate (frames per second) of the gif. Default is 10.
  • **kwds: keywords are passed directly to matplotlib.imshow. Use this to specify cmap etc.
def drop_bbl(self, drop=True):
937    def drop_bbl(self, drop=True):
938        """
939        Remove bad bands as stored in the 'bbl' key in the image header. Note that this operates in-place.
940
941        Args:
942            drop (bool): True if bad bands should be completely dropped. If False, these bands will be kept but replaced with nans.
943        """
944        assert 'bbl' in self.header, "Please specify a bad band list ('bbl') in the image header, as per the ENVI format definition."
945        mask = self.header.get_list('bbl') == 0
946        self.data[...,mask] = np.nan
947        if drop:
948            self.delete_nan_bands(inplace=True)

Remove bad bands as stored in the 'bbl' key in the image header. Note that this operates in-place.

Arguments:
  • drop (bool): True if bad bands should be completely dropped. If False, these bands will be kept but replaced with nans.
def mask(self, mask=None, flag=nan, invert=False, crop=False, bands=None):
 950    def mask(self, mask=None, flag=np.nan, invert=False, crop=False, bands=None):
 951        """
 952         Apply a mask to an image, flagging masked pixels with the specified value. Note that this applies the mask to the
 953         image in-situ.
 954
 955         Args:
 956            flag (float): the value to use for masked pixels. Default is np.nan
 957            mask (ndarray): a numpy array defining the mask polygon of the format [[x1,y1],[x2,y2],...]. If None is passed then
 958                    pickPolygon( ... ) is used to interactively define a polygon. If a file path is passed then the polygon
 959                    will be loaded using np.load( ... ). Alternatively if mask.shape == image.shape[0,1] then it is treated as a
 960                    binary image mask (must be boolean) and True values will be masked across all bands. Default is None.
 961            invert (bool): if True, pixels within the polygon will be masked. If False, pixels outside the polygon are masked. Default is False.
 962            crop (bool): True if rows/columns containing only zeros should be removed. Default is False.
 963            bands (tuple): the bands of the image to plot if no mask is specified. If None, the middle band is used.
 964
 965         Returns:
 966            Tuple containing
 967
 968            - mask (ndarray): a boolean array with True where pixels are masked and False elsewhere.
 969            - poly (ndarray): the mask polygon array in the format described above. Useful if the polygon was interactively defined.
 970         """
 971
 972        if mask is None:  # pick mask interactively
 973            if bands is None:
 974                bands = int(self.band_count() / 2)
 975
 976            regions = self.pickPolygons(region_names=["mask"], bands=bands)
 977
 978            # the user bailed without picking a mask?
 979            if len(regions) == 0:
 980                print("Warning - no mask picked/applied.")
 981                return
 982
 983            # extract polygon mask
 984            mask = regions[0]
 985
 986        # convert polygon mask to binary mask
 987        if mask.shape[1] == 2:
 988
 989            # build meshgrid with pixel coords
 990            xx, yy = np.meshgrid(np.arange(self.xdim()), np.arange(self.ydim()))
 991            xx = xx.flatten()
 992            yy = yy.flatten()
 993            points = np.vstack([xx, yy]).T  # coordinates of each pixel
 994
 995            # calculate per-pixel mask
 996            mask = path.Path(mask).contains_points(points)
 997            mask = mask.reshape((self.ydim(), self.xdim())).T
 998
 999            # flip as we want to mask (==True) outside points (unless invert is true)
1000            if not invert:
1001                mask = np.logical_not(mask)
1002
1003        # apply binary image mask
1004        assert mask.shape[0] == self.data.shape[0] and mask.shape[1] == self.data.shape[1], \
1005            "Error - mask shape %s does not match image shape %s" % (mask.shape, self.data.shape)
1006        for b in range(self.band_count()):
1007            self.data[:, :, b][mask] = flag
1008
1009        # crop image
1010        if crop:
1011            self.crop_to_data()
1012
1013        return mask

Apply a mask to an image, flagging masked pixels with the specified value. Note that this applies the mask to the image in-situ.

Arguments:
  • flag (float): the value to use for masked pixels. Default is np.nan
  • mask (ndarray): a numpy array defining the mask polygon of the format [[x1,y1],[x2,y2],...]. If None is passed then pickPolygon( ... ) is used to interactively define a polygon. If a file path is passed then the polygon will be loaded using np.load( ... ). Alternatively if mask.shape == image.shape[0,1] then it is treated as a binary image mask (must be boolean) and True values will be masked across all bands. Default is None.
  • invert (bool): if True, pixels within the polygon will be masked. If False, pixels outside the polygon are masked. Default is False.
  • crop (bool): True if rows/columns containing only zeros should be removed. Default is False.
  • bands (tuple): the bands of the image to plot if no mask is specified. If None, the middle band is used.
Returns:

Tuple containing

  • mask (ndarray): a boolean array with True where pixels are masked and False elsewhere.
  • poly (ndarray): the mask polygon array in the format described above. Useful if the polygon was interactively defined.
def crop_to_data(self):
1015    def crop_to_data(self):
1016        """
1017        Remove padding of nan or zero pixels from image. Note that this is performed in place.
1018        """
1019        valid = np.isfinite(self.data).any(axis=-1) & (self.data != 0).any(axis=-1)
1020
1021        # integrate along axes
1022        xdata = np.sum(valid, axis=1) > 0.0
1023        ydata = np.sum(valid, axis=0) > 0.0
1024
1025        # calculate domain containing valid pixels
1026        xmin = np.argmax(xdata)
1027        xmax = xdata.shape[0] - np.argmax(xdata[::-1])
1028        ymin = np.argmax(ydata)
1029        ymax = ydata.shape[0] - np.argmax(ydata[::-1])
1030
1031        # crop
1032        self.data = self.data[xmin:xmax, ymin:ymax, :]
1033        
1034        # shift affine origin to new top-left pixel
1035        if self.affine is not None: 
1036            a = self.affine # shorthand for affine
1037            new_affine = list(self.affine)
1038            new_affine[0] = a[0] + xmin*a[1] + ymin*a[2]
1039            new_affine[3] = a[3] + xmin*a[4] + ymin*a[5]
1040            self.affine = np.array(new_affine)
1041            self.header['affine'] = self.affine

Remove padding of nan or zero pixels from image. Note that this is performed in place.

def pickPolygons(self, region_names, bands=0):
1046    def pickPolygons(self, region_names, bands=0):
1047        """
1048        Creates a matplotlib gui for selecting polygon regions in an image.
1049
1050        Args:
1051            names (list, str): a list containing the names of the regions to pick. If a string is passed only one name is used.
1052            bands (tuple): the bands of the image to plot.
1053        """
1054
1055        if isinstance(region_names, str):
1056            region_names = [region_names]
1057
1058        assert isinstance(region_names, list), "Error - names must be a list or a string."
1059
1060        # set matplotlib backend
1061        backend = matplotlib.get_backend()
1062        matplotlib.use('Qt5Agg')  # need this backend for ROIPoly to work
1063
1064        # plot image and extract roi's
1065        fig, ax = self.quick_plot(bands)
1066        roi = MultiRoi(roi_names=region_names)
1067        plt.close(fig)  # close figure
1068
1069        # extract regions
1070        regions = []
1071        for name, r in roi.rois.items():
1072            # store region
1073            x = r.x
1074            y = r.y
1075            regions.append(np.vstack([x, y]).T)
1076
1077        # restore matplotlib backend (if possible)
1078        try:
1079            matplotlib.use(backend)
1080        except:
1081            print("Warning: could not reset matplotlib backend. Plots will remain interactive...")
1082            pass
1083
1084        return regions

Creates a matplotlib gui for selecting polygon regions in an image.

Arguments:
  • names (list, str): a list containing the names of the regions to pick. If a string is passed only one name is used.
  • bands (tuple): the bands of the image to plot.
def pickPoints( self, n=-1, bands=(680.0, 550.0, 505.0), integer=True, title='Pick Points', **kwds):
1086    def pickPoints(self, n=-1, bands=hylite.RGB, integer=True, title="Pick Points", **kwds):
1087        """
1088        Creates a matplotlib gui for picking pixels from an image.
1089
1090        Args:
1091            n (int): the number of pixels to pick, or -1 if the user can select as many as they wish. Default is -1.
1092            bands (tuple): the bands of the image to plot. Default is HyImage.RGB
1093            integer (bool): True if points coordinates should be cast to integers (for use as indices). Default is True.
1094            title (str): The title of the point picking window.
1095            **kwds: Keywords are passed to HyImage.quick_plot( ... ).
1096
1097        Returns:
1098            A list containing the picked point coordinates [ (x1,y1), (x2,y2), ... ].
1099        """
1100
1101        # set matplotlib backend
1102        backend = matplotlib.get_backend()
1103        matplotlib.use('Qt5Agg')  # need this backend for ROIPoly to work
1104
1105        # create figure
1106        fig, ax = self.quick_plot( bands, **kwds )
1107        ax.set_title(title)
1108
1109        # get points
1110        points = fig.ginput( n )
1111
1112        if integer:
1113            points = [ (int(p[0]), int(p[1])) for p in points ]
1114
1115        # restore matplotlib backend (if possible)
1116        try:
1117            matplotlib.use(backend)
1118        except:
1119            print("Warning: could not reset matplotlib backend. Plots will remain interactive...")
1120            pass
1121
1122        return points

Creates a matplotlib gui for picking pixels from an image.

Arguments:
  • n (int): the number of pixels to pick, or -1 if the user can select as many as they wish. Default is -1.
  • bands (tuple): the bands of the image to plot. Default is HyImage.RGB
  • integer (bool): True if points coordinates should be cast to integers (for use as indices). Default is True.
  • title (str): The title of the point picking window.
  • **kwds: Keywords are passed to HyImage.quick_plot( ... ).
Returns:

A list containing the picked point coordinates [ (x1,y1), (x2,y2), ... ].

def pickSamples(self, names=None, store=True, **kwds):
1124    def pickSamples(self, names=None, store=True, **kwds):
1125        """
1126        Pick sample probe points and store these in the image header file.
1127
1128        Args:
1129            names (str, list): the name of the sample to pick, or a list of names to pick multiple.
1130            store (bool): True if sample should be stored in the image header file (for later access). Default is True.
1131            **kwds: Keywords are passed to HyImage.quick_plot( ... )
1132
1133        Returns:
1134            a list containing a list of points for each sample.
1135        """
1136
1137        if isinstance(names, str):
1138            names = [names]
1139
1140        # pick points
1141        points = []
1142        for s in names:
1143            pnts = self.pickPoints(title="%s" % s, **kwds)
1144            if store:
1145                self.header['sample %s' % s] = pnts # store in header
1146            points.append(pnts)
1147        # add class to header file
1148        if store:
1149            cls_names = self.header.get_class_names()
1150            if cls_names is None:
1151                cls_names = []
1152            self.header['class names'] = cls_names + names
1153
1154        return points

Pick sample probe points and store these in the image header file.

Arguments:
  • names (str, list): the name of the sample to pick, or a list of names to pick multiple.
  • store (bool): True if sample should be stored in the image header file (for later access). Default is True.
  • **kwds: Keywords are passed to HyImage.quick_plot( ... )
Returns:

a list containing a list of points for each sample.