hylite.hyimage
Store and manipulate hyperspectral image data.
1""" 2Store and manipulate hyperspectral image data. 3""" 4 5import os 6import numpy as np 7import matplotlib 8import matplotlib.pyplot as plt 9from matplotlib import path 10from roipoly import MultiRoi 11import imageio 12import scipy as sp 13import hylite 14from hylite.hydata import HyData 15from hylite.hylibrary import HyLibrary 16 17 18 19class HyImage( HyData ): 20 """ 21 A class for hyperspectral image data. These can be individual scenes or hyperspectral orthoimages. 22 """ 23 24 def __init__(self, data, **kwds): 25 """ 26 Args: 27 data (ndarray): a numpy array such that data[x][y][band] gives each pixel value. 28 **kwds: 29 wav = A numpy array containing band wavelengths for this image. 30 affine = an affine transform of the format returned by GDAL.GetGeoTransform(). 31 projection = string defining the project. Default is None. 32 sensor = sensor name. Default is "unknown". 33 header = path to associated header file. Default is None. 34 """ 35 36 #call constructor for HyData 37 super().__init__(data, **kwds) 38 39 # special case - if dataset only has oneband, slice it so it still has 40 # the format data[x,y,b]. 41 if not self.data is None: 42 if len(self.data.shape) == 1: 43 self.data = self.data[None, None, :] # single pixel image 44 if len(self.data.shape) == 2: 45 self.data = self.data[:, :, None] # single band iamge 46 47 #load any additional project information (specific to images) 48 self.set_projection(kwds.get("projection",None)) 49 self.affine = np.array( kwds.get("affine",[0,1,0,0,0,1]) ) 50 self.header['affine'] = kwds.get("affine",[0,1,0,0,0,1]) # also store this here 51 52 # wavelengths 53 if 'wav' in kwds: 54 self.set_wavelengths(kwds['wav']) 55 56 #special header formatting 57 self.header['file type'] = 'ENVI Standard' 58 59 def copy(self,data=True): 60 """ 61 Make a deep copy of this image instance. 62 63 Args: 64 data (bool): True if a copy of the data should be made, otherwise only copy header. 65 66 Returns: 67 a new HyImage instance. 68 """ 69 if not data: 70 return HyImage(None, header=self.header.copy(), projection=self.projection, affine=self.affine) 71 else: 72 return HyImage( self.data.copy(), header=self.header.copy(), projection=self.projection, affine=self.affine) 73 74 def T(self): 75 """ 76 Return a transposed view of the data matrix (corresponding with the [y,x] indexing used by matplotlib, opencv etc. 77 """ 78 return np.transpose(self.data, (1,0,2)) 79 80 def xdim(self): 81 """ 82 Return number of pixels in x (first dimension of data array) 83 """ 84 return self.data.shape[0] 85 86 def ydim(self): 87 """ 88 Return number of pixels in y (second dimension of data array) 89 """ 90 return self.data.shape[1] 91 92 def aspx(self): 93 """ 94 Return the aspect ratio of this image (width/height). 95 """ 96 return self.ydim() / self.xdim() 97 98 ##################################### 99 ## GEOREFERENCING METHODS 100 ##################################### 101 102 def get_extent(self): 103 """ 104 Returns the width and height of this image in world coordinates. 105 106 Returns: 107 tuple with (width, height). 108 """ 109 return self.xdim * self.pixel_size[0], self.ydim * self.pixel_size[1] 110 111 def set_projection(self,proj): 112 """ 113 Set this project to an existing osgeo.osr.SpatialReference or GDAL georeference string. 114 115 Args: 116 proj (str, osgeo.osr.SpatialReference): the project to use as osgeo.osr.SpatialReference or GDAL georeference string. 117 """ 118 if proj is None: 119 self.projection = None 120 else: 121 try: 122 from osgeo.osr import SpatialReference 123 except: 124 assert False, "Error - GDAL must be installed to work with spatial projections in hylite." 125 if isinstance(proj, SpatialReference): 126 self.projection = proj 127 elif isinstance(proj, str): 128 self.projection = SpatialReference(proj) 129 else: 130 print("Invalid project %s" % proj) 131 raise 132 133 def set_projection_EPSG(self,EPSG): 134 """ 135 Sets this image project using an EPSG code. 136 137 Args: 138 EPSG (str): string EPSG code that can be passed to SpatialReference.SetFromUserInput(...). 139 """ 140 141 try: 142 from osgeo.osr import SpatialReference 143 except: 144 assert False, "Error - GDAL must be installed to work with spatial projections in hylite." 145 146 self.projection = SpatialReference() 147 self.projection.SetFromUserInput(EPSG) 148 149 def get_projection_EPSG(self): 150 """ 151 Gets a string describing this projections EPSG code (if it is an EPSG project). 152 153 Returns: 154 an EPSG code string of the format "EPSG:XXXX". 155 """ 156 if self.projection is None: 157 return None 158 else: 159 return "%s:%s" % (self.projection.GetAttrValue("AUTHORITY",0),self.projection.GetAttrValue("AUTHORITY",1)) 160 161 def pix_to_world(self, px, py, proj=None): 162 """ 163 Take pixel coordinates and return world coordinates 164 165 Args: 166 px (int): the pixel x-coord. 167 py (int): the pixel y-coord. 168 proj (str, osr.SpatialReference): the coordinate system to use. Default (None) uses the same system as this image. Otherwise 169 an osr.SpatialReference can be passed (HyImage.project), or an EPSG string (e.g. get_projection_EPSG(...)). 170 Returns: 171 the world coordinates in the coordinate system defined by get_projection_EPSG(...). 172 """ 173 174 try: 175 from osgeo import osr 176 import osgeo.gdal as gdal 177 from osgeo import ogr 178 except: 179 assert False, "Error - GDAL must be installed to work with spatial projections in hylite." 180 181 # parse project 182 if proj is None: 183 proj = self.projection 184 elif isinstance(proj, str) or isinstance(proj, int): 185 epsg = proj 186 if isinstance(epsg, str): 187 try: 188 epsg = int(str.split(':')[1]) 189 except: 190 assert False, "Error - %s is an invalid EPSG code." % proj 191 proj = osr.SpatialReference() 192 proj.ImportFromEPSG(epsg) 193 194 # check we have all the required info 195 assert isinstance(proj, osr.SpatialReference), "Error - invalid spatial reference %s" % proj 196 assert (not self.affine is None) and ( 197 not self.projection is None), "Error - project information is undefined." 198 199 #project to world coordinates in this images project/world coords 200 x,y = gdal.ApplyGeoTransform(self.affine, px, py) 201 202 #project to target coords (if different) 203 if not proj.IsSameGeogCS(self.projection): 204 P = ogr.Geometry(ogr.wkbPoint) 205 if proj.EPSGTreatsAsNorthingEasting(): 206 P.AddPoint(x, y) 207 else: 208 P.AddPoint(y, x) 209 P.AssignSpatialReference(self.projection) # tell the point what coordinates it's in 210 P.TransformTo(proj) # reproject it to the out spatial reference 211 x, y = P.GetX(), P.GetY() 212 213 #do we need to transpose? 214 if proj.EPSGTreatsAsLatLong(): 215 x,y=y,x #we want lon,lat not lat,lon 216 return x, y 217 218 def world_to_pix(self, x, y, proj = None): 219 """ 220 Take world coordinates and return pixel coordinates 221 222 Args: 223 x (float): the world x-coord. 224 y (float): the world y-coord. 225 proj (str, osr.SpatialReference): the coordinate system of the input coordinates. Default (None) uses the same system as this image. Otherwise 226 an osr.SpatialReference can be passed (HyImage.project), or an EPSG string (e.g. get_projection_EPSG(...)). 227 228 Returns: 229 the pixel coordinates based on the affine transform stored in self.affine. 230 """ 231 232 try: 233 from osgeo import osr 234 import osgeo.gdal as gdal 235 from osgeo import ogr 236 except: 237 assert False, "Error - GDAL must be installed to work with spatial projections in hylite." 238 239 # parse project 240 if proj is None: 241 proj = self.projection 242 elif isinstance(proj, str) or isinstance(proj, int): 243 epsg = proj 244 if isinstance(epsg, str): 245 try: 246 epsg = int(str.split(':')[1]) 247 except: 248 assert False, "Error - %s is an invalid EPSG code." % proj 249 proj = osr.SpatialReference() 250 proj.ImportFromEPSG(epsg) 251 252 253 # check we have all the required info 254 assert isinstance(proj, osr.SpatialReference), "Error - invalid spatial reference %s" % proj 255 assert (not self.affine is None) and (not self.projection is None), "Error - project information is undefined." 256 257 # project to this images CS (if different) 258 if not proj.IsSameGeogCS(self.projection): 259 P = ogr.Geometry(ogr.wkbPoint) 260 if proj.EPSGTreatsAsNorthingEasting(): 261 P.AddPoint(x, y) 262 else: 263 P.AddPoint(y, x) 264 P.AssignSpatialReference(proj) # tell the point what coordinates it's in 265 P.AddPoint(x, y) 266 P.TransformTo(self.projection) # reproject it to the out spatial reference 267 x, y = P.GetX(), P.GetY() 268 if self.projection.EPSGTreatsAsLatLong(): # do we need to transpose? 269 x, y = y, x # we want lon,lat not lat,lon 270 271 inv = gdal.InvGeoTransform(self.affine) 272 assert not inv is None, "Error - could not invert affine transform?" 273 274 #apply 275 return gdal.ApplyGeoTransform(inv, x, y) 276 277 def crop(self, xmin, xmax, ymin, ymax, bands=None): 278 """ 279 Return a cropped copy of this image. 280 281 Args: 282 xmin, xmax (int): pixel bounds in x (rows) 283 ymin, ymax (int): pixel bounds in y (columns) 284 bands (None, list, tuple): optional band indices or (min,max) range 285 286 Returns: 287 HyImage: cropped image with updated affine transform 288 """ 289 290 # ---- validate bounds ---- 291 xmin = int(max(0, xmin)) 292 ymin = int(max(0, ymin)) 293 xmax = int(min(self.xdim(), xmax)) 294 ymax = int(min(self.ydim(), ymax)) 295 296 assert xmin < xmax and ymin < ymax, "Invalid crop extent." 297 298 # ---- crop data ---- 299 if bands is None: 300 data = self.data[xmin:xmax, ymin:ymax, :].copy() 301 wav = self.get_wavelengths() 302 else: # band selection 303 if isinstance(bands, tuple): 304 b0 = self.get_band_index(bands[0]) 305 b1 = self.get_band_index(bands[1]) 306 data = self.data[xmin:xmax, ymin:ymax, b0:b1].copy() 307 wav = self.get_wavelengths()[b0:b1] 308 else: 309 idx = [self.get_band_index(b) for b in bands] 310 data = self.data[xmin:xmax, ymin:ymax, idx].copy() 311 wav = self.get_wavelengths()[idx] 312 313 # ---- update affine transform ---- 314 if self.affine is not None: 315 a = list(self.affine) 316 new_affine = a.copy() 317 318 # shift origin to new top-left pixel 319 new_affine[0] = a[0] + xmin*a[1] + ymin*a[2] 320 new_affine[3] = a[3] + xmin*a[4] + ymin*a[5] 321 else: 322 new_affine = None 323 324 # ---- construct output image ---- 325 out = HyImage( 326 data, 327 header=self.header.copy(), 328 projection=self.projection, 329 affine=new_affine, 330 wav=wav 331 ) 332 333 return out 334 335 def resize(self, newdims: tuple, interpolation: int = 1): 336 """ 337 Resize this image with opencv and update affine transform accordingly. 338 339 Args: 340 newdims (tuple): the new image dimensions (xdim, ydim) 341 interpolation (int): opencv interpolation method. Default is cv2.INTER_LINEAR. 342 """ 343 import cv2 # avoid import issues if opencv is missing 344 345 old_x, old_y = self.xdim(), self.ydim() 346 new_x, new_y = int(newdims[0]), int(newdims[1]) 347 348 assert new_x > 0 and new_y > 0, "Invalid resize dimensions." 349 350 # resize data (opencv uses width, height = y, x) 351 self.data = cv2.resize( 352 self.data, 353 (new_y, new_x), 354 interpolation=interpolation 355 ) 356 357 # update affine transform 358 if self.affine is not None: 359 a = list(self.affine) 360 361 sx = old_x / new_x 362 sy = old_y / new_y 363 364 self.affine = [ 365 a[0], # x origin unchanged 366 a[1] * sx, # pixel width 367 a[2] * sy, # row rotation 368 a[3], # y origin unchanged 369 a[4] * sx, # column rotation 370 a[5] * sy # pixel height 371 ] 372 373 def tile(self, tile_size): 374 """ 375 Break image into tiles of given size and return a list of HyImage tiles. 376 Each tile has an updated affine transform reflecting its position in the original image. 377 378 Args: 379 tile_size (tuple): (tile_x, tile_y) in pixels 380 Returns: 381 list of HyImage 382 """ 383 tiles = [] 384 tx, ty = tile_size 385 nx, ny = self.xdim(), self.ydim() 386 for i in range(0, nx, tx): 387 for j in range(0, ny, ty): 388 data_tile = self.data[i:min(i+tx, nx), j:min(j+ty, ny), :].copy() # slice data 389 390 # compute new affine 391 # affine = [x0, dx, rotx, y0, roty, dy] 392 x0, dx, rotx, y0, roty, dy = self.affine 393 new_x0 = x0 + i*dx + j*rotx 394 new_y0 = y0 + i*roty + j*dy 395 new_affine = [new_x0, dx, rotx, new_y0, roty, dy] 396 397 # create tile 398 tile_img = HyImage( 399 data_tile, 400 affine=new_affine, 401 projection=self.projection, 402 wav=self.get_wavelengths(), 403 header=self.header.copy() 404 ) 405 tile_img.header['xleft'] = i 406 tile_img.header['ytop'] = j 407 tiles.append(tile_img) 408 409 return tiles 410 411 @staticmethod 412 def mosaic( 413 tiles, 414 blend="mean", 415 resampling="nearest", 416 out_affine=None, 417 out_shape=None, 418 ): 419 """ 420 Mosaic georeferenced HyImage tiles using GDAL. Note that this assumes all tiles are in the same coordinate system. 421 422 Args: 423 tiles (list[HyImage]) 424 blend (str): 'first', 'min', 'max', 'mean', 'median' 425 resampling (str): 'nearest', 'bilinear', 'cubic' 426 out_affine (list): optional 6-element affine to define output grid. If None, the affine of the first tile is used. 427 out_shape (tuple): optional (xdim, ydim) shape of the output grid. If None, the extent of all tiles will be used. 428 Returns: 429 HyImage 430 """ 431 import numpy as np 432 from hylite.project.align import resample_raster 433 from osgeo import gdal, osr 434 435 assert len(tiles) > 0 436 assert blend in ("first", "min", "max", "mean", "median") 437 438 # compute bounds in world coordinates 439 # N.B. THIS ASSUMES ALL DATA ARE IN THE SAME CRS 440 points = [] 441 for t in tiles: 442 points.append( t.pix_to_world(0,0) ) 443 points.append( t.pix_to_world(t.xdim()+1,t.ydim()+1) ) 444 min_x, min_y = np.min(points, axis=0) 445 max_x, max_y = np.max(points, axis=0) 446 if out_shape is None: 447 out_shape = (np.array(tiles[0].world_to_pix(max_x, max_y)) + np.array(tiles[0].world_to_pix(min_x, min_y))).round().astype(int) 448 if out_affine is None: 449 out_affine = list(tiles[0].affine) 450 out_affine[0] = min_x 451 out_affine[3] = max_y 452 453 if blend == "first": # fill output, first come, first served. 454 out = np.full( tuple(out_shape) + (tiles[0].band_count(),), np.nan, dtype=np.float32) 455 for t in tiles: 456 r = resample_raster( t.data, t.affine, out_affine, out_shape ) 457 mask = np.isnan(out) 458 out[mask] = r[mask] 459 elif blend == "min": # keep minimum value in case of overlap 460 out = np.nanmin( np.stack([ resample_raster( t.data, t.affine, out_affine, out_shape ) for t in tiles ], 461 axis=0), axis=0 ) 462 elif blend == "max": # keep maximum value in case of overlap 463 out = np.nanmax( np.stack([ resample_raster( t.data, t.affine, out_affine, out_shape ) for t in tiles ], 464 axis=0), axis=0 ) 465 elif blend == "mean": # use average in case of overlap 466 out = np.nanmean( np.stack([ resample_raster( t.data, t.affine, out_affine, out_shape ) for t in tiles ], 467 axis=0), axis=0 ) 468 elif blend == "median": # use mean in case of overlap 469 out = np.nanmedian( np.stack([ resample_raster( t.data, t.affine, out_affine, out_shape ) for t in tiles ], 470 axis=0), axis=0 ) 471 472 # Return HyImage 473 out = HyImage( 474 out, 475 affine=out_affine, 476 projection=tiles[0].projection, 477 wav=tiles[0].get_wavelengths(), 478 header=tiles[0].header.copy() 479 ) 480 if 'xleft' in out.header: del out.header['xleft'] # stored in tiles, but not meaningful here 481 if 'ytop' in out.header: del out.header['ytop'] # stored in tiles, but not meaningful here 482 483 return out 484 485 ##################################### 486 ## BASIC TRANSFORMS 487 ##################################### 488 489 def flip(self, axis='x'): 490 """ 491 Flip the image on the x or y axis. Note that this will remove any defined affine transform. 492 493 Args: 494 axis (str): 'x' or 'y' or both 'xy'. 495 """ 496 497 if 'x' in axis.lower(): 498 self.data = np.flip(self.data,axis=0) 499 if 'y' in axis.lower(): 500 self.data = np.flip(self.data,axis=1) 501 self.affine = None 502 if 'affine' in self.header: del self.header['affine'] 503 self.push_to_header() # update width and height info 504 505 def rot90(self): 506 """ 507 Rotate this image by 90 degrees by transposing the underlying data array. Combine with flip('x') or flip('y') 508 to achieve positive/negative rotations. 509 """ 510 self.data = np.transpose( self.data, (1,0,2) ) 511 self.affine = None 512 if 'affine' in self.header: del self.header['affine'] 513 self.push_to_header() # update width and height info 514 515 ##################################### 516 ##IMAGE FILTERING 517 ##################################### 518 def fill_holes(self): 519 """ 520 Replaces nan pixel with an average of their neighbours, thus removing 1-pixel large holes from an image. Note that 521 for performance reasons this assumes that holes line up across bands. Note that this is not vectorized so very slow... 522 """ 523 524 # perform greyscale dilation 525 dilate = self.data.copy() 526 mask = np.logical_not(np.isfinite(dilate)) 527 dilate[mask] = 0 528 for b in range(self.band_count()): 529 dilate[:, :, b] = sp.ndimage.grey_dilation(dilate[:, :, b], size=(3, 3)) 530 531 # map back to holes in dataset 532 self.data[mask] = dilate[mask] 533 #self.data[self.data == 0] = np.nan # replace remaining 0's with nans 534 535 def blur(self, n=3): 536 """ 537 Applies a gaussian kernel of size n to the image using OpenCV. 538 539 Args: 540 n (int): the dimensions of the gaussian kernel to convolve. Default is 3. Increase for more blurry results. 541 """ 542 import cv2 # import this here to avoid errors if opencv is not installed properly 543 544 nanmask = np.isnan(self.data) 545 assert isinstance(n, int) and n >= 3, "Error - invalid kernel. N must be an integer > 3. " 546 kernel = np.ones((n, n), np.float32) / (n ** 2) 547 self.data = cv2.filter2D(self.data, -1, kernel) 548 self.data[nanmask] = np.nan # remove mask 549 550 def erode(self, size=3, iterations=1): 551 """ 552 Apply an erode filter to this image to expand background (nan) pixels. Refer to open-cv's erode 553 function for more details. 554 555 Args: 556 size (int): the size of the erode filter. Default is a 3x3 kernel. 557 iterations (int): the number of erode iterations. Default is 1. 558 """ 559 import cv2 # import this here to avoid errors if opencv is not installed properly 560 561 # erode 562 kernel = np.ones((size, size), np.uint8) 563 if self.is_float(): 564 mask = np.isfinite(self.data).any(axis=-1) 565 mask = cv2.erode(mask.astype(np.uint8), kernel, iterations=iterations) 566 self.data[mask == 0, :] = np.nan 567 else: 568 mask = (self.data != 0).any( axis=-1 ) 569 mask = cv2.erode(mask.astype(np.uint8), kernel, iterations=iterations) 570 self.data[mask == 0, :] = 0 571 572 def despeckle(self, size=5): 573 """ 574 Despeckle each band of this image (independently) using a median filter. 575 576 Args: 577 size (int): the size of the median filter kernel. Default is 5. Must be an odd number. 578 """ 579 580 assert (size % 2) == 1, "Error - size must be an odd integer" 581 import cv2 # import this here to avoid errors if opencv is not installed properly 582 if self.is_float(): 583 self.data = cv2.medianBlur( self.data.astype(np.float32), size ) 584 else: 585 self.data = cv2.medianBlur( self.data, size ) 586 587 ##################################### 588 ##FEATURES AND FEATURE MATCHING 589 ###################################### 590 def get_keypoints(self, band, eq=False, mask=True, method='sift', cfac=0.0,bfac=0.0, **kwds): 591 """ 592 Get feature descriptors from the specified band. 593 594 Args: 595 band (int,float,str,tuple): the band index (int) or wavelength (float) to extract features from. Alternatively, a tuple can be passed 596 containing a range of bands (min : max) to average before feature matching. 597 eq (bool): True if the image should be histogram equalized first. Default is False. 598 mask (bool): True if 0 value pixels should be masked. Default is True. 599 method (str): the feature detector to use. Options are 'SIFT' and 'ORB' (faster but less accurate). Default is 'SIFT'. 600 cfac (float): contrast adjustment to apply to hyperspectral bands before matching. Default is 0.0. 601 bfac (float): brightness adjustment to apply to hyperspectral bands before matching. Default is 0.0. 602 **kwds: keyword arguments are passed to the opencv feature detector. For SIFT these are: 603 604 - contrastThreshold: default is 0.01. 605 - edgeThreshold: default is 10. 606 - sigma: default is 1.0 607 608 For ORB these are: 609 610 - nfeatures = the number of features to detect. Default is 5000. 611 612 Returns: 613 Tuple containing 614 615 - k (ndarray): the keypoints detected 616 - d (ndarray): corresponding feature descriptors 617 """ 618 import cv2 # import this here to avoid errors if opencv is not installed properly 619 620 # get image 621 if isinstance(band, int) or isinstance(band, float): #single band 622 image = self.data[:, :, self.get_band_index(band)] 623 elif isinstance(band,tuple): #range of bands (averaged) 624 idx0 = self.get_band_index(band[0]) 625 idx1 = self.get_band_index(band[1]) 626 627 #deal with out of range errors 628 if idx0 is None: 629 idx0 = 0 630 if idx1 is None: 631 idx1 = self.band_count() 632 633 #average bands 634 image = np.nanmean(self.data[:,:,idx0:idx1],axis=2) 635 else: 636 assert False, "Error, unrecognised band %s" % band 637 638 #normalise image to range 0 - 1 639 image -= np.nanmin(image) 640 image = image / np.nanmax(image) 641 642 #apply brightness/contrast adjustment 643 image = (1.0+cfac)*image + bfac 644 image[image > 1.0] = 1.0 645 image[image < 0.0] = 0.0 646 647 #convert image to uint8 for opencv 648 image = np.uint8(255 * image) 649 if eq: 650 image = cv2.equalizeHist(image) 651 652 if mask: 653 mask = np.zeros(image.shape, dtype=np.uint8) 654 mask[image != 0] = 255 # include only non-zero pixels 655 else: 656 mask = None 657 658 if 'sift' in method.lower(): # SIFT 659 660 # setup default keywords 661 kwds["contrastThreshold"] = kwds.get("contrastThreshold", 0.01) 662 kwds["edgeThreshold"] = kwds.get("edgeThreshold", 10) 663 kwds["sigma"] = kwds.get("sigma", 1.0) 664 665 # make feature detector 666 #alg = cv2.xfeatures2d.SIFT_create(**kwds) 667 alg = cv2.SIFT_create() 668 elif 'orb' in method.lower(): # orb 669 kwds['nfeatures'] = kwds.get('nfeatures', 5000) 670 alg = cv2.ORB_create(scoreType=cv2.ORB_FAST_SCORE, **kwds) 671 else: 672 assert False, "Error - %s is not a recognised feature detector." % method 673 674 # detect keypoints 675 kp = alg.detect(image, mask) 676 677 # extract and return feature vectors 678 return alg.compute(image, kp) 679 680 @classmethod 681 def match_keypoints(cls, kp1, kp2, d1, d2, method='SIFT', dist=0.7, tree = 5, check = 100, min_count=5): 682 """ 683 Compares keypoint feature vectors from two images and returns matching pairs. 684 685 Args: 686 kp1 (ndarray): keypoints from the first image 687 kp2 (ndarray): keypoints from the second image 688 d1 (ndarray): descriptors for the keypoints from the first image 689 d2 (ndarray): descriptors for the keypoints from the second image 690 method (str): the method used to calculate the feature descriptors. Should be 'sift' or 'orb'. Default is 'sift'. 691 dist (float): minimum match distance (0 to 1), default is 0.7 692 tree (int): not sure what this does? Default is 5. See open-cv docs. 693 check (int): ditto. Default is 100. 694 min_count (int): the minimum number of matches to consider a valid matching operation. If fewer matches are found, 695 then the function returns None, None. Default is 5. 696 """ 697 import cv2 # import this here to avoid errors if opencv is not installed properly 698 if 'sift' in method.lower(): 699 algorithm = cv2.NORM_INF 700 elif 'orb' in method.lower(): 701 algorithm = cv2.NORM_HAMMING 702 else: 703 assert False, "Error - unknown matching algorithm %s" % method 704 705 #calculate flann matches 706 index_params = dict(algorithm=algorithm, trees=tree) 707 search_params = dict(checks=check) 708 flann = cv2.FlannBasedMatcher(index_params, search_params) 709 matches = flann.knnMatch(d1, d2, k=2) 710 711 # store all the good matches as per Lowe's ratio test. 712 good = [] 713 for m, n in matches: 714 if m.distance < dist * n.distance: 715 good.append(m) 716 717 if len(good) < min_count: 718 return None, None 719 else: 720 src_pts = np.float32([kp1[m.queryIdx].pt for m in good]).reshape(-1, 1, 2) 721 dst_pts = np.float32([kp2[m.trainIdx].pt for m in good]).reshape(-1, 1, 2) 722 return src_pts, dst_pts 723 724 ############################ 725 ## Visualisation methods 726 ############################ 727 def quick_plot(self, bands=0, ax=None, bfac=0.0, cfac=0.0, samples=False, tscale=False, invert=False, rot=False, flipX=False, flipY=False, 728 **kwds): 729 """ 730 Plot a band using matplotlib.imshow(...). 731 732 Args: 733 bands (str,int,float,tuple): the band name (string), index (integer) or wavelength (float) to plot. Default is 0. If a tuple is passed then 734 each band in the tuple (string or index) will be mapped to rgb. Bands with negative wavelengths or indices will be inverted before plotting. 735 ax: an axis object to plot to. If none, plt.imshow( ... ) is used. 736 bfac (float): a brightness adjustment to apply to RGB mappings (-1 to 1) 737 cfac (float): a contrast adjustment to apply to RGB mappings (-1 to 1) 738 samples (bool): True if sample points (defined in the header file) should be plotted. Default is False. Otherwise, a list of 739 [ (x,y), ... ] points can be passed. 740 tscale (bool): True if each band (for ternary images) should be scaled independently. Default is False. 741 When using scaling, vmin and vmax can be used to set the clipping percentiles (integers) or 742 (constant) values (float). 743 invert (bool) : True if each band should be inverted before plotting. Only works for multiband (ternary) images. 744 rot (bool): if True, the x and y axis will be flipped (90 degree rotation) before plotting. Default is False. 745 flipX (bool): if True, the x axis will be flipped before plotting (after applying rotations). 746 flipY (bool): if True, the y axis will be flippe before plotting (after applying rotations). 747 **kwds: keywords are passed to matplotlib.imshow( ... ), except for the following: 748 749 - mask = a 2D boolean mask containing true if pixels should be drawn and false otherwise. 750 - path = a file path to save the image too (at matching resolution; use fig.savefig(..) if you want to save the figure). 751 - ticks = True if x- and y- ticks should be plotted. Default is False. 752 - ps, pc = the size and color of sample points to plot. Can be constant or list. 753 - figsize = a figsize for the figure to create (if ax is None). 754 755 Returns: 756 Tuple containing 757 758 - fig: matplotlib figure object 759 - ax: matplotlib axes object. If a colorbar is created, (band is an integer or a float), then this will be stored in ax.cbar. 760 """ 761 762 #create new axes? 763 if ax is None: 764 fig, ax = plt.subplots(figsize=kwds.pop('figsize', (18,18*self.ydim()/self.xdim()) )) 765 766 # deal with ticks 767 if not kwds.pop('ticks', False ): 768 ax.set_xticks([]) 769 ax.set_yticks([]) 770 771 #map individual band using colourmap 772 if isinstance(bands, str) or isinstance(bands, int) or isinstance(bands, float): 773 #get band 774 if isinstance(bands, str): 775 data = self.data[:, :, self.get_band_index(bands)] 776 else: 777 data = self.data[:, :, self.get_band_index(np.abs(bands))] 778 if not isinstance(bands, str) and bands < 0: 779 data = np.nanmax(data) - data # flip 780 781 # convert integer vmin and vmax values to percentiles 782 if 'vmin' in kwds: 783 if isinstance(kwds['vmin'], int): 784 kwds['vmin'] = np.nanpercentile( data, kwds['vmin'] ) 785 if 'vmax' in kwds: 786 if isinstance(kwds['vmax'], int): 787 kwds['vmax'] = np.nanpercentile( data, kwds['vmax'] ) 788 789 #mask nans (and apply custom mask) 790 mask = np.isnan(data) 791 if not np.isnan(self.header.get_data_ignore_value()): 792 mask = mask + data == self.header.get_data_ignore_value() 793 if 'mask' in kwds: 794 mask = mask + kwds.get('mask') 795 del kwds['mask'] 796 data = np.ma.array(data, mask = mask > 0 ) 797 798 # apply rotations and flipping 799 if rot: 800 data = data.T 801 if flipX: 802 data = data[::-1, :] 803 if flipY: 804 data = data[:, ::-1] 805 806 # save? 807 if 'path' in kwds: 808 path = kwds.pop('path') 809 from matplotlib.pyplot import imsave 810 if not os.path.exists(os.path.dirname(path)): 811 os.makedirs(os.path.dirname(path)) # ensure output directory exists 812 imsave(path, data.T, **kwds) # save the image 813 814 ax.cbar = ax.imshow(data.T, interpolation=kwds.pop('interpolation', 'none'), **kwds) # change default interpolation to None 815 816 #map 3 bands to RGB 817 elif isinstance(bands, tuple) or isinstance(bands, list): 818 #get band indices and range 819 rgb = [] 820 for b in bands: 821 if isinstance(b, str): 822 rgb.append(self.get_band_index(b)) 823 else: 824 rgb.append(self.get_band_index(np.abs(b))) 825 826 #slice image (as copy) and map to 0 - 1 827 img = np.array(self.data[:, :, rgb]).copy() 828 if np.isnan(img).all(): 829 print("Warning - image contains no data.") 830 return ax.get_figure(), ax 831 832 # invert if needed 833 if invert: 834 bands = [-b for b in bands] 835 for i,b in enumerate(bands): 836 if not isinstance(b, str) and (b < 0): 837 img[..., i] = np.nanmax(img[..., i]) - img[..., i] 838 839 # do scaling 840 if tscale: # scale bands independently 841 for b in range(3): 842 mn = kwds.get("vmin", float(np.nanmin(img))) 843 mx = kwds.get("vmax", float(np.nanmax(img))) 844 if isinstance (mn, int): 845 assert mn >= 0 and mn <= 100, "Error - integer vmin values must be a percentile." 846 mn = float(np.nanpercentile(img[...,b], mn )) 847 if isinstance (mx, int): 848 assert mx >= 0 and mx <= 100, "Error - integer vmax values must be a percentile." 849 mx = float(np.nanpercentile(img[...,b], mx )) 850 img[...,b] = (img[..., b] - mn) / (mx - mn) 851 else: # scale bands together 852 mn = kwds.get("vmin", float(np.nanmin(img))) 853 mx = kwds.get("vmax", float(np.nanmax(img))) 854 if isinstance(mn, int): 855 assert mn >= 0 and mn <= 100, "Error - integer vmin values must be a percentile." 856 mn = float(np.nanpercentile(img, mn)) 857 if isinstance(mx, int): 858 assert mx >= 0 and mx <= 100, "Error - integer vmax values must be a percentile." 859 mx = float(np.nanpercentile(img, mx)) 860 img = (img - mn) / (mx - mn) 861 862 #apply brightness/contrast mapping 863 img = np.clip((1.0 + cfac) * img + bfac, 0, 1.0 ) 864 865 #apply masking so background is white 866 img[np.logical_not( np.isfinite( img ) )] = 1.0 867 if 'mask' in kwds: 868 img[kwds.pop("mask"),:] = 1.0 869 870 # apply rotations and flipping 871 if rot: 872 img = np.transpose( img, (1,0,2) ) 873 if flipX: 874 img = img[::-1, :, :] 875 if flipY: 876 img = img[:, ::-1, :] 877 878 # save? 879 if 'path' in kwds: 880 path = kwds.pop('path') 881 from matplotlib.pyplot import imsave 882 if not os.path.exists(os.path.dirname(path)): 883 os.makedirs(os.path.dirname(path)) # ensure output directory exists 884 imsave(path, np.transpose( np.clip( img*255, 0, 255).astype(np.uint8), (1, 0, 2))) # save the image 885 886 # plot samples? 887 ps = kwds.pop('ps', 5) 888 pc = kwds.pop('pc', 'r') 889 if samples: 890 if isinstance(samples, list) or isinstance(samples, np.ndarray): 891 ax.scatter([s[0] for s in samples], [s[1] for s in samples], s=ps, c=pc) 892 else: 893 for n in self.header.get_class_names(): 894 points = np.array(self.header.get_sample_points(n)) 895 ax.scatter(points[:, 0], points[:, 1], s=ps, c=pc) 896 897 #plot 898 ax.imshow(np.transpose(img, (1,0,2)), interpolation=kwds.pop('interpolation', 'none'), **kwds) 899 ax.cbar = None # no colorbar 900 901 return ax.get_figure(), ax 902 903 def createGIF(self, path, bands=None, figsize=(10,10), fps=10, **kwds): 904 """ 905 Create and save an animated gif that loops through the bands of the image. 906 907 Args: 908 path (str): the path to save the .gif 909 bands (tuple): Tuple containing the range of band indices to draw. Default is the whole range. 910 figsize (tuple): the size of the image to draw. Default is (10,10). 911 fps (int): the framerate (frames per second) of the gif. Default is 10. 912 **kwds: keywords are passed directly to matplotlib.imshow. Use this to specify cmap etc. 913 """ 914 915 frames = [] 916 if bands is None: 917 bands = (0,self.band_count()) 918 else: 919 assert 0 < bands[0] < self.band_count(), "Error - invalid range." 920 assert 0 < bands[1] < self.band_count(), "Error - invalid range." 921 assert bands[1] > bands[0], "Error - invalid range." 922 923 #plot frames 924 for i in range(bands[0],bands[1]): 925 fig, ax = plt.subplots(figsize=figsize) 926 ax.imshow(self.data[:, :, i], **kwds) 927 fig.canvas.draw() 928 frames.append(np.frombuffer(fig.canvas.tostring_rgb(), dtype='uint8')) 929 frames[-1] = np.reshape(frames[-1], (fig.canvas.get_width_height()[1], fig.canvas.get_width_height()[0], 3)) 930 plt.close(fig) 931 932 #save gif 933 imageio.mimsave( os.path.splitext(path)[0] + ".gif", frames, fps=fps) 934 935 ## masking 936 def drop_bbl(self, drop=True): 937 """ 938 Remove bad bands as stored in the 'bbl' key in the image header. Note that this operates in-place. 939 940 Args: 941 drop (bool): True if bad bands should be completely dropped. If False, these bands will be kept but replaced with nans. 942 """ 943 assert 'bbl' in self.header, "Please specify a bad band list ('bbl') in the image header, as per the ENVI format definition." 944 mask = self.header.get_list('bbl') == 0 945 self.data[...,mask] = np.nan 946 if drop: 947 self.delete_nan_bands(inplace=True) 948 949 def mask(self, mask=None, flag=np.nan, invert=False, crop=False, bands=None): 950 """ 951 Apply a mask to an image, flagging masked pixels with the specified value. Note that this applies the mask to the 952 image in-situ. 953 954 Args: 955 flag (float): the value to use for masked pixels. Default is np.nan 956 mask (ndarray): a numpy array defining the mask polygon of the format [[x1,y1],[x2,y2],...]. If None is passed then 957 pickPolygon( ... ) is used to interactively define a polygon. If a file path is passed then the polygon 958 will be loaded using np.load( ... ). Alternatively if mask.shape == image.shape[0,1] then it is treated as a 959 binary image mask (must be boolean) and True values will be masked across all bands. Default is None. 960 invert (bool): if True, pixels within the polygon will be masked. If False, pixels outside the polygon are masked. Default is False. 961 crop (bool): True if rows/columns containing only zeros should be removed. Default is False. 962 bands (tuple): the bands of the image to plot if no mask is specified. If None, the middle band is used. 963 964 Returns: 965 Tuple containing 966 967 - mask (ndarray): a boolean array with True where pixels are masked and False elsewhere. 968 - poly (ndarray): the mask polygon array in the format described above. Useful if the polygon was interactively defined. 969 """ 970 971 if mask is None: # pick mask interactively 972 if bands is None: 973 bands = int(self.band_count() / 2) 974 975 regions = self.pickPolygons(region_names=["mask"], bands=bands) 976 977 # the user bailed without picking a mask? 978 if len(regions) == 0: 979 print("Warning - no mask picked/applied.") 980 return 981 982 # extract polygon mask 983 mask = regions[0] 984 985 # convert polygon mask to binary mask 986 if mask.shape[1] == 2: 987 988 # build meshgrid with pixel coords 989 xx, yy = np.meshgrid(np.arange(self.xdim()), np.arange(self.ydim())) 990 xx = xx.flatten() 991 yy = yy.flatten() 992 points = np.vstack([xx, yy]).T # coordinates of each pixel 993 994 # calculate per-pixel mask 995 mask = path.Path(mask).contains_points(points) 996 mask = mask.reshape((self.ydim(), self.xdim())).T 997 998 # flip as we want to mask (==True) outside points (unless invert is true) 999 if not invert: 1000 mask = np.logical_not(mask) 1001 1002 # apply binary image mask 1003 assert mask.shape[0] == self.data.shape[0] and mask.shape[1] == self.data.shape[1], \ 1004 "Error - mask shape %s does not match image shape %s" % (mask.shape, self.data.shape) 1005 for b in range(self.band_count()): 1006 self.data[:, :, b][mask] = flag 1007 1008 # crop image 1009 if crop: 1010 self.crop_to_data() 1011 1012 return mask 1013 1014 def crop_to_data(self): 1015 """ 1016 Remove padding of nan or zero pixels from image. Note that this is performed in place. 1017 """ 1018 valid = np.isfinite(self.data).any(axis=-1) & (self.data != 0).any(axis=-1) 1019 1020 # integrate along axes 1021 xdata = np.sum(valid, axis=1) > 0.0 1022 ydata = np.sum(valid, axis=0) > 0.0 1023 1024 # calculate domain containing valid pixels 1025 xmin = np.argmax(xdata) 1026 xmax = xdata.shape[0] - np.argmax(xdata[::-1]) 1027 ymin = np.argmax(ydata) 1028 ymax = ydata.shape[0] - np.argmax(ydata[::-1]) 1029 1030 # crop 1031 self.data = self.data[xmin:xmax, ymin:ymax, :] 1032 1033 # shift affine origin to new top-left pixel 1034 if self.affine is not None: 1035 a = self.affine # shorthand for affine 1036 new_affine = list(self.affine) 1037 new_affine[0] = a[0] + xmin*a[1] + ymin*a[2] 1038 new_affine[3] = a[3] + xmin*a[4] + ymin*a[5] 1039 self.affine = np.array(new_affine) 1040 self.header['affine'] = self.affine 1041 1042 ################################################## 1043 ## Interactive tools for picking regions/pixels 1044 ################################################## 1045 def pickPolygons(self, region_names, bands=0): 1046 """ 1047 Creates a matplotlib gui for selecting polygon regions in an image. 1048 1049 Args: 1050 names (list, str): a list containing the names of the regions to pick. If a string is passed only one name is used. 1051 bands (tuple): the bands of the image to plot. 1052 """ 1053 1054 if isinstance(region_names, str): 1055 region_names = [region_names] 1056 1057 assert isinstance(region_names, list), "Error - names must be a list or a string." 1058 1059 # set matplotlib backend 1060 backend = matplotlib.get_backend() 1061 matplotlib.use('Qt5Agg') # need this backend for ROIPoly to work 1062 1063 # plot image and extract roi's 1064 fig, ax = self.quick_plot(bands) 1065 roi = MultiRoi(roi_names=region_names) 1066 plt.close(fig) # close figure 1067 1068 # extract regions 1069 regions = [] 1070 for name, r in roi.rois.items(): 1071 # store region 1072 x = r.x 1073 y = r.y 1074 regions.append(np.vstack([x, y]).T) 1075 1076 # restore matplotlib backend (if possible) 1077 try: 1078 matplotlib.use(backend) 1079 except: 1080 print("Warning: could not reset matplotlib backend. Plots will remain interactive...") 1081 pass 1082 1083 return regions 1084 1085 def pickPoints(self, n=-1, bands=hylite.RGB, integer=True, title="Pick Points", **kwds): 1086 """ 1087 Creates a matplotlib gui for picking pixels from an image. 1088 1089 Args: 1090 n (int): the number of pixels to pick, or -1 if the user can select as many as they wish. Default is -1. 1091 bands (tuple): the bands of the image to plot. Default is HyImage.RGB 1092 integer (bool): True if points coordinates should be cast to integers (for use as indices). Default is True. 1093 title (str): The title of the point picking window. 1094 **kwds: Keywords are passed to HyImage.quick_plot( ... ). 1095 1096 Returns: 1097 A list containing the picked point coordinates [ (x1,y1), (x2,y2), ... ]. 1098 """ 1099 1100 # set matplotlib backend 1101 backend = matplotlib.get_backend() 1102 matplotlib.use('Qt5Agg') # need this backend for ROIPoly to work 1103 1104 # create figure 1105 fig, ax = self.quick_plot( bands, **kwds ) 1106 ax.set_title(title) 1107 1108 # get points 1109 points = fig.ginput( n ) 1110 1111 if integer: 1112 points = [ (int(p[0]), int(p[1])) for p in points ] 1113 1114 # restore matplotlib backend (if possible) 1115 try: 1116 matplotlib.use(backend) 1117 except: 1118 print("Warning: could not reset matplotlib backend. Plots will remain interactive...") 1119 pass 1120 1121 return points 1122 1123 def pickSamples(self, names=None, store=True, **kwds): 1124 """ 1125 Pick sample probe points and store these in the image header file. 1126 1127 Args: 1128 names (str, list): the name of the sample to pick, or a list of names to pick multiple. 1129 store (bool): True if sample should be stored in the image header file (for later access). Default is True. 1130 **kwds: Keywords are passed to HyImage.quick_plot( ... ) 1131 1132 Returns: 1133 a list containing a list of points for each sample. 1134 """ 1135 1136 if isinstance(names, str): 1137 names = [names] 1138 1139 # pick points 1140 points = [] 1141 for s in names: 1142 pnts = self.pickPoints(title="%s" % s, **kwds) 1143 if store: 1144 self.header['sample %s' % s] = pnts # store in header 1145 points.append(pnts) 1146 # add class to header file 1147 if store: 1148 cls_names = self.header.get_class_names() 1149 if cls_names is None: 1150 cls_names = [] 1151 self.header['class names'] = cls_names + names 1152 1153 return points
20class HyImage( HyData ): 21 """ 22 A class for hyperspectral image data. These can be individual scenes or hyperspectral orthoimages. 23 """ 24 25 def __init__(self, data, **kwds): 26 """ 27 Args: 28 data (ndarray): a numpy array such that data[x][y][band] gives each pixel value. 29 **kwds: 30 wav = A numpy array containing band wavelengths for this image. 31 affine = an affine transform of the format returned by GDAL.GetGeoTransform(). 32 projection = string defining the project. Default is None. 33 sensor = sensor name. Default is "unknown". 34 header = path to associated header file. Default is None. 35 """ 36 37 #call constructor for HyData 38 super().__init__(data, **kwds) 39 40 # special case - if dataset only has oneband, slice it so it still has 41 # the format data[x,y,b]. 42 if not self.data is None: 43 if len(self.data.shape) == 1: 44 self.data = self.data[None, None, :] # single pixel image 45 if len(self.data.shape) == 2: 46 self.data = self.data[:, :, None] # single band iamge 47 48 #load any additional project information (specific to images) 49 self.set_projection(kwds.get("projection",None)) 50 self.affine = np.array( kwds.get("affine",[0,1,0,0,0,1]) ) 51 self.header['affine'] = kwds.get("affine",[0,1,0,0,0,1]) # also store this here 52 53 # wavelengths 54 if 'wav' in kwds: 55 self.set_wavelengths(kwds['wav']) 56 57 #special header formatting 58 self.header['file type'] = 'ENVI Standard' 59 60 def copy(self,data=True): 61 """ 62 Make a deep copy of this image instance. 63 64 Args: 65 data (bool): True if a copy of the data should be made, otherwise only copy header. 66 67 Returns: 68 a new HyImage instance. 69 """ 70 if not data: 71 return HyImage(None, header=self.header.copy(), projection=self.projection, affine=self.affine) 72 else: 73 return HyImage( self.data.copy(), header=self.header.copy(), projection=self.projection, affine=self.affine) 74 75 def T(self): 76 """ 77 Return a transposed view of the data matrix (corresponding with the [y,x] indexing used by matplotlib, opencv etc. 78 """ 79 return np.transpose(self.data, (1,0,2)) 80 81 def xdim(self): 82 """ 83 Return number of pixels in x (first dimension of data array) 84 """ 85 return self.data.shape[0] 86 87 def ydim(self): 88 """ 89 Return number of pixels in y (second dimension of data array) 90 """ 91 return self.data.shape[1] 92 93 def aspx(self): 94 """ 95 Return the aspect ratio of this image (width/height). 96 """ 97 return self.ydim() / self.xdim() 98 99 ##################################### 100 ## GEOREFERENCING METHODS 101 ##################################### 102 103 def get_extent(self): 104 """ 105 Returns the width and height of this image in world coordinates. 106 107 Returns: 108 tuple with (width, height). 109 """ 110 return self.xdim * self.pixel_size[0], self.ydim * self.pixel_size[1] 111 112 def set_projection(self,proj): 113 """ 114 Set this project to an existing osgeo.osr.SpatialReference or GDAL georeference string. 115 116 Args: 117 proj (str, osgeo.osr.SpatialReference): the project to use as osgeo.osr.SpatialReference or GDAL georeference string. 118 """ 119 if proj is None: 120 self.projection = None 121 else: 122 try: 123 from osgeo.osr import SpatialReference 124 except: 125 assert False, "Error - GDAL must be installed to work with spatial projections in hylite." 126 if isinstance(proj, SpatialReference): 127 self.projection = proj 128 elif isinstance(proj, str): 129 self.projection = SpatialReference(proj) 130 else: 131 print("Invalid project %s" % proj) 132 raise 133 134 def set_projection_EPSG(self,EPSG): 135 """ 136 Sets this image project using an EPSG code. 137 138 Args: 139 EPSG (str): string EPSG code that can be passed to SpatialReference.SetFromUserInput(...). 140 """ 141 142 try: 143 from osgeo.osr import SpatialReference 144 except: 145 assert False, "Error - GDAL must be installed to work with spatial projections in hylite." 146 147 self.projection = SpatialReference() 148 self.projection.SetFromUserInput(EPSG) 149 150 def get_projection_EPSG(self): 151 """ 152 Gets a string describing this projections EPSG code (if it is an EPSG project). 153 154 Returns: 155 an EPSG code string of the format "EPSG:XXXX". 156 """ 157 if self.projection is None: 158 return None 159 else: 160 return "%s:%s" % (self.projection.GetAttrValue("AUTHORITY",0),self.projection.GetAttrValue("AUTHORITY",1)) 161 162 def pix_to_world(self, px, py, proj=None): 163 """ 164 Take pixel coordinates and return world coordinates 165 166 Args: 167 px (int): the pixel x-coord. 168 py (int): the pixel y-coord. 169 proj (str, osr.SpatialReference): the coordinate system to use. Default (None) uses the same system as this image. Otherwise 170 an osr.SpatialReference can be passed (HyImage.project), or an EPSG string (e.g. get_projection_EPSG(...)). 171 Returns: 172 the world coordinates in the coordinate system defined by get_projection_EPSG(...). 173 """ 174 175 try: 176 from osgeo import osr 177 import osgeo.gdal as gdal 178 from osgeo import ogr 179 except: 180 assert False, "Error - GDAL must be installed to work with spatial projections in hylite." 181 182 # parse project 183 if proj is None: 184 proj = self.projection 185 elif isinstance(proj, str) or isinstance(proj, int): 186 epsg = proj 187 if isinstance(epsg, str): 188 try: 189 epsg = int(str.split(':')[1]) 190 except: 191 assert False, "Error - %s is an invalid EPSG code." % proj 192 proj = osr.SpatialReference() 193 proj.ImportFromEPSG(epsg) 194 195 # check we have all the required info 196 assert isinstance(proj, osr.SpatialReference), "Error - invalid spatial reference %s" % proj 197 assert (not self.affine is None) and ( 198 not self.projection is None), "Error - project information is undefined." 199 200 #project to world coordinates in this images project/world coords 201 x,y = gdal.ApplyGeoTransform(self.affine, px, py) 202 203 #project to target coords (if different) 204 if not proj.IsSameGeogCS(self.projection): 205 P = ogr.Geometry(ogr.wkbPoint) 206 if proj.EPSGTreatsAsNorthingEasting(): 207 P.AddPoint(x, y) 208 else: 209 P.AddPoint(y, x) 210 P.AssignSpatialReference(self.projection) # tell the point what coordinates it's in 211 P.TransformTo(proj) # reproject it to the out spatial reference 212 x, y = P.GetX(), P.GetY() 213 214 #do we need to transpose? 215 if proj.EPSGTreatsAsLatLong(): 216 x,y=y,x #we want lon,lat not lat,lon 217 return x, y 218 219 def world_to_pix(self, x, y, proj = None): 220 """ 221 Take world coordinates and return pixel coordinates 222 223 Args: 224 x (float): the world x-coord. 225 y (float): the world y-coord. 226 proj (str, osr.SpatialReference): the coordinate system of the input coordinates. Default (None) uses the same system as this image. Otherwise 227 an osr.SpatialReference can be passed (HyImage.project), or an EPSG string (e.g. get_projection_EPSG(...)). 228 229 Returns: 230 the pixel coordinates based on the affine transform stored in self.affine. 231 """ 232 233 try: 234 from osgeo import osr 235 import osgeo.gdal as gdal 236 from osgeo import ogr 237 except: 238 assert False, "Error - GDAL must be installed to work with spatial projections in hylite." 239 240 # parse project 241 if proj is None: 242 proj = self.projection 243 elif isinstance(proj, str) or isinstance(proj, int): 244 epsg = proj 245 if isinstance(epsg, str): 246 try: 247 epsg = int(str.split(':')[1]) 248 except: 249 assert False, "Error - %s is an invalid EPSG code." % proj 250 proj = osr.SpatialReference() 251 proj.ImportFromEPSG(epsg) 252 253 254 # check we have all the required info 255 assert isinstance(proj, osr.SpatialReference), "Error - invalid spatial reference %s" % proj 256 assert (not self.affine is None) and (not self.projection is None), "Error - project information is undefined." 257 258 # project to this images CS (if different) 259 if not proj.IsSameGeogCS(self.projection): 260 P = ogr.Geometry(ogr.wkbPoint) 261 if proj.EPSGTreatsAsNorthingEasting(): 262 P.AddPoint(x, y) 263 else: 264 P.AddPoint(y, x) 265 P.AssignSpatialReference(proj) # tell the point what coordinates it's in 266 P.AddPoint(x, y) 267 P.TransformTo(self.projection) # reproject it to the out spatial reference 268 x, y = P.GetX(), P.GetY() 269 if self.projection.EPSGTreatsAsLatLong(): # do we need to transpose? 270 x, y = y, x # we want lon,lat not lat,lon 271 272 inv = gdal.InvGeoTransform(self.affine) 273 assert not inv is None, "Error - could not invert affine transform?" 274 275 #apply 276 return gdal.ApplyGeoTransform(inv, x, y) 277 278 def crop(self, xmin, xmax, ymin, ymax, bands=None): 279 """ 280 Return a cropped copy of this image. 281 282 Args: 283 xmin, xmax (int): pixel bounds in x (rows) 284 ymin, ymax (int): pixel bounds in y (columns) 285 bands (None, list, tuple): optional band indices or (min,max) range 286 287 Returns: 288 HyImage: cropped image with updated affine transform 289 """ 290 291 # ---- validate bounds ---- 292 xmin = int(max(0, xmin)) 293 ymin = int(max(0, ymin)) 294 xmax = int(min(self.xdim(), xmax)) 295 ymax = int(min(self.ydim(), ymax)) 296 297 assert xmin < xmax and ymin < ymax, "Invalid crop extent." 298 299 # ---- crop data ---- 300 if bands is None: 301 data = self.data[xmin:xmax, ymin:ymax, :].copy() 302 wav = self.get_wavelengths() 303 else: # band selection 304 if isinstance(bands, tuple): 305 b0 = self.get_band_index(bands[0]) 306 b1 = self.get_band_index(bands[1]) 307 data = self.data[xmin:xmax, ymin:ymax, b0:b1].copy() 308 wav = self.get_wavelengths()[b0:b1] 309 else: 310 idx = [self.get_band_index(b) for b in bands] 311 data = self.data[xmin:xmax, ymin:ymax, idx].copy() 312 wav = self.get_wavelengths()[idx] 313 314 # ---- update affine transform ---- 315 if self.affine is not None: 316 a = list(self.affine) 317 new_affine = a.copy() 318 319 # shift origin to new top-left pixel 320 new_affine[0] = a[0] + xmin*a[1] + ymin*a[2] 321 new_affine[3] = a[3] + xmin*a[4] + ymin*a[5] 322 else: 323 new_affine = None 324 325 # ---- construct output image ---- 326 out = HyImage( 327 data, 328 header=self.header.copy(), 329 projection=self.projection, 330 affine=new_affine, 331 wav=wav 332 ) 333 334 return out 335 336 def resize(self, newdims: tuple, interpolation: int = 1): 337 """ 338 Resize this image with opencv and update affine transform accordingly. 339 340 Args: 341 newdims (tuple): the new image dimensions (xdim, ydim) 342 interpolation (int): opencv interpolation method. Default is cv2.INTER_LINEAR. 343 """ 344 import cv2 # avoid import issues if opencv is missing 345 346 old_x, old_y = self.xdim(), self.ydim() 347 new_x, new_y = int(newdims[0]), int(newdims[1]) 348 349 assert new_x > 0 and new_y > 0, "Invalid resize dimensions." 350 351 # resize data (opencv uses width, height = y, x) 352 self.data = cv2.resize( 353 self.data, 354 (new_y, new_x), 355 interpolation=interpolation 356 ) 357 358 # update affine transform 359 if self.affine is not None: 360 a = list(self.affine) 361 362 sx = old_x / new_x 363 sy = old_y / new_y 364 365 self.affine = [ 366 a[0], # x origin unchanged 367 a[1] * sx, # pixel width 368 a[2] * sy, # row rotation 369 a[3], # y origin unchanged 370 a[4] * sx, # column rotation 371 a[5] * sy # pixel height 372 ] 373 374 def tile(self, tile_size): 375 """ 376 Break image into tiles of given size and return a list of HyImage tiles. 377 Each tile has an updated affine transform reflecting its position in the original image. 378 379 Args: 380 tile_size (tuple): (tile_x, tile_y) in pixels 381 Returns: 382 list of HyImage 383 """ 384 tiles = [] 385 tx, ty = tile_size 386 nx, ny = self.xdim(), self.ydim() 387 for i in range(0, nx, tx): 388 for j in range(0, ny, ty): 389 data_tile = self.data[i:min(i+tx, nx), j:min(j+ty, ny), :].copy() # slice data 390 391 # compute new affine 392 # affine = [x0, dx, rotx, y0, roty, dy] 393 x0, dx, rotx, y0, roty, dy = self.affine 394 new_x0 = x0 + i*dx + j*rotx 395 new_y0 = y0 + i*roty + j*dy 396 new_affine = [new_x0, dx, rotx, new_y0, roty, dy] 397 398 # create tile 399 tile_img = HyImage( 400 data_tile, 401 affine=new_affine, 402 projection=self.projection, 403 wav=self.get_wavelengths(), 404 header=self.header.copy() 405 ) 406 tile_img.header['xleft'] = i 407 tile_img.header['ytop'] = j 408 tiles.append(tile_img) 409 410 return tiles 411 412 @staticmethod 413 def mosaic( 414 tiles, 415 blend="mean", 416 resampling="nearest", 417 out_affine=None, 418 out_shape=None, 419 ): 420 """ 421 Mosaic georeferenced HyImage tiles using GDAL. Note that this assumes all tiles are in the same coordinate system. 422 423 Args: 424 tiles (list[HyImage]) 425 blend (str): 'first', 'min', 'max', 'mean', 'median' 426 resampling (str): 'nearest', 'bilinear', 'cubic' 427 out_affine (list): optional 6-element affine to define output grid. If None, the affine of the first tile is used. 428 out_shape (tuple): optional (xdim, ydim) shape of the output grid. If None, the extent of all tiles will be used. 429 Returns: 430 HyImage 431 """ 432 import numpy as np 433 from hylite.project.align import resample_raster 434 from osgeo import gdal, osr 435 436 assert len(tiles) > 0 437 assert blend in ("first", "min", "max", "mean", "median") 438 439 # compute bounds in world coordinates 440 # N.B. THIS ASSUMES ALL DATA ARE IN THE SAME CRS 441 points = [] 442 for t in tiles: 443 points.append( t.pix_to_world(0,0) ) 444 points.append( t.pix_to_world(t.xdim()+1,t.ydim()+1) ) 445 min_x, min_y = np.min(points, axis=0) 446 max_x, max_y = np.max(points, axis=0) 447 if out_shape is None: 448 out_shape = (np.array(tiles[0].world_to_pix(max_x, max_y)) + np.array(tiles[0].world_to_pix(min_x, min_y))).round().astype(int) 449 if out_affine is None: 450 out_affine = list(tiles[0].affine) 451 out_affine[0] = min_x 452 out_affine[3] = max_y 453 454 if blend == "first": # fill output, first come, first served. 455 out = np.full( tuple(out_shape) + (tiles[0].band_count(),), np.nan, dtype=np.float32) 456 for t in tiles: 457 r = resample_raster( t.data, t.affine, out_affine, out_shape ) 458 mask = np.isnan(out) 459 out[mask] = r[mask] 460 elif blend == "min": # keep minimum value in case of overlap 461 out = np.nanmin( np.stack([ resample_raster( t.data, t.affine, out_affine, out_shape ) for t in tiles ], 462 axis=0), axis=0 ) 463 elif blend == "max": # keep maximum value in case of overlap 464 out = np.nanmax( np.stack([ resample_raster( t.data, t.affine, out_affine, out_shape ) for t in tiles ], 465 axis=0), axis=0 ) 466 elif blend == "mean": # use average in case of overlap 467 out = np.nanmean( np.stack([ resample_raster( t.data, t.affine, out_affine, out_shape ) for t in tiles ], 468 axis=0), axis=0 ) 469 elif blend == "median": # use mean in case of overlap 470 out = np.nanmedian( np.stack([ resample_raster( t.data, t.affine, out_affine, out_shape ) for t in tiles ], 471 axis=0), axis=0 ) 472 473 # Return HyImage 474 out = HyImage( 475 out, 476 affine=out_affine, 477 projection=tiles[0].projection, 478 wav=tiles[0].get_wavelengths(), 479 header=tiles[0].header.copy() 480 ) 481 if 'xleft' in out.header: del out.header['xleft'] # stored in tiles, but not meaningful here 482 if 'ytop' in out.header: del out.header['ytop'] # stored in tiles, but not meaningful here 483 484 return out 485 486 ##################################### 487 ## BASIC TRANSFORMS 488 ##################################### 489 490 def flip(self, axis='x'): 491 """ 492 Flip the image on the x or y axis. Note that this will remove any defined affine transform. 493 494 Args: 495 axis (str): 'x' or 'y' or both 'xy'. 496 """ 497 498 if 'x' in axis.lower(): 499 self.data = np.flip(self.data,axis=0) 500 if 'y' in axis.lower(): 501 self.data = np.flip(self.data,axis=1) 502 self.affine = None 503 if 'affine' in self.header: del self.header['affine'] 504 self.push_to_header() # update width and height info 505 506 def rot90(self): 507 """ 508 Rotate this image by 90 degrees by transposing the underlying data array. Combine with flip('x') or flip('y') 509 to achieve positive/negative rotations. 510 """ 511 self.data = np.transpose( self.data, (1,0,2) ) 512 self.affine = None 513 if 'affine' in self.header: del self.header['affine'] 514 self.push_to_header() # update width and height info 515 516 ##################################### 517 ##IMAGE FILTERING 518 ##################################### 519 def fill_holes(self): 520 """ 521 Replaces nan pixel with an average of their neighbours, thus removing 1-pixel large holes from an image. Note that 522 for performance reasons this assumes that holes line up across bands. Note that this is not vectorized so very slow... 523 """ 524 525 # perform greyscale dilation 526 dilate = self.data.copy() 527 mask = np.logical_not(np.isfinite(dilate)) 528 dilate[mask] = 0 529 for b in range(self.band_count()): 530 dilate[:, :, b] = sp.ndimage.grey_dilation(dilate[:, :, b], size=(3, 3)) 531 532 # map back to holes in dataset 533 self.data[mask] = dilate[mask] 534 #self.data[self.data == 0] = np.nan # replace remaining 0's with nans 535 536 def blur(self, n=3): 537 """ 538 Applies a gaussian kernel of size n to the image using OpenCV. 539 540 Args: 541 n (int): the dimensions of the gaussian kernel to convolve. Default is 3. Increase for more blurry results. 542 """ 543 import cv2 # import this here to avoid errors if opencv is not installed properly 544 545 nanmask = np.isnan(self.data) 546 assert isinstance(n, int) and n >= 3, "Error - invalid kernel. N must be an integer > 3. " 547 kernel = np.ones((n, n), np.float32) / (n ** 2) 548 self.data = cv2.filter2D(self.data, -1, kernel) 549 self.data[nanmask] = np.nan # remove mask 550 551 def erode(self, size=3, iterations=1): 552 """ 553 Apply an erode filter to this image to expand background (nan) pixels. Refer to open-cv's erode 554 function for more details. 555 556 Args: 557 size (int): the size of the erode filter. Default is a 3x3 kernel. 558 iterations (int): the number of erode iterations. Default is 1. 559 """ 560 import cv2 # import this here to avoid errors if opencv is not installed properly 561 562 # erode 563 kernel = np.ones((size, size), np.uint8) 564 if self.is_float(): 565 mask = np.isfinite(self.data).any(axis=-1) 566 mask = cv2.erode(mask.astype(np.uint8), kernel, iterations=iterations) 567 self.data[mask == 0, :] = np.nan 568 else: 569 mask = (self.data != 0).any( axis=-1 ) 570 mask = cv2.erode(mask.astype(np.uint8), kernel, iterations=iterations) 571 self.data[mask == 0, :] = 0 572 573 def despeckle(self, size=5): 574 """ 575 Despeckle each band of this image (independently) using a median filter. 576 577 Args: 578 size (int): the size of the median filter kernel. Default is 5. Must be an odd number. 579 """ 580 581 assert (size % 2) == 1, "Error - size must be an odd integer" 582 import cv2 # import this here to avoid errors if opencv is not installed properly 583 if self.is_float(): 584 self.data = cv2.medianBlur( self.data.astype(np.float32), size ) 585 else: 586 self.data = cv2.medianBlur( self.data, size ) 587 588 ##################################### 589 ##FEATURES AND FEATURE MATCHING 590 ###################################### 591 def get_keypoints(self, band, eq=False, mask=True, method='sift', cfac=0.0,bfac=0.0, **kwds): 592 """ 593 Get feature descriptors from the specified band. 594 595 Args: 596 band (int,float,str,tuple): the band index (int) or wavelength (float) to extract features from. Alternatively, a tuple can be passed 597 containing a range of bands (min : max) to average before feature matching. 598 eq (bool): True if the image should be histogram equalized first. Default is False. 599 mask (bool): True if 0 value pixels should be masked. Default is True. 600 method (str): the feature detector to use. Options are 'SIFT' and 'ORB' (faster but less accurate). Default is 'SIFT'. 601 cfac (float): contrast adjustment to apply to hyperspectral bands before matching. Default is 0.0. 602 bfac (float): brightness adjustment to apply to hyperspectral bands before matching. Default is 0.0. 603 **kwds: keyword arguments are passed to the opencv feature detector. For SIFT these are: 604 605 - contrastThreshold: default is 0.01. 606 - edgeThreshold: default is 10. 607 - sigma: default is 1.0 608 609 For ORB these are: 610 611 - nfeatures = the number of features to detect. Default is 5000. 612 613 Returns: 614 Tuple containing 615 616 - k (ndarray): the keypoints detected 617 - d (ndarray): corresponding feature descriptors 618 """ 619 import cv2 # import this here to avoid errors if opencv is not installed properly 620 621 # get image 622 if isinstance(band, int) or isinstance(band, float): #single band 623 image = self.data[:, :, self.get_band_index(band)] 624 elif isinstance(band,tuple): #range of bands (averaged) 625 idx0 = self.get_band_index(band[0]) 626 idx1 = self.get_band_index(band[1]) 627 628 #deal with out of range errors 629 if idx0 is None: 630 idx0 = 0 631 if idx1 is None: 632 idx1 = self.band_count() 633 634 #average bands 635 image = np.nanmean(self.data[:,:,idx0:idx1],axis=2) 636 else: 637 assert False, "Error, unrecognised band %s" % band 638 639 #normalise image to range 0 - 1 640 image -= np.nanmin(image) 641 image = image / np.nanmax(image) 642 643 #apply brightness/contrast adjustment 644 image = (1.0+cfac)*image + bfac 645 image[image > 1.0] = 1.0 646 image[image < 0.0] = 0.0 647 648 #convert image to uint8 for opencv 649 image = np.uint8(255 * image) 650 if eq: 651 image = cv2.equalizeHist(image) 652 653 if mask: 654 mask = np.zeros(image.shape, dtype=np.uint8) 655 mask[image != 0] = 255 # include only non-zero pixels 656 else: 657 mask = None 658 659 if 'sift' in method.lower(): # SIFT 660 661 # setup default keywords 662 kwds["contrastThreshold"] = kwds.get("contrastThreshold", 0.01) 663 kwds["edgeThreshold"] = kwds.get("edgeThreshold", 10) 664 kwds["sigma"] = kwds.get("sigma", 1.0) 665 666 # make feature detector 667 #alg = cv2.xfeatures2d.SIFT_create(**kwds) 668 alg = cv2.SIFT_create() 669 elif 'orb' in method.lower(): # orb 670 kwds['nfeatures'] = kwds.get('nfeatures', 5000) 671 alg = cv2.ORB_create(scoreType=cv2.ORB_FAST_SCORE, **kwds) 672 else: 673 assert False, "Error - %s is not a recognised feature detector." % method 674 675 # detect keypoints 676 kp = alg.detect(image, mask) 677 678 # extract and return feature vectors 679 return alg.compute(image, kp) 680 681 @classmethod 682 def match_keypoints(cls, kp1, kp2, d1, d2, method='SIFT', dist=0.7, tree = 5, check = 100, min_count=5): 683 """ 684 Compares keypoint feature vectors from two images and returns matching pairs. 685 686 Args: 687 kp1 (ndarray): keypoints from the first image 688 kp2 (ndarray): keypoints from the second image 689 d1 (ndarray): descriptors for the keypoints from the first image 690 d2 (ndarray): descriptors for the keypoints from the second image 691 method (str): the method used to calculate the feature descriptors. Should be 'sift' or 'orb'. Default is 'sift'. 692 dist (float): minimum match distance (0 to 1), default is 0.7 693 tree (int): not sure what this does? Default is 5. See open-cv docs. 694 check (int): ditto. Default is 100. 695 min_count (int): the minimum number of matches to consider a valid matching operation. If fewer matches are found, 696 then the function returns None, None. Default is 5. 697 """ 698 import cv2 # import this here to avoid errors if opencv is not installed properly 699 if 'sift' in method.lower(): 700 algorithm = cv2.NORM_INF 701 elif 'orb' in method.lower(): 702 algorithm = cv2.NORM_HAMMING 703 else: 704 assert False, "Error - unknown matching algorithm %s" % method 705 706 #calculate flann matches 707 index_params = dict(algorithm=algorithm, trees=tree) 708 search_params = dict(checks=check) 709 flann = cv2.FlannBasedMatcher(index_params, search_params) 710 matches = flann.knnMatch(d1, d2, k=2) 711 712 # store all the good matches as per Lowe's ratio test. 713 good = [] 714 for m, n in matches: 715 if m.distance < dist * n.distance: 716 good.append(m) 717 718 if len(good) < min_count: 719 return None, None 720 else: 721 src_pts = np.float32([kp1[m.queryIdx].pt for m in good]).reshape(-1, 1, 2) 722 dst_pts = np.float32([kp2[m.trainIdx].pt for m in good]).reshape(-1, 1, 2) 723 return src_pts, dst_pts 724 725 ############################ 726 ## Visualisation methods 727 ############################ 728 def quick_plot(self, bands=0, ax=None, bfac=0.0, cfac=0.0, samples=False, tscale=False, invert=False, rot=False, flipX=False, flipY=False, 729 **kwds): 730 """ 731 Plot a band using matplotlib.imshow(...). 732 733 Args: 734 bands (str,int,float,tuple): the band name (string), index (integer) or wavelength (float) to plot. Default is 0. If a tuple is passed then 735 each band in the tuple (string or index) will be mapped to rgb. Bands with negative wavelengths or indices will be inverted before plotting. 736 ax: an axis object to plot to. If none, plt.imshow( ... ) is used. 737 bfac (float): a brightness adjustment to apply to RGB mappings (-1 to 1) 738 cfac (float): a contrast adjustment to apply to RGB mappings (-1 to 1) 739 samples (bool): True if sample points (defined in the header file) should be plotted. Default is False. Otherwise, a list of 740 [ (x,y), ... ] points can be passed. 741 tscale (bool): True if each band (for ternary images) should be scaled independently. Default is False. 742 When using scaling, vmin and vmax can be used to set the clipping percentiles (integers) or 743 (constant) values (float). 744 invert (bool) : True if each band should be inverted before plotting. Only works for multiband (ternary) images. 745 rot (bool): if True, the x and y axis will be flipped (90 degree rotation) before plotting. Default is False. 746 flipX (bool): if True, the x axis will be flipped before plotting (after applying rotations). 747 flipY (bool): if True, the y axis will be flippe before plotting (after applying rotations). 748 **kwds: keywords are passed to matplotlib.imshow( ... ), except for the following: 749 750 - mask = a 2D boolean mask containing true if pixels should be drawn and false otherwise. 751 - path = a file path to save the image too (at matching resolution; use fig.savefig(..) if you want to save the figure). 752 - ticks = True if x- and y- ticks should be plotted. Default is False. 753 - ps, pc = the size and color of sample points to plot. Can be constant or list. 754 - figsize = a figsize for the figure to create (if ax is None). 755 756 Returns: 757 Tuple containing 758 759 - fig: matplotlib figure object 760 - ax: matplotlib axes object. If a colorbar is created, (band is an integer or a float), then this will be stored in ax.cbar. 761 """ 762 763 #create new axes? 764 if ax is None: 765 fig, ax = plt.subplots(figsize=kwds.pop('figsize', (18,18*self.ydim()/self.xdim()) )) 766 767 # deal with ticks 768 if not kwds.pop('ticks', False ): 769 ax.set_xticks([]) 770 ax.set_yticks([]) 771 772 #map individual band using colourmap 773 if isinstance(bands, str) or isinstance(bands, int) or isinstance(bands, float): 774 #get band 775 if isinstance(bands, str): 776 data = self.data[:, :, self.get_band_index(bands)] 777 else: 778 data = self.data[:, :, self.get_band_index(np.abs(bands))] 779 if not isinstance(bands, str) and bands < 0: 780 data = np.nanmax(data) - data # flip 781 782 # convert integer vmin and vmax values to percentiles 783 if 'vmin' in kwds: 784 if isinstance(kwds['vmin'], int): 785 kwds['vmin'] = np.nanpercentile( data, kwds['vmin'] ) 786 if 'vmax' in kwds: 787 if isinstance(kwds['vmax'], int): 788 kwds['vmax'] = np.nanpercentile( data, kwds['vmax'] ) 789 790 #mask nans (and apply custom mask) 791 mask = np.isnan(data) 792 if not np.isnan(self.header.get_data_ignore_value()): 793 mask = mask + data == self.header.get_data_ignore_value() 794 if 'mask' in kwds: 795 mask = mask + kwds.get('mask') 796 del kwds['mask'] 797 data = np.ma.array(data, mask = mask > 0 ) 798 799 # apply rotations and flipping 800 if rot: 801 data = data.T 802 if flipX: 803 data = data[::-1, :] 804 if flipY: 805 data = data[:, ::-1] 806 807 # save? 808 if 'path' in kwds: 809 path = kwds.pop('path') 810 from matplotlib.pyplot import imsave 811 if not os.path.exists(os.path.dirname(path)): 812 os.makedirs(os.path.dirname(path)) # ensure output directory exists 813 imsave(path, data.T, **kwds) # save the image 814 815 ax.cbar = ax.imshow(data.T, interpolation=kwds.pop('interpolation', 'none'), **kwds) # change default interpolation to None 816 817 #map 3 bands to RGB 818 elif isinstance(bands, tuple) or isinstance(bands, list): 819 #get band indices and range 820 rgb = [] 821 for b in bands: 822 if isinstance(b, str): 823 rgb.append(self.get_band_index(b)) 824 else: 825 rgb.append(self.get_band_index(np.abs(b))) 826 827 #slice image (as copy) and map to 0 - 1 828 img = np.array(self.data[:, :, rgb]).copy() 829 if np.isnan(img).all(): 830 print("Warning - image contains no data.") 831 return ax.get_figure(), ax 832 833 # invert if needed 834 if invert: 835 bands = [-b for b in bands] 836 for i,b in enumerate(bands): 837 if not isinstance(b, str) and (b < 0): 838 img[..., i] = np.nanmax(img[..., i]) - img[..., i] 839 840 # do scaling 841 if tscale: # scale bands independently 842 for b in range(3): 843 mn = kwds.get("vmin", float(np.nanmin(img))) 844 mx = kwds.get("vmax", float(np.nanmax(img))) 845 if isinstance (mn, int): 846 assert mn >= 0 and mn <= 100, "Error - integer vmin values must be a percentile." 847 mn = float(np.nanpercentile(img[...,b], mn )) 848 if isinstance (mx, int): 849 assert mx >= 0 and mx <= 100, "Error - integer vmax values must be a percentile." 850 mx = float(np.nanpercentile(img[...,b], mx )) 851 img[...,b] = (img[..., b] - mn) / (mx - mn) 852 else: # scale bands together 853 mn = kwds.get("vmin", float(np.nanmin(img))) 854 mx = kwds.get("vmax", float(np.nanmax(img))) 855 if isinstance(mn, int): 856 assert mn >= 0 and mn <= 100, "Error - integer vmin values must be a percentile." 857 mn = float(np.nanpercentile(img, mn)) 858 if isinstance(mx, int): 859 assert mx >= 0 and mx <= 100, "Error - integer vmax values must be a percentile." 860 mx = float(np.nanpercentile(img, mx)) 861 img = (img - mn) / (mx - mn) 862 863 #apply brightness/contrast mapping 864 img = np.clip((1.0 + cfac) * img + bfac, 0, 1.0 ) 865 866 #apply masking so background is white 867 img[np.logical_not( np.isfinite( img ) )] = 1.0 868 if 'mask' in kwds: 869 img[kwds.pop("mask"),:] = 1.0 870 871 # apply rotations and flipping 872 if rot: 873 img = np.transpose( img, (1,0,2) ) 874 if flipX: 875 img = img[::-1, :, :] 876 if flipY: 877 img = img[:, ::-1, :] 878 879 # save? 880 if 'path' in kwds: 881 path = kwds.pop('path') 882 from matplotlib.pyplot import imsave 883 if not os.path.exists(os.path.dirname(path)): 884 os.makedirs(os.path.dirname(path)) # ensure output directory exists 885 imsave(path, np.transpose( np.clip( img*255, 0, 255).astype(np.uint8), (1, 0, 2))) # save the image 886 887 # plot samples? 888 ps = kwds.pop('ps', 5) 889 pc = kwds.pop('pc', 'r') 890 if samples: 891 if isinstance(samples, list) or isinstance(samples, np.ndarray): 892 ax.scatter([s[0] for s in samples], [s[1] for s in samples], s=ps, c=pc) 893 else: 894 for n in self.header.get_class_names(): 895 points = np.array(self.header.get_sample_points(n)) 896 ax.scatter(points[:, 0], points[:, 1], s=ps, c=pc) 897 898 #plot 899 ax.imshow(np.transpose(img, (1,0,2)), interpolation=kwds.pop('interpolation', 'none'), **kwds) 900 ax.cbar = None # no colorbar 901 902 return ax.get_figure(), ax 903 904 def createGIF(self, path, bands=None, figsize=(10,10), fps=10, **kwds): 905 """ 906 Create and save an animated gif that loops through the bands of the image. 907 908 Args: 909 path (str): the path to save the .gif 910 bands (tuple): Tuple containing the range of band indices to draw. Default is the whole range. 911 figsize (tuple): the size of the image to draw. Default is (10,10). 912 fps (int): the framerate (frames per second) of the gif. Default is 10. 913 **kwds: keywords are passed directly to matplotlib.imshow. Use this to specify cmap etc. 914 """ 915 916 frames = [] 917 if bands is None: 918 bands = (0,self.band_count()) 919 else: 920 assert 0 < bands[0] < self.band_count(), "Error - invalid range." 921 assert 0 < bands[1] < self.band_count(), "Error - invalid range." 922 assert bands[1] > bands[0], "Error - invalid range." 923 924 #plot frames 925 for i in range(bands[0],bands[1]): 926 fig, ax = plt.subplots(figsize=figsize) 927 ax.imshow(self.data[:, :, i], **kwds) 928 fig.canvas.draw() 929 frames.append(np.frombuffer(fig.canvas.tostring_rgb(), dtype='uint8')) 930 frames[-1] = np.reshape(frames[-1], (fig.canvas.get_width_height()[1], fig.canvas.get_width_height()[0], 3)) 931 plt.close(fig) 932 933 #save gif 934 imageio.mimsave( os.path.splitext(path)[0] + ".gif", frames, fps=fps) 935 936 ## masking 937 def drop_bbl(self, drop=True): 938 """ 939 Remove bad bands as stored in the 'bbl' key in the image header. Note that this operates in-place. 940 941 Args: 942 drop (bool): True if bad bands should be completely dropped. If False, these bands will be kept but replaced with nans. 943 """ 944 assert 'bbl' in self.header, "Please specify a bad band list ('bbl') in the image header, as per the ENVI format definition." 945 mask = self.header.get_list('bbl') == 0 946 self.data[...,mask] = np.nan 947 if drop: 948 self.delete_nan_bands(inplace=True) 949 950 def mask(self, mask=None, flag=np.nan, invert=False, crop=False, bands=None): 951 """ 952 Apply a mask to an image, flagging masked pixels with the specified value. Note that this applies the mask to the 953 image in-situ. 954 955 Args: 956 flag (float): the value to use for masked pixels. Default is np.nan 957 mask (ndarray): a numpy array defining the mask polygon of the format [[x1,y1],[x2,y2],...]. If None is passed then 958 pickPolygon( ... ) is used to interactively define a polygon. If a file path is passed then the polygon 959 will be loaded using np.load( ... ). Alternatively if mask.shape == image.shape[0,1] then it is treated as a 960 binary image mask (must be boolean) and True values will be masked across all bands. Default is None. 961 invert (bool): if True, pixels within the polygon will be masked. If False, pixels outside the polygon are masked. Default is False. 962 crop (bool): True if rows/columns containing only zeros should be removed. Default is False. 963 bands (tuple): the bands of the image to plot if no mask is specified. If None, the middle band is used. 964 965 Returns: 966 Tuple containing 967 968 - mask (ndarray): a boolean array with True where pixels are masked and False elsewhere. 969 - poly (ndarray): the mask polygon array in the format described above. Useful if the polygon was interactively defined. 970 """ 971 972 if mask is None: # pick mask interactively 973 if bands is None: 974 bands = int(self.band_count() / 2) 975 976 regions = self.pickPolygons(region_names=["mask"], bands=bands) 977 978 # the user bailed without picking a mask? 979 if len(regions) == 0: 980 print("Warning - no mask picked/applied.") 981 return 982 983 # extract polygon mask 984 mask = regions[0] 985 986 # convert polygon mask to binary mask 987 if mask.shape[1] == 2: 988 989 # build meshgrid with pixel coords 990 xx, yy = np.meshgrid(np.arange(self.xdim()), np.arange(self.ydim())) 991 xx = xx.flatten() 992 yy = yy.flatten() 993 points = np.vstack([xx, yy]).T # coordinates of each pixel 994 995 # calculate per-pixel mask 996 mask = path.Path(mask).contains_points(points) 997 mask = mask.reshape((self.ydim(), self.xdim())).T 998 999 # flip as we want to mask (==True) outside points (unless invert is true) 1000 if not invert: 1001 mask = np.logical_not(mask) 1002 1003 # apply binary image mask 1004 assert mask.shape[0] == self.data.shape[0] and mask.shape[1] == self.data.shape[1], \ 1005 "Error - mask shape %s does not match image shape %s" % (mask.shape, self.data.shape) 1006 for b in range(self.band_count()): 1007 self.data[:, :, b][mask] = flag 1008 1009 # crop image 1010 if crop: 1011 self.crop_to_data() 1012 1013 return mask 1014 1015 def crop_to_data(self): 1016 """ 1017 Remove padding of nan or zero pixels from image. Note that this is performed in place. 1018 """ 1019 valid = np.isfinite(self.data).any(axis=-1) & (self.data != 0).any(axis=-1) 1020 1021 # integrate along axes 1022 xdata = np.sum(valid, axis=1) > 0.0 1023 ydata = np.sum(valid, axis=0) > 0.0 1024 1025 # calculate domain containing valid pixels 1026 xmin = np.argmax(xdata) 1027 xmax = xdata.shape[0] - np.argmax(xdata[::-1]) 1028 ymin = np.argmax(ydata) 1029 ymax = ydata.shape[0] - np.argmax(ydata[::-1]) 1030 1031 # crop 1032 self.data = self.data[xmin:xmax, ymin:ymax, :] 1033 1034 # shift affine origin to new top-left pixel 1035 if self.affine is not None: 1036 a = self.affine # shorthand for affine 1037 new_affine = list(self.affine) 1038 new_affine[0] = a[0] + xmin*a[1] + ymin*a[2] 1039 new_affine[3] = a[3] + xmin*a[4] + ymin*a[5] 1040 self.affine = np.array(new_affine) 1041 self.header['affine'] = self.affine 1042 1043 ################################################## 1044 ## Interactive tools for picking regions/pixels 1045 ################################################## 1046 def pickPolygons(self, region_names, bands=0): 1047 """ 1048 Creates a matplotlib gui for selecting polygon regions in an image. 1049 1050 Args: 1051 names (list, str): a list containing the names of the regions to pick. If a string is passed only one name is used. 1052 bands (tuple): the bands of the image to plot. 1053 """ 1054 1055 if isinstance(region_names, str): 1056 region_names = [region_names] 1057 1058 assert isinstance(region_names, list), "Error - names must be a list or a string." 1059 1060 # set matplotlib backend 1061 backend = matplotlib.get_backend() 1062 matplotlib.use('Qt5Agg') # need this backend for ROIPoly to work 1063 1064 # plot image and extract roi's 1065 fig, ax = self.quick_plot(bands) 1066 roi = MultiRoi(roi_names=region_names) 1067 plt.close(fig) # close figure 1068 1069 # extract regions 1070 regions = [] 1071 for name, r in roi.rois.items(): 1072 # store region 1073 x = r.x 1074 y = r.y 1075 regions.append(np.vstack([x, y]).T) 1076 1077 # restore matplotlib backend (if possible) 1078 try: 1079 matplotlib.use(backend) 1080 except: 1081 print("Warning: could not reset matplotlib backend. Plots will remain interactive...") 1082 pass 1083 1084 return regions 1085 1086 def pickPoints(self, n=-1, bands=hylite.RGB, integer=True, title="Pick Points", **kwds): 1087 """ 1088 Creates a matplotlib gui for picking pixels from an image. 1089 1090 Args: 1091 n (int): the number of pixels to pick, or -1 if the user can select as many as they wish. Default is -1. 1092 bands (tuple): the bands of the image to plot. Default is HyImage.RGB 1093 integer (bool): True if points coordinates should be cast to integers (for use as indices). Default is True. 1094 title (str): The title of the point picking window. 1095 **kwds: Keywords are passed to HyImage.quick_plot( ... ). 1096 1097 Returns: 1098 A list containing the picked point coordinates [ (x1,y1), (x2,y2), ... ]. 1099 """ 1100 1101 # set matplotlib backend 1102 backend = matplotlib.get_backend() 1103 matplotlib.use('Qt5Agg') # need this backend for ROIPoly to work 1104 1105 # create figure 1106 fig, ax = self.quick_plot( bands, **kwds ) 1107 ax.set_title(title) 1108 1109 # get points 1110 points = fig.ginput( n ) 1111 1112 if integer: 1113 points = [ (int(p[0]), int(p[1])) for p in points ] 1114 1115 # restore matplotlib backend (if possible) 1116 try: 1117 matplotlib.use(backend) 1118 except: 1119 print("Warning: could not reset matplotlib backend. Plots will remain interactive...") 1120 pass 1121 1122 return points 1123 1124 def pickSamples(self, names=None, store=True, **kwds): 1125 """ 1126 Pick sample probe points and store these in the image header file. 1127 1128 Args: 1129 names (str, list): the name of the sample to pick, or a list of names to pick multiple. 1130 store (bool): True if sample should be stored in the image header file (for later access). Default is True. 1131 **kwds: Keywords are passed to HyImage.quick_plot( ... ) 1132 1133 Returns: 1134 a list containing a list of points for each sample. 1135 """ 1136 1137 if isinstance(names, str): 1138 names = [names] 1139 1140 # pick points 1141 points = [] 1142 for s in names: 1143 pnts = self.pickPoints(title="%s" % s, **kwds) 1144 if store: 1145 self.header['sample %s' % s] = pnts # store in header 1146 points.append(pnts) 1147 # add class to header file 1148 if store: 1149 cls_names = self.header.get_class_names() 1150 if cls_names is None: 1151 cls_names = [] 1152 self.header['class names'] = cls_names + names 1153 1154 return points
A class for hyperspectral image data. These can be individual scenes or hyperspectral orthoimages.
25 def __init__(self, data, **kwds): 26 """ 27 Args: 28 data (ndarray): a numpy array such that data[x][y][band] gives each pixel value. 29 **kwds: 30 wav = A numpy array containing band wavelengths for this image. 31 affine = an affine transform of the format returned by GDAL.GetGeoTransform(). 32 projection = string defining the project. Default is None. 33 sensor = sensor name. Default is "unknown". 34 header = path to associated header file. Default is None. 35 """ 36 37 #call constructor for HyData 38 super().__init__(data, **kwds) 39 40 # special case - if dataset only has oneband, slice it so it still has 41 # the format data[x,y,b]. 42 if not self.data is None: 43 if len(self.data.shape) == 1: 44 self.data = self.data[None, None, :] # single pixel image 45 if len(self.data.shape) == 2: 46 self.data = self.data[:, :, None] # single band iamge 47 48 #load any additional project information (specific to images) 49 self.set_projection(kwds.get("projection",None)) 50 self.affine = np.array( kwds.get("affine",[0,1,0,0,0,1]) ) 51 self.header['affine'] = kwds.get("affine",[0,1,0,0,0,1]) # also store this here 52 53 # wavelengths 54 if 'wav' in kwds: 55 self.set_wavelengths(kwds['wav']) 56 57 #special header formatting 58 self.header['file type'] = 'ENVI Standard'
Arguments:
- data (ndarray): a numpy array such that data[x][y][band] gives each pixel value.
- **kwds: wav = A numpy array containing band wavelengths for this image. affine = an affine transform of the format returned by GDAL.GetGeoTransform(). projection = string defining the project. Default is None. sensor = sensor name. Default is "unknown". header = path to associated header file. Default is None.
60 def copy(self,data=True): 61 """ 62 Make a deep copy of this image instance. 63 64 Args: 65 data (bool): True if a copy of the data should be made, otherwise only copy header. 66 67 Returns: 68 a new HyImage instance. 69 """ 70 if not data: 71 return HyImage(None, header=self.header.copy(), projection=self.projection, affine=self.affine) 72 else: 73 return HyImage( self.data.copy(), header=self.header.copy(), projection=self.projection, affine=self.affine)
Make a deep copy of this image instance.
Arguments:
- data (bool): True if a copy of the data should be made, otherwise only copy header.
Returns:
a new HyImage instance.
75 def T(self): 76 """ 77 Return a transposed view of the data matrix (corresponding with the [y,x] indexing used by matplotlib, opencv etc. 78 """ 79 return np.transpose(self.data, (1,0,2))
Return a transposed view of the data matrix (corresponding with the [y,x] indexing used by matplotlib, opencv etc.
81 def xdim(self): 82 """ 83 Return number of pixels in x (first dimension of data array) 84 """ 85 return self.data.shape[0]
Return number of pixels in x (first dimension of data array)
87 def ydim(self): 88 """ 89 Return number of pixels in y (second dimension of data array) 90 """ 91 return self.data.shape[1]
Return number of pixels in y (second dimension of data array)
93 def aspx(self): 94 """ 95 Return the aspect ratio of this image (width/height). 96 """ 97 return self.ydim() / self.xdim()
Return the aspect ratio of this image (width/height).
103 def get_extent(self): 104 """ 105 Returns the width and height of this image in world coordinates. 106 107 Returns: 108 tuple with (width, height). 109 """ 110 return self.xdim * self.pixel_size[0], self.ydim * self.pixel_size[1]
Returns the width and height of this image in world coordinates.
Returns:
tuple with (width, height).
112 def set_projection(self,proj): 113 """ 114 Set this project to an existing osgeo.osr.SpatialReference or GDAL georeference string. 115 116 Args: 117 proj (str, osgeo.osr.SpatialReference): the project to use as osgeo.osr.SpatialReference or GDAL georeference string. 118 """ 119 if proj is None: 120 self.projection = None 121 else: 122 try: 123 from osgeo.osr import SpatialReference 124 except: 125 assert False, "Error - GDAL must be installed to work with spatial projections in hylite." 126 if isinstance(proj, SpatialReference): 127 self.projection = proj 128 elif isinstance(proj, str): 129 self.projection = SpatialReference(proj) 130 else: 131 print("Invalid project %s" % proj) 132 raise
Set this project to an existing osgeo.osr.SpatialReference or GDAL georeference string.
Arguments:
- proj (str, osgeo.osr.SpatialReference): the project to use as osgeo.osr.SpatialReference or GDAL georeference string.
134 def set_projection_EPSG(self,EPSG): 135 """ 136 Sets this image project using an EPSG code. 137 138 Args: 139 EPSG (str): string EPSG code that can be passed to SpatialReference.SetFromUserInput(...). 140 """ 141 142 try: 143 from osgeo.osr import SpatialReference 144 except: 145 assert False, "Error - GDAL must be installed to work with spatial projections in hylite." 146 147 self.projection = SpatialReference() 148 self.projection.SetFromUserInput(EPSG)
Sets this image project using an EPSG code.
Arguments:
- EPSG (str): string EPSG code that can be passed to SpatialReference.SetFromUserInput(...).
150 def get_projection_EPSG(self): 151 """ 152 Gets a string describing this projections EPSG code (if it is an EPSG project). 153 154 Returns: 155 an EPSG code string of the format "EPSG:XXXX". 156 """ 157 if self.projection is None: 158 return None 159 else: 160 return "%s:%s" % (self.projection.GetAttrValue("AUTHORITY",0),self.projection.GetAttrValue("AUTHORITY",1))
Gets a string describing this projections EPSG code (if it is an EPSG project).
Returns:
an EPSG code string of the format "EPSG:XXXX".
162 def pix_to_world(self, px, py, proj=None): 163 """ 164 Take pixel coordinates and return world coordinates 165 166 Args: 167 px (int): the pixel x-coord. 168 py (int): the pixel y-coord. 169 proj (str, osr.SpatialReference): the coordinate system to use. Default (None) uses the same system as this image. Otherwise 170 an osr.SpatialReference can be passed (HyImage.project), or an EPSG string (e.g. get_projection_EPSG(...)). 171 Returns: 172 the world coordinates in the coordinate system defined by get_projection_EPSG(...). 173 """ 174 175 try: 176 from osgeo import osr 177 import osgeo.gdal as gdal 178 from osgeo import ogr 179 except: 180 assert False, "Error - GDAL must be installed to work with spatial projections in hylite." 181 182 # parse project 183 if proj is None: 184 proj = self.projection 185 elif isinstance(proj, str) or isinstance(proj, int): 186 epsg = proj 187 if isinstance(epsg, str): 188 try: 189 epsg = int(str.split(':')[1]) 190 except: 191 assert False, "Error - %s is an invalid EPSG code." % proj 192 proj = osr.SpatialReference() 193 proj.ImportFromEPSG(epsg) 194 195 # check we have all the required info 196 assert isinstance(proj, osr.SpatialReference), "Error - invalid spatial reference %s" % proj 197 assert (not self.affine is None) and ( 198 not self.projection is None), "Error - project information is undefined." 199 200 #project to world coordinates in this images project/world coords 201 x,y = gdal.ApplyGeoTransform(self.affine, px, py) 202 203 #project to target coords (if different) 204 if not proj.IsSameGeogCS(self.projection): 205 P = ogr.Geometry(ogr.wkbPoint) 206 if proj.EPSGTreatsAsNorthingEasting(): 207 P.AddPoint(x, y) 208 else: 209 P.AddPoint(y, x) 210 P.AssignSpatialReference(self.projection) # tell the point what coordinates it's in 211 P.TransformTo(proj) # reproject it to the out spatial reference 212 x, y = P.GetX(), P.GetY() 213 214 #do we need to transpose? 215 if proj.EPSGTreatsAsLatLong(): 216 x,y=y,x #we want lon,lat not lat,lon 217 return x, y
Take pixel coordinates and return world coordinates
Arguments:
- px (int): the pixel x-coord.
- py (int): the pixel y-coord.
- proj (str, osr.SpatialReference): the coordinate system to use. Default (None) uses the same system as this image. Otherwise an osr.SpatialReference can be passed (HyImage.project), or an EPSG string (e.g. get_projection_EPSG(...)).
Returns:
the world coordinates in the coordinate system defined by get_projection_EPSG(...).
219 def world_to_pix(self, x, y, proj = None): 220 """ 221 Take world coordinates and return pixel coordinates 222 223 Args: 224 x (float): the world x-coord. 225 y (float): the world y-coord. 226 proj (str, osr.SpatialReference): the coordinate system of the input coordinates. Default (None) uses the same system as this image. Otherwise 227 an osr.SpatialReference can be passed (HyImage.project), or an EPSG string (e.g. get_projection_EPSG(...)). 228 229 Returns: 230 the pixel coordinates based on the affine transform stored in self.affine. 231 """ 232 233 try: 234 from osgeo import osr 235 import osgeo.gdal as gdal 236 from osgeo import ogr 237 except: 238 assert False, "Error - GDAL must be installed to work with spatial projections in hylite." 239 240 # parse project 241 if proj is None: 242 proj = self.projection 243 elif isinstance(proj, str) or isinstance(proj, int): 244 epsg = proj 245 if isinstance(epsg, str): 246 try: 247 epsg = int(str.split(':')[1]) 248 except: 249 assert False, "Error - %s is an invalid EPSG code." % proj 250 proj = osr.SpatialReference() 251 proj.ImportFromEPSG(epsg) 252 253 254 # check we have all the required info 255 assert isinstance(proj, osr.SpatialReference), "Error - invalid spatial reference %s" % proj 256 assert (not self.affine is None) and (not self.projection is None), "Error - project information is undefined." 257 258 # project to this images CS (if different) 259 if not proj.IsSameGeogCS(self.projection): 260 P = ogr.Geometry(ogr.wkbPoint) 261 if proj.EPSGTreatsAsNorthingEasting(): 262 P.AddPoint(x, y) 263 else: 264 P.AddPoint(y, x) 265 P.AssignSpatialReference(proj) # tell the point what coordinates it's in 266 P.AddPoint(x, y) 267 P.TransformTo(self.projection) # reproject it to the out spatial reference 268 x, y = P.GetX(), P.GetY() 269 if self.projection.EPSGTreatsAsLatLong(): # do we need to transpose? 270 x, y = y, x # we want lon,lat not lat,lon 271 272 inv = gdal.InvGeoTransform(self.affine) 273 assert not inv is None, "Error - could not invert affine transform?" 274 275 #apply 276 return gdal.ApplyGeoTransform(inv, x, y)
Take world coordinates and return pixel coordinates
Arguments:
- x (float): the world x-coord.
- y (float): the world y-coord.
- proj (str, osr.SpatialReference): the coordinate system of the input coordinates. Default (None) uses the same system as this image. Otherwise an osr.SpatialReference can be passed (HyImage.project), or an EPSG string (e.g. get_projection_EPSG(...)).
Returns:
the pixel coordinates based on the affine transform stored in self.affine.
278 def crop(self, xmin, xmax, ymin, ymax, bands=None): 279 """ 280 Return a cropped copy of this image. 281 282 Args: 283 xmin, xmax (int): pixel bounds in x (rows) 284 ymin, ymax (int): pixel bounds in y (columns) 285 bands (None, list, tuple): optional band indices or (min,max) range 286 287 Returns: 288 HyImage: cropped image with updated affine transform 289 """ 290 291 # ---- validate bounds ---- 292 xmin = int(max(0, xmin)) 293 ymin = int(max(0, ymin)) 294 xmax = int(min(self.xdim(), xmax)) 295 ymax = int(min(self.ydim(), ymax)) 296 297 assert xmin < xmax and ymin < ymax, "Invalid crop extent." 298 299 # ---- crop data ---- 300 if bands is None: 301 data = self.data[xmin:xmax, ymin:ymax, :].copy() 302 wav = self.get_wavelengths() 303 else: # band selection 304 if isinstance(bands, tuple): 305 b0 = self.get_band_index(bands[0]) 306 b1 = self.get_band_index(bands[1]) 307 data = self.data[xmin:xmax, ymin:ymax, b0:b1].copy() 308 wav = self.get_wavelengths()[b0:b1] 309 else: 310 idx = [self.get_band_index(b) for b in bands] 311 data = self.data[xmin:xmax, ymin:ymax, idx].copy() 312 wav = self.get_wavelengths()[idx] 313 314 # ---- update affine transform ---- 315 if self.affine is not None: 316 a = list(self.affine) 317 new_affine = a.copy() 318 319 # shift origin to new top-left pixel 320 new_affine[0] = a[0] + xmin*a[1] + ymin*a[2] 321 new_affine[3] = a[3] + xmin*a[4] + ymin*a[5] 322 else: 323 new_affine = None 324 325 # ---- construct output image ---- 326 out = HyImage( 327 data, 328 header=self.header.copy(), 329 projection=self.projection, 330 affine=new_affine, 331 wav=wav 332 ) 333 334 return out
Return a cropped copy of this image.
Arguments:
- xmin, xmax (int): pixel bounds in x (rows)
- ymin, ymax (int): pixel bounds in y (columns)
- bands (None, list, tuple): optional band indices or (min,max) range
Returns:
HyImage: cropped image with updated affine transform
336 def resize(self, newdims: tuple, interpolation: int = 1): 337 """ 338 Resize this image with opencv and update affine transform accordingly. 339 340 Args: 341 newdims (tuple): the new image dimensions (xdim, ydim) 342 interpolation (int): opencv interpolation method. Default is cv2.INTER_LINEAR. 343 """ 344 import cv2 # avoid import issues if opencv is missing 345 346 old_x, old_y = self.xdim(), self.ydim() 347 new_x, new_y = int(newdims[0]), int(newdims[1]) 348 349 assert new_x > 0 and new_y > 0, "Invalid resize dimensions." 350 351 # resize data (opencv uses width, height = y, x) 352 self.data = cv2.resize( 353 self.data, 354 (new_y, new_x), 355 interpolation=interpolation 356 ) 357 358 # update affine transform 359 if self.affine is not None: 360 a = list(self.affine) 361 362 sx = old_x / new_x 363 sy = old_y / new_y 364 365 self.affine = [ 366 a[0], # x origin unchanged 367 a[1] * sx, # pixel width 368 a[2] * sy, # row rotation 369 a[3], # y origin unchanged 370 a[4] * sx, # column rotation 371 a[5] * sy # pixel height 372 ]
Resize this image with opencv and update affine transform accordingly.
Arguments:
- newdims (tuple): the new image dimensions (xdim, ydim)
- interpolation (int): opencv interpolation method. Default is cv2.INTER_LINEAR.
374 def tile(self, tile_size): 375 """ 376 Break image into tiles of given size and return a list of HyImage tiles. 377 Each tile has an updated affine transform reflecting its position in the original image. 378 379 Args: 380 tile_size (tuple): (tile_x, tile_y) in pixels 381 Returns: 382 list of HyImage 383 """ 384 tiles = [] 385 tx, ty = tile_size 386 nx, ny = self.xdim(), self.ydim() 387 for i in range(0, nx, tx): 388 for j in range(0, ny, ty): 389 data_tile = self.data[i:min(i+tx, nx), j:min(j+ty, ny), :].copy() # slice data 390 391 # compute new affine 392 # affine = [x0, dx, rotx, y0, roty, dy] 393 x0, dx, rotx, y0, roty, dy = self.affine 394 new_x0 = x0 + i*dx + j*rotx 395 new_y0 = y0 + i*roty + j*dy 396 new_affine = [new_x0, dx, rotx, new_y0, roty, dy] 397 398 # create tile 399 tile_img = HyImage( 400 data_tile, 401 affine=new_affine, 402 projection=self.projection, 403 wav=self.get_wavelengths(), 404 header=self.header.copy() 405 ) 406 tile_img.header['xleft'] = i 407 tile_img.header['ytop'] = j 408 tiles.append(tile_img) 409 410 return tiles
Break image into tiles of given size and return a list of HyImage tiles. Each tile has an updated affine transform reflecting its position in the original image.
Arguments:
- tile_size (tuple): (tile_x, tile_y) in pixels
Returns:
list of HyImage
412 @staticmethod 413 def mosaic( 414 tiles, 415 blend="mean", 416 resampling="nearest", 417 out_affine=None, 418 out_shape=None, 419 ): 420 """ 421 Mosaic georeferenced HyImage tiles using GDAL. Note that this assumes all tiles are in the same coordinate system. 422 423 Args: 424 tiles (list[HyImage]) 425 blend (str): 'first', 'min', 'max', 'mean', 'median' 426 resampling (str): 'nearest', 'bilinear', 'cubic' 427 out_affine (list): optional 6-element affine to define output grid. If None, the affine of the first tile is used. 428 out_shape (tuple): optional (xdim, ydim) shape of the output grid. If None, the extent of all tiles will be used. 429 Returns: 430 HyImage 431 """ 432 import numpy as np 433 from hylite.project.align import resample_raster 434 from osgeo import gdal, osr 435 436 assert len(tiles) > 0 437 assert blend in ("first", "min", "max", "mean", "median") 438 439 # compute bounds in world coordinates 440 # N.B. THIS ASSUMES ALL DATA ARE IN THE SAME CRS 441 points = [] 442 for t in tiles: 443 points.append( t.pix_to_world(0,0) ) 444 points.append( t.pix_to_world(t.xdim()+1,t.ydim()+1) ) 445 min_x, min_y = np.min(points, axis=0) 446 max_x, max_y = np.max(points, axis=0) 447 if out_shape is None: 448 out_shape = (np.array(tiles[0].world_to_pix(max_x, max_y)) + np.array(tiles[0].world_to_pix(min_x, min_y))).round().astype(int) 449 if out_affine is None: 450 out_affine = list(tiles[0].affine) 451 out_affine[0] = min_x 452 out_affine[3] = max_y 453 454 if blend == "first": # fill output, first come, first served. 455 out = np.full( tuple(out_shape) + (tiles[0].band_count(),), np.nan, dtype=np.float32) 456 for t in tiles: 457 r = resample_raster( t.data, t.affine, out_affine, out_shape ) 458 mask = np.isnan(out) 459 out[mask] = r[mask] 460 elif blend == "min": # keep minimum value in case of overlap 461 out = np.nanmin( np.stack([ resample_raster( t.data, t.affine, out_affine, out_shape ) for t in tiles ], 462 axis=0), axis=0 ) 463 elif blend == "max": # keep maximum value in case of overlap 464 out = np.nanmax( np.stack([ resample_raster( t.data, t.affine, out_affine, out_shape ) for t in tiles ], 465 axis=0), axis=0 ) 466 elif blend == "mean": # use average in case of overlap 467 out = np.nanmean( np.stack([ resample_raster( t.data, t.affine, out_affine, out_shape ) for t in tiles ], 468 axis=0), axis=0 ) 469 elif blend == "median": # use mean in case of overlap 470 out = np.nanmedian( np.stack([ resample_raster( t.data, t.affine, out_affine, out_shape ) for t in tiles ], 471 axis=0), axis=0 ) 472 473 # Return HyImage 474 out = HyImage( 475 out, 476 affine=out_affine, 477 projection=tiles[0].projection, 478 wav=tiles[0].get_wavelengths(), 479 header=tiles[0].header.copy() 480 ) 481 if 'xleft' in out.header: del out.header['xleft'] # stored in tiles, but not meaningful here 482 if 'ytop' in out.header: del out.header['ytop'] # stored in tiles, but not meaningful here 483 484 return out
Mosaic georeferenced HyImage tiles using GDAL. Note that this assumes all tiles are in the same coordinate system.
Arguments:
- tiles (list[HyImage])
- blend (str): 'first', 'min', 'max', 'mean', 'median'
- resampling (str): 'nearest', 'bilinear', 'cubic'
- out_affine (list): optional 6-element affine to define output grid. If None, the affine of the first tile is used.
- out_shape (tuple): optional (xdim, ydim) shape of the output grid. If None, the extent of all tiles will be used.
Returns:
HyImage
490 def flip(self, axis='x'): 491 """ 492 Flip the image on the x or y axis. Note that this will remove any defined affine transform. 493 494 Args: 495 axis (str): 'x' or 'y' or both 'xy'. 496 """ 497 498 if 'x' in axis.lower(): 499 self.data = np.flip(self.data,axis=0) 500 if 'y' in axis.lower(): 501 self.data = np.flip(self.data,axis=1) 502 self.affine = None 503 if 'affine' in self.header: del self.header['affine'] 504 self.push_to_header() # update width and height info
Flip the image on the x or y axis. Note that this will remove any defined affine transform.
Arguments:
- axis (str): 'x' or 'y' or both 'xy'.
506 def rot90(self): 507 """ 508 Rotate this image by 90 degrees by transposing the underlying data array. Combine with flip('x') or flip('y') 509 to achieve positive/negative rotations. 510 """ 511 self.data = np.transpose( self.data, (1,0,2) ) 512 self.affine = None 513 if 'affine' in self.header: del self.header['affine'] 514 self.push_to_header() # update width and height info
Rotate this image by 90 degrees by transposing the underlying data array. Combine with flip('x') or flip('y') to achieve positive/negative rotations.
519 def fill_holes(self): 520 """ 521 Replaces nan pixel with an average of their neighbours, thus removing 1-pixel large holes from an image. Note that 522 for performance reasons this assumes that holes line up across bands. Note that this is not vectorized so very slow... 523 """ 524 525 # perform greyscale dilation 526 dilate = self.data.copy() 527 mask = np.logical_not(np.isfinite(dilate)) 528 dilate[mask] = 0 529 for b in range(self.band_count()): 530 dilate[:, :, b] = sp.ndimage.grey_dilation(dilate[:, :, b], size=(3, 3)) 531 532 # map back to holes in dataset 533 self.data[mask] = dilate[mask] 534 #self.data[self.data == 0] = np.nan # replace remaining 0's with nans
Replaces nan pixel with an average of their neighbours, thus removing 1-pixel large holes from an image. Note that for performance reasons this assumes that holes line up across bands. Note that this is not vectorized so very slow...
536 def blur(self, n=3): 537 """ 538 Applies a gaussian kernel of size n to the image using OpenCV. 539 540 Args: 541 n (int): the dimensions of the gaussian kernel to convolve. Default is 3. Increase for more blurry results. 542 """ 543 import cv2 # import this here to avoid errors if opencv is not installed properly 544 545 nanmask = np.isnan(self.data) 546 assert isinstance(n, int) and n >= 3, "Error - invalid kernel. N must be an integer > 3. " 547 kernel = np.ones((n, n), np.float32) / (n ** 2) 548 self.data = cv2.filter2D(self.data, -1, kernel) 549 self.data[nanmask] = np.nan # remove mask
Applies a gaussian kernel of size n to the image using OpenCV.
Arguments:
- n (int): the dimensions of the gaussian kernel to convolve. Default is 3. Increase for more blurry results.
551 def erode(self, size=3, iterations=1): 552 """ 553 Apply an erode filter to this image to expand background (nan) pixels. Refer to open-cv's erode 554 function for more details. 555 556 Args: 557 size (int): the size of the erode filter. Default is a 3x3 kernel. 558 iterations (int): the number of erode iterations. Default is 1. 559 """ 560 import cv2 # import this here to avoid errors if opencv is not installed properly 561 562 # erode 563 kernel = np.ones((size, size), np.uint8) 564 if self.is_float(): 565 mask = np.isfinite(self.data).any(axis=-1) 566 mask = cv2.erode(mask.astype(np.uint8), kernel, iterations=iterations) 567 self.data[mask == 0, :] = np.nan 568 else: 569 mask = (self.data != 0).any( axis=-1 ) 570 mask = cv2.erode(mask.astype(np.uint8), kernel, iterations=iterations) 571 self.data[mask == 0, :] = 0
Apply an erode filter to this image to expand background (nan) pixels. Refer to open-cv's erode function for more details.
Arguments:
- size (int): the size of the erode filter. Default is a 3x3 kernel.
- iterations (int): the number of erode iterations. Default is 1.
573 def despeckle(self, size=5): 574 """ 575 Despeckle each band of this image (independently) using a median filter. 576 577 Args: 578 size (int): the size of the median filter kernel. Default is 5. Must be an odd number. 579 """ 580 581 assert (size % 2) == 1, "Error - size must be an odd integer" 582 import cv2 # import this here to avoid errors if opencv is not installed properly 583 if self.is_float(): 584 self.data = cv2.medianBlur( self.data.astype(np.float32), size ) 585 else: 586 self.data = cv2.medianBlur( self.data, size )
Despeckle each band of this image (independently) using a median filter.
Arguments:
- size (int): the size of the median filter kernel. Default is 5. Must be an odd number.
591 def get_keypoints(self, band, eq=False, mask=True, method='sift', cfac=0.0,bfac=0.0, **kwds): 592 """ 593 Get feature descriptors from the specified band. 594 595 Args: 596 band (int,float,str,tuple): the band index (int) or wavelength (float) to extract features from. Alternatively, a tuple can be passed 597 containing a range of bands (min : max) to average before feature matching. 598 eq (bool): True if the image should be histogram equalized first. Default is False. 599 mask (bool): True if 0 value pixels should be masked. Default is True. 600 method (str): the feature detector to use. Options are 'SIFT' and 'ORB' (faster but less accurate). Default is 'SIFT'. 601 cfac (float): contrast adjustment to apply to hyperspectral bands before matching. Default is 0.0. 602 bfac (float): brightness adjustment to apply to hyperspectral bands before matching. Default is 0.0. 603 **kwds: keyword arguments are passed to the opencv feature detector. For SIFT these are: 604 605 - contrastThreshold: default is 0.01. 606 - edgeThreshold: default is 10. 607 - sigma: default is 1.0 608 609 For ORB these are: 610 611 - nfeatures = the number of features to detect. Default is 5000. 612 613 Returns: 614 Tuple containing 615 616 - k (ndarray): the keypoints detected 617 - d (ndarray): corresponding feature descriptors 618 """ 619 import cv2 # import this here to avoid errors if opencv is not installed properly 620 621 # get image 622 if isinstance(band, int) or isinstance(band, float): #single band 623 image = self.data[:, :, self.get_band_index(band)] 624 elif isinstance(band,tuple): #range of bands (averaged) 625 idx0 = self.get_band_index(band[0]) 626 idx1 = self.get_band_index(band[1]) 627 628 #deal with out of range errors 629 if idx0 is None: 630 idx0 = 0 631 if idx1 is None: 632 idx1 = self.band_count() 633 634 #average bands 635 image = np.nanmean(self.data[:,:,idx0:idx1],axis=2) 636 else: 637 assert False, "Error, unrecognised band %s" % band 638 639 #normalise image to range 0 - 1 640 image -= np.nanmin(image) 641 image = image / np.nanmax(image) 642 643 #apply brightness/contrast adjustment 644 image = (1.0+cfac)*image + bfac 645 image[image > 1.0] = 1.0 646 image[image < 0.0] = 0.0 647 648 #convert image to uint8 for opencv 649 image = np.uint8(255 * image) 650 if eq: 651 image = cv2.equalizeHist(image) 652 653 if mask: 654 mask = np.zeros(image.shape, dtype=np.uint8) 655 mask[image != 0] = 255 # include only non-zero pixels 656 else: 657 mask = None 658 659 if 'sift' in method.lower(): # SIFT 660 661 # setup default keywords 662 kwds["contrastThreshold"] = kwds.get("contrastThreshold", 0.01) 663 kwds["edgeThreshold"] = kwds.get("edgeThreshold", 10) 664 kwds["sigma"] = kwds.get("sigma", 1.0) 665 666 # make feature detector 667 #alg = cv2.xfeatures2d.SIFT_create(**kwds) 668 alg = cv2.SIFT_create() 669 elif 'orb' in method.lower(): # orb 670 kwds['nfeatures'] = kwds.get('nfeatures', 5000) 671 alg = cv2.ORB_create(scoreType=cv2.ORB_FAST_SCORE, **kwds) 672 else: 673 assert False, "Error - %s is not a recognised feature detector." % method 674 675 # detect keypoints 676 kp = alg.detect(image, mask) 677 678 # extract and return feature vectors 679 return alg.compute(image, kp)
Get feature descriptors from the specified band.
Arguments:
- band (int,float,str,tuple): the band index (int) or wavelength (float) to extract features from. Alternatively, a tuple can be passed containing a range of bands (min : max) to average before feature matching.
- eq (bool): True if the image should be histogram equalized first. Default is False.
- mask (bool): True if 0 value pixels should be masked. Default is True.
- method (str): the feature detector to use. Options are 'SIFT' and 'ORB' (faster but less accurate). Default is 'SIFT'.
- cfac (float): contrast adjustment to apply to hyperspectral bands before matching. Default is 0.0.
- bfac (float): brightness adjustment to apply to hyperspectral bands before matching. Default is 0.0.
**kwds: keyword arguments are passed to the opencv feature detector. For SIFT these are:
- contrastThreshold: default is 0.01.
- edgeThreshold: default is 10.
- sigma: default is 1.0
For ORB these are:
- nfeatures = the number of features to detect. Default is 5000.
Returns: Tuple containing
- k (ndarray): the keypoints detected
- d (ndarray): corresponding feature descriptors
681 @classmethod 682 def match_keypoints(cls, kp1, kp2, d1, d2, method='SIFT', dist=0.7, tree = 5, check = 100, min_count=5): 683 """ 684 Compares keypoint feature vectors from two images and returns matching pairs. 685 686 Args: 687 kp1 (ndarray): keypoints from the first image 688 kp2 (ndarray): keypoints from the second image 689 d1 (ndarray): descriptors for the keypoints from the first image 690 d2 (ndarray): descriptors for the keypoints from the second image 691 method (str): the method used to calculate the feature descriptors. Should be 'sift' or 'orb'. Default is 'sift'. 692 dist (float): minimum match distance (0 to 1), default is 0.7 693 tree (int): not sure what this does? Default is 5. See open-cv docs. 694 check (int): ditto. Default is 100. 695 min_count (int): the minimum number of matches to consider a valid matching operation. If fewer matches are found, 696 then the function returns None, None. Default is 5. 697 """ 698 import cv2 # import this here to avoid errors if opencv is not installed properly 699 if 'sift' in method.lower(): 700 algorithm = cv2.NORM_INF 701 elif 'orb' in method.lower(): 702 algorithm = cv2.NORM_HAMMING 703 else: 704 assert False, "Error - unknown matching algorithm %s" % method 705 706 #calculate flann matches 707 index_params = dict(algorithm=algorithm, trees=tree) 708 search_params = dict(checks=check) 709 flann = cv2.FlannBasedMatcher(index_params, search_params) 710 matches = flann.knnMatch(d1, d2, k=2) 711 712 # store all the good matches as per Lowe's ratio test. 713 good = [] 714 for m, n in matches: 715 if m.distance < dist * n.distance: 716 good.append(m) 717 718 if len(good) < min_count: 719 return None, None 720 else: 721 src_pts = np.float32([kp1[m.queryIdx].pt for m in good]).reshape(-1, 1, 2) 722 dst_pts = np.float32([kp2[m.trainIdx].pt for m in good]).reshape(-1, 1, 2) 723 return src_pts, dst_pts
Compares keypoint feature vectors from two images and returns matching pairs.
Arguments:
- kp1 (ndarray): keypoints from the first image
- kp2 (ndarray): keypoints from the second image
- d1 (ndarray): descriptors for the keypoints from the first image
- d2 (ndarray): descriptors for the keypoints from the second image
- method (str): the method used to calculate the feature descriptors. Should be 'sift' or 'orb'. Default is 'sift'.
- dist (float): minimum match distance (0 to 1), default is 0.7
- tree (int): not sure what this does? Default is 5. See open-cv docs.
- check (int): ditto. Default is 100.
- min_count (int): the minimum number of matches to consider a valid matching operation. If fewer matches are found, then the function returns None, None. Default is 5.
728 def quick_plot(self, bands=0, ax=None, bfac=0.0, cfac=0.0, samples=False, tscale=False, invert=False, rot=False, flipX=False, flipY=False, 729 **kwds): 730 """ 731 Plot a band using matplotlib.imshow(...). 732 733 Args: 734 bands (str,int,float,tuple): the band name (string), index (integer) or wavelength (float) to plot. Default is 0. If a tuple is passed then 735 each band in the tuple (string or index) will be mapped to rgb. Bands with negative wavelengths or indices will be inverted before plotting. 736 ax: an axis object to plot to. If none, plt.imshow( ... ) is used. 737 bfac (float): a brightness adjustment to apply to RGB mappings (-1 to 1) 738 cfac (float): a contrast adjustment to apply to RGB mappings (-1 to 1) 739 samples (bool): True if sample points (defined in the header file) should be plotted. Default is False. Otherwise, a list of 740 [ (x,y), ... ] points can be passed. 741 tscale (bool): True if each band (for ternary images) should be scaled independently. Default is False. 742 When using scaling, vmin and vmax can be used to set the clipping percentiles (integers) or 743 (constant) values (float). 744 invert (bool) : True if each band should be inverted before plotting. Only works for multiband (ternary) images. 745 rot (bool): if True, the x and y axis will be flipped (90 degree rotation) before plotting. Default is False. 746 flipX (bool): if True, the x axis will be flipped before plotting (after applying rotations). 747 flipY (bool): if True, the y axis will be flippe before plotting (after applying rotations). 748 **kwds: keywords are passed to matplotlib.imshow( ... ), except for the following: 749 750 - mask = a 2D boolean mask containing true if pixels should be drawn and false otherwise. 751 - path = a file path to save the image too (at matching resolution; use fig.savefig(..) if you want to save the figure). 752 - ticks = True if x- and y- ticks should be plotted. Default is False. 753 - ps, pc = the size and color of sample points to plot. Can be constant or list. 754 - figsize = a figsize for the figure to create (if ax is None). 755 756 Returns: 757 Tuple containing 758 759 - fig: matplotlib figure object 760 - ax: matplotlib axes object. If a colorbar is created, (band is an integer or a float), then this will be stored in ax.cbar. 761 """ 762 763 #create new axes? 764 if ax is None: 765 fig, ax = plt.subplots(figsize=kwds.pop('figsize', (18,18*self.ydim()/self.xdim()) )) 766 767 # deal with ticks 768 if not kwds.pop('ticks', False ): 769 ax.set_xticks([]) 770 ax.set_yticks([]) 771 772 #map individual band using colourmap 773 if isinstance(bands, str) or isinstance(bands, int) or isinstance(bands, float): 774 #get band 775 if isinstance(bands, str): 776 data = self.data[:, :, self.get_band_index(bands)] 777 else: 778 data = self.data[:, :, self.get_band_index(np.abs(bands))] 779 if not isinstance(bands, str) and bands < 0: 780 data = np.nanmax(data) - data # flip 781 782 # convert integer vmin and vmax values to percentiles 783 if 'vmin' in kwds: 784 if isinstance(kwds['vmin'], int): 785 kwds['vmin'] = np.nanpercentile( data, kwds['vmin'] ) 786 if 'vmax' in kwds: 787 if isinstance(kwds['vmax'], int): 788 kwds['vmax'] = np.nanpercentile( data, kwds['vmax'] ) 789 790 #mask nans (and apply custom mask) 791 mask = np.isnan(data) 792 if not np.isnan(self.header.get_data_ignore_value()): 793 mask = mask + data == self.header.get_data_ignore_value() 794 if 'mask' in kwds: 795 mask = mask + kwds.get('mask') 796 del kwds['mask'] 797 data = np.ma.array(data, mask = mask > 0 ) 798 799 # apply rotations and flipping 800 if rot: 801 data = data.T 802 if flipX: 803 data = data[::-1, :] 804 if flipY: 805 data = data[:, ::-1] 806 807 # save? 808 if 'path' in kwds: 809 path = kwds.pop('path') 810 from matplotlib.pyplot import imsave 811 if not os.path.exists(os.path.dirname(path)): 812 os.makedirs(os.path.dirname(path)) # ensure output directory exists 813 imsave(path, data.T, **kwds) # save the image 814 815 ax.cbar = ax.imshow(data.T, interpolation=kwds.pop('interpolation', 'none'), **kwds) # change default interpolation to None 816 817 #map 3 bands to RGB 818 elif isinstance(bands, tuple) or isinstance(bands, list): 819 #get band indices and range 820 rgb = [] 821 for b in bands: 822 if isinstance(b, str): 823 rgb.append(self.get_band_index(b)) 824 else: 825 rgb.append(self.get_band_index(np.abs(b))) 826 827 #slice image (as copy) and map to 0 - 1 828 img = np.array(self.data[:, :, rgb]).copy() 829 if np.isnan(img).all(): 830 print("Warning - image contains no data.") 831 return ax.get_figure(), ax 832 833 # invert if needed 834 if invert: 835 bands = [-b for b in bands] 836 for i,b in enumerate(bands): 837 if not isinstance(b, str) and (b < 0): 838 img[..., i] = np.nanmax(img[..., i]) - img[..., i] 839 840 # do scaling 841 if tscale: # scale bands independently 842 for b in range(3): 843 mn = kwds.get("vmin", float(np.nanmin(img))) 844 mx = kwds.get("vmax", float(np.nanmax(img))) 845 if isinstance (mn, int): 846 assert mn >= 0 and mn <= 100, "Error - integer vmin values must be a percentile." 847 mn = float(np.nanpercentile(img[...,b], mn )) 848 if isinstance (mx, int): 849 assert mx >= 0 and mx <= 100, "Error - integer vmax values must be a percentile." 850 mx = float(np.nanpercentile(img[...,b], mx )) 851 img[...,b] = (img[..., b] - mn) / (mx - mn) 852 else: # scale bands together 853 mn = kwds.get("vmin", float(np.nanmin(img))) 854 mx = kwds.get("vmax", float(np.nanmax(img))) 855 if isinstance(mn, int): 856 assert mn >= 0 and mn <= 100, "Error - integer vmin values must be a percentile." 857 mn = float(np.nanpercentile(img, mn)) 858 if isinstance(mx, int): 859 assert mx >= 0 and mx <= 100, "Error - integer vmax values must be a percentile." 860 mx = float(np.nanpercentile(img, mx)) 861 img = (img - mn) / (mx - mn) 862 863 #apply brightness/contrast mapping 864 img = np.clip((1.0 + cfac) * img + bfac, 0, 1.0 ) 865 866 #apply masking so background is white 867 img[np.logical_not( np.isfinite( img ) )] = 1.0 868 if 'mask' in kwds: 869 img[kwds.pop("mask"),:] = 1.0 870 871 # apply rotations and flipping 872 if rot: 873 img = np.transpose( img, (1,0,2) ) 874 if flipX: 875 img = img[::-1, :, :] 876 if flipY: 877 img = img[:, ::-1, :] 878 879 # save? 880 if 'path' in kwds: 881 path = kwds.pop('path') 882 from matplotlib.pyplot import imsave 883 if not os.path.exists(os.path.dirname(path)): 884 os.makedirs(os.path.dirname(path)) # ensure output directory exists 885 imsave(path, np.transpose( np.clip( img*255, 0, 255).astype(np.uint8), (1, 0, 2))) # save the image 886 887 # plot samples? 888 ps = kwds.pop('ps', 5) 889 pc = kwds.pop('pc', 'r') 890 if samples: 891 if isinstance(samples, list) or isinstance(samples, np.ndarray): 892 ax.scatter([s[0] for s in samples], [s[1] for s in samples], s=ps, c=pc) 893 else: 894 for n in self.header.get_class_names(): 895 points = np.array(self.header.get_sample_points(n)) 896 ax.scatter(points[:, 0], points[:, 1], s=ps, c=pc) 897 898 #plot 899 ax.imshow(np.transpose(img, (1,0,2)), interpolation=kwds.pop('interpolation', 'none'), **kwds) 900 ax.cbar = None # no colorbar 901 902 return ax.get_figure(), ax
Plot a band using matplotlib.imshow(...).
Arguments:
- bands (str,int,float,tuple): the band name (string), index (integer) or wavelength (float) to plot. Default is 0. If a tuple is passed then each band in the tuple (string or index) will be mapped to rgb. Bands with negative wavelengths or indices will be inverted before plotting.
- ax: an axis object to plot to. If none, plt.imshow( ... ) is used.
- bfac (float): a brightness adjustment to apply to RGB mappings (-1 to 1)
- cfac (float): a contrast adjustment to apply to RGB mappings (-1 to 1)
- samples (bool): True if sample points (defined in the header file) should be plotted. Default is False. Otherwise, a list of [ (x,y), ... ] points can be passed.
- tscale (bool): True if each band (for ternary images) should be scaled independently. Default is False. When using scaling, vmin and vmax can be used to set the clipping percentiles (integers) or (constant) values (float).
- invert (bool) : True if each band should be inverted before plotting. Only works for multiband (ternary) images.
- rot (bool): if True, the x and y axis will be flipped (90 degree rotation) before plotting. Default is False.
- flipX (bool): if True, the x axis will be flipped before plotting (after applying rotations).
- flipY (bool): if True, the y axis will be flippe before plotting (after applying rotations).
**kwds: keywords are passed to matplotlib.imshow( ... ), except for the following:
- mask = a 2D boolean mask containing true if pixels should be drawn and false otherwise.
- path = a file path to save the image too (at matching resolution; use fig.savefig(..) if you want to save the figure).
- ticks = True if x- and y- ticks should be plotted. Default is False.
- ps, pc = the size and color of sample points to plot. Can be constant or list.
- figsize = a figsize for the figure to create (if ax is None).
Returns:
Tuple containing
- fig: matplotlib figure object
- ax: matplotlib axes object. If a colorbar is created, (band is an integer or a float), then this will be stored in ax.cbar.
904 def createGIF(self, path, bands=None, figsize=(10,10), fps=10, **kwds): 905 """ 906 Create and save an animated gif that loops through the bands of the image. 907 908 Args: 909 path (str): the path to save the .gif 910 bands (tuple): Tuple containing the range of band indices to draw. Default is the whole range. 911 figsize (tuple): the size of the image to draw. Default is (10,10). 912 fps (int): the framerate (frames per second) of the gif. Default is 10. 913 **kwds: keywords are passed directly to matplotlib.imshow. Use this to specify cmap etc. 914 """ 915 916 frames = [] 917 if bands is None: 918 bands = (0,self.band_count()) 919 else: 920 assert 0 < bands[0] < self.band_count(), "Error - invalid range." 921 assert 0 < bands[1] < self.band_count(), "Error - invalid range." 922 assert bands[1] > bands[0], "Error - invalid range." 923 924 #plot frames 925 for i in range(bands[0],bands[1]): 926 fig, ax = plt.subplots(figsize=figsize) 927 ax.imshow(self.data[:, :, i], **kwds) 928 fig.canvas.draw() 929 frames.append(np.frombuffer(fig.canvas.tostring_rgb(), dtype='uint8')) 930 frames[-1] = np.reshape(frames[-1], (fig.canvas.get_width_height()[1], fig.canvas.get_width_height()[0], 3)) 931 plt.close(fig) 932 933 #save gif 934 imageio.mimsave( os.path.splitext(path)[0] + ".gif", frames, fps=fps)
Create and save an animated gif that loops through the bands of the image.
Arguments:
- path (str): the path to save the .gif
- bands (tuple): Tuple containing the range of band indices to draw. Default is the whole range.
- figsize (tuple): the size of the image to draw. Default is (10,10).
- fps (int): the framerate (frames per second) of the gif. Default is 10.
- **kwds: keywords are passed directly to matplotlib.imshow. Use this to specify cmap etc.
937 def drop_bbl(self, drop=True): 938 """ 939 Remove bad bands as stored in the 'bbl' key in the image header. Note that this operates in-place. 940 941 Args: 942 drop (bool): True if bad bands should be completely dropped. If False, these bands will be kept but replaced with nans. 943 """ 944 assert 'bbl' in self.header, "Please specify a bad band list ('bbl') in the image header, as per the ENVI format definition." 945 mask = self.header.get_list('bbl') == 0 946 self.data[...,mask] = np.nan 947 if drop: 948 self.delete_nan_bands(inplace=True)
Remove bad bands as stored in the 'bbl' key in the image header. Note that this operates in-place.
Arguments:
- drop (bool): True if bad bands should be completely dropped. If False, these bands will be kept but replaced with nans.
950 def mask(self, mask=None, flag=np.nan, invert=False, crop=False, bands=None): 951 """ 952 Apply a mask to an image, flagging masked pixels with the specified value. Note that this applies the mask to the 953 image in-situ. 954 955 Args: 956 flag (float): the value to use for masked pixels. Default is np.nan 957 mask (ndarray): a numpy array defining the mask polygon of the format [[x1,y1],[x2,y2],...]. If None is passed then 958 pickPolygon( ... ) is used to interactively define a polygon. If a file path is passed then the polygon 959 will be loaded using np.load( ... ). Alternatively if mask.shape == image.shape[0,1] then it is treated as a 960 binary image mask (must be boolean) and True values will be masked across all bands. Default is None. 961 invert (bool): if True, pixels within the polygon will be masked. If False, pixels outside the polygon are masked. Default is False. 962 crop (bool): True if rows/columns containing only zeros should be removed. Default is False. 963 bands (tuple): the bands of the image to plot if no mask is specified. If None, the middle band is used. 964 965 Returns: 966 Tuple containing 967 968 - mask (ndarray): a boolean array with True where pixels are masked and False elsewhere. 969 - poly (ndarray): the mask polygon array in the format described above. Useful if the polygon was interactively defined. 970 """ 971 972 if mask is None: # pick mask interactively 973 if bands is None: 974 bands = int(self.band_count() / 2) 975 976 regions = self.pickPolygons(region_names=["mask"], bands=bands) 977 978 # the user bailed without picking a mask? 979 if len(regions) == 0: 980 print("Warning - no mask picked/applied.") 981 return 982 983 # extract polygon mask 984 mask = regions[0] 985 986 # convert polygon mask to binary mask 987 if mask.shape[1] == 2: 988 989 # build meshgrid with pixel coords 990 xx, yy = np.meshgrid(np.arange(self.xdim()), np.arange(self.ydim())) 991 xx = xx.flatten() 992 yy = yy.flatten() 993 points = np.vstack([xx, yy]).T # coordinates of each pixel 994 995 # calculate per-pixel mask 996 mask = path.Path(mask).contains_points(points) 997 mask = mask.reshape((self.ydim(), self.xdim())).T 998 999 # flip as we want to mask (==True) outside points (unless invert is true) 1000 if not invert: 1001 mask = np.logical_not(mask) 1002 1003 # apply binary image mask 1004 assert mask.shape[0] == self.data.shape[0] and mask.shape[1] == self.data.shape[1], \ 1005 "Error - mask shape %s does not match image shape %s" % (mask.shape, self.data.shape) 1006 for b in range(self.band_count()): 1007 self.data[:, :, b][mask] = flag 1008 1009 # crop image 1010 if crop: 1011 self.crop_to_data() 1012 1013 return mask
Apply a mask to an image, flagging masked pixels with the specified value. Note that this applies the mask to the image in-situ.
Arguments:
- flag (float): the value to use for masked pixels. Default is np.nan
- mask (ndarray): a numpy array defining the mask polygon of the format [[x1,y1],[x2,y2],...]. If None is passed then pickPolygon( ... ) is used to interactively define a polygon. If a file path is passed then the polygon will be loaded using np.load( ... ). Alternatively if mask.shape == image.shape[0,1] then it is treated as a binary image mask (must be boolean) and True values will be masked across all bands. Default is None.
- invert (bool): if True, pixels within the polygon will be masked. If False, pixels outside the polygon are masked. Default is False.
- crop (bool): True if rows/columns containing only zeros should be removed. Default is False.
- bands (tuple): the bands of the image to plot if no mask is specified. If None, the middle band is used.
Returns:
Tuple containing
- mask (ndarray): a boolean array with True where pixels are masked and False elsewhere.
- poly (ndarray): the mask polygon array in the format described above. Useful if the polygon was interactively defined.
1015 def crop_to_data(self): 1016 """ 1017 Remove padding of nan or zero pixels from image. Note that this is performed in place. 1018 """ 1019 valid = np.isfinite(self.data).any(axis=-1) & (self.data != 0).any(axis=-1) 1020 1021 # integrate along axes 1022 xdata = np.sum(valid, axis=1) > 0.0 1023 ydata = np.sum(valid, axis=0) > 0.0 1024 1025 # calculate domain containing valid pixels 1026 xmin = np.argmax(xdata) 1027 xmax = xdata.shape[0] - np.argmax(xdata[::-1]) 1028 ymin = np.argmax(ydata) 1029 ymax = ydata.shape[0] - np.argmax(ydata[::-1]) 1030 1031 # crop 1032 self.data = self.data[xmin:xmax, ymin:ymax, :] 1033 1034 # shift affine origin to new top-left pixel 1035 if self.affine is not None: 1036 a = self.affine # shorthand for affine 1037 new_affine = list(self.affine) 1038 new_affine[0] = a[0] + xmin*a[1] + ymin*a[2] 1039 new_affine[3] = a[3] + xmin*a[4] + ymin*a[5] 1040 self.affine = np.array(new_affine) 1041 self.header['affine'] = self.affine
Remove padding of nan or zero pixels from image. Note that this is performed in place.
1046 def pickPolygons(self, region_names, bands=0): 1047 """ 1048 Creates a matplotlib gui for selecting polygon regions in an image. 1049 1050 Args: 1051 names (list, str): a list containing the names of the regions to pick. If a string is passed only one name is used. 1052 bands (tuple): the bands of the image to plot. 1053 """ 1054 1055 if isinstance(region_names, str): 1056 region_names = [region_names] 1057 1058 assert isinstance(region_names, list), "Error - names must be a list or a string." 1059 1060 # set matplotlib backend 1061 backend = matplotlib.get_backend() 1062 matplotlib.use('Qt5Agg') # need this backend for ROIPoly to work 1063 1064 # plot image and extract roi's 1065 fig, ax = self.quick_plot(bands) 1066 roi = MultiRoi(roi_names=region_names) 1067 plt.close(fig) # close figure 1068 1069 # extract regions 1070 regions = [] 1071 for name, r in roi.rois.items(): 1072 # store region 1073 x = r.x 1074 y = r.y 1075 regions.append(np.vstack([x, y]).T) 1076 1077 # restore matplotlib backend (if possible) 1078 try: 1079 matplotlib.use(backend) 1080 except: 1081 print("Warning: could not reset matplotlib backend. Plots will remain interactive...") 1082 pass 1083 1084 return regions
Creates a matplotlib gui for selecting polygon regions in an image.
Arguments:
- names (list, str): a list containing the names of the regions to pick. If a string is passed only one name is used.
- bands (tuple): the bands of the image to plot.
1086 def pickPoints(self, n=-1, bands=hylite.RGB, integer=True, title="Pick Points", **kwds): 1087 """ 1088 Creates a matplotlib gui for picking pixels from an image. 1089 1090 Args: 1091 n (int): the number of pixels to pick, or -1 if the user can select as many as they wish. Default is -1. 1092 bands (tuple): the bands of the image to plot. Default is HyImage.RGB 1093 integer (bool): True if points coordinates should be cast to integers (for use as indices). Default is True. 1094 title (str): The title of the point picking window. 1095 **kwds: Keywords are passed to HyImage.quick_plot( ... ). 1096 1097 Returns: 1098 A list containing the picked point coordinates [ (x1,y1), (x2,y2), ... ]. 1099 """ 1100 1101 # set matplotlib backend 1102 backend = matplotlib.get_backend() 1103 matplotlib.use('Qt5Agg') # need this backend for ROIPoly to work 1104 1105 # create figure 1106 fig, ax = self.quick_plot( bands, **kwds ) 1107 ax.set_title(title) 1108 1109 # get points 1110 points = fig.ginput( n ) 1111 1112 if integer: 1113 points = [ (int(p[0]), int(p[1])) for p in points ] 1114 1115 # restore matplotlib backend (if possible) 1116 try: 1117 matplotlib.use(backend) 1118 except: 1119 print("Warning: could not reset matplotlib backend. Plots will remain interactive...") 1120 pass 1121 1122 return points
Creates a matplotlib gui for picking pixels from an image.
Arguments:
- n (int): the number of pixels to pick, or -1 if the user can select as many as they wish. Default is -1.
- bands (tuple): the bands of the image to plot. Default is HyImage.RGB
- integer (bool): True if points coordinates should be cast to integers (for use as indices). Default is True.
- title (str): The title of the point picking window.
- **kwds: Keywords are passed to HyImage.quick_plot( ... ).
Returns:
A list containing the picked point coordinates [ (x1,y1), (x2,y2), ... ].
1124 def pickSamples(self, names=None, store=True, **kwds): 1125 """ 1126 Pick sample probe points and store these in the image header file. 1127 1128 Args: 1129 names (str, list): the name of the sample to pick, or a list of names to pick multiple. 1130 store (bool): True if sample should be stored in the image header file (for later access). Default is True. 1131 **kwds: Keywords are passed to HyImage.quick_plot( ... ) 1132 1133 Returns: 1134 a list containing a list of points for each sample. 1135 """ 1136 1137 if isinstance(names, str): 1138 names = [names] 1139 1140 # pick points 1141 points = [] 1142 for s in names: 1143 pnts = self.pickPoints(title="%s" % s, **kwds) 1144 if store: 1145 self.header['sample %s' % s] = pnts # store in header 1146 points.append(pnts) 1147 # add class to header file 1148 if store: 1149 cls_names = self.header.get_class_names() 1150 if cls_names is None: 1151 cls_names = [] 1152 self.header['class names'] = cls_names + names 1153 1154 return points
Pick sample probe points and store these in the image header file.
Arguments:
- names (str, list): the name of the sample to pick, or a list of names to pick multiple.
- store (bool): True if sample should be stored in the image header file (for later access). Default is True.
- **kwds: Keywords are passed to HyImage.quick_plot( ... )
Returns:
a list containing a list of points for each sample.
Inherited Members
- hylite.hydata.HyData
- to_grey
- data
- set_header
- push_to_header
- has_wavelengths
- get_wavelengths
- has_band_names
- get_band_names
- has_fwhm
- get_fwhm
- set_wavelengths
- set_band_names
- set_fwhm
- is_image
- is_point
- is_classification
- band_count
- samples
- lines
- is_int
- is_float
- export_bands
- delete_nan_bands
- set_as_nan
- mask_bands
- mask_water_features
- get_band
- get_band_grey
- get_raveled
- X
- eval
- set_raveled
- get_band_index
- resample
- contiguous_chunks
- smooth_median
- smooth_savgol
- fill_gaps
- plot_spectra
- compress
- decompress
- getQuantized
- fromQuanta
- normalise
- percent_clip
- correct_spectral_shift