hylite.hyimage
Store and manipulate hyperspectral image data.
1""" 2Store and manipulate hyperspectral image data. 3""" 4 5import os 6import numpy as np 7import matplotlib 8import matplotlib.pyplot as plt 9from matplotlib import path 10from roipoly import MultiRoi 11import imageio 12import scipy as sp 13import hylite 14from hylite.hydata import HyData 15from hylite.hylibrary import HyLibrary 16 17 18 19class HyImage( HyData ): 20 """ 21 A class for hyperspectral image data. These can be individual scenes or hyperspectral orthoimages. 22 """ 23 24 def __init__(self, data, **kwds): 25 """ 26 Args: 27 data (ndarray): a numpy array such that data[x][y][band] gives each pixel value. 28 **kwds: 29 wav = A numpy array containing band wavelengths for this image. 30 affine = an affine transform of the format returned by GDAL.GetGeoTransform(). 31 projection = string defining the project. Default is None. 32 sensor = sensor name. Default is "unknown". 33 header = path to associated header file. Default is None. 34 """ 35 36 #call constructor for HyData 37 super().__init__(data, **kwds) 38 39 # special case - if dataset only has oneband, slice it so it still has 40 # the format data[x,y,b]. 41 if not self.data is None: 42 if len(self.data.shape) == 1: 43 self.data = self.data[None, None, :] # single pixel image 44 if len(self.data.shape) == 2: 45 self.data = self.data[:, :, None] # single band iamge 46 47 #load any additional project information (specific to images) 48 self.set_projection(kwds.get("projection",None)) 49 self.affine = kwds.get("affine",[0,1,0,0,0,1]) 50 51 # wavelengths 52 if 'wav' in kwds: 53 self.set_wavelengths(kwds['wav']) 54 55 #special header formatting 56 self.header['file type'] = 'ENVI Standard' 57 58 def copy(self,data=True): 59 """ 60 Make a deep copy of this image instance. 61 62 Args: 63 data (bool): True if a copy of the data should be made, otherwise only copy header. 64 65 Returns: 66 a new HyImage instance. 67 """ 68 if not data: 69 return HyImage(None, header=self.header.copy(), projection=self.projection, affine=self.affine) 70 else: 71 return HyImage( self.data.copy(), header=self.header.copy(), projection=self.projection, affine=self.affine) 72 73 def T(self): 74 """ 75 Return a transposed view of the data matrix (corresponding with the [y,x] indexing used by matplotlib, opencv etc. 76 """ 77 return np.transpose(self.data, (1,0,2)) 78 79 def xdim(self): 80 """ 81 Return number of pixels in x (first dimension of data array) 82 """ 83 return self.data.shape[0] 84 85 def ydim(self): 86 """ 87 Return number of pixels in y (second dimension of data array) 88 """ 89 return self.data.shape[1] 90 91 def aspx(self): 92 """ 93 Return the aspect ratio of this image (width/height). 94 """ 95 return self.ydim() / self.xdim() 96 97 def get_extent(self): 98 """ 99 Returns the width and height of this image in world coordinates. 100 101 Returns: 102 tuple with (width, height). 103 """ 104 return self.xdim * self.pixel_size[0], self.ydim * self.pixel_size[1] 105 106 def set_projection(self,proj): 107 """ 108 Set this project to an existing osgeo.osr.SpatialReference or GDAL georeference string. 109 110 Args: 111 proj (str, osgeo.osr.SpatialReference): the project to use as osgeo.osr.SpatialReference or GDAL georeference string. 112 """ 113 if proj is None: 114 self.projection = None 115 else: 116 try: 117 from osgeo.osr import SpatialReference 118 except: 119 assert False, "Error - GDAL must be installed to work with spatial projections in hylite." 120 if isinstance(proj, SpatialReference): 121 self.projection = proj 122 elif isinstance(proj, str): 123 self.projection = SpatialReference(proj) 124 else: 125 print("Invalid project %s" % proj) 126 raise 127 128 def set_projection_EPSG(self,EPSG): 129 """ 130 Sets this image project using an EPSG code. 131 132 Args: 133 EPSG (str): string EPSG code that can be passed to SpatialReference.SetFromUserInput(...). 134 """ 135 136 try: 137 from osgeo.osr import SpatialReference 138 except: 139 assert False, "Error - GDAL must be installed to work with spatial projections in hylite." 140 141 self.projection = SpatialReference() 142 self.projection.SetFromUserInput(EPSG) 143 144 def get_projection_EPSG(self): 145 """ 146 Gets a string describing this projections EPSG code (if it is an EPSG project). 147 148 Returns: 149 an EPSG code string of the format "EPSG:XXXX". 150 """ 151 if self.projection is None: 152 return None 153 else: 154 return "%s:%s" % (self.projection.GetAttrValue("AUTHORITY",0),self.projection.GetAttrValue("AUTHORITY",1)) 155 156 def pix_to_world(self, px, py, proj=None): 157 """ 158 Take pixel coordinates and return world coordinates 159 160 Args: 161 px (int): the pixel x-coord. 162 py (int): the pixel y-coord. 163 proj (str, osr.SpatialReference): the coordinate system to use. Default (None) uses the same system as this image. Otherwise 164 an osr.SpatialReference can be passed (HyImage.project), or an EPSG string (e.g. get_projection_EPSG(...)). 165 Returns: 166 the world coordinates in the coordinate system defined by get_projection_EPSG(...). 167 """ 168 169 try: 170 from osgeo import osr 171 import osgeo.gdal as gdal 172 from osgeo import ogr 173 except: 174 assert False, "Error - GDAL must be installed to work with spatial projections in hylite." 175 176 # parse project 177 if proj is None: 178 proj = self.projection 179 elif isinstance(proj, str) or isinstance(proj, int): 180 epsg = proj 181 if isinstance(epsg, str): 182 try: 183 epsg = int(str.split(':')[1]) 184 except: 185 assert False, "Error - %s is an invalid EPSG code." % proj 186 proj = osr.SpatialReference() 187 proj.ImportFromEPSG(epsg) 188 189 # check we have all the required info 190 assert isinstance(proj, osr.SpatialReference), "Error - invalid spatial reference %s" % proj 191 assert (not self.affine is None) and ( 192 not self.projection is None), "Error - project information is undefined." 193 194 #project to world coordinates in this images project/world coords 195 x,y = gdal.ApplyGeoTransform(self.affine, px, py) 196 197 #project to target coords (if different) 198 if not proj.IsSameGeogCS(self.projection): 199 P = ogr.Geometry(ogr.wkbPoint) 200 if proj.EPSGTreatsAsNorthingEasting(): 201 P.AddPoint(x, y) 202 else: 203 P.AddPoint(y, x) 204 P.AssignSpatialReference(self.projection) # tell the point what coordinates it's in 205 P.TransformTo(proj) # reproject it to the out spatial reference 206 x, y = P.GetX(), P.GetY() 207 208 #do we need to transpose? 209 if proj.EPSGTreatsAsLatLong(): 210 x,y=y,x #we want lon,lat not lat,lon 211 return x, y 212 213 def world_to_pix(self, x, y, proj = None): 214 """ 215 Take world coordinates and return pixel coordinates 216 217 Args: 218 x (float): the world x-coord. 219 y (float): the world y-coord. 220 proj (str, osr.SpatialReference): the coordinate system of the input coordinates. Default (None) uses the same system as this image. Otherwise 221 an osr.SpatialReference can be passed (HyImage.project), or an EPSG string (e.g. get_projection_EPSG(...)). 222 223 Returns: 224 the pixel coordinates based on the affine transform stored in self.affine. 225 """ 226 227 try: 228 from osgeo import osr 229 import osgeo.gdal as gdal 230 from osgeo import ogr 231 except: 232 assert False, "Error - GDAL must be installed to work with spatial projections in hylite." 233 234 # parse project 235 if proj is None: 236 proj = self.projection 237 elif isinstance(proj, str) or isinstance(proj, int): 238 epsg = proj 239 if isinstance(epsg, str): 240 try: 241 epsg = int(str.split(':')[1]) 242 except: 243 assert False, "Error - %s is an invalid EPSG code." % proj 244 proj = osr.SpatialReference() 245 proj.ImportFromEPSG(epsg) 246 247 248 # check we have all the required info 249 assert isinstance(proj, osr.SpatialReference), "Error - invalid spatial reference %s" % proj 250 assert (not self.affine is None) and (not self.projection is None), "Error - project information is undefined." 251 252 # project to this images CS (if different) 253 if not proj.IsSameGeogCS(self.projection): 254 P = ogr.Geometry(ogr.wkbPoint) 255 if proj.EPSGTreatsAsNorthingEasting(): 256 P.AddPoint(x, y) 257 else: 258 P.AddPoint(y, x) 259 P.AssignSpatialReference(proj) # tell the point what coordinates it's in 260 P.AddPoint(x, y) 261 P.TransformTo(self.projection) # reproject it to the out spatial reference 262 x, y = P.GetX(), P.GetY() 263 if self.projection.EPSGTreatsAsLatLong(): # do we need to transpose? 264 x, y = y, x # we want lon,lat not lat,lon 265 266 inv = gdal.InvGeoTransform(self.affine) 267 assert not inv is None, "Error - could not invert affine transform?" 268 269 #apply 270 return gdal.ApplyGeoTransform(inv, x, y) 271 272 def flip(self, axis='x'): 273 """ 274 Flip the image on the x or y axis. 275 276 Args: 277 axis (str): 'x' or 'y' or both 'xy'. 278 """ 279 280 if 'x' in axis.lower(): 281 self.data = np.flip(self.data,axis=0) 282 if 'y' in axis.lower(): 283 self.data = np.flip(self.data,axis=1) 284 285 def rot90(self): 286 """ 287 Rotate this image by 90 degrees by transposing the underlying data array. Combine with flip('x') or flip('y') 288 to achieve positive/negative rotations. 289 """ 290 self.data = np.transpose( self.data, (1,0,2) ) 291 self.push_to_header() 292 293 ##################################### 294 ##IMAGE FILTERING 295 ##################################### 296 def fill_holes(self): 297 """ 298 Replaces nan pixel with an average of their neighbours, thus removing 1-pixel large holes from an image. Note that 299 for performance reasons this assumes that holes line up across bands. Note that this is not vectorized so very slow... 300 """ 301 302 # perform greyscale dilation 303 dilate = self.data.copy() 304 mask = np.logical_not(np.isfinite(dilate)) 305 dilate[mask] = 0 306 for b in range(self.band_count()): 307 dilate[:, :, b] = sp.ndimage.grey_dilation(dilate[:, :, b], size=(3, 3)) 308 309 # map back to holes in dataset 310 self.data[mask] = dilate[mask] 311 #self.data[self.data == 0] = np.nan # replace remaining 0's with nans 312 313 def blur(self, n=3): 314 """ 315 Applies a gaussian kernel of size n to the image using OpenCV. 316 317 Args: 318 n (int): the dimensions of the gaussian kernel to convolve. Default is 3. Increase for more blurry results. 319 """ 320 import cv2 # import this here to avoid errors if opencv is not installed properly 321 322 nanmask = np.isnan(self.data) 323 assert isinstance(n, int) and n >= 3, "Error - invalid kernel. N must be an integer > 3. " 324 kernel = np.ones((n, n), np.float32) / (n ** 2) 325 self.data = cv2.filter2D(self.data, -1, kernel) 326 self.data[nanmask] = np.nan # remove mask 327 328 def erode(self, size=3, iterations=1): 329 """ 330 Apply an erode filter to this image to expand background (nan) pixels. Refer to open-cv's erode 331 function for more details. 332 333 Args: 334 size (int): the size of the erode filter. Default is a 3x3 kernel. 335 iterations (int): the number of erode iterations. Default is 1. 336 """ 337 import cv2 # import this here to avoid errors if opencv is not installed properly 338 339 # erode 340 kernel = np.ones((size, size), np.uint8) 341 if self.is_float(): 342 mask = np.isfinite(self.data).any(axis=-1) 343 mask = cv2.erode(mask.astype(np.uint8), kernel, iterations=iterations) 344 self.data[mask == 0, :] = np.nan 345 else: 346 mask = (self.data != 0).any( axis=-1 ) 347 mask = cv2.erode(mask.astype(np.uint8), kernel, iterations=iterations) 348 self.data[mask == 0, :] = 0 349 350 def resize(self, newdims : tuple, interpolation : int = 1): 351 """ 352 Resize this image with opencv. 353 354 Args: 355 newdims (tuple): the new image dimensions. 356 interpolation (int): opencv interpolation method. Default is cv2.INTER_LINEAR. 357 """ 358 import cv2 # import this here to avoid errors if opencv is not installed properly 359 self.data = cv2.resize(self.data, (newdims[1],newdims[0]), interpolation=interpolation) 360 361 def despeckle(self, size=5): 362 """ 363 Despeckle each band of this image (independently) using a median filter. 364 365 Args: 366 size (int): the size of the median filter kernel. Default is 5. Must be an odd number. 367 """ 368 369 assert (size % 2) == 1, "Error - size must be an odd integer" 370 import cv2 # import this here to avoid errors if opencv is not installed properly 371 if self.is_float(): 372 self.data = cv2.medianBlur( self.data.astype(np.float32), size ) 373 else: 374 self.data = cv2.medianBlur( self.data, size ) 375 376 ##################################### 377 ##FEATURES AND FEATURE MATCHING 378 ###################################### 379 def get_keypoints(self, band, eq=False, mask=True, method='sift', cfac=0.0,bfac=0.0, **kwds): 380 """ 381 Get feature descriptors from the specified band. 382 383 Args: 384 band (int,float,str,tuple): the band index (int) or wavelength (float) to extract features from. Alternatively, a tuple can be passed 385 containing a range of bands (min : max) to average before feature matching. 386 eq (bool): True if the image should be histogram equalized first. Default is False. 387 mask (bool): True if 0 value pixels should be masked. Default is True. 388 method (str): the feature detector to use. Options are 'SIFT' and 'ORB' (faster but less accurate). Default is 'SIFT'. 389 cfac (float): contrast adjustment to apply to hyperspectral bands before matching. Default is 0.0. 390 bfac (float): brightness adjustment to apply to hyperspectral bands before matching. Default is 0.0. 391 **kwds: keyword arguments are passed to the opencv feature detector. For SIFT these are: 392 393 - contrastThreshold: default is 0.01. 394 - edgeThreshold: default is 10. 395 - sigma: default is 1.0 396 397 For ORB these are: 398 399 - nfeatures = the number of features to detect. Default is 5000. 400 401 Returns: 402 Tuple containing 403 404 - k (ndarray): the keypoints detected 405 - d (ndarray): corresponding feature descriptors 406 """ 407 import cv2 # import this here to avoid errors if opencv is not installed properly 408 409 # get image 410 if isinstance(band, int) or isinstance(band, float): #single band 411 image = self.data[:, :, self.get_band_index(band)] 412 elif isinstance(band,tuple): #range of bands (averaged) 413 idx0 = self.get_band_index(band[0]) 414 idx1 = self.get_band_index(band[1]) 415 416 #deal with out of range errors 417 if idx0 is None: 418 idx0 = 0 419 if idx1 is None: 420 idx1 = self.band_count() 421 422 #average bands 423 image = np.nanmean(self.data[:,:,idx0:idx1],axis=2) 424 else: 425 assert False, "Error, unrecognised band %s" % band 426 427 #normalise image to range 0 - 1 428 image -= np.nanmin(image) 429 image = image / np.nanmax(image) 430 431 #apply brightness/contrast adjustment 432 image = (1.0+cfac)*image + bfac 433 image[image > 1.0] = 1.0 434 image[image < 0.0] = 0.0 435 436 #convert image to uint8 for opencv 437 image = np.uint8(255 * image) 438 if eq: 439 image = cv2.equalizeHist(image) 440 441 if mask: 442 mask = np.zeros(image.shape, dtype=np.uint8) 443 mask[image != 0] = 255 # include only non-zero pixels 444 else: 445 mask = None 446 447 if 'sift' in method.lower(): # SIFT 448 449 # setup default keywords 450 kwds["contrastThreshold"] = kwds.get("contrastThreshold", 0.01) 451 kwds["edgeThreshold"] = kwds.get("edgeThreshold", 10) 452 kwds["sigma"] = kwds.get("sigma", 1.0) 453 454 # make feature detector 455 #alg = cv2.xfeatures2d.SIFT_create(**kwds) 456 alg = cv2.SIFT_create() 457 elif 'orb' in method.lower(): # orb 458 kwds['nfeatures'] = kwds.get('nfeatures', 5000) 459 alg = cv2.ORB_create(scoreType=cv2.ORB_FAST_SCORE, **kwds) 460 else: 461 assert False, "Error - %s is not a recognised feature detector." % method 462 463 # detect keypoints 464 kp = alg.detect(image, mask) 465 466 # extract and return feature vectors 467 return alg.compute(image, kp) 468 469 @classmethod 470 def match_keypoints(cls, kp1, kp2, d1, d2, method='SIFT', dist=0.7, tree = 5, check = 100, min_count=5): 471 """ 472 Compares keypoint feature vectors from two images and returns matching pairs. 473 474 Args: 475 kp1 (ndarray): keypoints from the first image 476 kp2 (ndarray): keypoints from the second image 477 d1 (ndarray): descriptors for the keypoints from the first image 478 d2 (ndarray): descriptors for the keypoints from the second image 479 method (str): the method used to calculate the feature descriptors. Should be 'sift' or 'orb'. Default is 'sift'. 480 dist (float): minimum match distance (0 to 1), default is 0.7 481 tree (int): not sure what this does? Default is 5. See open-cv docs. 482 check (int): ditto. Default is 100. 483 min_count (int): the minimum number of matches to consider a valid matching operation. If fewer matches are found, 484 then the function returns None, None. Default is 5. 485 """ 486 import cv2 # import this here to avoid errors if opencv is not installed properly 487 if 'sift' in method.lower(): 488 algorithm = cv2.NORM_INF 489 elif 'orb' in method.lower(): 490 algorithm = cv2.NORM_HAMMING 491 else: 492 assert False, "Error - unknown matching algorithm %s" % method 493 494 #calculate flann matches 495 index_params = dict(algorithm=algorithm, trees=tree) 496 search_params = dict(checks=check) 497 flann = cv2.FlannBasedMatcher(index_params, search_params) 498 matches = flann.knnMatch(d1, d2, k=2) 499 500 # store all the good matches as per Lowe's ratio test. 501 good = [] 502 for m, n in matches: 503 if m.distance < dist * n.distance: 504 good.append(m) 505 506 if len(good) < min_count: 507 return None, None 508 else: 509 src_pts = np.float32([kp1[m.queryIdx].pt for m in good]).reshape(-1, 1, 2) 510 dst_pts = np.float32([kp2[m.trainIdx].pt for m in good]).reshape(-1, 1, 2) 511 return src_pts, dst_pts 512 513 ############################ 514 ## Visualisation methods 515 ############################ 516 def quick_plot(self, band=0, ax=None, bfac=0.0, cfac=0.0, samples=False, tscale=False, invert=False, rot=False, flipX=False, flipY=False, 517 **kwds): 518 """ 519 Plot a band using matplotlib.imshow(...). 520 521 Args: 522 band (str,int,float,tuple): the band name (string), index (integer) or wavelength (float) to plot. Default is 0. If a tuple is passed then 523 each band in the tuple (string or index) will be mapped to rgb. Bands with negative wavelengths or indices will be inverted before plotting. 524 ax: an axis object to plot to. If none, plt.imshow( ... ) is used. 525 bfac (float): a brightness adjustment to apply to RGB mappings (-1 to 1) 526 cfac (float): a contrast adjustment to apply to RGB mappings (-1 to 1) 527 samples (bool): True if sample points (defined in the header file) should be plotted. Default is False. Otherwise, a list of 528 [ (x,y), ... ] points can be passed. 529 tscale (bool): True if each band (for ternary images) should be scaled independently. Default is False. 530 When using scaling, vmin and vmax can be used to set the clipping percentiles (integers) or 531 (constant) values (float). 532 invert (bool) : True if each band should be inverted before plotting. Only works for multiband (ternary) images. 533 rot (bool): if True, the x and y axis will be flipped (90 degree rotation) before plotting. Default is False. 534 flipX (bool): if True, the x axis will be flipped before plotting (after applying rotations). 535 flipY (bool): if True, the y axis will be flippe before plotting (after applying rotations). 536 **kwds: keywords are passed to matplotlib.imshow( ... ), except for the following: 537 538 - mask = a 2D boolean mask containing true if pixels should be drawn and false otherwise. 539 - path = a file path to save the image too (at matching resolution; use fig.savefig(..) if you want to save the figure). 540 - ticks = True if x- and y- ticks should be plotted. Default is False. 541 - ps, pc = the size and color of sample points to plot. Can be constant or list. 542 - figsize = a figsize for the figure to create (if ax is None). 543 544 Returns: 545 Tuple containing 546 547 - fig: matplotlib figure object 548 - ax: matplotlib axes object. If a colorbar is created, (band is an integer or a float), then this will be stored in ax.cbar. 549 """ 550 551 #create new axes? 552 if ax is None: 553 fig, ax = plt.subplots(figsize=kwds.pop('figsize', (18,18*self.ydim()/self.xdim()) )) 554 555 # deal with ticks 556 if not kwds.pop('ticks', False ): 557 ax.set_xticks([]) 558 ax.set_yticks([]) 559 560 #map individual band using colourmap 561 if isinstance(band, str) or isinstance(band, int) or isinstance(band, float): 562 #get band 563 if isinstance(band, str): 564 data = self.data[:, :, self.get_band_index(band)] 565 else: 566 data = self.data[:, :, self.get_band_index(np.abs(band))] 567 if not isinstance(band, str) and band < 0: 568 data = np.nanmax(data) - data # flip 569 570 # convert integer vmin and vmax values to percentiles 571 if 'vmin' in kwds: 572 if isinstance(kwds['vmin'], int): 573 kwds['vmin'] = np.nanpercentile( data, kwds['vmin'] ) 574 if 'vmax' in kwds: 575 if isinstance(kwds['vmax'], int): 576 kwds['vmax'] = np.nanpercentile( data, kwds['vmax'] ) 577 578 #mask nans (and apply custom mask) 579 mask = np.isnan(data) 580 if not np.isnan(self.header.get_data_ignore_value()): 581 mask = mask + data == self.header.get_data_ignore_value() 582 if 'mask' in kwds: 583 mask = mask + kwds.get('mask') 584 del kwds['mask'] 585 data = np.ma.array(data, mask = mask > 0 ) 586 587 # apply rotations and flipping 588 if rot: 589 data = data.T 590 if flipX: 591 data = data[::-1, :] 592 if flipY: 593 data = data[:, ::-1] 594 595 # save? 596 if 'path' in kwds: 597 path = kwds.pop('path') 598 from matplotlib.pyplot import imsave 599 if not os.path.exists(os.path.dirname(path)): 600 os.makedirs(os.path.dirname(path)) # ensure output directory exists 601 imsave(path, data.T, **kwds) # save the image 602 603 ax.cbar = ax.imshow(data.T, interpolation=kwds.pop('interpolation', 'none'), **kwds) # change default interpolation to None 604 605 #map 3 bands to RGB 606 elif isinstance(band, tuple) or isinstance(band, list): 607 #get band indices and range 608 rgb = [] 609 for b in band: 610 if isinstance(b, str): 611 rgb.append(self.get_band_index(b)) 612 else: 613 rgb.append(self.get_band_index(np.abs(b))) 614 615 #slice image (as copy) and map to 0 - 1 616 img = np.array(self.data[:, :, rgb]).copy() 617 if np.isnan(img).all(): 618 print("Warning - image contains no data.") 619 return ax.get_figure(), ax 620 621 # invert if needed 622 if invert: 623 band = [-b for b in band] 624 for i,b in enumerate(band): 625 if not isinstance(b, str) and (b < 0): 626 img[..., i] = np.nanmax(img[..., i]) - img[..., i] 627 628 # do scaling 629 if tscale: # scale bands independently 630 for b in range(3): 631 mn = kwds.get("vmin", float(np.nanmin(img))) 632 mx = kwds.get("vmax", float(np.nanmax(img))) 633 if isinstance (mn, int): 634 assert mn >= 0 and mn <= 100, "Error - integer vmin values must be a percentile." 635 mn = float(np.nanpercentile(img[...,b], mn )) 636 if isinstance (mx, int): 637 assert mx >= 0 and mx <= 100, "Error - integer vmax values must be a percentile." 638 mx = float(np.nanpercentile(img[...,b], mx )) 639 img[...,b] = (img[..., b] - mn) / (mx - mn) 640 else: # scale bands together 641 mn = kwds.get("vmin", float(np.nanmin(img))) 642 mx = kwds.get("vmax", float(np.nanmax(img))) 643 if isinstance(mn, int): 644 assert mn >= 0 and mn <= 100, "Error - integer vmin values must be a percentile." 645 mn = float(np.nanpercentile(img, mn)) 646 if isinstance(mx, int): 647 assert mx >= 0 and mx <= 100, "Error - integer vmax values must be a percentile." 648 mx = float(np.nanpercentile(img, mx)) 649 img = (img - mn) / (mx - mn) 650 651 #apply brightness/contrast mapping 652 img = np.clip((1.0 + cfac) * img + bfac, 0, 1.0 ) 653 654 #apply masking so background is white 655 img[np.logical_not( np.isfinite( img ) )] = 1.0 656 if 'mask' in kwds: 657 img[kwds.pop("mask"),:] = 1.0 658 659 # apply rotations and flipping 660 if rot: 661 img = np.transpose( img, (1,0,2) ) 662 if flipX: 663 img = img[::-1, :, :] 664 if flipY: 665 img = img[:, ::-1, :] 666 667 # save? 668 if 'path' in kwds: 669 path = kwds.pop('path') 670 from matplotlib.pyplot import imsave 671 if not os.path.exists(os.path.dirname(path)): 672 os.makedirs(os.path.dirname(path)) # ensure output directory exists 673 imsave(path, np.transpose( np.clip( img*255, 0, 255).astype(np.uint8), (1, 0, 2))) # save the image 674 675 # plot samples? 676 ps = kwds.pop('ps', 5) 677 pc = kwds.pop('pc', 'r') 678 if samples: 679 if isinstance(samples, list) or isinstance(samples, np.ndarray): 680 ax.scatter([s[0] for s in samples], [s[1] for s in samples], s=ps, c=pc) 681 else: 682 for n in self.header.get_class_names(): 683 points = np.array(self.header.get_sample_points(n)) 684 ax.scatter(points[:, 0], points[:, 1], s=ps, c=pc) 685 686 #plot 687 ax.imshow(np.transpose(img, (1,0,2)), interpolation=kwds.pop('interpolation', 'none'), **kwds) 688 ax.cbar = None # no colorbar 689 690 return ax.get_figure(), ax 691 692 def createGIF(self, path, bands=None, figsize=(10,10), fps=10, **kwds): 693 """ 694 Create and save an animated gif that loops through the bands of the image. 695 696 Args: 697 path (str): the path to save the .gif 698 bands (tuple): Tuple containing the range of band indices to draw. Default is the whole range. 699 figsize (tuple): the size of the image to draw. Default is (10,10). 700 fps (int): the framerate (frames per second) of the gif. Default is 10. 701 **kwds: keywords are passed directly to matplotlib.imshow. Use this to specify cmap etc. 702 """ 703 704 frames = [] 705 if bands is None: 706 bands = (0,self.band_count()) 707 else: 708 assert 0 < bands[0] < self.band_count(), "Error - invalid range." 709 assert 0 < bands[1] < self.band_count(), "Error - invalid range." 710 assert bands[1] > bands[0], "Error - invalid range." 711 712 #plot frames 713 for i in range(bands[0],bands[1]): 714 fig, ax = plt.subplots(figsize=figsize) 715 ax.imshow(self.data[:, :, i], **kwds) 716 fig.canvas.draw() 717 frames.append(np.frombuffer(fig.canvas.tostring_rgb(), dtype='uint8')) 718 frames[-1] = np.reshape(frames[-1], (fig.canvas.get_width_height()[1], fig.canvas.get_width_height()[0], 3)) 719 plt.close(fig) 720 721 #save gif 722 imageio.mimsave( os.path.splitext(path)[0] + ".gif", frames, fps=fps) 723 724 ## masking 725 def drop_bbl(self, drop=True): 726 """ 727 Remove bad bands as stored in the 'bbl' key in the image header. Note that this operates in-place. 728 729 Args: 730 drop (bool): True if bad bands should be completely dropped. If False, these bands will be kept but replaced with nans. 731 """ 732 assert 'bbl' in self.header, "Please specify a bad band list ('bbl') in the image header, as per the ENVI format definition." 733 mask = self.header.get_list('bbl') == 0 734 self.data[...,mask] = np.nan 735 if drop: 736 self.delete_nan_bands(inplace=True) 737 738 def mask(self, mask=None, flag=np.nan, invert=False, crop=False, bands=None): 739 """ 740 Apply a mask to an image, flagging masked pixels with the specified value. Note that this applies the mask to the 741 image in-situ. 742 743 Args: 744 flag (float): the value to use for masked pixels. Default is np.nan 745 mask (ndarray): a numpy array defining the mask polygon of the format [[x1,y1],[x2,y2],...]. If None is passed then 746 pickPolygon( ... ) is used to interactively define a polygon. If a file path is passed then the polygon 747 will be loaded using np.load( ... ). Alternatively if mask.shape == image.shape[0,1] then it is treated as a 748 binary image mask (must be boolean) and True values will be masked across all bands. Default is None. 749 invert (bool): if True, pixels within the polygon will be masked. If False, pixels outside the polygon are masked. Default is False. 750 crop (bool): True if rows/columns containing only zeros should be removed. Default is False. 751 bands (tuple): the bands of the image to plot if no mask is specified. If None, the middle band is used. 752 753 Returns: 754 Tuple containing 755 756 - mask (ndarray): a boolean array with True where pixels are masked and False elsewhere. 757 - poly (ndarray): the mask polygon array in the format described above. Useful if the polygon was interactively defined. 758 """ 759 760 if mask is None: # pick mask interactively 761 if bands is None: 762 bands = int(self.band_count() / 2) 763 764 regions = self.pickPolygons(region_names=["mask"], bands=bands) 765 766 # the user bailed without picking a mask? 767 if len(regions) == 0: 768 print("Warning - no mask picked/applied.") 769 return 770 771 # extract polygon mask 772 mask = regions[0] 773 774 # convert polygon mask to binary mask 775 if mask.shape[1] == 2: 776 777 # build meshgrid with pixel coords 778 xx, yy = np.meshgrid(np.arange(self.xdim()), np.arange(self.ydim())) 779 xx = xx.flatten() 780 yy = yy.flatten() 781 points = np.vstack([xx, yy]).T # coordinates of each pixel 782 783 # calculate per-pixel mask 784 mask = path.Path(mask).contains_points(points) 785 mask = mask.reshape((self.ydim(), self.xdim())).T 786 787 # flip as we want to mask (==True) outside points (unless invert is true) 788 if not invert: 789 mask = np.logical_not(mask) 790 791 # apply binary image mask 792 assert mask.shape[0] == self.data.shape[0] and mask.shape[1] == self.data.shape[1], \ 793 "Error - mask shape %s does not match image shape %s" % (mask.shape, self.data.shape) 794 for b in range(self.band_count()): 795 self.data[:, :, b][mask] = flag 796 797 # crop image 798 if crop: 799 # calculate non-masked pixels 800 valid = np.logical_not(mask) 801 802 # integrate along axes 803 xdata = np.sum(valid, axis=1) > 0.0 804 ydata = np.sum(valid, axis=0) > 0.0 805 806 # calculate domain containing valid pixels 807 xmin = np.argmax(xdata) 808 xmax = xdata.shape[0] - np.argmax(xdata[::-1]) 809 ymin = np.argmax(ydata) 810 ymax = ydata.shape[0] - np.argmax(ydata[::-1]) 811 812 # crop 813 self.data = self.data[xmin:xmax, ymin:ymax, :] 814 815 return mask 816 817 def crop_to_data(self): 818 """ 819 Remove padding of nan or zero pixels from image. Note that this is performed in place. 820 """ 821 822 valid = np.isfinite(self.data).any(axis=-1) & (self.data != 0).any(axis=-1) 823 ymin, ymax = np.percentile(np.argwhere(np.sum(valid, axis=0) != 0), (0, 100)) 824 xmin, xmax = np.percentile(np.argwhere(np.sum(valid, axis=1) != 0), (0, 100)) 825 self.data = self.data[int(xmin):int(xmax), int(ymin):int(ymax), :] # do clipping 826 827 ################################################## 828 ## Interactive tools for picking regions/pixels 829 ################################################## 830 def pickPolygons(self, region_names, bands=0): 831 """ 832 Creates a matplotlib gui for selecting polygon regions in an image. 833 834 Args: 835 names (list, str): a list containing the names of the regions to pick. If a string is passed only one name is used. 836 bands (tuple): the bands of the image to plot. 837 """ 838 839 if isinstance(region_names, str): 840 region_names = [region_names] 841 842 assert isinstance(region_names, list), "Error - names must be a list or a string." 843 844 # set matplotlib backend 845 backend = matplotlib.get_backend() 846 matplotlib.use('Qt5Agg') # need this backend for ROIPoly to work 847 848 # plot image and extract roi's 849 fig, ax = self.quick_plot(bands) 850 roi = MultiRoi(roi_names=region_names) 851 plt.close(fig) # close figure 852 853 # extract regions 854 regions = [] 855 for name, r in roi.rois.items(): 856 # store region 857 x = r.x 858 y = r.y 859 regions.append(np.vstack([x, y]).T) 860 861 # restore matplotlib backend (if possible) 862 try: 863 matplotlib.use(backend) 864 except: 865 print("Warning: could not reset matplotlib backend. Plots will remain interactive...") 866 pass 867 868 return regions 869 870 def pickPoints(self, n=-1, bands=hylite.RGB, integer=True, title="Pick Points", **kwds): 871 """ 872 Creates a matplotlib gui for picking pixels from an image. 873 874 Args: 875 n (int): the number of pixels to pick, or -1 if the user can select as many as they wish. Default is -1. 876 bands (tuple): the bands of the image to plot. Default is HyImage.RGB 877 integer (bool): True if points coordinates should be cast to integers (for use as indices). Default is True. 878 title (str): The title of the point picking window. 879 **kwds: Keywords are passed to HyImage.quick_plot( ... ). 880 881 Returns: 882 A list containing the picked point coordinates [ (x1,y1), (x2,y2), ... ]. 883 """ 884 885 # set matplotlib backend 886 backend = matplotlib.get_backend() 887 matplotlib.use('Qt5Agg') # need this backend for ROIPoly to work 888 889 # create figure 890 fig, ax = self.quick_plot( bands, **kwds ) 891 ax.set_title(title) 892 893 # get points 894 points = fig.ginput( n ) 895 896 if integer: 897 points = [ (int(p[0]), int(p[1])) for p in points ] 898 899 # restore matplotlib backend (if possible) 900 try: 901 matplotlib.use(backend) 902 except: 903 print("Warning: could not reset matplotlib backend. Plots will remain interactive...") 904 pass 905 906 return points 907 908 def pickSamples(self, names=None, store=True, **kwds): 909 """ 910 Pick sample probe points and store these in the image header file. 911 912 Args: 913 names (str, list): the name of the sample to pick, or a list of names to pick multiple. 914 store (bool): True if sample should be stored in the image header file (for later access). Default is True. 915 **kwds: Keywords are passed to HyImage.quick_plot( ... ) 916 917 Returns: 918 a list containing a list of points for each sample. 919 """ 920 921 if isinstance(names, str): 922 names = [names] 923 924 # pick points 925 points = [] 926 for s in names: 927 pnts = self.pickPoints(title="%s" % s, **kwds) 928 if store: 929 self.header['sample %s' % s] = pnts # store in header 930 points.append(pnts) 931 # add class to header file 932 if store: 933 cls_names = self.header.get_class_names() 934 if cls_names is None: 935 cls_names = [] 936 self.header['class names'] = cls_names + names 937 938 return points
20class HyImage( HyData ): 21 """ 22 A class for hyperspectral image data. These can be individual scenes or hyperspectral orthoimages. 23 """ 24 25 def __init__(self, data, **kwds): 26 """ 27 Args: 28 data (ndarray): a numpy array such that data[x][y][band] gives each pixel value. 29 **kwds: 30 wav = A numpy array containing band wavelengths for this image. 31 affine = an affine transform of the format returned by GDAL.GetGeoTransform(). 32 projection = string defining the project. Default is None. 33 sensor = sensor name. Default is "unknown". 34 header = path to associated header file. Default is None. 35 """ 36 37 #call constructor for HyData 38 super().__init__(data, **kwds) 39 40 # special case - if dataset only has oneband, slice it so it still has 41 # the format data[x,y,b]. 42 if not self.data is None: 43 if len(self.data.shape) == 1: 44 self.data = self.data[None, None, :] # single pixel image 45 if len(self.data.shape) == 2: 46 self.data = self.data[:, :, None] # single band iamge 47 48 #load any additional project information (specific to images) 49 self.set_projection(kwds.get("projection",None)) 50 self.affine = kwds.get("affine",[0,1,0,0,0,1]) 51 52 # wavelengths 53 if 'wav' in kwds: 54 self.set_wavelengths(kwds['wav']) 55 56 #special header formatting 57 self.header['file type'] = 'ENVI Standard' 58 59 def copy(self,data=True): 60 """ 61 Make a deep copy of this image instance. 62 63 Args: 64 data (bool): True if a copy of the data should be made, otherwise only copy header. 65 66 Returns: 67 a new HyImage instance. 68 """ 69 if not data: 70 return HyImage(None, header=self.header.copy(), projection=self.projection, affine=self.affine) 71 else: 72 return HyImage( self.data.copy(), header=self.header.copy(), projection=self.projection, affine=self.affine) 73 74 def T(self): 75 """ 76 Return a transposed view of the data matrix (corresponding with the [y,x] indexing used by matplotlib, opencv etc. 77 """ 78 return np.transpose(self.data, (1,0,2)) 79 80 def xdim(self): 81 """ 82 Return number of pixels in x (first dimension of data array) 83 """ 84 return self.data.shape[0] 85 86 def ydim(self): 87 """ 88 Return number of pixels in y (second dimension of data array) 89 """ 90 return self.data.shape[1] 91 92 def aspx(self): 93 """ 94 Return the aspect ratio of this image (width/height). 95 """ 96 return self.ydim() / self.xdim() 97 98 def get_extent(self): 99 """ 100 Returns the width and height of this image in world coordinates. 101 102 Returns: 103 tuple with (width, height). 104 """ 105 return self.xdim * self.pixel_size[0], self.ydim * self.pixel_size[1] 106 107 def set_projection(self,proj): 108 """ 109 Set this project to an existing osgeo.osr.SpatialReference or GDAL georeference string. 110 111 Args: 112 proj (str, osgeo.osr.SpatialReference): the project to use as osgeo.osr.SpatialReference or GDAL georeference string. 113 """ 114 if proj is None: 115 self.projection = None 116 else: 117 try: 118 from osgeo.osr import SpatialReference 119 except: 120 assert False, "Error - GDAL must be installed to work with spatial projections in hylite." 121 if isinstance(proj, SpatialReference): 122 self.projection = proj 123 elif isinstance(proj, str): 124 self.projection = SpatialReference(proj) 125 else: 126 print("Invalid project %s" % proj) 127 raise 128 129 def set_projection_EPSG(self,EPSG): 130 """ 131 Sets this image project using an EPSG code. 132 133 Args: 134 EPSG (str): string EPSG code that can be passed to SpatialReference.SetFromUserInput(...). 135 """ 136 137 try: 138 from osgeo.osr import SpatialReference 139 except: 140 assert False, "Error - GDAL must be installed to work with spatial projections in hylite." 141 142 self.projection = SpatialReference() 143 self.projection.SetFromUserInput(EPSG) 144 145 def get_projection_EPSG(self): 146 """ 147 Gets a string describing this projections EPSG code (if it is an EPSG project). 148 149 Returns: 150 an EPSG code string of the format "EPSG:XXXX". 151 """ 152 if self.projection is None: 153 return None 154 else: 155 return "%s:%s" % (self.projection.GetAttrValue("AUTHORITY",0),self.projection.GetAttrValue("AUTHORITY",1)) 156 157 def pix_to_world(self, px, py, proj=None): 158 """ 159 Take pixel coordinates and return world coordinates 160 161 Args: 162 px (int): the pixel x-coord. 163 py (int): the pixel y-coord. 164 proj (str, osr.SpatialReference): the coordinate system to use. Default (None) uses the same system as this image. Otherwise 165 an osr.SpatialReference can be passed (HyImage.project), or an EPSG string (e.g. get_projection_EPSG(...)). 166 Returns: 167 the world coordinates in the coordinate system defined by get_projection_EPSG(...). 168 """ 169 170 try: 171 from osgeo import osr 172 import osgeo.gdal as gdal 173 from osgeo import ogr 174 except: 175 assert False, "Error - GDAL must be installed to work with spatial projections in hylite." 176 177 # parse project 178 if proj is None: 179 proj = self.projection 180 elif isinstance(proj, str) or isinstance(proj, int): 181 epsg = proj 182 if isinstance(epsg, str): 183 try: 184 epsg = int(str.split(':')[1]) 185 except: 186 assert False, "Error - %s is an invalid EPSG code." % proj 187 proj = osr.SpatialReference() 188 proj.ImportFromEPSG(epsg) 189 190 # check we have all the required info 191 assert isinstance(proj, osr.SpatialReference), "Error - invalid spatial reference %s" % proj 192 assert (not self.affine is None) and ( 193 not self.projection is None), "Error - project information is undefined." 194 195 #project to world coordinates in this images project/world coords 196 x,y = gdal.ApplyGeoTransform(self.affine, px, py) 197 198 #project to target coords (if different) 199 if not proj.IsSameGeogCS(self.projection): 200 P = ogr.Geometry(ogr.wkbPoint) 201 if proj.EPSGTreatsAsNorthingEasting(): 202 P.AddPoint(x, y) 203 else: 204 P.AddPoint(y, x) 205 P.AssignSpatialReference(self.projection) # tell the point what coordinates it's in 206 P.TransformTo(proj) # reproject it to the out spatial reference 207 x, y = P.GetX(), P.GetY() 208 209 #do we need to transpose? 210 if proj.EPSGTreatsAsLatLong(): 211 x,y=y,x #we want lon,lat not lat,lon 212 return x, y 213 214 def world_to_pix(self, x, y, proj = None): 215 """ 216 Take world coordinates and return pixel coordinates 217 218 Args: 219 x (float): the world x-coord. 220 y (float): the world y-coord. 221 proj (str, osr.SpatialReference): the coordinate system of the input coordinates. Default (None) uses the same system as this image. Otherwise 222 an osr.SpatialReference can be passed (HyImage.project), or an EPSG string (e.g. get_projection_EPSG(...)). 223 224 Returns: 225 the pixel coordinates based on the affine transform stored in self.affine. 226 """ 227 228 try: 229 from osgeo import osr 230 import osgeo.gdal as gdal 231 from osgeo import ogr 232 except: 233 assert False, "Error - GDAL must be installed to work with spatial projections in hylite." 234 235 # parse project 236 if proj is None: 237 proj = self.projection 238 elif isinstance(proj, str) or isinstance(proj, int): 239 epsg = proj 240 if isinstance(epsg, str): 241 try: 242 epsg = int(str.split(':')[1]) 243 except: 244 assert False, "Error - %s is an invalid EPSG code." % proj 245 proj = osr.SpatialReference() 246 proj.ImportFromEPSG(epsg) 247 248 249 # check we have all the required info 250 assert isinstance(proj, osr.SpatialReference), "Error - invalid spatial reference %s" % proj 251 assert (not self.affine is None) and (not self.projection is None), "Error - project information is undefined." 252 253 # project to this images CS (if different) 254 if not proj.IsSameGeogCS(self.projection): 255 P = ogr.Geometry(ogr.wkbPoint) 256 if proj.EPSGTreatsAsNorthingEasting(): 257 P.AddPoint(x, y) 258 else: 259 P.AddPoint(y, x) 260 P.AssignSpatialReference(proj) # tell the point what coordinates it's in 261 P.AddPoint(x, y) 262 P.TransformTo(self.projection) # reproject it to the out spatial reference 263 x, y = P.GetX(), P.GetY() 264 if self.projection.EPSGTreatsAsLatLong(): # do we need to transpose? 265 x, y = y, x # we want lon,lat not lat,lon 266 267 inv = gdal.InvGeoTransform(self.affine) 268 assert not inv is None, "Error - could not invert affine transform?" 269 270 #apply 271 return gdal.ApplyGeoTransform(inv, x, y) 272 273 def flip(self, axis='x'): 274 """ 275 Flip the image on the x or y axis. 276 277 Args: 278 axis (str): 'x' or 'y' or both 'xy'. 279 """ 280 281 if 'x' in axis.lower(): 282 self.data = np.flip(self.data,axis=0) 283 if 'y' in axis.lower(): 284 self.data = np.flip(self.data,axis=1) 285 286 def rot90(self): 287 """ 288 Rotate this image by 90 degrees by transposing the underlying data array. Combine with flip('x') or flip('y') 289 to achieve positive/negative rotations. 290 """ 291 self.data = np.transpose( self.data, (1,0,2) ) 292 self.push_to_header() 293 294 ##################################### 295 ##IMAGE FILTERING 296 ##################################### 297 def fill_holes(self): 298 """ 299 Replaces nan pixel with an average of their neighbours, thus removing 1-pixel large holes from an image. Note that 300 for performance reasons this assumes that holes line up across bands. Note that this is not vectorized so very slow... 301 """ 302 303 # perform greyscale dilation 304 dilate = self.data.copy() 305 mask = np.logical_not(np.isfinite(dilate)) 306 dilate[mask] = 0 307 for b in range(self.band_count()): 308 dilate[:, :, b] = sp.ndimage.grey_dilation(dilate[:, :, b], size=(3, 3)) 309 310 # map back to holes in dataset 311 self.data[mask] = dilate[mask] 312 #self.data[self.data == 0] = np.nan # replace remaining 0's with nans 313 314 def blur(self, n=3): 315 """ 316 Applies a gaussian kernel of size n to the image using OpenCV. 317 318 Args: 319 n (int): the dimensions of the gaussian kernel to convolve. Default is 3. Increase for more blurry results. 320 """ 321 import cv2 # import this here to avoid errors if opencv is not installed properly 322 323 nanmask = np.isnan(self.data) 324 assert isinstance(n, int) and n >= 3, "Error - invalid kernel. N must be an integer > 3. " 325 kernel = np.ones((n, n), np.float32) / (n ** 2) 326 self.data = cv2.filter2D(self.data, -1, kernel) 327 self.data[nanmask] = np.nan # remove mask 328 329 def erode(self, size=3, iterations=1): 330 """ 331 Apply an erode filter to this image to expand background (nan) pixels. Refer to open-cv's erode 332 function for more details. 333 334 Args: 335 size (int): the size of the erode filter. Default is a 3x3 kernel. 336 iterations (int): the number of erode iterations. Default is 1. 337 """ 338 import cv2 # import this here to avoid errors if opencv is not installed properly 339 340 # erode 341 kernel = np.ones((size, size), np.uint8) 342 if self.is_float(): 343 mask = np.isfinite(self.data).any(axis=-1) 344 mask = cv2.erode(mask.astype(np.uint8), kernel, iterations=iterations) 345 self.data[mask == 0, :] = np.nan 346 else: 347 mask = (self.data != 0).any( axis=-1 ) 348 mask = cv2.erode(mask.astype(np.uint8), kernel, iterations=iterations) 349 self.data[mask == 0, :] = 0 350 351 def resize(self, newdims : tuple, interpolation : int = 1): 352 """ 353 Resize this image with opencv. 354 355 Args: 356 newdims (tuple): the new image dimensions. 357 interpolation (int): opencv interpolation method. Default is cv2.INTER_LINEAR. 358 """ 359 import cv2 # import this here to avoid errors if opencv is not installed properly 360 self.data = cv2.resize(self.data, (newdims[1],newdims[0]), interpolation=interpolation) 361 362 def despeckle(self, size=5): 363 """ 364 Despeckle each band of this image (independently) using a median filter. 365 366 Args: 367 size (int): the size of the median filter kernel. Default is 5. Must be an odd number. 368 """ 369 370 assert (size % 2) == 1, "Error - size must be an odd integer" 371 import cv2 # import this here to avoid errors if opencv is not installed properly 372 if self.is_float(): 373 self.data = cv2.medianBlur( self.data.astype(np.float32), size ) 374 else: 375 self.data = cv2.medianBlur( self.data, size ) 376 377 ##################################### 378 ##FEATURES AND FEATURE MATCHING 379 ###################################### 380 def get_keypoints(self, band, eq=False, mask=True, method='sift', cfac=0.0,bfac=0.0, **kwds): 381 """ 382 Get feature descriptors from the specified band. 383 384 Args: 385 band (int,float,str,tuple): the band index (int) or wavelength (float) to extract features from. Alternatively, a tuple can be passed 386 containing a range of bands (min : max) to average before feature matching. 387 eq (bool): True if the image should be histogram equalized first. Default is False. 388 mask (bool): True if 0 value pixels should be masked. Default is True. 389 method (str): the feature detector to use. Options are 'SIFT' and 'ORB' (faster but less accurate). Default is 'SIFT'. 390 cfac (float): contrast adjustment to apply to hyperspectral bands before matching. Default is 0.0. 391 bfac (float): brightness adjustment to apply to hyperspectral bands before matching. Default is 0.0. 392 **kwds: keyword arguments are passed to the opencv feature detector. For SIFT these are: 393 394 - contrastThreshold: default is 0.01. 395 - edgeThreshold: default is 10. 396 - sigma: default is 1.0 397 398 For ORB these are: 399 400 - nfeatures = the number of features to detect. Default is 5000. 401 402 Returns: 403 Tuple containing 404 405 - k (ndarray): the keypoints detected 406 - d (ndarray): corresponding feature descriptors 407 """ 408 import cv2 # import this here to avoid errors if opencv is not installed properly 409 410 # get image 411 if isinstance(band, int) or isinstance(band, float): #single band 412 image = self.data[:, :, self.get_band_index(band)] 413 elif isinstance(band,tuple): #range of bands (averaged) 414 idx0 = self.get_band_index(band[0]) 415 idx1 = self.get_band_index(band[1]) 416 417 #deal with out of range errors 418 if idx0 is None: 419 idx0 = 0 420 if idx1 is None: 421 idx1 = self.band_count() 422 423 #average bands 424 image = np.nanmean(self.data[:,:,idx0:idx1],axis=2) 425 else: 426 assert False, "Error, unrecognised band %s" % band 427 428 #normalise image to range 0 - 1 429 image -= np.nanmin(image) 430 image = image / np.nanmax(image) 431 432 #apply brightness/contrast adjustment 433 image = (1.0+cfac)*image + bfac 434 image[image > 1.0] = 1.0 435 image[image < 0.0] = 0.0 436 437 #convert image to uint8 for opencv 438 image = np.uint8(255 * image) 439 if eq: 440 image = cv2.equalizeHist(image) 441 442 if mask: 443 mask = np.zeros(image.shape, dtype=np.uint8) 444 mask[image != 0] = 255 # include only non-zero pixels 445 else: 446 mask = None 447 448 if 'sift' in method.lower(): # SIFT 449 450 # setup default keywords 451 kwds["contrastThreshold"] = kwds.get("contrastThreshold", 0.01) 452 kwds["edgeThreshold"] = kwds.get("edgeThreshold", 10) 453 kwds["sigma"] = kwds.get("sigma", 1.0) 454 455 # make feature detector 456 #alg = cv2.xfeatures2d.SIFT_create(**kwds) 457 alg = cv2.SIFT_create() 458 elif 'orb' in method.lower(): # orb 459 kwds['nfeatures'] = kwds.get('nfeatures', 5000) 460 alg = cv2.ORB_create(scoreType=cv2.ORB_FAST_SCORE, **kwds) 461 else: 462 assert False, "Error - %s is not a recognised feature detector." % method 463 464 # detect keypoints 465 kp = alg.detect(image, mask) 466 467 # extract and return feature vectors 468 return alg.compute(image, kp) 469 470 @classmethod 471 def match_keypoints(cls, kp1, kp2, d1, d2, method='SIFT', dist=0.7, tree = 5, check = 100, min_count=5): 472 """ 473 Compares keypoint feature vectors from two images and returns matching pairs. 474 475 Args: 476 kp1 (ndarray): keypoints from the first image 477 kp2 (ndarray): keypoints from the second image 478 d1 (ndarray): descriptors for the keypoints from the first image 479 d2 (ndarray): descriptors for the keypoints from the second image 480 method (str): the method used to calculate the feature descriptors. Should be 'sift' or 'orb'. Default is 'sift'. 481 dist (float): minimum match distance (0 to 1), default is 0.7 482 tree (int): not sure what this does? Default is 5. See open-cv docs. 483 check (int): ditto. Default is 100. 484 min_count (int): the minimum number of matches to consider a valid matching operation. If fewer matches are found, 485 then the function returns None, None. Default is 5. 486 """ 487 import cv2 # import this here to avoid errors if opencv is not installed properly 488 if 'sift' in method.lower(): 489 algorithm = cv2.NORM_INF 490 elif 'orb' in method.lower(): 491 algorithm = cv2.NORM_HAMMING 492 else: 493 assert False, "Error - unknown matching algorithm %s" % method 494 495 #calculate flann matches 496 index_params = dict(algorithm=algorithm, trees=tree) 497 search_params = dict(checks=check) 498 flann = cv2.FlannBasedMatcher(index_params, search_params) 499 matches = flann.knnMatch(d1, d2, k=2) 500 501 # store all the good matches as per Lowe's ratio test. 502 good = [] 503 for m, n in matches: 504 if m.distance < dist * n.distance: 505 good.append(m) 506 507 if len(good) < min_count: 508 return None, None 509 else: 510 src_pts = np.float32([kp1[m.queryIdx].pt for m in good]).reshape(-1, 1, 2) 511 dst_pts = np.float32([kp2[m.trainIdx].pt for m in good]).reshape(-1, 1, 2) 512 return src_pts, dst_pts 513 514 ############################ 515 ## Visualisation methods 516 ############################ 517 def quick_plot(self, band=0, ax=None, bfac=0.0, cfac=0.0, samples=False, tscale=False, invert=False, rot=False, flipX=False, flipY=False, 518 **kwds): 519 """ 520 Plot a band using matplotlib.imshow(...). 521 522 Args: 523 band (str,int,float,tuple): the band name (string), index (integer) or wavelength (float) to plot. Default is 0. If a tuple is passed then 524 each band in the tuple (string or index) will be mapped to rgb. Bands with negative wavelengths or indices will be inverted before plotting. 525 ax: an axis object to plot to. If none, plt.imshow( ... ) is used. 526 bfac (float): a brightness adjustment to apply to RGB mappings (-1 to 1) 527 cfac (float): a contrast adjustment to apply to RGB mappings (-1 to 1) 528 samples (bool): True if sample points (defined in the header file) should be plotted. Default is False. Otherwise, a list of 529 [ (x,y), ... ] points can be passed. 530 tscale (bool): True if each band (for ternary images) should be scaled independently. Default is False. 531 When using scaling, vmin and vmax can be used to set the clipping percentiles (integers) or 532 (constant) values (float). 533 invert (bool) : True if each band should be inverted before plotting. Only works for multiband (ternary) images. 534 rot (bool): if True, the x and y axis will be flipped (90 degree rotation) before plotting. Default is False. 535 flipX (bool): if True, the x axis will be flipped before plotting (after applying rotations). 536 flipY (bool): if True, the y axis will be flippe before plotting (after applying rotations). 537 **kwds: keywords are passed to matplotlib.imshow( ... ), except for the following: 538 539 - mask = a 2D boolean mask containing true if pixels should be drawn and false otherwise. 540 - path = a file path to save the image too (at matching resolution; use fig.savefig(..) if you want to save the figure). 541 - ticks = True if x- and y- ticks should be plotted. Default is False. 542 - ps, pc = the size and color of sample points to plot. Can be constant or list. 543 - figsize = a figsize for the figure to create (if ax is None). 544 545 Returns: 546 Tuple containing 547 548 - fig: matplotlib figure object 549 - ax: matplotlib axes object. If a colorbar is created, (band is an integer or a float), then this will be stored in ax.cbar. 550 """ 551 552 #create new axes? 553 if ax is None: 554 fig, ax = plt.subplots(figsize=kwds.pop('figsize', (18,18*self.ydim()/self.xdim()) )) 555 556 # deal with ticks 557 if not kwds.pop('ticks', False ): 558 ax.set_xticks([]) 559 ax.set_yticks([]) 560 561 #map individual band using colourmap 562 if isinstance(band, str) or isinstance(band, int) or isinstance(band, float): 563 #get band 564 if isinstance(band, str): 565 data = self.data[:, :, self.get_band_index(band)] 566 else: 567 data = self.data[:, :, self.get_band_index(np.abs(band))] 568 if not isinstance(band, str) and band < 0: 569 data = np.nanmax(data) - data # flip 570 571 # convert integer vmin and vmax values to percentiles 572 if 'vmin' in kwds: 573 if isinstance(kwds['vmin'], int): 574 kwds['vmin'] = np.nanpercentile( data, kwds['vmin'] ) 575 if 'vmax' in kwds: 576 if isinstance(kwds['vmax'], int): 577 kwds['vmax'] = np.nanpercentile( data, kwds['vmax'] ) 578 579 #mask nans (and apply custom mask) 580 mask = np.isnan(data) 581 if not np.isnan(self.header.get_data_ignore_value()): 582 mask = mask + data == self.header.get_data_ignore_value() 583 if 'mask' in kwds: 584 mask = mask + kwds.get('mask') 585 del kwds['mask'] 586 data = np.ma.array(data, mask = mask > 0 ) 587 588 # apply rotations and flipping 589 if rot: 590 data = data.T 591 if flipX: 592 data = data[::-1, :] 593 if flipY: 594 data = data[:, ::-1] 595 596 # save? 597 if 'path' in kwds: 598 path = kwds.pop('path') 599 from matplotlib.pyplot import imsave 600 if not os.path.exists(os.path.dirname(path)): 601 os.makedirs(os.path.dirname(path)) # ensure output directory exists 602 imsave(path, data.T, **kwds) # save the image 603 604 ax.cbar = ax.imshow(data.T, interpolation=kwds.pop('interpolation', 'none'), **kwds) # change default interpolation to None 605 606 #map 3 bands to RGB 607 elif isinstance(band, tuple) or isinstance(band, list): 608 #get band indices and range 609 rgb = [] 610 for b in band: 611 if isinstance(b, str): 612 rgb.append(self.get_band_index(b)) 613 else: 614 rgb.append(self.get_band_index(np.abs(b))) 615 616 #slice image (as copy) and map to 0 - 1 617 img = np.array(self.data[:, :, rgb]).copy() 618 if np.isnan(img).all(): 619 print("Warning - image contains no data.") 620 return ax.get_figure(), ax 621 622 # invert if needed 623 if invert: 624 band = [-b for b in band] 625 for i,b in enumerate(band): 626 if not isinstance(b, str) and (b < 0): 627 img[..., i] = np.nanmax(img[..., i]) - img[..., i] 628 629 # do scaling 630 if tscale: # scale bands independently 631 for b in range(3): 632 mn = kwds.get("vmin", float(np.nanmin(img))) 633 mx = kwds.get("vmax", float(np.nanmax(img))) 634 if isinstance (mn, int): 635 assert mn >= 0 and mn <= 100, "Error - integer vmin values must be a percentile." 636 mn = float(np.nanpercentile(img[...,b], mn )) 637 if isinstance (mx, int): 638 assert mx >= 0 and mx <= 100, "Error - integer vmax values must be a percentile." 639 mx = float(np.nanpercentile(img[...,b], mx )) 640 img[...,b] = (img[..., b] - mn) / (mx - mn) 641 else: # scale bands together 642 mn = kwds.get("vmin", float(np.nanmin(img))) 643 mx = kwds.get("vmax", float(np.nanmax(img))) 644 if isinstance(mn, int): 645 assert mn >= 0 and mn <= 100, "Error - integer vmin values must be a percentile." 646 mn = float(np.nanpercentile(img, mn)) 647 if isinstance(mx, int): 648 assert mx >= 0 and mx <= 100, "Error - integer vmax values must be a percentile." 649 mx = float(np.nanpercentile(img, mx)) 650 img = (img - mn) / (mx - mn) 651 652 #apply brightness/contrast mapping 653 img = np.clip((1.0 + cfac) * img + bfac, 0, 1.0 ) 654 655 #apply masking so background is white 656 img[np.logical_not( np.isfinite( img ) )] = 1.0 657 if 'mask' in kwds: 658 img[kwds.pop("mask"),:] = 1.0 659 660 # apply rotations and flipping 661 if rot: 662 img = np.transpose( img, (1,0,2) ) 663 if flipX: 664 img = img[::-1, :, :] 665 if flipY: 666 img = img[:, ::-1, :] 667 668 # save? 669 if 'path' in kwds: 670 path = kwds.pop('path') 671 from matplotlib.pyplot import imsave 672 if not os.path.exists(os.path.dirname(path)): 673 os.makedirs(os.path.dirname(path)) # ensure output directory exists 674 imsave(path, np.transpose( np.clip( img*255, 0, 255).astype(np.uint8), (1, 0, 2))) # save the image 675 676 # plot samples? 677 ps = kwds.pop('ps', 5) 678 pc = kwds.pop('pc', 'r') 679 if samples: 680 if isinstance(samples, list) or isinstance(samples, np.ndarray): 681 ax.scatter([s[0] for s in samples], [s[1] for s in samples], s=ps, c=pc) 682 else: 683 for n in self.header.get_class_names(): 684 points = np.array(self.header.get_sample_points(n)) 685 ax.scatter(points[:, 0], points[:, 1], s=ps, c=pc) 686 687 #plot 688 ax.imshow(np.transpose(img, (1,0,2)), interpolation=kwds.pop('interpolation', 'none'), **kwds) 689 ax.cbar = None # no colorbar 690 691 return ax.get_figure(), ax 692 693 def createGIF(self, path, bands=None, figsize=(10,10), fps=10, **kwds): 694 """ 695 Create and save an animated gif that loops through the bands of the image. 696 697 Args: 698 path (str): the path to save the .gif 699 bands (tuple): Tuple containing the range of band indices to draw. Default is the whole range. 700 figsize (tuple): the size of the image to draw. Default is (10,10). 701 fps (int): the framerate (frames per second) of the gif. Default is 10. 702 **kwds: keywords are passed directly to matplotlib.imshow. Use this to specify cmap etc. 703 """ 704 705 frames = [] 706 if bands is None: 707 bands = (0,self.band_count()) 708 else: 709 assert 0 < bands[0] < self.band_count(), "Error - invalid range." 710 assert 0 < bands[1] < self.band_count(), "Error - invalid range." 711 assert bands[1] > bands[0], "Error - invalid range." 712 713 #plot frames 714 for i in range(bands[0],bands[1]): 715 fig, ax = plt.subplots(figsize=figsize) 716 ax.imshow(self.data[:, :, i], **kwds) 717 fig.canvas.draw() 718 frames.append(np.frombuffer(fig.canvas.tostring_rgb(), dtype='uint8')) 719 frames[-1] = np.reshape(frames[-1], (fig.canvas.get_width_height()[1], fig.canvas.get_width_height()[0], 3)) 720 plt.close(fig) 721 722 #save gif 723 imageio.mimsave( os.path.splitext(path)[0] + ".gif", frames, fps=fps) 724 725 ## masking 726 def drop_bbl(self, drop=True): 727 """ 728 Remove bad bands as stored in the 'bbl' key in the image header. Note that this operates in-place. 729 730 Args: 731 drop (bool): True if bad bands should be completely dropped. If False, these bands will be kept but replaced with nans. 732 """ 733 assert 'bbl' in self.header, "Please specify a bad band list ('bbl') in the image header, as per the ENVI format definition." 734 mask = self.header.get_list('bbl') == 0 735 self.data[...,mask] = np.nan 736 if drop: 737 self.delete_nan_bands(inplace=True) 738 739 def mask(self, mask=None, flag=np.nan, invert=False, crop=False, bands=None): 740 """ 741 Apply a mask to an image, flagging masked pixels with the specified value. Note that this applies the mask to the 742 image in-situ. 743 744 Args: 745 flag (float): the value to use for masked pixels. Default is np.nan 746 mask (ndarray): a numpy array defining the mask polygon of the format [[x1,y1],[x2,y2],...]. If None is passed then 747 pickPolygon( ... ) is used to interactively define a polygon. If a file path is passed then the polygon 748 will be loaded using np.load( ... ). Alternatively if mask.shape == image.shape[0,1] then it is treated as a 749 binary image mask (must be boolean) and True values will be masked across all bands. Default is None. 750 invert (bool): if True, pixels within the polygon will be masked. If False, pixels outside the polygon are masked. Default is False. 751 crop (bool): True if rows/columns containing only zeros should be removed. Default is False. 752 bands (tuple): the bands of the image to plot if no mask is specified. If None, the middle band is used. 753 754 Returns: 755 Tuple containing 756 757 - mask (ndarray): a boolean array with True where pixels are masked and False elsewhere. 758 - poly (ndarray): the mask polygon array in the format described above. Useful if the polygon was interactively defined. 759 """ 760 761 if mask is None: # pick mask interactively 762 if bands is None: 763 bands = int(self.band_count() / 2) 764 765 regions = self.pickPolygons(region_names=["mask"], bands=bands) 766 767 # the user bailed without picking a mask? 768 if len(regions) == 0: 769 print("Warning - no mask picked/applied.") 770 return 771 772 # extract polygon mask 773 mask = regions[0] 774 775 # convert polygon mask to binary mask 776 if mask.shape[1] == 2: 777 778 # build meshgrid with pixel coords 779 xx, yy = np.meshgrid(np.arange(self.xdim()), np.arange(self.ydim())) 780 xx = xx.flatten() 781 yy = yy.flatten() 782 points = np.vstack([xx, yy]).T # coordinates of each pixel 783 784 # calculate per-pixel mask 785 mask = path.Path(mask).contains_points(points) 786 mask = mask.reshape((self.ydim(), self.xdim())).T 787 788 # flip as we want to mask (==True) outside points (unless invert is true) 789 if not invert: 790 mask = np.logical_not(mask) 791 792 # apply binary image mask 793 assert mask.shape[0] == self.data.shape[0] and mask.shape[1] == self.data.shape[1], \ 794 "Error - mask shape %s does not match image shape %s" % (mask.shape, self.data.shape) 795 for b in range(self.band_count()): 796 self.data[:, :, b][mask] = flag 797 798 # crop image 799 if crop: 800 # calculate non-masked pixels 801 valid = np.logical_not(mask) 802 803 # integrate along axes 804 xdata = np.sum(valid, axis=1) > 0.0 805 ydata = np.sum(valid, axis=0) > 0.0 806 807 # calculate domain containing valid pixels 808 xmin = np.argmax(xdata) 809 xmax = xdata.shape[0] - np.argmax(xdata[::-1]) 810 ymin = np.argmax(ydata) 811 ymax = ydata.shape[0] - np.argmax(ydata[::-1]) 812 813 # crop 814 self.data = self.data[xmin:xmax, ymin:ymax, :] 815 816 return mask 817 818 def crop_to_data(self): 819 """ 820 Remove padding of nan or zero pixels from image. Note that this is performed in place. 821 """ 822 823 valid = np.isfinite(self.data).any(axis=-1) & (self.data != 0).any(axis=-1) 824 ymin, ymax = np.percentile(np.argwhere(np.sum(valid, axis=0) != 0), (0, 100)) 825 xmin, xmax = np.percentile(np.argwhere(np.sum(valid, axis=1) != 0), (0, 100)) 826 self.data = self.data[int(xmin):int(xmax), int(ymin):int(ymax), :] # do clipping 827 828 ################################################## 829 ## Interactive tools for picking regions/pixels 830 ################################################## 831 def pickPolygons(self, region_names, bands=0): 832 """ 833 Creates a matplotlib gui for selecting polygon regions in an image. 834 835 Args: 836 names (list, str): a list containing the names of the regions to pick. If a string is passed only one name is used. 837 bands (tuple): the bands of the image to plot. 838 """ 839 840 if isinstance(region_names, str): 841 region_names = [region_names] 842 843 assert isinstance(region_names, list), "Error - names must be a list or a string." 844 845 # set matplotlib backend 846 backend = matplotlib.get_backend() 847 matplotlib.use('Qt5Agg') # need this backend for ROIPoly to work 848 849 # plot image and extract roi's 850 fig, ax = self.quick_plot(bands) 851 roi = MultiRoi(roi_names=region_names) 852 plt.close(fig) # close figure 853 854 # extract regions 855 regions = [] 856 for name, r in roi.rois.items(): 857 # store region 858 x = r.x 859 y = r.y 860 regions.append(np.vstack([x, y]).T) 861 862 # restore matplotlib backend (if possible) 863 try: 864 matplotlib.use(backend) 865 except: 866 print("Warning: could not reset matplotlib backend. Plots will remain interactive...") 867 pass 868 869 return regions 870 871 def pickPoints(self, n=-1, bands=hylite.RGB, integer=True, title="Pick Points", **kwds): 872 """ 873 Creates a matplotlib gui for picking pixels from an image. 874 875 Args: 876 n (int): the number of pixels to pick, or -1 if the user can select as many as they wish. Default is -1. 877 bands (tuple): the bands of the image to plot. Default is HyImage.RGB 878 integer (bool): True if points coordinates should be cast to integers (for use as indices). Default is True. 879 title (str): The title of the point picking window. 880 **kwds: Keywords are passed to HyImage.quick_plot( ... ). 881 882 Returns: 883 A list containing the picked point coordinates [ (x1,y1), (x2,y2), ... ]. 884 """ 885 886 # set matplotlib backend 887 backend = matplotlib.get_backend() 888 matplotlib.use('Qt5Agg') # need this backend for ROIPoly to work 889 890 # create figure 891 fig, ax = self.quick_plot( bands, **kwds ) 892 ax.set_title(title) 893 894 # get points 895 points = fig.ginput( n ) 896 897 if integer: 898 points = [ (int(p[0]), int(p[1])) for p in points ] 899 900 # restore matplotlib backend (if possible) 901 try: 902 matplotlib.use(backend) 903 except: 904 print("Warning: could not reset matplotlib backend. Plots will remain interactive...") 905 pass 906 907 return points 908 909 def pickSamples(self, names=None, store=True, **kwds): 910 """ 911 Pick sample probe points and store these in the image header file. 912 913 Args: 914 names (str, list): the name of the sample to pick, or a list of names to pick multiple. 915 store (bool): True if sample should be stored in the image header file (for later access). Default is True. 916 **kwds: Keywords are passed to HyImage.quick_plot( ... ) 917 918 Returns: 919 a list containing a list of points for each sample. 920 """ 921 922 if isinstance(names, str): 923 names = [names] 924 925 # pick points 926 points = [] 927 for s in names: 928 pnts = self.pickPoints(title="%s" % s, **kwds) 929 if store: 930 self.header['sample %s' % s] = pnts # store in header 931 points.append(pnts) 932 # add class to header file 933 if store: 934 cls_names = self.header.get_class_names() 935 if cls_names is None: 936 cls_names = [] 937 self.header['class names'] = cls_names + names 938 939 return points
A class for hyperspectral image data. These can be individual scenes or hyperspectral orthoimages.
25 def __init__(self, data, **kwds): 26 """ 27 Args: 28 data (ndarray): a numpy array such that data[x][y][band] gives each pixel value. 29 **kwds: 30 wav = A numpy array containing band wavelengths for this image. 31 affine = an affine transform of the format returned by GDAL.GetGeoTransform(). 32 projection = string defining the project. Default is None. 33 sensor = sensor name. Default is "unknown". 34 header = path to associated header file. Default is None. 35 """ 36 37 #call constructor for HyData 38 super().__init__(data, **kwds) 39 40 # special case - if dataset only has oneband, slice it so it still has 41 # the format data[x,y,b]. 42 if not self.data is None: 43 if len(self.data.shape) == 1: 44 self.data = self.data[None, None, :] # single pixel image 45 if len(self.data.shape) == 2: 46 self.data = self.data[:, :, None] # single band iamge 47 48 #load any additional project information (specific to images) 49 self.set_projection(kwds.get("projection",None)) 50 self.affine = kwds.get("affine",[0,1,0,0,0,1]) 51 52 # wavelengths 53 if 'wav' in kwds: 54 self.set_wavelengths(kwds['wav']) 55 56 #special header formatting 57 self.header['file type'] = 'ENVI Standard'
Arguments:
- data (ndarray): a numpy array such that data[x][y][band] gives each pixel value.
- **kwds: wav = A numpy array containing band wavelengths for this image. affine = an affine transform of the format returned by GDAL.GetGeoTransform(). projection = string defining the project. Default is None. sensor = sensor name. Default is "unknown". header = path to associated header file. Default is None.
59 def copy(self,data=True): 60 """ 61 Make a deep copy of this image instance. 62 63 Args: 64 data (bool): True if a copy of the data should be made, otherwise only copy header. 65 66 Returns: 67 a new HyImage instance. 68 """ 69 if not data: 70 return HyImage(None, header=self.header.copy(), projection=self.projection, affine=self.affine) 71 else: 72 return HyImage( self.data.copy(), header=self.header.copy(), projection=self.projection, affine=self.affine)
Make a deep copy of this image instance.
Arguments:
- data (bool): True if a copy of the data should be made, otherwise only copy header.
Returns:
a new HyImage instance.
74 def T(self): 75 """ 76 Return a transposed view of the data matrix (corresponding with the [y,x] indexing used by matplotlib, opencv etc. 77 """ 78 return np.transpose(self.data, (1,0,2))
Return a transposed view of the data matrix (corresponding with the [y,x] indexing used by matplotlib, opencv etc.
80 def xdim(self): 81 """ 82 Return number of pixels in x (first dimension of data array) 83 """ 84 return self.data.shape[0]
Return number of pixels in x (first dimension of data array)
86 def ydim(self): 87 """ 88 Return number of pixels in y (second dimension of data array) 89 """ 90 return self.data.shape[1]
Return number of pixels in y (second dimension of data array)
92 def aspx(self): 93 """ 94 Return the aspect ratio of this image (width/height). 95 """ 96 return self.ydim() / self.xdim()
Return the aspect ratio of this image (width/height).
98 def get_extent(self): 99 """ 100 Returns the width and height of this image in world coordinates. 101 102 Returns: 103 tuple with (width, height). 104 """ 105 return self.xdim * self.pixel_size[0], self.ydim * self.pixel_size[1]
Returns the width and height of this image in world coordinates.
Returns:
tuple with (width, height).
107 def set_projection(self,proj): 108 """ 109 Set this project to an existing osgeo.osr.SpatialReference or GDAL georeference string. 110 111 Args: 112 proj (str, osgeo.osr.SpatialReference): the project to use as osgeo.osr.SpatialReference or GDAL georeference string. 113 """ 114 if proj is None: 115 self.projection = None 116 else: 117 try: 118 from osgeo.osr import SpatialReference 119 except: 120 assert False, "Error - GDAL must be installed to work with spatial projections in hylite." 121 if isinstance(proj, SpatialReference): 122 self.projection = proj 123 elif isinstance(proj, str): 124 self.projection = SpatialReference(proj) 125 else: 126 print("Invalid project %s" % proj) 127 raise
Set this project to an existing osgeo.osr.SpatialReference or GDAL georeference string.
Arguments:
- proj (str, osgeo.osr.SpatialReference): the project to use as osgeo.osr.SpatialReference or GDAL georeference string.
129 def set_projection_EPSG(self,EPSG): 130 """ 131 Sets this image project using an EPSG code. 132 133 Args: 134 EPSG (str): string EPSG code that can be passed to SpatialReference.SetFromUserInput(...). 135 """ 136 137 try: 138 from osgeo.osr import SpatialReference 139 except: 140 assert False, "Error - GDAL must be installed to work with spatial projections in hylite." 141 142 self.projection = SpatialReference() 143 self.projection.SetFromUserInput(EPSG)
Sets this image project using an EPSG code.
Arguments:
- EPSG (str): string EPSG code that can be passed to SpatialReference.SetFromUserInput(...).
145 def get_projection_EPSG(self): 146 """ 147 Gets a string describing this projections EPSG code (if it is an EPSG project). 148 149 Returns: 150 an EPSG code string of the format "EPSG:XXXX". 151 """ 152 if self.projection is None: 153 return None 154 else: 155 return "%s:%s" % (self.projection.GetAttrValue("AUTHORITY",0),self.projection.GetAttrValue("AUTHORITY",1))
Gets a string describing this projections EPSG code (if it is an EPSG project).
Returns:
an EPSG code string of the format "EPSG:XXXX".
157 def pix_to_world(self, px, py, proj=None): 158 """ 159 Take pixel coordinates and return world coordinates 160 161 Args: 162 px (int): the pixel x-coord. 163 py (int): the pixel y-coord. 164 proj (str, osr.SpatialReference): the coordinate system to use. Default (None) uses the same system as this image. Otherwise 165 an osr.SpatialReference can be passed (HyImage.project), or an EPSG string (e.g. get_projection_EPSG(...)). 166 Returns: 167 the world coordinates in the coordinate system defined by get_projection_EPSG(...). 168 """ 169 170 try: 171 from osgeo import osr 172 import osgeo.gdal as gdal 173 from osgeo import ogr 174 except: 175 assert False, "Error - GDAL must be installed to work with spatial projections in hylite." 176 177 # parse project 178 if proj is None: 179 proj = self.projection 180 elif isinstance(proj, str) or isinstance(proj, int): 181 epsg = proj 182 if isinstance(epsg, str): 183 try: 184 epsg = int(str.split(':')[1]) 185 except: 186 assert False, "Error - %s is an invalid EPSG code." % proj 187 proj = osr.SpatialReference() 188 proj.ImportFromEPSG(epsg) 189 190 # check we have all the required info 191 assert isinstance(proj, osr.SpatialReference), "Error - invalid spatial reference %s" % proj 192 assert (not self.affine is None) and ( 193 not self.projection is None), "Error - project information is undefined." 194 195 #project to world coordinates in this images project/world coords 196 x,y = gdal.ApplyGeoTransform(self.affine, px, py) 197 198 #project to target coords (if different) 199 if not proj.IsSameGeogCS(self.projection): 200 P = ogr.Geometry(ogr.wkbPoint) 201 if proj.EPSGTreatsAsNorthingEasting(): 202 P.AddPoint(x, y) 203 else: 204 P.AddPoint(y, x) 205 P.AssignSpatialReference(self.projection) # tell the point what coordinates it's in 206 P.TransformTo(proj) # reproject it to the out spatial reference 207 x, y = P.GetX(), P.GetY() 208 209 #do we need to transpose? 210 if proj.EPSGTreatsAsLatLong(): 211 x,y=y,x #we want lon,lat not lat,lon 212 return x, y
Take pixel coordinates and return world coordinates
Arguments:
- px (int): the pixel x-coord.
- py (int): the pixel y-coord.
- proj (str, osr.SpatialReference): the coordinate system to use. Default (None) uses the same system as this image. Otherwise an osr.SpatialReference can be passed (HyImage.project), or an EPSG string (e.g. get_projection_EPSG(...)).
Returns:
the world coordinates in the coordinate system defined by get_projection_EPSG(...).
214 def world_to_pix(self, x, y, proj = None): 215 """ 216 Take world coordinates and return pixel coordinates 217 218 Args: 219 x (float): the world x-coord. 220 y (float): the world y-coord. 221 proj (str, osr.SpatialReference): the coordinate system of the input coordinates. Default (None) uses the same system as this image. Otherwise 222 an osr.SpatialReference can be passed (HyImage.project), or an EPSG string (e.g. get_projection_EPSG(...)). 223 224 Returns: 225 the pixel coordinates based on the affine transform stored in self.affine. 226 """ 227 228 try: 229 from osgeo import osr 230 import osgeo.gdal as gdal 231 from osgeo import ogr 232 except: 233 assert False, "Error - GDAL must be installed to work with spatial projections in hylite." 234 235 # parse project 236 if proj is None: 237 proj = self.projection 238 elif isinstance(proj, str) or isinstance(proj, int): 239 epsg = proj 240 if isinstance(epsg, str): 241 try: 242 epsg = int(str.split(':')[1]) 243 except: 244 assert False, "Error - %s is an invalid EPSG code." % proj 245 proj = osr.SpatialReference() 246 proj.ImportFromEPSG(epsg) 247 248 249 # check we have all the required info 250 assert isinstance(proj, osr.SpatialReference), "Error - invalid spatial reference %s" % proj 251 assert (not self.affine is None) and (not self.projection is None), "Error - project information is undefined." 252 253 # project to this images CS (if different) 254 if not proj.IsSameGeogCS(self.projection): 255 P = ogr.Geometry(ogr.wkbPoint) 256 if proj.EPSGTreatsAsNorthingEasting(): 257 P.AddPoint(x, y) 258 else: 259 P.AddPoint(y, x) 260 P.AssignSpatialReference(proj) # tell the point what coordinates it's in 261 P.AddPoint(x, y) 262 P.TransformTo(self.projection) # reproject it to the out spatial reference 263 x, y = P.GetX(), P.GetY() 264 if self.projection.EPSGTreatsAsLatLong(): # do we need to transpose? 265 x, y = y, x # we want lon,lat not lat,lon 266 267 inv = gdal.InvGeoTransform(self.affine) 268 assert not inv is None, "Error - could not invert affine transform?" 269 270 #apply 271 return gdal.ApplyGeoTransform(inv, x, y)
Take world coordinates and return pixel coordinates
Arguments:
- x (float): the world x-coord.
- y (float): the world y-coord.
- proj (str, osr.SpatialReference): the coordinate system of the input coordinates. Default (None) uses the same system as this image. Otherwise an osr.SpatialReference can be passed (HyImage.project), or an EPSG string (e.g. get_projection_EPSG(...)).
Returns:
the pixel coordinates based on the affine transform stored in self.affine.
273 def flip(self, axis='x'): 274 """ 275 Flip the image on the x or y axis. 276 277 Args: 278 axis (str): 'x' or 'y' or both 'xy'. 279 """ 280 281 if 'x' in axis.lower(): 282 self.data = np.flip(self.data,axis=0) 283 if 'y' in axis.lower(): 284 self.data = np.flip(self.data,axis=1)
Flip the image on the x or y axis.
Arguments:
- axis (str): 'x' or 'y' or both 'xy'.
286 def rot90(self): 287 """ 288 Rotate this image by 90 degrees by transposing the underlying data array. Combine with flip('x') or flip('y') 289 to achieve positive/negative rotations. 290 """ 291 self.data = np.transpose( self.data, (1,0,2) ) 292 self.push_to_header()
Rotate this image by 90 degrees by transposing the underlying data array. Combine with flip('x') or flip('y') to achieve positive/negative rotations.
297 def fill_holes(self): 298 """ 299 Replaces nan pixel with an average of their neighbours, thus removing 1-pixel large holes from an image. Note that 300 for performance reasons this assumes that holes line up across bands. Note that this is not vectorized so very slow... 301 """ 302 303 # perform greyscale dilation 304 dilate = self.data.copy() 305 mask = np.logical_not(np.isfinite(dilate)) 306 dilate[mask] = 0 307 for b in range(self.band_count()): 308 dilate[:, :, b] = sp.ndimage.grey_dilation(dilate[:, :, b], size=(3, 3)) 309 310 # map back to holes in dataset 311 self.data[mask] = dilate[mask] 312 #self.data[self.data == 0] = np.nan # replace remaining 0's with nans
Replaces nan pixel with an average of their neighbours, thus removing 1-pixel large holes from an image. Note that for performance reasons this assumes that holes line up across bands. Note that this is not vectorized so very slow...
314 def blur(self, n=3): 315 """ 316 Applies a gaussian kernel of size n to the image using OpenCV. 317 318 Args: 319 n (int): the dimensions of the gaussian kernel to convolve. Default is 3. Increase for more blurry results. 320 """ 321 import cv2 # import this here to avoid errors if opencv is not installed properly 322 323 nanmask = np.isnan(self.data) 324 assert isinstance(n, int) and n >= 3, "Error - invalid kernel. N must be an integer > 3. " 325 kernel = np.ones((n, n), np.float32) / (n ** 2) 326 self.data = cv2.filter2D(self.data, -1, kernel) 327 self.data[nanmask] = np.nan # remove mask
Applies a gaussian kernel of size n to the image using OpenCV.
Arguments:
- n (int): the dimensions of the gaussian kernel to convolve. Default is 3. Increase for more blurry results.
329 def erode(self, size=3, iterations=1): 330 """ 331 Apply an erode filter to this image to expand background (nan) pixels. Refer to open-cv's erode 332 function for more details. 333 334 Args: 335 size (int): the size of the erode filter. Default is a 3x3 kernel. 336 iterations (int): the number of erode iterations. Default is 1. 337 """ 338 import cv2 # import this here to avoid errors if opencv is not installed properly 339 340 # erode 341 kernel = np.ones((size, size), np.uint8) 342 if self.is_float(): 343 mask = np.isfinite(self.data).any(axis=-1) 344 mask = cv2.erode(mask.astype(np.uint8), kernel, iterations=iterations) 345 self.data[mask == 0, :] = np.nan 346 else: 347 mask = (self.data != 0).any( axis=-1 ) 348 mask = cv2.erode(mask.astype(np.uint8), kernel, iterations=iterations) 349 self.data[mask == 0, :] = 0
Apply an erode filter to this image to expand background (nan) pixels. Refer to open-cv's erode function for more details.
Arguments:
- size (int): the size of the erode filter. Default is a 3x3 kernel.
- iterations (int): the number of erode iterations. Default is 1.
351 def resize(self, newdims : tuple, interpolation : int = 1): 352 """ 353 Resize this image with opencv. 354 355 Args: 356 newdims (tuple): the new image dimensions. 357 interpolation (int): opencv interpolation method. Default is cv2.INTER_LINEAR. 358 """ 359 import cv2 # import this here to avoid errors if opencv is not installed properly 360 self.data = cv2.resize(self.data, (newdims[1],newdims[0]), interpolation=interpolation)
Resize this image with opencv.
Arguments:
- newdims (tuple): the new image dimensions.
- interpolation (int): opencv interpolation method. Default is cv2.INTER_LINEAR.
362 def despeckle(self, size=5): 363 """ 364 Despeckle each band of this image (independently) using a median filter. 365 366 Args: 367 size (int): the size of the median filter kernel. Default is 5. Must be an odd number. 368 """ 369 370 assert (size % 2) == 1, "Error - size must be an odd integer" 371 import cv2 # import this here to avoid errors if opencv is not installed properly 372 if self.is_float(): 373 self.data = cv2.medianBlur( self.data.astype(np.float32), size ) 374 else: 375 self.data = cv2.medianBlur( self.data, size )
Despeckle each band of this image (independently) using a median filter.
Arguments:
- size (int): the size of the median filter kernel. Default is 5. Must be an odd number.
380 def get_keypoints(self, band, eq=False, mask=True, method='sift', cfac=0.0,bfac=0.0, **kwds): 381 """ 382 Get feature descriptors from the specified band. 383 384 Args: 385 band (int,float,str,tuple): the band index (int) or wavelength (float) to extract features from. Alternatively, a tuple can be passed 386 containing a range of bands (min : max) to average before feature matching. 387 eq (bool): True if the image should be histogram equalized first. Default is False. 388 mask (bool): True if 0 value pixels should be masked. Default is True. 389 method (str): the feature detector to use. Options are 'SIFT' and 'ORB' (faster but less accurate). Default is 'SIFT'. 390 cfac (float): contrast adjustment to apply to hyperspectral bands before matching. Default is 0.0. 391 bfac (float): brightness adjustment to apply to hyperspectral bands before matching. Default is 0.0. 392 **kwds: keyword arguments are passed to the opencv feature detector. For SIFT these are: 393 394 - contrastThreshold: default is 0.01. 395 - edgeThreshold: default is 10. 396 - sigma: default is 1.0 397 398 For ORB these are: 399 400 - nfeatures = the number of features to detect. Default is 5000. 401 402 Returns: 403 Tuple containing 404 405 - k (ndarray): the keypoints detected 406 - d (ndarray): corresponding feature descriptors 407 """ 408 import cv2 # import this here to avoid errors if opencv is not installed properly 409 410 # get image 411 if isinstance(band, int) or isinstance(band, float): #single band 412 image = self.data[:, :, self.get_band_index(band)] 413 elif isinstance(band,tuple): #range of bands (averaged) 414 idx0 = self.get_band_index(band[0]) 415 idx1 = self.get_band_index(band[1]) 416 417 #deal with out of range errors 418 if idx0 is None: 419 idx0 = 0 420 if idx1 is None: 421 idx1 = self.band_count() 422 423 #average bands 424 image = np.nanmean(self.data[:,:,idx0:idx1],axis=2) 425 else: 426 assert False, "Error, unrecognised band %s" % band 427 428 #normalise image to range 0 - 1 429 image -= np.nanmin(image) 430 image = image / np.nanmax(image) 431 432 #apply brightness/contrast adjustment 433 image = (1.0+cfac)*image + bfac 434 image[image > 1.0] = 1.0 435 image[image < 0.0] = 0.0 436 437 #convert image to uint8 for opencv 438 image = np.uint8(255 * image) 439 if eq: 440 image = cv2.equalizeHist(image) 441 442 if mask: 443 mask = np.zeros(image.shape, dtype=np.uint8) 444 mask[image != 0] = 255 # include only non-zero pixels 445 else: 446 mask = None 447 448 if 'sift' in method.lower(): # SIFT 449 450 # setup default keywords 451 kwds["contrastThreshold"] = kwds.get("contrastThreshold", 0.01) 452 kwds["edgeThreshold"] = kwds.get("edgeThreshold", 10) 453 kwds["sigma"] = kwds.get("sigma", 1.0) 454 455 # make feature detector 456 #alg = cv2.xfeatures2d.SIFT_create(**kwds) 457 alg = cv2.SIFT_create() 458 elif 'orb' in method.lower(): # orb 459 kwds['nfeatures'] = kwds.get('nfeatures', 5000) 460 alg = cv2.ORB_create(scoreType=cv2.ORB_FAST_SCORE, **kwds) 461 else: 462 assert False, "Error - %s is not a recognised feature detector." % method 463 464 # detect keypoints 465 kp = alg.detect(image, mask) 466 467 # extract and return feature vectors 468 return alg.compute(image, kp)
Get feature descriptors from the specified band.
Arguments:
- band (int,float,str,tuple): the band index (int) or wavelength (float) to extract features from. Alternatively, a tuple can be passed containing a range of bands (min : max) to average before feature matching.
- eq (bool): True if the image should be histogram equalized first. Default is False.
- mask (bool): True if 0 value pixels should be masked. Default is True.
- method (str): the feature detector to use. Options are 'SIFT' and 'ORB' (faster but less accurate). Default is 'SIFT'.
- cfac (float): contrast adjustment to apply to hyperspectral bands before matching. Default is 0.0.
- bfac (float): brightness adjustment to apply to hyperspectral bands before matching. Default is 0.0.
**kwds: keyword arguments are passed to the opencv feature detector. For SIFT these are:
- contrastThreshold: default is 0.01.
- edgeThreshold: default is 10.
- sigma: default is 1.0
For ORB these are:
- nfeatures = the number of features to detect. Default is 5000.
Returns: Tuple containing
- k (ndarray): the keypoints detected
- d (ndarray): corresponding feature descriptors
470 @classmethod 471 def match_keypoints(cls, kp1, kp2, d1, d2, method='SIFT', dist=0.7, tree = 5, check = 100, min_count=5): 472 """ 473 Compares keypoint feature vectors from two images and returns matching pairs. 474 475 Args: 476 kp1 (ndarray): keypoints from the first image 477 kp2 (ndarray): keypoints from the second image 478 d1 (ndarray): descriptors for the keypoints from the first image 479 d2 (ndarray): descriptors for the keypoints from the second image 480 method (str): the method used to calculate the feature descriptors. Should be 'sift' or 'orb'. Default is 'sift'. 481 dist (float): minimum match distance (0 to 1), default is 0.7 482 tree (int): not sure what this does? Default is 5. See open-cv docs. 483 check (int): ditto. Default is 100. 484 min_count (int): the minimum number of matches to consider a valid matching operation. If fewer matches are found, 485 then the function returns None, None. Default is 5. 486 """ 487 import cv2 # import this here to avoid errors if opencv is not installed properly 488 if 'sift' in method.lower(): 489 algorithm = cv2.NORM_INF 490 elif 'orb' in method.lower(): 491 algorithm = cv2.NORM_HAMMING 492 else: 493 assert False, "Error - unknown matching algorithm %s" % method 494 495 #calculate flann matches 496 index_params = dict(algorithm=algorithm, trees=tree) 497 search_params = dict(checks=check) 498 flann = cv2.FlannBasedMatcher(index_params, search_params) 499 matches = flann.knnMatch(d1, d2, k=2) 500 501 # store all the good matches as per Lowe's ratio test. 502 good = [] 503 for m, n in matches: 504 if m.distance < dist * n.distance: 505 good.append(m) 506 507 if len(good) < min_count: 508 return None, None 509 else: 510 src_pts = np.float32([kp1[m.queryIdx].pt for m in good]).reshape(-1, 1, 2) 511 dst_pts = np.float32([kp2[m.trainIdx].pt for m in good]).reshape(-1, 1, 2) 512 return src_pts, dst_pts
Compares keypoint feature vectors from two images and returns matching pairs.
Arguments:
- kp1 (ndarray): keypoints from the first image
- kp2 (ndarray): keypoints from the second image
- d1 (ndarray): descriptors for the keypoints from the first image
- d2 (ndarray): descriptors for the keypoints from the second image
- method (str): the method used to calculate the feature descriptors. Should be 'sift' or 'orb'. Default is 'sift'.
- dist (float): minimum match distance (0 to 1), default is 0.7
- tree (int): not sure what this does? Default is 5. See open-cv docs.
- check (int): ditto. Default is 100.
- min_count (int): the minimum number of matches to consider a valid matching operation. If fewer matches are found, then the function returns None, None. Default is 5.
517 def quick_plot(self, band=0, ax=None, bfac=0.0, cfac=0.0, samples=False, tscale=False, invert=False, rot=False, flipX=False, flipY=False, 518 **kwds): 519 """ 520 Plot a band using matplotlib.imshow(...). 521 522 Args: 523 band (str,int,float,tuple): the band name (string), index (integer) or wavelength (float) to plot. Default is 0. If a tuple is passed then 524 each band in the tuple (string or index) will be mapped to rgb. Bands with negative wavelengths or indices will be inverted before plotting. 525 ax: an axis object to plot to. If none, plt.imshow( ... ) is used. 526 bfac (float): a brightness adjustment to apply to RGB mappings (-1 to 1) 527 cfac (float): a contrast adjustment to apply to RGB mappings (-1 to 1) 528 samples (bool): True if sample points (defined in the header file) should be plotted. Default is False. Otherwise, a list of 529 [ (x,y), ... ] points can be passed. 530 tscale (bool): True if each band (for ternary images) should be scaled independently. Default is False. 531 When using scaling, vmin and vmax can be used to set the clipping percentiles (integers) or 532 (constant) values (float). 533 invert (bool) : True if each band should be inverted before plotting. Only works for multiband (ternary) images. 534 rot (bool): if True, the x and y axis will be flipped (90 degree rotation) before plotting. Default is False. 535 flipX (bool): if True, the x axis will be flipped before plotting (after applying rotations). 536 flipY (bool): if True, the y axis will be flippe before plotting (after applying rotations). 537 **kwds: keywords are passed to matplotlib.imshow( ... ), except for the following: 538 539 - mask = a 2D boolean mask containing true if pixels should be drawn and false otherwise. 540 - path = a file path to save the image too (at matching resolution; use fig.savefig(..) if you want to save the figure). 541 - ticks = True if x- and y- ticks should be plotted. Default is False. 542 - ps, pc = the size and color of sample points to plot. Can be constant or list. 543 - figsize = a figsize for the figure to create (if ax is None). 544 545 Returns: 546 Tuple containing 547 548 - fig: matplotlib figure object 549 - ax: matplotlib axes object. If a colorbar is created, (band is an integer or a float), then this will be stored in ax.cbar. 550 """ 551 552 #create new axes? 553 if ax is None: 554 fig, ax = plt.subplots(figsize=kwds.pop('figsize', (18,18*self.ydim()/self.xdim()) )) 555 556 # deal with ticks 557 if not kwds.pop('ticks', False ): 558 ax.set_xticks([]) 559 ax.set_yticks([]) 560 561 #map individual band using colourmap 562 if isinstance(band, str) or isinstance(band, int) or isinstance(band, float): 563 #get band 564 if isinstance(band, str): 565 data = self.data[:, :, self.get_band_index(band)] 566 else: 567 data = self.data[:, :, self.get_band_index(np.abs(band))] 568 if not isinstance(band, str) and band < 0: 569 data = np.nanmax(data) - data # flip 570 571 # convert integer vmin and vmax values to percentiles 572 if 'vmin' in kwds: 573 if isinstance(kwds['vmin'], int): 574 kwds['vmin'] = np.nanpercentile( data, kwds['vmin'] ) 575 if 'vmax' in kwds: 576 if isinstance(kwds['vmax'], int): 577 kwds['vmax'] = np.nanpercentile( data, kwds['vmax'] ) 578 579 #mask nans (and apply custom mask) 580 mask = np.isnan(data) 581 if not np.isnan(self.header.get_data_ignore_value()): 582 mask = mask + data == self.header.get_data_ignore_value() 583 if 'mask' in kwds: 584 mask = mask + kwds.get('mask') 585 del kwds['mask'] 586 data = np.ma.array(data, mask = mask > 0 ) 587 588 # apply rotations and flipping 589 if rot: 590 data = data.T 591 if flipX: 592 data = data[::-1, :] 593 if flipY: 594 data = data[:, ::-1] 595 596 # save? 597 if 'path' in kwds: 598 path = kwds.pop('path') 599 from matplotlib.pyplot import imsave 600 if not os.path.exists(os.path.dirname(path)): 601 os.makedirs(os.path.dirname(path)) # ensure output directory exists 602 imsave(path, data.T, **kwds) # save the image 603 604 ax.cbar = ax.imshow(data.T, interpolation=kwds.pop('interpolation', 'none'), **kwds) # change default interpolation to None 605 606 #map 3 bands to RGB 607 elif isinstance(band, tuple) or isinstance(band, list): 608 #get band indices and range 609 rgb = [] 610 for b in band: 611 if isinstance(b, str): 612 rgb.append(self.get_band_index(b)) 613 else: 614 rgb.append(self.get_band_index(np.abs(b))) 615 616 #slice image (as copy) and map to 0 - 1 617 img = np.array(self.data[:, :, rgb]).copy() 618 if np.isnan(img).all(): 619 print("Warning - image contains no data.") 620 return ax.get_figure(), ax 621 622 # invert if needed 623 if invert: 624 band = [-b for b in band] 625 for i,b in enumerate(band): 626 if not isinstance(b, str) and (b < 0): 627 img[..., i] = np.nanmax(img[..., i]) - img[..., i] 628 629 # do scaling 630 if tscale: # scale bands independently 631 for b in range(3): 632 mn = kwds.get("vmin", float(np.nanmin(img))) 633 mx = kwds.get("vmax", float(np.nanmax(img))) 634 if isinstance (mn, int): 635 assert mn >= 0 and mn <= 100, "Error - integer vmin values must be a percentile." 636 mn = float(np.nanpercentile(img[...,b], mn )) 637 if isinstance (mx, int): 638 assert mx >= 0 and mx <= 100, "Error - integer vmax values must be a percentile." 639 mx = float(np.nanpercentile(img[...,b], mx )) 640 img[...,b] = (img[..., b] - mn) / (mx - mn) 641 else: # scale bands together 642 mn = kwds.get("vmin", float(np.nanmin(img))) 643 mx = kwds.get("vmax", float(np.nanmax(img))) 644 if isinstance(mn, int): 645 assert mn >= 0 and mn <= 100, "Error - integer vmin values must be a percentile." 646 mn = float(np.nanpercentile(img, mn)) 647 if isinstance(mx, int): 648 assert mx >= 0 and mx <= 100, "Error - integer vmax values must be a percentile." 649 mx = float(np.nanpercentile(img, mx)) 650 img = (img - mn) / (mx - mn) 651 652 #apply brightness/contrast mapping 653 img = np.clip((1.0 + cfac) * img + bfac, 0, 1.0 ) 654 655 #apply masking so background is white 656 img[np.logical_not( np.isfinite( img ) )] = 1.0 657 if 'mask' in kwds: 658 img[kwds.pop("mask"),:] = 1.0 659 660 # apply rotations and flipping 661 if rot: 662 img = np.transpose( img, (1,0,2) ) 663 if flipX: 664 img = img[::-1, :, :] 665 if flipY: 666 img = img[:, ::-1, :] 667 668 # save? 669 if 'path' in kwds: 670 path = kwds.pop('path') 671 from matplotlib.pyplot import imsave 672 if not os.path.exists(os.path.dirname(path)): 673 os.makedirs(os.path.dirname(path)) # ensure output directory exists 674 imsave(path, np.transpose( np.clip( img*255, 0, 255).astype(np.uint8), (1, 0, 2))) # save the image 675 676 # plot samples? 677 ps = kwds.pop('ps', 5) 678 pc = kwds.pop('pc', 'r') 679 if samples: 680 if isinstance(samples, list) or isinstance(samples, np.ndarray): 681 ax.scatter([s[0] for s in samples], [s[1] for s in samples], s=ps, c=pc) 682 else: 683 for n in self.header.get_class_names(): 684 points = np.array(self.header.get_sample_points(n)) 685 ax.scatter(points[:, 0], points[:, 1], s=ps, c=pc) 686 687 #plot 688 ax.imshow(np.transpose(img, (1,0,2)), interpolation=kwds.pop('interpolation', 'none'), **kwds) 689 ax.cbar = None # no colorbar 690 691 return ax.get_figure(), ax
Plot a band using matplotlib.imshow(...).
Arguments:
- band (str,int,float,tuple): the band name (string), index (integer) or wavelength (float) to plot. Default is 0. If a tuple is passed then each band in the tuple (string or index) will be mapped to rgb. Bands with negative wavelengths or indices will be inverted before plotting.
- ax: an axis object to plot to. If none, plt.imshow( ... ) is used.
- bfac (float): a brightness adjustment to apply to RGB mappings (-1 to 1)
- cfac (float): a contrast adjustment to apply to RGB mappings (-1 to 1)
- samples (bool): True if sample points (defined in the header file) should be plotted. Default is False. Otherwise, a list of [ (x,y), ... ] points can be passed.
- tscale (bool): True if each band (for ternary images) should be scaled independently. Default is False. When using scaling, vmin and vmax can be used to set the clipping percentiles (integers) or (constant) values (float).
- invert (bool) : True if each band should be inverted before plotting. Only works for multiband (ternary) images.
- rot (bool): if True, the x and y axis will be flipped (90 degree rotation) before plotting. Default is False.
- flipX (bool): if True, the x axis will be flipped before plotting (after applying rotations).
- flipY (bool): if True, the y axis will be flippe before plotting (after applying rotations).
**kwds: keywords are passed to matplotlib.imshow( ... ), except for the following:
- mask = a 2D boolean mask containing true if pixels should be drawn and false otherwise.
- path = a file path to save the image too (at matching resolution; use fig.savefig(..) if you want to save the figure).
- ticks = True if x- and y- ticks should be plotted. Default is False.
- ps, pc = the size and color of sample points to plot. Can be constant or list.
- figsize = a figsize for the figure to create (if ax is None).
Returns:
Tuple containing
- fig: matplotlib figure object
- ax: matplotlib axes object. If a colorbar is created, (band is an integer or a float), then this will be stored in ax.cbar.
693 def createGIF(self, path, bands=None, figsize=(10,10), fps=10, **kwds): 694 """ 695 Create and save an animated gif that loops through the bands of the image. 696 697 Args: 698 path (str): the path to save the .gif 699 bands (tuple): Tuple containing the range of band indices to draw. Default is the whole range. 700 figsize (tuple): the size of the image to draw. Default is (10,10). 701 fps (int): the framerate (frames per second) of the gif. Default is 10. 702 **kwds: keywords are passed directly to matplotlib.imshow. Use this to specify cmap etc. 703 """ 704 705 frames = [] 706 if bands is None: 707 bands = (0,self.band_count()) 708 else: 709 assert 0 < bands[0] < self.band_count(), "Error - invalid range." 710 assert 0 < bands[1] < self.band_count(), "Error - invalid range." 711 assert bands[1] > bands[0], "Error - invalid range." 712 713 #plot frames 714 for i in range(bands[0],bands[1]): 715 fig, ax = plt.subplots(figsize=figsize) 716 ax.imshow(self.data[:, :, i], **kwds) 717 fig.canvas.draw() 718 frames.append(np.frombuffer(fig.canvas.tostring_rgb(), dtype='uint8')) 719 frames[-1] = np.reshape(frames[-1], (fig.canvas.get_width_height()[1], fig.canvas.get_width_height()[0], 3)) 720 plt.close(fig) 721 722 #save gif 723 imageio.mimsave( os.path.splitext(path)[0] + ".gif", frames, fps=fps)
Create and save an animated gif that loops through the bands of the image.
Arguments:
- path (str): the path to save the .gif
- bands (tuple): Tuple containing the range of band indices to draw. Default is the whole range.
- figsize (tuple): the size of the image to draw. Default is (10,10).
- fps (int): the framerate (frames per second) of the gif. Default is 10.
- **kwds: keywords are passed directly to matplotlib.imshow. Use this to specify cmap etc.
726 def drop_bbl(self, drop=True): 727 """ 728 Remove bad bands as stored in the 'bbl' key in the image header. Note that this operates in-place. 729 730 Args: 731 drop (bool): True if bad bands should be completely dropped. If False, these bands will be kept but replaced with nans. 732 """ 733 assert 'bbl' in self.header, "Please specify a bad band list ('bbl') in the image header, as per the ENVI format definition." 734 mask = self.header.get_list('bbl') == 0 735 self.data[...,mask] = np.nan 736 if drop: 737 self.delete_nan_bands(inplace=True)
Remove bad bands as stored in the 'bbl' key in the image header. Note that this operates in-place.
Arguments:
- drop (bool): True if bad bands should be completely dropped. If False, these bands will be kept but replaced with nans.
739 def mask(self, mask=None, flag=np.nan, invert=False, crop=False, bands=None): 740 """ 741 Apply a mask to an image, flagging masked pixels with the specified value. Note that this applies the mask to the 742 image in-situ. 743 744 Args: 745 flag (float): the value to use for masked pixels. Default is np.nan 746 mask (ndarray): a numpy array defining the mask polygon of the format [[x1,y1],[x2,y2],...]. If None is passed then 747 pickPolygon( ... ) is used to interactively define a polygon. If a file path is passed then the polygon 748 will be loaded using np.load( ... ). Alternatively if mask.shape == image.shape[0,1] then it is treated as a 749 binary image mask (must be boolean) and True values will be masked across all bands. Default is None. 750 invert (bool): if True, pixels within the polygon will be masked. If False, pixels outside the polygon are masked. Default is False. 751 crop (bool): True if rows/columns containing only zeros should be removed. Default is False. 752 bands (tuple): the bands of the image to plot if no mask is specified. If None, the middle band is used. 753 754 Returns: 755 Tuple containing 756 757 - mask (ndarray): a boolean array with True where pixels are masked and False elsewhere. 758 - poly (ndarray): the mask polygon array in the format described above. Useful if the polygon was interactively defined. 759 """ 760 761 if mask is None: # pick mask interactively 762 if bands is None: 763 bands = int(self.band_count() / 2) 764 765 regions = self.pickPolygons(region_names=["mask"], bands=bands) 766 767 # the user bailed without picking a mask? 768 if len(regions) == 0: 769 print("Warning - no mask picked/applied.") 770 return 771 772 # extract polygon mask 773 mask = regions[0] 774 775 # convert polygon mask to binary mask 776 if mask.shape[1] == 2: 777 778 # build meshgrid with pixel coords 779 xx, yy = np.meshgrid(np.arange(self.xdim()), np.arange(self.ydim())) 780 xx = xx.flatten() 781 yy = yy.flatten() 782 points = np.vstack([xx, yy]).T # coordinates of each pixel 783 784 # calculate per-pixel mask 785 mask = path.Path(mask).contains_points(points) 786 mask = mask.reshape((self.ydim(), self.xdim())).T 787 788 # flip as we want to mask (==True) outside points (unless invert is true) 789 if not invert: 790 mask = np.logical_not(mask) 791 792 # apply binary image mask 793 assert mask.shape[0] == self.data.shape[0] and mask.shape[1] == self.data.shape[1], \ 794 "Error - mask shape %s does not match image shape %s" % (mask.shape, self.data.shape) 795 for b in range(self.band_count()): 796 self.data[:, :, b][mask] = flag 797 798 # crop image 799 if crop: 800 # calculate non-masked pixels 801 valid = np.logical_not(mask) 802 803 # integrate along axes 804 xdata = np.sum(valid, axis=1) > 0.0 805 ydata = np.sum(valid, axis=0) > 0.0 806 807 # calculate domain containing valid pixels 808 xmin = np.argmax(xdata) 809 xmax = xdata.shape[0] - np.argmax(xdata[::-1]) 810 ymin = np.argmax(ydata) 811 ymax = ydata.shape[0] - np.argmax(ydata[::-1]) 812 813 # crop 814 self.data = self.data[xmin:xmax, ymin:ymax, :] 815 816 return mask
Apply a mask to an image, flagging masked pixels with the specified value. Note that this applies the mask to the image in-situ.
Arguments:
- flag (float): the value to use for masked pixels. Default is np.nan
- mask (ndarray): a numpy array defining the mask polygon of the format [[x1,y1],[x2,y2],...]. If None is passed then pickPolygon( ... ) is used to interactively define a polygon. If a file path is passed then the polygon will be loaded using np.load( ... ). Alternatively if mask.shape == image.shape[0,1] then it is treated as a binary image mask (must be boolean) and True values will be masked across all bands. Default is None.
- invert (bool): if True, pixels within the polygon will be masked. If False, pixels outside the polygon are masked. Default is False.
- crop (bool): True if rows/columns containing only zeros should be removed. Default is False.
- bands (tuple): the bands of the image to plot if no mask is specified. If None, the middle band is used.
Returns:
Tuple containing
- mask (ndarray): a boolean array with True where pixels are masked and False elsewhere.
- poly (ndarray): the mask polygon array in the format described above. Useful if the polygon was interactively defined.
818 def crop_to_data(self): 819 """ 820 Remove padding of nan or zero pixels from image. Note that this is performed in place. 821 """ 822 823 valid = np.isfinite(self.data).any(axis=-1) & (self.data != 0).any(axis=-1) 824 ymin, ymax = np.percentile(np.argwhere(np.sum(valid, axis=0) != 0), (0, 100)) 825 xmin, xmax = np.percentile(np.argwhere(np.sum(valid, axis=1) != 0), (0, 100)) 826 self.data = self.data[int(xmin):int(xmax), int(ymin):int(ymax), :] # do clipping
Remove padding of nan or zero pixels from image. Note that this is performed in place.
831 def pickPolygons(self, region_names, bands=0): 832 """ 833 Creates a matplotlib gui for selecting polygon regions in an image. 834 835 Args: 836 names (list, str): a list containing the names of the regions to pick. If a string is passed only one name is used. 837 bands (tuple): the bands of the image to plot. 838 """ 839 840 if isinstance(region_names, str): 841 region_names = [region_names] 842 843 assert isinstance(region_names, list), "Error - names must be a list or a string." 844 845 # set matplotlib backend 846 backend = matplotlib.get_backend() 847 matplotlib.use('Qt5Agg') # need this backend for ROIPoly to work 848 849 # plot image and extract roi's 850 fig, ax = self.quick_plot(bands) 851 roi = MultiRoi(roi_names=region_names) 852 plt.close(fig) # close figure 853 854 # extract regions 855 regions = [] 856 for name, r in roi.rois.items(): 857 # store region 858 x = r.x 859 y = r.y 860 regions.append(np.vstack([x, y]).T) 861 862 # restore matplotlib backend (if possible) 863 try: 864 matplotlib.use(backend) 865 except: 866 print("Warning: could not reset matplotlib backend. Plots will remain interactive...") 867 pass 868 869 return regions
Creates a matplotlib gui for selecting polygon regions in an image.
Arguments:
- names (list, str): a list containing the names of the regions to pick. If a string is passed only one name is used.
- bands (tuple): the bands of the image to plot.
871 def pickPoints(self, n=-1, bands=hylite.RGB, integer=True, title="Pick Points", **kwds): 872 """ 873 Creates a matplotlib gui for picking pixels from an image. 874 875 Args: 876 n (int): the number of pixels to pick, or -1 if the user can select as many as they wish. Default is -1. 877 bands (tuple): the bands of the image to plot. Default is HyImage.RGB 878 integer (bool): True if points coordinates should be cast to integers (for use as indices). Default is True. 879 title (str): The title of the point picking window. 880 **kwds: Keywords are passed to HyImage.quick_plot( ... ). 881 882 Returns: 883 A list containing the picked point coordinates [ (x1,y1), (x2,y2), ... ]. 884 """ 885 886 # set matplotlib backend 887 backend = matplotlib.get_backend() 888 matplotlib.use('Qt5Agg') # need this backend for ROIPoly to work 889 890 # create figure 891 fig, ax = self.quick_plot( bands, **kwds ) 892 ax.set_title(title) 893 894 # get points 895 points = fig.ginput( n ) 896 897 if integer: 898 points = [ (int(p[0]), int(p[1])) for p in points ] 899 900 # restore matplotlib backend (if possible) 901 try: 902 matplotlib.use(backend) 903 except: 904 print("Warning: could not reset matplotlib backend. Plots will remain interactive...") 905 pass 906 907 return points
Creates a matplotlib gui for picking pixels from an image.
Arguments:
- n (int): the number of pixels to pick, or -1 if the user can select as many as they wish. Default is -1.
- bands (tuple): the bands of the image to plot. Default is HyImage.RGB
- integer (bool): True if points coordinates should be cast to integers (for use as indices). Default is True.
- title (str): The title of the point picking window.
- **kwds: Keywords are passed to HyImage.quick_plot( ... ).
Returns:
A list containing the picked point coordinates [ (x1,y1), (x2,y2), ... ].
909 def pickSamples(self, names=None, store=True, **kwds): 910 """ 911 Pick sample probe points and store these in the image header file. 912 913 Args: 914 names (str, list): the name of the sample to pick, or a list of names to pick multiple. 915 store (bool): True if sample should be stored in the image header file (for later access). Default is True. 916 **kwds: Keywords are passed to HyImage.quick_plot( ... ) 917 918 Returns: 919 a list containing a list of points for each sample. 920 """ 921 922 if isinstance(names, str): 923 names = [names] 924 925 # pick points 926 points = [] 927 for s in names: 928 pnts = self.pickPoints(title="%s" % s, **kwds) 929 if store: 930 self.header['sample %s' % s] = pnts # store in header 931 points.append(pnts) 932 # add class to header file 933 if store: 934 cls_names = self.header.get_class_names() 935 if cls_names is None: 936 cls_names = [] 937 self.header['class names'] = cls_names + names 938 939 return points
Pick sample probe points and store these in the image header file.
Arguments:
- names (str, list): the name of the sample to pick, or a list of names to pick multiple.
- store (bool): True if sample should be stored in the image header file (for later access). Default is True.
- **kwds: Keywords are passed to HyImage.quick_plot( ... )
Returns:
a list containing a list of points for each sample.
Inherited Members
- hylite.hydata.HyData
- to_grey
- set_header
- push_to_header
- has_wavelengths
- get_wavelengths
- has_band_names
- get_band_names
- has_fwhm
- get_fwhm
- set_wavelengths
- set_band_names
- set_fwhm
- is_image
- is_point
- is_classification
- band_count
- samples
- lines
- is_int
- is_float
- export_bands
- delete_nan_bands
- set_as_nan
- mask_bands
- mask_water_features
- get_band
- get_band_grey
- get_raveled
- X
- eval
- set_raveled
- get_band_index
- resample
- contiguous_chunks
- smooth_median
- smooth_savgol
- fill_gaps
- plot_spectra
- compress
- decompress
- getQuantized
- fromQuanta
- normalise
- percent_clip
- correct_spectral_shift