hylite.hyimage
Store and manipulate hyperspectral image data.
1""" 2Store and manipulate hyperspectral image data. 3""" 4 5import os 6import numpy as np 7import matplotlib 8import matplotlib.pyplot as plt 9from matplotlib import path 10from roipoly import MultiRoi 11import imageio 12import scipy as sp 13import hylite 14from hylite.hydata import HyData 15from hylite.hylibrary import HyLibrary 16 17 18 19class HyImage( HyData ): 20 """ 21 A class for hyperspectral image data. These can be individual scenes or hyperspectral orthoimages. 22 """ 23 24 def __init__(self, data, **kwds): 25 """ 26 Args: 27 data (ndarray): a numpy array such that data[x][y][band] gives each pixel value. 28 **kwds: 29 wav = A numpy array containing band wavelengths for this image. 30 affine = an affine transform of the format returned by GDAL.GetGeoTransform(). 31 projection = string defining the project. Default is None. 32 sensor = sensor name. Default is "unknown". 33 header = path to associated header file. Default is None. 34 """ 35 36 #call constructor for HyData 37 super().__init__(data, **kwds) 38 39 # special case - if dataset only has oneband, slice it so it still has 40 # the format data[x,y,b]. 41 if not self.data is None: 42 if len(self.data.shape) == 1: 43 self.data = self.data[None, None, :] # single pixel image 44 if len(self.data.shape) == 2: 45 self.data = self.data[:, :, None] # single band iamge 46 47 #load any additional project information (specific to images) 48 self.set_projection(kwds.get("projection",None)) 49 self.affine = kwds.get("affine",[0,1,0,0,0,1]) 50 51 # wavelengths 52 if 'wav' in kwds: 53 self.set_wavelengths(kwds['wav']) 54 55 #special header formatting 56 self.header['file type'] = 'ENVI Standard' 57 58 def copy(self,data=True): 59 """ 60 Make a deep copy of this image instance. 61 62 Args: 63 data (bool): True if a copy of the data should be made, otherwise only copy header. 64 65 Returns: 66 a new HyImage instance. 67 """ 68 if not data: 69 return HyImage(None, header=self.header.copy(), projection=self.projection, affine=self.affine) 70 else: 71 return HyImage( self.data.copy(), header=self.header.copy(), projection=self.projection, affine=self.affine) 72 73 def T(self): 74 """ 75 Return a transposed view of the data matrix (corresponding with the [y,x] indexing used by matplotlib, opencv etc. 76 """ 77 return np.transpose(self.data, (1,0,2)) 78 79 def xdim(self): 80 """ 81 Return number of pixels in x (first dimension of data array) 82 """ 83 return self.data.shape[0] 84 85 def ydim(self): 86 """ 87 Return number of pixels in y (second dimension of data array) 88 """ 89 return self.data.shape[1] 90 91 def aspx(self): 92 """ 93 Return the aspect ratio of this image (width/height). 94 """ 95 return self.ydim() / self.xdim() 96 97 def get_extent(self): 98 """ 99 Returns the width and height of this image in world coordinates. 100 101 Returns: 102 tuple with (width, height). 103 """ 104 return self.xdim * self.pixel_size[0], self.ydim * self.pixel_size[1] 105 106 def set_projection(self,proj): 107 """ 108 Set this project to an existing osgeo.osr.SpatialReference or GDAL georeference string. 109 110 Args: 111 proj (str, osgeo.osr.SpatialReference): the project to use as osgeo.osr.SpatialReference or GDAL georeference string. 112 """ 113 if proj is None: 114 self.projection = None 115 else: 116 try: 117 from osgeo.osr import SpatialReference 118 except: 119 assert False, "Error - GDAL must be installed to work with spatial projections in hylite." 120 if isinstance(proj, SpatialReference): 121 self.projection = proj 122 elif isinstance(proj, str): 123 self.projection = SpatialReference(proj) 124 else: 125 print("Invalid project %s" % proj) 126 raise 127 128 def set_projection_EPSG(self,EPSG): 129 """ 130 Sets this image project using an EPSG code. 131 132 Args: 133 EPSG (str): string EPSG code that can be passed to SpatialReference.SetFromUserInput(...). 134 """ 135 136 try: 137 from osgeo.osr import SpatialReference 138 except: 139 assert False, "Error - GDAL must be installed to work with spatial projections in hylite." 140 141 self.projection = SpatialReference() 142 self.projection.SetFromUserInput(EPSG) 143 144 def get_projection_EPSG(self): 145 """ 146 Gets a string describing this projections EPSG code (if it is an EPSG project). 147 148 Returns: 149 an EPSG code string of the format "EPSG:XXXX". 150 """ 151 if self.projection is None: 152 return None 153 else: 154 return "%s:%s" % (self.projection.GetAttrValue("AUTHORITY",0),self.projection.GetAttrValue("AUTHORITY",1)) 155 156 def pix_to_world(self, px, py, proj=None): 157 """ 158 Take pixel coordinates and return world coordinates 159 160 Args: 161 px (int): the pixel x-coord. 162 py (int): the pixel y-coord. 163 proj (str, osr.SpatialReference): the coordinate system to use. Default (None) uses the same system as this image. Otherwise 164 an osr.SpatialReference can be passed (HyImage.project), or an EPSG string (e.g. get_projection_EPSG(...)). 165 Returns: 166 the world coordinates in the coordinate system defined by get_projection_EPSG(...). 167 """ 168 169 try: 170 from osgeo import osr 171 import osgeo.gdal as gdal 172 from osgeo import ogr 173 except: 174 assert False, "Error - GDAL must be installed to work with spatial projections in hylite." 175 176 # parse project 177 if proj is None: 178 proj = self.projection 179 elif isinstance(proj, str) or isinstance(proj, int): 180 epsg = proj 181 if isinstance(epsg, str): 182 try: 183 epsg = int(str.split(':')[1]) 184 except: 185 assert False, "Error - %s is an invalid EPSG code." % proj 186 proj = osr.SpatialReference() 187 proj.ImportFromEPSG(epsg) 188 189 # check we have all the required info 190 assert isinstance(proj, osr.SpatialReference), "Error - invalid spatial reference %s" % proj 191 assert (not self.affine is None) and ( 192 not self.projection is None), "Error - project information is undefined." 193 194 #project to world coordinates in this images project/world coords 195 x,y = gdal.ApplyGeoTransform(self.affine, px, py) 196 197 #project to target coords (if different) 198 if not proj.IsSameGeogCS(self.projection): 199 P = ogr.Geometry(ogr.wkbPoint) 200 if proj.EPSGTreatsAsNorthingEasting(): 201 P.AddPoint(x, y) 202 else: 203 P.AddPoint(y, x) 204 P.AssignSpatialReference(self.projection) # tell the point what coordinates it's in 205 P.TransformTo(proj) # reproject it to the out spatial reference 206 x, y = P.GetX(), P.GetY() 207 208 #do we need to transpose? 209 if proj.EPSGTreatsAsLatLong(): 210 x,y=y,x #we want lon,lat not lat,lon 211 return x, y 212 213 def world_to_pix(self, x, y, proj = None): 214 """ 215 Take world coordinates and return pixel coordinates 216 217 Args: 218 x (float): the world x-coord. 219 y (float): the world y-coord. 220 proj (str, osr.SpatialReference): the coordinate system of the input coordinates. Default (None) uses the same system as this image. Otherwise 221 an osr.SpatialReference can be passed (HyImage.project), or an EPSG string (e.g. get_projection_EPSG(...)). 222 223 Returns: 224 the pixel coordinates based on the affine transform stored in self.affine. 225 """ 226 227 try: 228 from osgeo import osr 229 import osgeo.gdal as gdal 230 from osgeo import ogr 231 except: 232 assert False, "Error - GDAL must be installed to work with spatial projections in hylite." 233 234 # parse project 235 if proj is None: 236 proj = self.projection 237 elif isinstance(proj, str) or isinstance(proj, int): 238 epsg = proj 239 if isinstance(epsg, str): 240 try: 241 epsg = int(str.split(':')[1]) 242 except: 243 assert False, "Error - %s is an invalid EPSG code." % proj 244 proj = osr.SpatialReference() 245 proj.ImportFromEPSG(epsg) 246 247 248 # check we have all the required info 249 assert isinstance(proj, osr.SpatialReference), "Error - invalid spatial reference %s" % proj 250 assert (not self.affine is None) and (not self.projection is None), "Error - project information is undefined." 251 252 # project to this images CS (if different) 253 if not proj.IsSameGeogCS(self.projection): 254 P = ogr.Geometry(ogr.wkbPoint) 255 if proj.EPSGTreatsAsNorthingEasting(): 256 P.AddPoint(x, y) 257 else: 258 P.AddPoint(y, x) 259 P.AssignSpatialReference(proj) # tell the point what coordinates it's in 260 P.AddPoint(x, y) 261 P.TransformTo(self.projection) # reproject it to the out spatial reference 262 x, y = P.GetX(), P.GetY() 263 if self.projection.EPSGTreatsAsLatLong(): # do we need to transpose? 264 x, y = y, x # we want lon,lat not lat,lon 265 266 inv = gdal.InvGeoTransform(self.affine) 267 assert not inv is None, "Error - could not invert affine transform?" 268 269 #apply 270 return gdal.ApplyGeoTransform(inv, x, y) 271 272 def flip(self, axis='x'): 273 """ 274 Flip the image on the x or y axis. 275 276 Args: 277 axis (str): 'x' or 'y' or both 'xy'. 278 """ 279 280 if 'x' in axis.lower(): 281 self.data = np.flip(self.data,axis=0) 282 if 'y' in axis.lower(): 283 self.data = np.flip(self.data,axis=1) 284 285 def rot90(self): 286 """ 287 Rotate this image by 90 degrees by transposing the underlying data array. Combine with flip('x') or flip('y') 288 to achieve positive/negative rotations. 289 """ 290 self.data = np.transpose( self.data, (1,0,2) ) 291 self.push_to_header() 292 293 ##################################### 294 ##IMAGE FILTERING 295 ##################################### 296 def fill_holes(self): 297 """ 298 Replaces nan pixel with an average of their neighbours, thus removing 1-pixel large holes from an image. Note that 299 for performance reasons this assumes that holes line up across bands. Note that this is not vectorized so very slow... 300 """ 301 302 # perform greyscale dilation 303 dilate = self.data.copy() 304 mask = np.logical_not(np.isfinite(dilate)) 305 dilate[mask] = 0 306 for b in range(self.band_count()): 307 dilate[:, :, b] = sp.ndimage.grey_dilation(dilate[:, :, b], size=(3, 3)) 308 309 # map back to holes in dataset 310 self.data[mask] = dilate[mask] 311 #self.data[self.data == 0] = np.nan # replace remaining 0's with nans 312 313 def blur(self, n=3): 314 """ 315 Applies a gaussian kernel of size n to the image using OpenCV. 316 317 Args: 318 n (int): the dimensions of the gaussian kernel to convolve. Default is 3. Increase for more blurry results. 319 """ 320 import cv2 # import this here to avoid errors if opencv is not installed properly 321 322 nanmask = np.isnan(self.data) 323 assert isinstance(n, int) and n >= 3, "Error - invalid kernel. N must be an integer > 3. " 324 kernel = np.ones((n, n), np.float32) / (n ** 2) 325 self.data = cv2.filter2D(self.data, -1, kernel) 326 self.data[nanmask] = np.nan # remove mask 327 328 def erode(self, size=3, iterations=1): 329 """ 330 Apply an erode filter to this image to expand background (nan) pixels. Refer to open-cv's erode 331 function for more details. 332 333 Args: 334 size (int): the size of the erode filter. Default is a 3x3 kernel. 335 iterations (int): the number of erode iterations. Default is 1. 336 """ 337 import cv2 # import this here to avoid errors if opencv is not installed properly 338 339 # erode 340 kernel = np.ones((size, size), np.uint8) 341 if self.is_float(): 342 mask = np.isfinite(self.data).any(axis=-1) 343 mask = cv2.erode(mask.astype(np.uint8), kernel, iterations=iterations) 344 self.data[mask == 0, :] = np.nan 345 else: 346 mask = (self.data != 0).any( axis=-1 ) 347 mask = cv2.erode(mask.astype(np.uint8), kernel, iterations=iterations) 348 self.data[mask == 0, :] = 0 349 350 def resize(self, newdims : tuple, interpolation : int = 1): 351 """ 352 Resize this image with opencv. 353 354 Args: 355 newdims (tuple): the new image dimensions. 356 interpolation (int): opencv interpolation method. Default is cv2.INTER_LINEAR. 357 """ 358 import cv2 # import this here to avoid errors if opencv is not installed properly 359 self.data = cv2.resize(self.data, (newdims[1],newdims[0]), interpolation=interpolation) 360 361 def despeckle(self, size=5): 362 """ 363 Despeckle each band of this image (independently) using a median filter. 364 365 Args: 366 size (int): the size of the median filter kernel. Default is 5. Must be an odd number. 367 """ 368 369 assert (size % 2) == 1, "Error - size must be an odd integer" 370 import cv2 # import this here to avoid errors if opencv is not installed properly 371 if self.is_float(): 372 self.data = cv2.medianBlur( self.data.astype(np.float32), size ) 373 else: 374 self.data = cv2.medianBlur( self.data, size ) 375 376 ##################################### 377 ##FEATURES AND FEATURE MATCHING 378 ###################################### 379 def get_keypoints(self, band, eq=False, mask=True, method='sift', cfac=0.0,bfac=0.0, **kwds): 380 """ 381 Get feature descriptors from the specified band. 382 383 Args: 384 band (int,float,str,tuple): the band index (int) or wavelength (float) to extract features from. Alternatively, a tuple can be passed 385 containing a range of bands (min : max) to average before feature matching. 386 eq (bool): True if the image should be histogram equalized first. Default is False. 387 mask (bool): True if 0 value pixels should be masked. Default is True. 388 method (str): the feature detector to use. Options are 'SIFT' and 'ORB' (faster but less accurate). Default is 'SIFT'. 389 cfac (float): contrast adjustment to apply to hyperspectral bands before matching. Default is 0.0. 390 bfac (float): brightness adjustment to apply to hyperspectral bands before matching. Default is 0.0. 391 **kwds: keyword arguments are passed to the opencv feature detector. For SIFT these are: 392 393 - contrastThreshold: default is 0.01. 394 - edgeThreshold: default is 10. 395 - sigma: default is 1.0 396 397 For ORB these are: 398 399 - nfeatures = the number of features to detect. Default is 5000. 400 401 Returns: 402 Tuple containing 403 404 - k (ndarray): the keypoints detected 405 - d (ndarray): corresponding feature descriptors 406 """ 407 import cv2 # import this here to avoid errors if opencv is not installed properly 408 409 # get image 410 if isinstance(band, int) or isinstance(band, float): #single band 411 image = self.data[:, :, self.get_band_index(band)] 412 elif isinstance(band,tuple): #range of bands (averaged) 413 idx0 = self.get_band_index(band[0]) 414 idx1 = self.get_band_index(band[1]) 415 416 #deal with out of range errors 417 if idx0 is None: 418 idx0 = 0 419 if idx1 is None: 420 idx1 = self.band_count() 421 422 #average bands 423 image = np.nanmean(self.data[:,:,idx0:idx1],axis=2) 424 else: 425 assert False, "Error, unrecognised band %s" % band 426 427 #normalise image to range 0 - 1 428 image -= np.nanmin(image) 429 image = image / np.nanmax(image) 430 431 #apply brightness/contrast adjustment 432 image = (1.0+cfac)*image + bfac 433 image[image > 1.0] = 1.0 434 image[image < 0.0] = 0.0 435 436 #convert image to uint8 for opencv 437 image = np.uint8(255 * image) 438 if eq: 439 image = cv2.equalizeHist(image) 440 441 if mask: 442 mask = np.zeros(image.shape, dtype=np.uint8) 443 mask[image != 0] = 255 # include only non-zero pixels 444 else: 445 mask = None 446 447 if 'sift' in method.lower(): # SIFT 448 449 # setup default keywords 450 kwds["contrastThreshold"] = kwds.get("contrastThreshold", 0.01) 451 kwds["edgeThreshold"] = kwds.get("edgeThreshold", 10) 452 kwds["sigma"] = kwds.get("sigma", 1.0) 453 454 # make feature detector 455 #alg = cv2.xfeatures2d.SIFT_create(**kwds) 456 alg = cv2.SIFT_create() 457 elif 'orb' in method.lower(): # orb 458 kwds['nfeatures'] = kwds.get('nfeatures', 5000) 459 alg = cv2.ORB_create(scoreType=cv2.ORB_FAST_SCORE, **kwds) 460 else: 461 assert False, "Error - %s is not a recognised feature detector." % method 462 463 # detect keypoints 464 kp = alg.detect(image, mask) 465 466 # extract and return feature vectors 467 return alg.compute(image, kp) 468 469 @classmethod 470 def match_keypoints(cls, kp1, kp2, d1, d2, method='SIFT', dist=0.7, tree = 5, check = 100, min_count=5): 471 """ 472 Compares keypoint feature vectors from two images and returns matching pairs. 473 474 Args: 475 kp1 (ndarray): keypoints from the first image 476 kp2 (ndarray): keypoints from the second image 477 d1 (ndarray): descriptors for the keypoints from the first image 478 d2 (ndarray): descriptors for the keypoints from the second image 479 method (str): the method used to calculate the feature descriptors. Should be 'sift' or 'orb'. Default is 'sift'. 480 dist (float): minimum match distance (0 to 1), default is 0.7 481 tree (int): not sure what this does? Default is 5. See open-cv docs. 482 check (int): ditto. Default is 100. 483 min_count (int): the minimum number of matches to consider a valid matching operation. If fewer matches are found, 484 then the function returns None, None. Default is 5. 485 """ 486 import cv2 # import this here to avoid errors if opencv is not installed properly 487 if 'sift' in method.lower(): 488 algorithm = cv2.NORM_INF 489 elif 'orb' in method.lower(): 490 algorithm = cv2.NORM_HAMMING 491 else: 492 assert False, "Error - unknown matching algorithm %s" % method 493 494 #calculate flann matches 495 index_params = dict(algorithm=algorithm, trees=tree) 496 search_params = dict(checks=check) 497 flann = cv2.FlannBasedMatcher(index_params, search_params) 498 matches = flann.knnMatch(d1, d2, k=2) 499 500 # store all the good matches as per Lowe's ratio test. 501 good = [] 502 for m, n in matches: 503 if m.distance < dist * n.distance: 504 good.append(m) 505 506 if len(good) < min_count: 507 return None, None 508 else: 509 src_pts = np.float32([kp1[m.queryIdx].pt for m in good]).reshape(-1, 1, 2) 510 dst_pts = np.float32([kp2[m.trainIdx].pt for m in good]).reshape(-1, 1, 2) 511 return src_pts, dst_pts 512 513 ############################ 514 ## Visualisation methods 515 ############################ 516 def quick_plot(self, band=0, ax=None, bfac=0.0, cfac=0.0, samples=False, tscale=False, rot=False, flipX=False, flipY=False, 517 **kwds): 518 """ 519 Plot a band using matplotlib.imshow(...). 520 521 Args: 522 band (str,int,float,tuple): the band name (string), index (integer) or wavelength (float) to plot. Default is 0. If a tuple is passed then 523 each band in the tuple (string or index) will be mapped to rgb. Bands with negative wavelengths or indices will be inverted before plotting. 524 ax: an axis object to plot to. If none, plt.imshow( ... ) is used. 525 bfac (float): a brightness adjustment to apply to RGB mappings (-1 to 1) 526 cfac (float): a contrast adjustment to apply to RGB mappings (-1 to 1) 527 samples (bool): True if sample points (defined in the header file) should be plotted. Default is False. Otherwise, a list of 528 [ (x,y), ... ] points can be passed. 529 tscale (bool): True if each band (for ternary images) should be scaled independently. Default is False. 530 When using scaling, vmin and vmax can be used to set the clipping percentiles (integers) or 531 (constant) values (float). 532 rot (bool): if True, the x and y axis will be flipped (90 degree rotation) before plotting. Default is False. 533 flipX (bool): if True, the x axis will be flipped before plotting (after applying rotations). 534 flipY (bool): if True, the y axis will be flippe before plotting (after applying rotations). 535 **kwds: keywords are passed to matplotlib.imshow( ... ), except for the following: 536 537 - mask = a 2D boolean mask containing true if pixels should be drawn and false otherwise. 538 - path = a file path to save the image too (at matching resolution; use fig.savefig(..) if you want to save the figure). 539 - ticks = True if x- and y- ticks should be plotted. Default is False. 540 - ps, pc = the size and color of sample points to plot. Can be constant or list. 541 - figsize = a figsize for the figure to create (if ax is None). 542 543 Returns: 544 Tuple containing 545 546 - fig: matplotlib figure object 547 - ax: matplotlib axes object. If a colorbar is created, (band is an integer or a float), then this will be stored in ax.cbar. 548 """ 549 550 #create new axes? 551 if ax is None: 552 fig, ax = plt.subplots(figsize=kwds.pop('figsize', (18,18*self.ydim()/self.xdim()) )) 553 554 # deal with ticks 555 if not kwds.pop('ticks', False ): 556 ax.set_xticks([]) 557 ax.set_yticks([]) 558 559 #map individual band using colourmap 560 if isinstance(band, str) or isinstance(band, int) or isinstance(band, float): 561 #get band 562 if isinstance(band, str): 563 data = self.data[:, :, self.get_band_index(band)] 564 else: 565 data = self.data[:, :, self.get_band_index(np.abs(band))] 566 if not isinstance(band, str) and band < 0: 567 data = np.nanmax(data) - data # flip 568 569 # convert integer vmin and vmax values to percentiles 570 if 'vmin' in kwds: 571 if isinstance(kwds['vmin'], int): 572 kwds['vmin'] = np.nanpercentile( data, kwds['vmin'] ) 573 if 'vmax' in kwds: 574 if isinstance(kwds['vmax'], int): 575 kwds['vmax'] = np.nanpercentile( data, kwds['vmax'] ) 576 577 #mask nans (and apply custom mask) 578 mask = np.isnan(data) 579 if not np.isnan(self.header.get_data_ignore_value()): 580 mask = mask + data == self.header.get_data_ignore_value() 581 if 'mask' in kwds: 582 mask = mask + kwds.get('mask') 583 del kwds['mask'] 584 data = np.ma.array(data, mask = mask > 0 ) 585 586 # apply rotations and flipping 587 if rot: 588 data = data.T 589 if flipX: 590 data = data[::-1, :] 591 if flipY: 592 data = data[:, ::-1] 593 594 # save? 595 if 'path' in kwds: 596 path = kwds.pop('path') 597 from matplotlib.pyplot import imsave 598 if not os.path.exists(os.path.dirname(path)): 599 os.makedirs(os.path.dirname(path)) # ensure output directory exists 600 imsave(path, data.T, **kwds) # save the image 601 602 ax.cbar = ax.imshow(data.T, interpolation=kwds.pop('interpolation', 'none'), **kwds) # change default interpolation to None 603 604 #map 3 bands to RGB 605 elif isinstance(band, tuple) or isinstance(band, list): 606 #get band indices and range 607 rgb = [] 608 for b in band: 609 if isinstance(b, str): 610 rgb.append(self.get_band_index(b)) 611 else: 612 rgb.append(self.get_band_index(np.abs(b))) 613 614 #slice image (as copy) and map to 0 - 1 615 img = np.array(self.data[:, :, rgb]).copy() 616 if np.isnan(img).all(): 617 print("Warning - image contains no data.") 618 return ax.get_figure(), ax 619 620 # invert if needed 621 for i,b in enumerate(band): 622 if not isinstance(b, str) and (b < 0): 623 img[..., i] = np.nanmax(img[..., i]) - img[..., i] 624 625 # do scaling 626 if tscale: # scale bands independently 627 for b in range(3): 628 mn = kwds.get("vmin", float(np.nanmin(img))) 629 mx = kwds.get("vmax", float(np.nanmax(img))) 630 if isinstance (mn, int): 631 assert mn >= 0 and mn <= 100, "Error - integer vmin values must be a percentile." 632 mn = float(np.nanpercentile(img[...,b], mn )) 633 if isinstance (mx, int): 634 assert mx >= 0 and mx <= 100, "Error - integer vmax values must be a percentile." 635 mx = float(np.nanpercentile(img[...,b], mx )) 636 img[...,b] = (img[..., b] - mn) / (mx - mn) 637 else: # scale bands together 638 mn = kwds.get("vmin", float(np.nanmin(img))) 639 mx = kwds.get("vmax", float(np.nanmax(img))) 640 if isinstance(mn, int): 641 assert mn >= 0 and mn <= 100, "Error - integer vmin values must be a percentile." 642 mn = float(np.nanpercentile(img, mn)) 643 if isinstance(mx, int): 644 assert mx >= 0 and mx <= 100, "Error - integer vmax values must be a percentile." 645 mx = float(np.nanpercentile(img, mx)) 646 img = (img - mn) / (mx - mn) 647 648 #apply brightness/contrast mapping 649 img = np.clip((1.0 + cfac) * img + bfac, 0, 1.0 ) 650 651 #apply masking so background is white 652 img[np.logical_not( np.isfinite( img ) )] = 1.0 653 if 'mask' in kwds: 654 img[kwds.pop("mask"),:] = 1.0 655 656 # apply rotations and flipping 657 if rot: 658 img = np.transpose( img, (1,0,2) ) 659 if flipX: 660 img = img[::-1, :, :] 661 if flipY: 662 img = img[:, ::-1, :] 663 664 # save? 665 if 'path' in kwds: 666 path = kwds.pop('path') 667 from matplotlib.pyplot import imsave 668 if not os.path.exists(os.path.dirname(path)): 669 os.makedirs(os.path.dirname(path)) # ensure output directory exists 670 imsave(path, np.transpose( np.clip( img*255, 0, 255).astype(np.uint8), (1, 0, 2))) # save the image 671 672 # plot samples? 673 ps = kwds.pop('ps', 5) 674 pc = kwds.pop('pc', 'r') 675 if samples: 676 if isinstance(samples, list) or isinstance(samples, np.ndarray): 677 ax.scatter([s[0] for s in samples], [s[1] for s in samples], s=ps, c=pc) 678 else: 679 for n in self.header.get_class_names(): 680 points = np.array(self.header.get_sample_points(n)) 681 ax.scatter(points[:, 0], points[:, 1], s=ps, c=pc) 682 683 #plot 684 ax.imshow(np.transpose(img, (1,0,2)), interpolation=kwds.pop('interpolation', 'none'), **kwds) 685 ax.cbar = None # no colorbar 686 687 return ax.get_figure(), ax 688 689 def createGIF(self, path, bands=None, figsize=(10,10), fps=10, **kwds): 690 """ 691 Create and save an animated gif that loops through the bands of the image. 692 693 Args: 694 path (str): the path to save the .gif 695 bands (tuple): Tuple containing the range of band indices to draw. Default is the whole range. 696 figsize (tuple): the size of the image to draw. Default is (10,10). 697 fps (int): the framerate (frames per second) of the gif. Default is 10. 698 **kwds: keywords are passed directly to matplotlib.imshow. Use this to specify cmap etc. 699 """ 700 701 frames = [] 702 if bands is None: 703 bands = (0,self.band_count()) 704 else: 705 assert 0 < bands[0] < self.band_count(), "Error - invalid range." 706 assert 0 < bands[1] < self.band_count(), "Error - invalid range." 707 assert bands[1] > bands[0], "Error - invalid range." 708 709 #plot frames 710 for i in range(bands[0],bands[1]): 711 fig, ax = plt.subplots(figsize=figsize) 712 ax.imshow(self.data[:, :, i], **kwds) 713 fig.canvas.draw() 714 frames.append(np.frombuffer(fig.canvas.tostring_rgb(), dtype='uint8')) 715 frames[-1] = np.reshape(frames[-1], (fig.canvas.get_width_height()[1], fig.canvas.get_width_height()[0], 3)) 716 plt.close(fig) 717 718 #save gif 719 imageio.mimsave( os.path.splitext(path)[0] + ".gif", frames, fps=fps) 720 721 ## masking 722 def mask(self, mask=None, flag=np.nan, invert=False, crop=False, bands=None): 723 """ 724 Apply a mask to an image, flagging masked pixels with the specified value. Note that this applies the mask to the 725 image in-situ. 726 727 Args: 728 flag (float): the value to use for masked pixels. Default is np.nan 729 mask (ndarray): a numpy array defining the mask polygon of the format [[x1,y1],[x2,y2],...]. If None is passed then 730 pickPolygon( ... ) is used to interactively define a polygon. If a file path is passed then the polygon 731 will be loaded using np.load( ... ). Alternatively if mask.shape == image.shape[0,1] then it is treated as a 732 binary image mask (must be boolean) and True values will be masked across all bands. Default is None. 733 invert (bool): if True, pixels within the polygon will be masked. If False, pixels outside the polygon are masked. Default is False. 734 crop (bool): True if rows/columns containing only zeros should be removed. Default is False. 735 bands (tuple): the bands of the image to plot if no mask is specified. If None, the middle band is used. 736 737 Returns: 738 Tuple containing 739 740 - mask (ndarray): a boolean array with True where pixels are masked and False elsewhere. 741 - poly (ndarray): the mask polygon array in the format described above. Useful if the polygon was interactively defined. 742 """ 743 744 if mask is None: # pick mask interactively 745 if bands is None: 746 bands = int(self.band_count() / 2) 747 748 regions = self.pickPolygons(region_names=["mask"], bands=bands) 749 750 # the user bailed without picking a mask? 751 if len(regions) == 0: 752 print("Warning - no mask picked/applied.") 753 return 754 755 # extract polygon mask 756 mask = regions[0] 757 758 # convert polygon mask to binary mask 759 if mask.shape[1] == 2: 760 761 # build meshgrid with pixel coords 762 xx, yy = np.meshgrid(np.arange(self.xdim()), np.arange(self.ydim())) 763 xx = xx.flatten() 764 yy = yy.flatten() 765 points = np.vstack([xx, yy]).T # coordinates of each pixel 766 767 # calculate per-pixel mask 768 mask = path.Path(mask).contains_points(points) 769 mask = mask.reshape((self.ydim(), self.xdim())).T 770 771 # flip as we want to mask (==True) outside points (unless invert is true) 772 if not invert: 773 mask = np.logical_not(mask) 774 775 # apply binary image mask 776 assert mask.shape[0] == self.data.shape[0] and mask.shape[1] == self.data.shape[1], \ 777 "Error - mask shape %s does not match image shape %s" % (mask.shape, self.data.shape) 778 for b in range(self.band_count()): 779 self.data[:, :, b][mask] = flag 780 781 # crop image 782 if crop: 783 # calculate non-masked pixels 784 valid = np.logical_not(mask) 785 786 # integrate along axes 787 xdata = np.sum(valid, axis=1) > 0.0 788 ydata = np.sum(valid, axis=0) > 0.0 789 790 # calculate domain containing valid pixels 791 xmin = np.argmax(xdata) 792 xmax = xdata.shape[0] - np.argmax(xdata[::-1]) 793 ymin = np.argmax(ydata) 794 ymax = ydata.shape[0] - np.argmax(ydata[::-1]) 795 796 # crop 797 self.data = self.data[xmin:xmax, ymin:ymax, :] 798 799 return mask 800 801 def crop_to_data(self): 802 """ 803 Remove padding of nan or zero pixels from image. Note that this is performed in place. 804 """ 805 806 valid = np.isfinite(self.data).any(axis=-1) & (self.data != 0).any(axis=-1) 807 ymin, ymax = np.percentile(np.argwhere(np.sum(valid, axis=0) != 0), (0, 100)) 808 xmin, xmax = np.percentile(np.argwhere(np.sum(valid, axis=1) != 0), (0, 100)) 809 self.data = self.data[int(xmin):int(xmax), int(ymin):int(ymax), :] # do clipping 810 811 ################################################## 812 ## Interactive tools for picking regions/pixels 813 ################################################## 814 def pickPolygons(self, region_names, bands=0): 815 """ 816 Creates a matplotlib gui for selecting polygon regions in an image. 817 818 Args: 819 names (list, str): a list containing the names of the regions to pick. If a string is passed only one name is used. 820 bands (tuple): the bands of the image to plot. 821 """ 822 823 if isinstance(region_names, str): 824 region_names = [region_names] 825 826 assert isinstance(region_names, list), "Error - names must be a list or a string." 827 828 # set matplotlib backend 829 backend = matplotlib.get_backend() 830 matplotlib.use('Qt5Agg') # need this backend for ROIPoly to work 831 832 # plot image and extract roi's 833 fig, ax = self.quick_plot(bands) 834 roi = MultiRoi(roi_names=region_names) 835 plt.close(fig) # close figure 836 837 # extract regions 838 regions = [] 839 for name, r in roi.rois.items(): 840 # store region 841 x = r.x 842 y = r.y 843 regions.append(np.vstack([x, y]).T) 844 845 # restore matplotlib backend (if possible) 846 try: 847 matplotlib.use(backend) 848 except: 849 print("Warning: could not reset matplotlib backend. Plots will remain interactive...") 850 pass 851 852 return regions 853 854 def pickPoints(self, n=-1, bands=hylite.RGB, integer=True, title="Pick Points", **kwds): 855 """ 856 Creates a matplotlib gui for picking pixels from an image. 857 858 Args: 859 n (int): the number of pixels to pick, or -1 if the user can select as many as they wish. Default is -1. 860 bands (tuple): the bands of the image to plot. Default is HyImage.RGB 861 integer (bool): True if points coordinates should be cast to integers (for use as indices). Default is True. 862 title (str): The title of the point picking window. 863 **kwds: Keywords are passed to HyImage.quick_plot( ... ). 864 865 Returns: 866 A list containing the picked point coordinates [ (x1,y1), (x2,y2), ... ]. 867 """ 868 869 # set matplotlib backend 870 backend = matplotlib.get_backend() 871 matplotlib.use('Qt5Agg') # need this backend for ROIPoly to work 872 873 # create figure 874 fig, ax = self.quick_plot( bands, **kwds ) 875 ax.set_title(title) 876 877 # get points 878 points = fig.ginput( n ) 879 880 if integer: 881 points = [ (int(p[0]), int(p[1])) for p in points ] 882 883 # restore matplotlib backend (if possible) 884 try: 885 matplotlib.use(backend) 886 except: 887 print("Warning: could not reset matplotlib backend. Plots will remain interactive...") 888 pass 889 890 return points 891 892 def pickSamples(self, names=None, store=True, **kwds): 893 """ 894 Pick sample probe points and store these in the image header file. 895 896 Args: 897 names (str, list): the name of the sample to pick, or a list of names to pick multiple. 898 store (bool): True if sample should be stored in the image header file (for later access). Default is True. 899 **kwds: Keywords are passed to HyImage.quick_plot( ... ) 900 901 Returns: 902 a list containing a list of points for each sample. 903 """ 904 905 if isinstance(names, str): 906 names = [names] 907 908 # pick points 909 points = [] 910 for s in names: 911 pnts = self.pickPoints(title="%s" % s, **kwds) 912 if store: 913 self.header['sample %s' % s] = pnts # store in header 914 points.append(pnts) 915 # add class to header file 916 if store: 917 cls_names = self.header.get_class_names() 918 if cls_names is None: 919 cls_names = [] 920 self.header['class names'] = cls_names + names 921 922 return points
20class HyImage( HyData ): 21 """ 22 A class for hyperspectral image data. These can be individual scenes or hyperspectral orthoimages. 23 """ 24 25 def __init__(self, data, **kwds): 26 """ 27 Args: 28 data (ndarray): a numpy array such that data[x][y][band] gives each pixel value. 29 **kwds: 30 wav = A numpy array containing band wavelengths for this image. 31 affine = an affine transform of the format returned by GDAL.GetGeoTransform(). 32 projection = string defining the project. Default is None. 33 sensor = sensor name. Default is "unknown". 34 header = path to associated header file. Default is None. 35 """ 36 37 #call constructor for HyData 38 super().__init__(data, **kwds) 39 40 # special case - if dataset only has oneband, slice it so it still has 41 # the format data[x,y,b]. 42 if not self.data is None: 43 if len(self.data.shape) == 1: 44 self.data = self.data[None, None, :] # single pixel image 45 if len(self.data.shape) == 2: 46 self.data = self.data[:, :, None] # single band iamge 47 48 #load any additional project information (specific to images) 49 self.set_projection(kwds.get("projection",None)) 50 self.affine = kwds.get("affine",[0,1,0,0,0,1]) 51 52 # wavelengths 53 if 'wav' in kwds: 54 self.set_wavelengths(kwds['wav']) 55 56 #special header formatting 57 self.header['file type'] = 'ENVI Standard' 58 59 def copy(self,data=True): 60 """ 61 Make a deep copy of this image instance. 62 63 Args: 64 data (bool): True if a copy of the data should be made, otherwise only copy header. 65 66 Returns: 67 a new HyImage instance. 68 """ 69 if not data: 70 return HyImage(None, header=self.header.copy(), projection=self.projection, affine=self.affine) 71 else: 72 return HyImage( self.data.copy(), header=self.header.copy(), projection=self.projection, affine=self.affine) 73 74 def T(self): 75 """ 76 Return a transposed view of the data matrix (corresponding with the [y,x] indexing used by matplotlib, opencv etc. 77 """ 78 return np.transpose(self.data, (1,0,2)) 79 80 def xdim(self): 81 """ 82 Return number of pixels in x (first dimension of data array) 83 """ 84 return self.data.shape[0] 85 86 def ydim(self): 87 """ 88 Return number of pixels in y (second dimension of data array) 89 """ 90 return self.data.shape[1] 91 92 def aspx(self): 93 """ 94 Return the aspect ratio of this image (width/height). 95 """ 96 return self.ydim() / self.xdim() 97 98 def get_extent(self): 99 """ 100 Returns the width and height of this image in world coordinates. 101 102 Returns: 103 tuple with (width, height). 104 """ 105 return self.xdim * self.pixel_size[0], self.ydim * self.pixel_size[1] 106 107 def set_projection(self,proj): 108 """ 109 Set this project to an existing osgeo.osr.SpatialReference or GDAL georeference string. 110 111 Args: 112 proj (str, osgeo.osr.SpatialReference): the project to use as osgeo.osr.SpatialReference or GDAL georeference string. 113 """ 114 if proj is None: 115 self.projection = None 116 else: 117 try: 118 from osgeo.osr import SpatialReference 119 except: 120 assert False, "Error - GDAL must be installed to work with spatial projections in hylite." 121 if isinstance(proj, SpatialReference): 122 self.projection = proj 123 elif isinstance(proj, str): 124 self.projection = SpatialReference(proj) 125 else: 126 print("Invalid project %s" % proj) 127 raise 128 129 def set_projection_EPSG(self,EPSG): 130 """ 131 Sets this image project using an EPSG code. 132 133 Args: 134 EPSG (str): string EPSG code that can be passed to SpatialReference.SetFromUserInput(...). 135 """ 136 137 try: 138 from osgeo.osr import SpatialReference 139 except: 140 assert False, "Error - GDAL must be installed to work with spatial projections in hylite." 141 142 self.projection = SpatialReference() 143 self.projection.SetFromUserInput(EPSG) 144 145 def get_projection_EPSG(self): 146 """ 147 Gets a string describing this projections EPSG code (if it is an EPSG project). 148 149 Returns: 150 an EPSG code string of the format "EPSG:XXXX". 151 """ 152 if self.projection is None: 153 return None 154 else: 155 return "%s:%s" % (self.projection.GetAttrValue("AUTHORITY",0),self.projection.GetAttrValue("AUTHORITY",1)) 156 157 def pix_to_world(self, px, py, proj=None): 158 """ 159 Take pixel coordinates and return world coordinates 160 161 Args: 162 px (int): the pixel x-coord. 163 py (int): the pixel y-coord. 164 proj (str, osr.SpatialReference): the coordinate system to use. Default (None) uses the same system as this image. Otherwise 165 an osr.SpatialReference can be passed (HyImage.project), or an EPSG string (e.g. get_projection_EPSG(...)). 166 Returns: 167 the world coordinates in the coordinate system defined by get_projection_EPSG(...). 168 """ 169 170 try: 171 from osgeo import osr 172 import osgeo.gdal as gdal 173 from osgeo import ogr 174 except: 175 assert False, "Error - GDAL must be installed to work with spatial projections in hylite." 176 177 # parse project 178 if proj is None: 179 proj = self.projection 180 elif isinstance(proj, str) or isinstance(proj, int): 181 epsg = proj 182 if isinstance(epsg, str): 183 try: 184 epsg = int(str.split(':')[1]) 185 except: 186 assert False, "Error - %s is an invalid EPSG code." % proj 187 proj = osr.SpatialReference() 188 proj.ImportFromEPSG(epsg) 189 190 # check we have all the required info 191 assert isinstance(proj, osr.SpatialReference), "Error - invalid spatial reference %s" % proj 192 assert (not self.affine is None) and ( 193 not self.projection is None), "Error - project information is undefined." 194 195 #project to world coordinates in this images project/world coords 196 x,y = gdal.ApplyGeoTransform(self.affine, px, py) 197 198 #project to target coords (if different) 199 if not proj.IsSameGeogCS(self.projection): 200 P = ogr.Geometry(ogr.wkbPoint) 201 if proj.EPSGTreatsAsNorthingEasting(): 202 P.AddPoint(x, y) 203 else: 204 P.AddPoint(y, x) 205 P.AssignSpatialReference(self.projection) # tell the point what coordinates it's in 206 P.TransformTo(proj) # reproject it to the out spatial reference 207 x, y = P.GetX(), P.GetY() 208 209 #do we need to transpose? 210 if proj.EPSGTreatsAsLatLong(): 211 x,y=y,x #we want lon,lat not lat,lon 212 return x, y 213 214 def world_to_pix(self, x, y, proj = None): 215 """ 216 Take world coordinates and return pixel coordinates 217 218 Args: 219 x (float): the world x-coord. 220 y (float): the world y-coord. 221 proj (str, osr.SpatialReference): the coordinate system of the input coordinates. Default (None) uses the same system as this image. Otherwise 222 an osr.SpatialReference can be passed (HyImage.project), or an EPSG string (e.g. get_projection_EPSG(...)). 223 224 Returns: 225 the pixel coordinates based on the affine transform stored in self.affine. 226 """ 227 228 try: 229 from osgeo import osr 230 import osgeo.gdal as gdal 231 from osgeo import ogr 232 except: 233 assert False, "Error - GDAL must be installed to work with spatial projections in hylite." 234 235 # parse project 236 if proj is None: 237 proj = self.projection 238 elif isinstance(proj, str) or isinstance(proj, int): 239 epsg = proj 240 if isinstance(epsg, str): 241 try: 242 epsg = int(str.split(':')[1]) 243 except: 244 assert False, "Error - %s is an invalid EPSG code." % proj 245 proj = osr.SpatialReference() 246 proj.ImportFromEPSG(epsg) 247 248 249 # check we have all the required info 250 assert isinstance(proj, osr.SpatialReference), "Error - invalid spatial reference %s" % proj 251 assert (not self.affine is None) and (not self.projection is None), "Error - project information is undefined." 252 253 # project to this images CS (if different) 254 if not proj.IsSameGeogCS(self.projection): 255 P = ogr.Geometry(ogr.wkbPoint) 256 if proj.EPSGTreatsAsNorthingEasting(): 257 P.AddPoint(x, y) 258 else: 259 P.AddPoint(y, x) 260 P.AssignSpatialReference(proj) # tell the point what coordinates it's in 261 P.AddPoint(x, y) 262 P.TransformTo(self.projection) # reproject it to the out spatial reference 263 x, y = P.GetX(), P.GetY() 264 if self.projection.EPSGTreatsAsLatLong(): # do we need to transpose? 265 x, y = y, x # we want lon,lat not lat,lon 266 267 inv = gdal.InvGeoTransform(self.affine) 268 assert not inv is None, "Error - could not invert affine transform?" 269 270 #apply 271 return gdal.ApplyGeoTransform(inv, x, y) 272 273 def flip(self, axis='x'): 274 """ 275 Flip the image on the x or y axis. 276 277 Args: 278 axis (str): 'x' or 'y' or both 'xy'. 279 """ 280 281 if 'x' in axis.lower(): 282 self.data = np.flip(self.data,axis=0) 283 if 'y' in axis.lower(): 284 self.data = np.flip(self.data,axis=1) 285 286 def rot90(self): 287 """ 288 Rotate this image by 90 degrees by transposing the underlying data array. Combine with flip('x') or flip('y') 289 to achieve positive/negative rotations. 290 """ 291 self.data = np.transpose( self.data, (1,0,2) ) 292 self.push_to_header() 293 294 ##################################### 295 ##IMAGE FILTERING 296 ##################################### 297 def fill_holes(self): 298 """ 299 Replaces nan pixel with an average of their neighbours, thus removing 1-pixel large holes from an image. Note that 300 for performance reasons this assumes that holes line up across bands. Note that this is not vectorized so very slow... 301 """ 302 303 # perform greyscale dilation 304 dilate = self.data.copy() 305 mask = np.logical_not(np.isfinite(dilate)) 306 dilate[mask] = 0 307 for b in range(self.band_count()): 308 dilate[:, :, b] = sp.ndimage.grey_dilation(dilate[:, :, b], size=(3, 3)) 309 310 # map back to holes in dataset 311 self.data[mask] = dilate[mask] 312 #self.data[self.data == 0] = np.nan # replace remaining 0's with nans 313 314 def blur(self, n=3): 315 """ 316 Applies a gaussian kernel of size n to the image using OpenCV. 317 318 Args: 319 n (int): the dimensions of the gaussian kernel to convolve. Default is 3. Increase for more blurry results. 320 """ 321 import cv2 # import this here to avoid errors if opencv is not installed properly 322 323 nanmask = np.isnan(self.data) 324 assert isinstance(n, int) and n >= 3, "Error - invalid kernel. N must be an integer > 3. " 325 kernel = np.ones((n, n), np.float32) / (n ** 2) 326 self.data = cv2.filter2D(self.data, -1, kernel) 327 self.data[nanmask] = np.nan # remove mask 328 329 def erode(self, size=3, iterations=1): 330 """ 331 Apply an erode filter to this image to expand background (nan) pixels. Refer to open-cv's erode 332 function for more details. 333 334 Args: 335 size (int): the size of the erode filter. Default is a 3x3 kernel. 336 iterations (int): the number of erode iterations. Default is 1. 337 """ 338 import cv2 # import this here to avoid errors if opencv is not installed properly 339 340 # erode 341 kernel = np.ones((size, size), np.uint8) 342 if self.is_float(): 343 mask = np.isfinite(self.data).any(axis=-1) 344 mask = cv2.erode(mask.astype(np.uint8), kernel, iterations=iterations) 345 self.data[mask == 0, :] = np.nan 346 else: 347 mask = (self.data != 0).any( axis=-1 ) 348 mask = cv2.erode(mask.astype(np.uint8), kernel, iterations=iterations) 349 self.data[mask == 0, :] = 0 350 351 def resize(self, newdims : tuple, interpolation : int = 1): 352 """ 353 Resize this image with opencv. 354 355 Args: 356 newdims (tuple): the new image dimensions. 357 interpolation (int): opencv interpolation method. Default is cv2.INTER_LINEAR. 358 """ 359 import cv2 # import this here to avoid errors if opencv is not installed properly 360 self.data = cv2.resize(self.data, (newdims[1],newdims[0]), interpolation=interpolation) 361 362 def despeckle(self, size=5): 363 """ 364 Despeckle each band of this image (independently) using a median filter. 365 366 Args: 367 size (int): the size of the median filter kernel. Default is 5. Must be an odd number. 368 """ 369 370 assert (size % 2) == 1, "Error - size must be an odd integer" 371 import cv2 # import this here to avoid errors if opencv is not installed properly 372 if self.is_float(): 373 self.data = cv2.medianBlur( self.data.astype(np.float32), size ) 374 else: 375 self.data = cv2.medianBlur( self.data, size ) 376 377 ##################################### 378 ##FEATURES AND FEATURE MATCHING 379 ###################################### 380 def get_keypoints(self, band, eq=False, mask=True, method='sift', cfac=0.0,bfac=0.0, **kwds): 381 """ 382 Get feature descriptors from the specified band. 383 384 Args: 385 band (int,float,str,tuple): the band index (int) or wavelength (float) to extract features from. Alternatively, a tuple can be passed 386 containing a range of bands (min : max) to average before feature matching. 387 eq (bool): True if the image should be histogram equalized first. Default is False. 388 mask (bool): True if 0 value pixels should be masked. Default is True. 389 method (str): the feature detector to use. Options are 'SIFT' and 'ORB' (faster but less accurate). Default is 'SIFT'. 390 cfac (float): contrast adjustment to apply to hyperspectral bands before matching. Default is 0.0. 391 bfac (float): brightness adjustment to apply to hyperspectral bands before matching. Default is 0.0. 392 **kwds: keyword arguments are passed to the opencv feature detector. For SIFT these are: 393 394 - contrastThreshold: default is 0.01. 395 - edgeThreshold: default is 10. 396 - sigma: default is 1.0 397 398 For ORB these are: 399 400 - nfeatures = the number of features to detect. Default is 5000. 401 402 Returns: 403 Tuple containing 404 405 - k (ndarray): the keypoints detected 406 - d (ndarray): corresponding feature descriptors 407 """ 408 import cv2 # import this here to avoid errors if opencv is not installed properly 409 410 # get image 411 if isinstance(band, int) or isinstance(band, float): #single band 412 image = self.data[:, :, self.get_band_index(band)] 413 elif isinstance(band,tuple): #range of bands (averaged) 414 idx0 = self.get_band_index(band[0]) 415 idx1 = self.get_band_index(band[1]) 416 417 #deal with out of range errors 418 if idx0 is None: 419 idx0 = 0 420 if idx1 is None: 421 idx1 = self.band_count() 422 423 #average bands 424 image = np.nanmean(self.data[:,:,idx0:idx1],axis=2) 425 else: 426 assert False, "Error, unrecognised band %s" % band 427 428 #normalise image to range 0 - 1 429 image -= np.nanmin(image) 430 image = image / np.nanmax(image) 431 432 #apply brightness/contrast adjustment 433 image = (1.0+cfac)*image + bfac 434 image[image > 1.0] = 1.0 435 image[image < 0.0] = 0.0 436 437 #convert image to uint8 for opencv 438 image = np.uint8(255 * image) 439 if eq: 440 image = cv2.equalizeHist(image) 441 442 if mask: 443 mask = np.zeros(image.shape, dtype=np.uint8) 444 mask[image != 0] = 255 # include only non-zero pixels 445 else: 446 mask = None 447 448 if 'sift' in method.lower(): # SIFT 449 450 # setup default keywords 451 kwds["contrastThreshold"] = kwds.get("contrastThreshold", 0.01) 452 kwds["edgeThreshold"] = kwds.get("edgeThreshold", 10) 453 kwds["sigma"] = kwds.get("sigma", 1.0) 454 455 # make feature detector 456 #alg = cv2.xfeatures2d.SIFT_create(**kwds) 457 alg = cv2.SIFT_create() 458 elif 'orb' in method.lower(): # orb 459 kwds['nfeatures'] = kwds.get('nfeatures', 5000) 460 alg = cv2.ORB_create(scoreType=cv2.ORB_FAST_SCORE, **kwds) 461 else: 462 assert False, "Error - %s is not a recognised feature detector." % method 463 464 # detect keypoints 465 kp = alg.detect(image, mask) 466 467 # extract and return feature vectors 468 return alg.compute(image, kp) 469 470 @classmethod 471 def match_keypoints(cls, kp1, kp2, d1, d2, method='SIFT', dist=0.7, tree = 5, check = 100, min_count=5): 472 """ 473 Compares keypoint feature vectors from two images and returns matching pairs. 474 475 Args: 476 kp1 (ndarray): keypoints from the first image 477 kp2 (ndarray): keypoints from the second image 478 d1 (ndarray): descriptors for the keypoints from the first image 479 d2 (ndarray): descriptors for the keypoints from the second image 480 method (str): the method used to calculate the feature descriptors. Should be 'sift' or 'orb'. Default is 'sift'. 481 dist (float): minimum match distance (0 to 1), default is 0.7 482 tree (int): not sure what this does? Default is 5. See open-cv docs. 483 check (int): ditto. Default is 100. 484 min_count (int): the minimum number of matches to consider a valid matching operation. If fewer matches are found, 485 then the function returns None, None. Default is 5. 486 """ 487 import cv2 # import this here to avoid errors if opencv is not installed properly 488 if 'sift' in method.lower(): 489 algorithm = cv2.NORM_INF 490 elif 'orb' in method.lower(): 491 algorithm = cv2.NORM_HAMMING 492 else: 493 assert False, "Error - unknown matching algorithm %s" % method 494 495 #calculate flann matches 496 index_params = dict(algorithm=algorithm, trees=tree) 497 search_params = dict(checks=check) 498 flann = cv2.FlannBasedMatcher(index_params, search_params) 499 matches = flann.knnMatch(d1, d2, k=2) 500 501 # store all the good matches as per Lowe's ratio test. 502 good = [] 503 for m, n in matches: 504 if m.distance < dist * n.distance: 505 good.append(m) 506 507 if len(good) < min_count: 508 return None, None 509 else: 510 src_pts = np.float32([kp1[m.queryIdx].pt for m in good]).reshape(-1, 1, 2) 511 dst_pts = np.float32([kp2[m.trainIdx].pt for m in good]).reshape(-1, 1, 2) 512 return src_pts, dst_pts 513 514 ############################ 515 ## Visualisation methods 516 ############################ 517 def quick_plot(self, band=0, ax=None, bfac=0.0, cfac=0.0, samples=False, tscale=False, rot=False, flipX=False, flipY=False, 518 **kwds): 519 """ 520 Plot a band using matplotlib.imshow(...). 521 522 Args: 523 band (str,int,float,tuple): the band name (string), index (integer) or wavelength (float) to plot. Default is 0. If a tuple is passed then 524 each band in the tuple (string or index) will be mapped to rgb. Bands with negative wavelengths or indices will be inverted before plotting. 525 ax: an axis object to plot to. If none, plt.imshow( ... ) is used. 526 bfac (float): a brightness adjustment to apply to RGB mappings (-1 to 1) 527 cfac (float): a contrast adjustment to apply to RGB mappings (-1 to 1) 528 samples (bool): True if sample points (defined in the header file) should be plotted. Default is False. Otherwise, a list of 529 [ (x,y), ... ] points can be passed. 530 tscale (bool): True if each band (for ternary images) should be scaled independently. Default is False. 531 When using scaling, vmin and vmax can be used to set the clipping percentiles (integers) or 532 (constant) values (float). 533 rot (bool): if True, the x and y axis will be flipped (90 degree rotation) before plotting. Default is False. 534 flipX (bool): if True, the x axis will be flipped before plotting (after applying rotations). 535 flipY (bool): if True, the y axis will be flippe before plotting (after applying rotations). 536 **kwds: keywords are passed to matplotlib.imshow( ... ), except for the following: 537 538 - mask = a 2D boolean mask containing true if pixels should be drawn and false otherwise. 539 - path = a file path to save the image too (at matching resolution; use fig.savefig(..) if you want to save the figure). 540 - ticks = True if x- and y- ticks should be plotted. Default is False. 541 - ps, pc = the size and color of sample points to plot. Can be constant or list. 542 - figsize = a figsize for the figure to create (if ax is None). 543 544 Returns: 545 Tuple containing 546 547 - fig: matplotlib figure object 548 - ax: matplotlib axes object. If a colorbar is created, (band is an integer or a float), then this will be stored in ax.cbar. 549 """ 550 551 #create new axes? 552 if ax is None: 553 fig, ax = plt.subplots(figsize=kwds.pop('figsize', (18,18*self.ydim()/self.xdim()) )) 554 555 # deal with ticks 556 if not kwds.pop('ticks', False ): 557 ax.set_xticks([]) 558 ax.set_yticks([]) 559 560 #map individual band using colourmap 561 if isinstance(band, str) or isinstance(band, int) or isinstance(band, float): 562 #get band 563 if isinstance(band, str): 564 data = self.data[:, :, self.get_band_index(band)] 565 else: 566 data = self.data[:, :, self.get_band_index(np.abs(band))] 567 if not isinstance(band, str) and band < 0: 568 data = np.nanmax(data) - data # flip 569 570 # convert integer vmin and vmax values to percentiles 571 if 'vmin' in kwds: 572 if isinstance(kwds['vmin'], int): 573 kwds['vmin'] = np.nanpercentile( data, kwds['vmin'] ) 574 if 'vmax' in kwds: 575 if isinstance(kwds['vmax'], int): 576 kwds['vmax'] = np.nanpercentile( data, kwds['vmax'] ) 577 578 #mask nans (and apply custom mask) 579 mask = np.isnan(data) 580 if not np.isnan(self.header.get_data_ignore_value()): 581 mask = mask + data == self.header.get_data_ignore_value() 582 if 'mask' in kwds: 583 mask = mask + kwds.get('mask') 584 del kwds['mask'] 585 data = np.ma.array(data, mask = mask > 0 ) 586 587 # apply rotations and flipping 588 if rot: 589 data = data.T 590 if flipX: 591 data = data[::-1, :] 592 if flipY: 593 data = data[:, ::-1] 594 595 # save? 596 if 'path' in kwds: 597 path = kwds.pop('path') 598 from matplotlib.pyplot import imsave 599 if not os.path.exists(os.path.dirname(path)): 600 os.makedirs(os.path.dirname(path)) # ensure output directory exists 601 imsave(path, data.T, **kwds) # save the image 602 603 ax.cbar = ax.imshow(data.T, interpolation=kwds.pop('interpolation', 'none'), **kwds) # change default interpolation to None 604 605 #map 3 bands to RGB 606 elif isinstance(band, tuple) or isinstance(band, list): 607 #get band indices and range 608 rgb = [] 609 for b in band: 610 if isinstance(b, str): 611 rgb.append(self.get_band_index(b)) 612 else: 613 rgb.append(self.get_band_index(np.abs(b))) 614 615 #slice image (as copy) and map to 0 - 1 616 img = np.array(self.data[:, :, rgb]).copy() 617 if np.isnan(img).all(): 618 print("Warning - image contains no data.") 619 return ax.get_figure(), ax 620 621 # invert if needed 622 for i,b in enumerate(band): 623 if not isinstance(b, str) and (b < 0): 624 img[..., i] = np.nanmax(img[..., i]) - img[..., i] 625 626 # do scaling 627 if tscale: # scale bands independently 628 for b in range(3): 629 mn = kwds.get("vmin", float(np.nanmin(img))) 630 mx = kwds.get("vmax", float(np.nanmax(img))) 631 if isinstance (mn, int): 632 assert mn >= 0 and mn <= 100, "Error - integer vmin values must be a percentile." 633 mn = float(np.nanpercentile(img[...,b], mn )) 634 if isinstance (mx, int): 635 assert mx >= 0 and mx <= 100, "Error - integer vmax values must be a percentile." 636 mx = float(np.nanpercentile(img[...,b], mx )) 637 img[...,b] = (img[..., b] - mn) / (mx - mn) 638 else: # scale bands together 639 mn = kwds.get("vmin", float(np.nanmin(img))) 640 mx = kwds.get("vmax", float(np.nanmax(img))) 641 if isinstance(mn, int): 642 assert mn >= 0 and mn <= 100, "Error - integer vmin values must be a percentile." 643 mn = float(np.nanpercentile(img, mn)) 644 if isinstance(mx, int): 645 assert mx >= 0 and mx <= 100, "Error - integer vmax values must be a percentile." 646 mx = float(np.nanpercentile(img, mx)) 647 img = (img - mn) / (mx - mn) 648 649 #apply brightness/contrast mapping 650 img = np.clip((1.0 + cfac) * img + bfac, 0, 1.0 ) 651 652 #apply masking so background is white 653 img[np.logical_not( np.isfinite( img ) )] = 1.0 654 if 'mask' in kwds: 655 img[kwds.pop("mask"),:] = 1.0 656 657 # apply rotations and flipping 658 if rot: 659 img = np.transpose( img, (1,0,2) ) 660 if flipX: 661 img = img[::-1, :, :] 662 if flipY: 663 img = img[:, ::-1, :] 664 665 # save? 666 if 'path' in kwds: 667 path = kwds.pop('path') 668 from matplotlib.pyplot import imsave 669 if not os.path.exists(os.path.dirname(path)): 670 os.makedirs(os.path.dirname(path)) # ensure output directory exists 671 imsave(path, np.transpose( np.clip( img*255, 0, 255).astype(np.uint8), (1, 0, 2))) # save the image 672 673 # plot samples? 674 ps = kwds.pop('ps', 5) 675 pc = kwds.pop('pc', 'r') 676 if samples: 677 if isinstance(samples, list) or isinstance(samples, np.ndarray): 678 ax.scatter([s[0] for s in samples], [s[1] for s in samples], s=ps, c=pc) 679 else: 680 for n in self.header.get_class_names(): 681 points = np.array(self.header.get_sample_points(n)) 682 ax.scatter(points[:, 0], points[:, 1], s=ps, c=pc) 683 684 #plot 685 ax.imshow(np.transpose(img, (1,0,2)), interpolation=kwds.pop('interpolation', 'none'), **kwds) 686 ax.cbar = None # no colorbar 687 688 return ax.get_figure(), ax 689 690 def createGIF(self, path, bands=None, figsize=(10,10), fps=10, **kwds): 691 """ 692 Create and save an animated gif that loops through the bands of the image. 693 694 Args: 695 path (str): the path to save the .gif 696 bands (tuple): Tuple containing the range of band indices to draw. Default is the whole range. 697 figsize (tuple): the size of the image to draw. Default is (10,10). 698 fps (int): the framerate (frames per second) of the gif. Default is 10. 699 **kwds: keywords are passed directly to matplotlib.imshow. Use this to specify cmap etc. 700 """ 701 702 frames = [] 703 if bands is None: 704 bands = (0,self.band_count()) 705 else: 706 assert 0 < bands[0] < self.band_count(), "Error - invalid range." 707 assert 0 < bands[1] < self.band_count(), "Error - invalid range." 708 assert bands[1] > bands[0], "Error - invalid range." 709 710 #plot frames 711 for i in range(bands[0],bands[1]): 712 fig, ax = plt.subplots(figsize=figsize) 713 ax.imshow(self.data[:, :, i], **kwds) 714 fig.canvas.draw() 715 frames.append(np.frombuffer(fig.canvas.tostring_rgb(), dtype='uint8')) 716 frames[-1] = np.reshape(frames[-1], (fig.canvas.get_width_height()[1], fig.canvas.get_width_height()[0], 3)) 717 plt.close(fig) 718 719 #save gif 720 imageio.mimsave( os.path.splitext(path)[0] + ".gif", frames, fps=fps) 721 722 ## masking 723 def mask(self, mask=None, flag=np.nan, invert=False, crop=False, bands=None): 724 """ 725 Apply a mask to an image, flagging masked pixels with the specified value. Note that this applies the mask to the 726 image in-situ. 727 728 Args: 729 flag (float): the value to use for masked pixels. Default is np.nan 730 mask (ndarray): a numpy array defining the mask polygon of the format [[x1,y1],[x2,y2],...]. If None is passed then 731 pickPolygon( ... ) is used to interactively define a polygon. If a file path is passed then the polygon 732 will be loaded using np.load( ... ). Alternatively if mask.shape == image.shape[0,1] then it is treated as a 733 binary image mask (must be boolean) and True values will be masked across all bands. Default is None. 734 invert (bool): if True, pixels within the polygon will be masked. If False, pixels outside the polygon are masked. Default is False. 735 crop (bool): True if rows/columns containing only zeros should be removed. Default is False. 736 bands (tuple): the bands of the image to plot if no mask is specified. If None, the middle band is used. 737 738 Returns: 739 Tuple containing 740 741 - mask (ndarray): a boolean array with True where pixels are masked and False elsewhere. 742 - poly (ndarray): the mask polygon array in the format described above. Useful if the polygon was interactively defined. 743 """ 744 745 if mask is None: # pick mask interactively 746 if bands is None: 747 bands = int(self.band_count() / 2) 748 749 regions = self.pickPolygons(region_names=["mask"], bands=bands) 750 751 # the user bailed without picking a mask? 752 if len(regions) == 0: 753 print("Warning - no mask picked/applied.") 754 return 755 756 # extract polygon mask 757 mask = regions[0] 758 759 # convert polygon mask to binary mask 760 if mask.shape[1] == 2: 761 762 # build meshgrid with pixel coords 763 xx, yy = np.meshgrid(np.arange(self.xdim()), np.arange(self.ydim())) 764 xx = xx.flatten() 765 yy = yy.flatten() 766 points = np.vstack([xx, yy]).T # coordinates of each pixel 767 768 # calculate per-pixel mask 769 mask = path.Path(mask).contains_points(points) 770 mask = mask.reshape((self.ydim(), self.xdim())).T 771 772 # flip as we want to mask (==True) outside points (unless invert is true) 773 if not invert: 774 mask = np.logical_not(mask) 775 776 # apply binary image mask 777 assert mask.shape[0] == self.data.shape[0] and mask.shape[1] == self.data.shape[1], \ 778 "Error - mask shape %s does not match image shape %s" % (mask.shape, self.data.shape) 779 for b in range(self.band_count()): 780 self.data[:, :, b][mask] = flag 781 782 # crop image 783 if crop: 784 # calculate non-masked pixels 785 valid = np.logical_not(mask) 786 787 # integrate along axes 788 xdata = np.sum(valid, axis=1) > 0.0 789 ydata = np.sum(valid, axis=0) > 0.0 790 791 # calculate domain containing valid pixels 792 xmin = np.argmax(xdata) 793 xmax = xdata.shape[0] - np.argmax(xdata[::-1]) 794 ymin = np.argmax(ydata) 795 ymax = ydata.shape[0] - np.argmax(ydata[::-1]) 796 797 # crop 798 self.data = self.data[xmin:xmax, ymin:ymax, :] 799 800 return mask 801 802 def crop_to_data(self): 803 """ 804 Remove padding of nan or zero pixels from image. Note that this is performed in place. 805 """ 806 807 valid = np.isfinite(self.data).any(axis=-1) & (self.data != 0).any(axis=-1) 808 ymin, ymax = np.percentile(np.argwhere(np.sum(valid, axis=0) != 0), (0, 100)) 809 xmin, xmax = np.percentile(np.argwhere(np.sum(valid, axis=1) != 0), (0, 100)) 810 self.data = self.data[int(xmin):int(xmax), int(ymin):int(ymax), :] # do clipping 811 812 ################################################## 813 ## Interactive tools for picking regions/pixels 814 ################################################## 815 def pickPolygons(self, region_names, bands=0): 816 """ 817 Creates a matplotlib gui for selecting polygon regions in an image. 818 819 Args: 820 names (list, str): a list containing the names of the regions to pick. If a string is passed only one name is used. 821 bands (tuple): the bands of the image to plot. 822 """ 823 824 if isinstance(region_names, str): 825 region_names = [region_names] 826 827 assert isinstance(region_names, list), "Error - names must be a list or a string." 828 829 # set matplotlib backend 830 backend = matplotlib.get_backend() 831 matplotlib.use('Qt5Agg') # need this backend for ROIPoly to work 832 833 # plot image and extract roi's 834 fig, ax = self.quick_plot(bands) 835 roi = MultiRoi(roi_names=region_names) 836 plt.close(fig) # close figure 837 838 # extract regions 839 regions = [] 840 for name, r in roi.rois.items(): 841 # store region 842 x = r.x 843 y = r.y 844 regions.append(np.vstack([x, y]).T) 845 846 # restore matplotlib backend (if possible) 847 try: 848 matplotlib.use(backend) 849 except: 850 print("Warning: could not reset matplotlib backend. Plots will remain interactive...") 851 pass 852 853 return regions 854 855 def pickPoints(self, n=-1, bands=hylite.RGB, integer=True, title="Pick Points", **kwds): 856 """ 857 Creates a matplotlib gui for picking pixels from an image. 858 859 Args: 860 n (int): the number of pixels to pick, or -1 if the user can select as many as they wish. Default is -1. 861 bands (tuple): the bands of the image to plot. Default is HyImage.RGB 862 integer (bool): True if points coordinates should be cast to integers (for use as indices). Default is True. 863 title (str): The title of the point picking window. 864 **kwds: Keywords are passed to HyImage.quick_plot( ... ). 865 866 Returns: 867 A list containing the picked point coordinates [ (x1,y1), (x2,y2), ... ]. 868 """ 869 870 # set matplotlib backend 871 backend = matplotlib.get_backend() 872 matplotlib.use('Qt5Agg') # need this backend for ROIPoly to work 873 874 # create figure 875 fig, ax = self.quick_plot( bands, **kwds ) 876 ax.set_title(title) 877 878 # get points 879 points = fig.ginput( n ) 880 881 if integer: 882 points = [ (int(p[0]), int(p[1])) for p in points ] 883 884 # restore matplotlib backend (if possible) 885 try: 886 matplotlib.use(backend) 887 except: 888 print("Warning: could not reset matplotlib backend. Plots will remain interactive...") 889 pass 890 891 return points 892 893 def pickSamples(self, names=None, store=True, **kwds): 894 """ 895 Pick sample probe points and store these in the image header file. 896 897 Args: 898 names (str, list): the name of the sample to pick, or a list of names to pick multiple. 899 store (bool): True if sample should be stored in the image header file (for later access). Default is True. 900 **kwds: Keywords are passed to HyImage.quick_plot( ... ) 901 902 Returns: 903 a list containing a list of points for each sample. 904 """ 905 906 if isinstance(names, str): 907 names = [names] 908 909 # pick points 910 points = [] 911 for s in names: 912 pnts = self.pickPoints(title="%s" % s, **kwds) 913 if store: 914 self.header['sample %s' % s] = pnts # store in header 915 points.append(pnts) 916 # add class to header file 917 if store: 918 cls_names = self.header.get_class_names() 919 if cls_names is None: 920 cls_names = [] 921 self.header['class names'] = cls_names + names 922 923 return points
A class for hyperspectral image data. These can be individual scenes or hyperspectral orthoimages.
25 def __init__(self, data, **kwds): 26 """ 27 Args: 28 data (ndarray): a numpy array such that data[x][y][band] gives each pixel value. 29 **kwds: 30 wav = A numpy array containing band wavelengths for this image. 31 affine = an affine transform of the format returned by GDAL.GetGeoTransform(). 32 projection = string defining the project. Default is None. 33 sensor = sensor name. Default is "unknown". 34 header = path to associated header file. Default is None. 35 """ 36 37 #call constructor for HyData 38 super().__init__(data, **kwds) 39 40 # special case - if dataset only has oneband, slice it so it still has 41 # the format data[x,y,b]. 42 if not self.data is None: 43 if len(self.data.shape) == 1: 44 self.data = self.data[None, None, :] # single pixel image 45 if len(self.data.shape) == 2: 46 self.data = self.data[:, :, None] # single band iamge 47 48 #load any additional project information (specific to images) 49 self.set_projection(kwds.get("projection",None)) 50 self.affine = kwds.get("affine",[0,1,0,0,0,1]) 51 52 # wavelengths 53 if 'wav' in kwds: 54 self.set_wavelengths(kwds['wav']) 55 56 #special header formatting 57 self.header['file type'] = 'ENVI Standard'
Arguments:
- data (ndarray): a numpy array such that data[x][y][band] gives each pixel value.
- **kwds: wav = A numpy array containing band wavelengths for this image. affine = an affine transform of the format returned by GDAL.GetGeoTransform(). projection = string defining the project. Default is None. sensor = sensor name. Default is "unknown". header = path to associated header file. Default is None.
59 def copy(self,data=True): 60 """ 61 Make a deep copy of this image instance. 62 63 Args: 64 data (bool): True if a copy of the data should be made, otherwise only copy header. 65 66 Returns: 67 a new HyImage instance. 68 """ 69 if not data: 70 return HyImage(None, header=self.header.copy(), projection=self.projection, affine=self.affine) 71 else: 72 return HyImage( self.data.copy(), header=self.header.copy(), projection=self.projection, affine=self.affine)
Make a deep copy of this image instance.
Arguments:
- data (bool): True if a copy of the data should be made, otherwise only copy header.
Returns:
a new HyImage instance.
74 def T(self): 75 """ 76 Return a transposed view of the data matrix (corresponding with the [y,x] indexing used by matplotlib, opencv etc. 77 """ 78 return np.transpose(self.data, (1,0,2))
Return a transposed view of the data matrix (corresponding with the [y,x] indexing used by matplotlib, opencv etc.
80 def xdim(self): 81 """ 82 Return number of pixels in x (first dimension of data array) 83 """ 84 return self.data.shape[0]
Return number of pixels in x (first dimension of data array)
86 def ydim(self): 87 """ 88 Return number of pixels in y (second dimension of data array) 89 """ 90 return self.data.shape[1]
Return number of pixels in y (second dimension of data array)
92 def aspx(self): 93 """ 94 Return the aspect ratio of this image (width/height). 95 """ 96 return self.ydim() / self.xdim()
Return the aspect ratio of this image (width/height).
98 def get_extent(self): 99 """ 100 Returns the width and height of this image in world coordinates. 101 102 Returns: 103 tuple with (width, height). 104 """ 105 return self.xdim * self.pixel_size[0], self.ydim * self.pixel_size[1]
Returns the width and height of this image in world coordinates.
Returns:
tuple with (width, height).
107 def set_projection(self,proj): 108 """ 109 Set this project to an existing osgeo.osr.SpatialReference or GDAL georeference string. 110 111 Args: 112 proj (str, osgeo.osr.SpatialReference): the project to use as osgeo.osr.SpatialReference or GDAL georeference string. 113 """ 114 if proj is None: 115 self.projection = None 116 else: 117 try: 118 from osgeo.osr import SpatialReference 119 except: 120 assert False, "Error - GDAL must be installed to work with spatial projections in hylite." 121 if isinstance(proj, SpatialReference): 122 self.projection = proj 123 elif isinstance(proj, str): 124 self.projection = SpatialReference(proj) 125 else: 126 print("Invalid project %s" % proj) 127 raise
Set this project to an existing osgeo.osr.SpatialReference or GDAL georeference string.
Arguments:
- proj (str, osgeo.osr.SpatialReference): the project to use as osgeo.osr.SpatialReference or GDAL georeference string.
129 def set_projection_EPSG(self,EPSG): 130 """ 131 Sets this image project using an EPSG code. 132 133 Args: 134 EPSG (str): string EPSG code that can be passed to SpatialReference.SetFromUserInput(...). 135 """ 136 137 try: 138 from osgeo.osr import SpatialReference 139 except: 140 assert False, "Error - GDAL must be installed to work with spatial projections in hylite." 141 142 self.projection = SpatialReference() 143 self.projection.SetFromUserInput(EPSG)
Sets this image project using an EPSG code.
Arguments:
- EPSG (str): string EPSG code that can be passed to SpatialReference.SetFromUserInput(...).
145 def get_projection_EPSG(self): 146 """ 147 Gets a string describing this projections EPSG code (if it is an EPSG project). 148 149 Returns: 150 an EPSG code string of the format "EPSG:XXXX". 151 """ 152 if self.projection is None: 153 return None 154 else: 155 return "%s:%s" % (self.projection.GetAttrValue("AUTHORITY",0),self.projection.GetAttrValue("AUTHORITY",1))
Gets a string describing this projections EPSG code (if it is an EPSG project).
Returns:
an EPSG code string of the format "EPSG:XXXX".
157 def pix_to_world(self, px, py, proj=None): 158 """ 159 Take pixel coordinates and return world coordinates 160 161 Args: 162 px (int): the pixel x-coord. 163 py (int): the pixel y-coord. 164 proj (str, osr.SpatialReference): the coordinate system to use. Default (None) uses the same system as this image. Otherwise 165 an osr.SpatialReference can be passed (HyImage.project), or an EPSG string (e.g. get_projection_EPSG(...)). 166 Returns: 167 the world coordinates in the coordinate system defined by get_projection_EPSG(...). 168 """ 169 170 try: 171 from osgeo import osr 172 import osgeo.gdal as gdal 173 from osgeo import ogr 174 except: 175 assert False, "Error - GDAL must be installed to work with spatial projections in hylite." 176 177 # parse project 178 if proj is None: 179 proj = self.projection 180 elif isinstance(proj, str) or isinstance(proj, int): 181 epsg = proj 182 if isinstance(epsg, str): 183 try: 184 epsg = int(str.split(':')[1]) 185 except: 186 assert False, "Error - %s is an invalid EPSG code." % proj 187 proj = osr.SpatialReference() 188 proj.ImportFromEPSG(epsg) 189 190 # check we have all the required info 191 assert isinstance(proj, osr.SpatialReference), "Error - invalid spatial reference %s" % proj 192 assert (not self.affine is None) and ( 193 not self.projection is None), "Error - project information is undefined." 194 195 #project to world coordinates in this images project/world coords 196 x,y = gdal.ApplyGeoTransform(self.affine, px, py) 197 198 #project to target coords (if different) 199 if not proj.IsSameGeogCS(self.projection): 200 P = ogr.Geometry(ogr.wkbPoint) 201 if proj.EPSGTreatsAsNorthingEasting(): 202 P.AddPoint(x, y) 203 else: 204 P.AddPoint(y, x) 205 P.AssignSpatialReference(self.projection) # tell the point what coordinates it's in 206 P.TransformTo(proj) # reproject it to the out spatial reference 207 x, y = P.GetX(), P.GetY() 208 209 #do we need to transpose? 210 if proj.EPSGTreatsAsLatLong(): 211 x,y=y,x #we want lon,lat not lat,lon 212 return x, y
Take pixel coordinates and return world coordinates
Arguments:
- px (int): the pixel x-coord.
- py (int): the pixel y-coord.
- proj (str, osr.SpatialReference): the coordinate system to use. Default (None) uses the same system as this image. Otherwise an osr.SpatialReference can be passed (HyImage.project), or an EPSG string (e.g. get_projection_EPSG(...)).
Returns:
the world coordinates in the coordinate system defined by get_projection_EPSG(...).
214 def world_to_pix(self, x, y, proj = None): 215 """ 216 Take world coordinates and return pixel coordinates 217 218 Args: 219 x (float): the world x-coord. 220 y (float): the world y-coord. 221 proj (str, osr.SpatialReference): the coordinate system of the input coordinates. Default (None) uses the same system as this image. Otherwise 222 an osr.SpatialReference can be passed (HyImage.project), or an EPSG string (e.g. get_projection_EPSG(...)). 223 224 Returns: 225 the pixel coordinates based on the affine transform stored in self.affine. 226 """ 227 228 try: 229 from osgeo import osr 230 import osgeo.gdal as gdal 231 from osgeo import ogr 232 except: 233 assert False, "Error - GDAL must be installed to work with spatial projections in hylite." 234 235 # parse project 236 if proj is None: 237 proj = self.projection 238 elif isinstance(proj, str) or isinstance(proj, int): 239 epsg = proj 240 if isinstance(epsg, str): 241 try: 242 epsg = int(str.split(':')[1]) 243 except: 244 assert False, "Error - %s is an invalid EPSG code." % proj 245 proj = osr.SpatialReference() 246 proj.ImportFromEPSG(epsg) 247 248 249 # check we have all the required info 250 assert isinstance(proj, osr.SpatialReference), "Error - invalid spatial reference %s" % proj 251 assert (not self.affine is None) and (not self.projection is None), "Error - project information is undefined." 252 253 # project to this images CS (if different) 254 if not proj.IsSameGeogCS(self.projection): 255 P = ogr.Geometry(ogr.wkbPoint) 256 if proj.EPSGTreatsAsNorthingEasting(): 257 P.AddPoint(x, y) 258 else: 259 P.AddPoint(y, x) 260 P.AssignSpatialReference(proj) # tell the point what coordinates it's in 261 P.AddPoint(x, y) 262 P.TransformTo(self.projection) # reproject it to the out spatial reference 263 x, y = P.GetX(), P.GetY() 264 if self.projection.EPSGTreatsAsLatLong(): # do we need to transpose? 265 x, y = y, x # we want lon,lat not lat,lon 266 267 inv = gdal.InvGeoTransform(self.affine) 268 assert not inv is None, "Error - could not invert affine transform?" 269 270 #apply 271 return gdal.ApplyGeoTransform(inv, x, y)
Take world coordinates and return pixel coordinates
Arguments:
- x (float): the world x-coord.
- y (float): the world y-coord.
- proj (str, osr.SpatialReference): the coordinate system of the input coordinates. Default (None) uses the same system as this image. Otherwise an osr.SpatialReference can be passed (HyImage.project), or an EPSG string (e.g. get_projection_EPSG(...)).
Returns:
the pixel coordinates based on the affine transform stored in self.affine.
273 def flip(self, axis='x'): 274 """ 275 Flip the image on the x or y axis. 276 277 Args: 278 axis (str): 'x' or 'y' or both 'xy'. 279 """ 280 281 if 'x' in axis.lower(): 282 self.data = np.flip(self.data,axis=0) 283 if 'y' in axis.lower(): 284 self.data = np.flip(self.data,axis=1)
Flip the image on the x or y axis.
Arguments:
- axis (str): 'x' or 'y' or both 'xy'.
286 def rot90(self): 287 """ 288 Rotate this image by 90 degrees by transposing the underlying data array. Combine with flip('x') or flip('y') 289 to achieve positive/negative rotations. 290 """ 291 self.data = np.transpose( self.data, (1,0,2) ) 292 self.push_to_header()
Rotate this image by 90 degrees by transposing the underlying data array. Combine with flip('x') or flip('y') to achieve positive/negative rotations.
297 def fill_holes(self): 298 """ 299 Replaces nan pixel with an average of their neighbours, thus removing 1-pixel large holes from an image. Note that 300 for performance reasons this assumes that holes line up across bands. Note that this is not vectorized so very slow... 301 """ 302 303 # perform greyscale dilation 304 dilate = self.data.copy() 305 mask = np.logical_not(np.isfinite(dilate)) 306 dilate[mask] = 0 307 for b in range(self.band_count()): 308 dilate[:, :, b] = sp.ndimage.grey_dilation(dilate[:, :, b], size=(3, 3)) 309 310 # map back to holes in dataset 311 self.data[mask] = dilate[mask] 312 #self.data[self.data == 0] = np.nan # replace remaining 0's with nans
Replaces nan pixel with an average of their neighbours, thus removing 1-pixel large holes from an image. Note that for performance reasons this assumes that holes line up across bands. Note that this is not vectorized so very slow...
314 def blur(self, n=3): 315 """ 316 Applies a gaussian kernel of size n to the image using OpenCV. 317 318 Args: 319 n (int): the dimensions of the gaussian kernel to convolve. Default is 3. Increase for more blurry results. 320 """ 321 import cv2 # import this here to avoid errors if opencv is not installed properly 322 323 nanmask = np.isnan(self.data) 324 assert isinstance(n, int) and n >= 3, "Error - invalid kernel. N must be an integer > 3. " 325 kernel = np.ones((n, n), np.float32) / (n ** 2) 326 self.data = cv2.filter2D(self.data, -1, kernel) 327 self.data[nanmask] = np.nan # remove mask
Applies a gaussian kernel of size n to the image using OpenCV.
Arguments:
- n (int): the dimensions of the gaussian kernel to convolve. Default is 3. Increase for more blurry results.
329 def erode(self, size=3, iterations=1): 330 """ 331 Apply an erode filter to this image to expand background (nan) pixels. Refer to open-cv's erode 332 function for more details. 333 334 Args: 335 size (int): the size of the erode filter. Default is a 3x3 kernel. 336 iterations (int): the number of erode iterations. Default is 1. 337 """ 338 import cv2 # import this here to avoid errors if opencv is not installed properly 339 340 # erode 341 kernel = np.ones((size, size), np.uint8) 342 if self.is_float(): 343 mask = np.isfinite(self.data).any(axis=-1) 344 mask = cv2.erode(mask.astype(np.uint8), kernel, iterations=iterations) 345 self.data[mask == 0, :] = np.nan 346 else: 347 mask = (self.data != 0).any( axis=-1 ) 348 mask = cv2.erode(mask.astype(np.uint8), kernel, iterations=iterations) 349 self.data[mask == 0, :] = 0
Apply an erode filter to this image to expand background (nan) pixels. Refer to open-cv's erode function for more details.
Arguments:
- size (int): the size of the erode filter. Default is a 3x3 kernel.
- iterations (int): the number of erode iterations. Default is 1.
351 def resize(self, newdims : tuple, interpolation : int = 1): 352 """ 353 Resize this image with opencv. 354 355 Args: 356 newdims (tuple): the new image dimensions. 357 interpolation (int): opencv interpolation method. Default is cv2.INTER_LINEAR. 358 """ 359 import cv2 # import this here to avoid errors if opencv is not installed properly 360 self.data = cv2.resize(self.data, (newdims[1],newdims[0]), interpolation=interpolation)
Resize this image with opencv.
Arguments:
- newdims (tuple): the new image dimensions.
- interpolation (int): opencv interpolation method. Default is cv2.INTER_LINEAR.
362 def despeckle(self, size=5): 363 """ 364 Despeckle each band of this image (independently) using a median filter. 365 366 Args: 367 size (int): the size of the median filter kernel. Default is 5. Must be an odd number. 368 """ 369 370 assert (size % 2) == 1, "Error - size must be an odd integer" 371 import cv2 # import this here to avoid errors if opencv is not installed properly 372 if self.is_float(): 373 self.data = cv2.medianBlur( self.data.astype(np.float32), size ) 374 else: 375 self.data = cv2.medianBlur( self.data, size )
Despeckle each band of this image (independently) using a median filter.
Arguments:
- size (int): the size of the median filter kernel. Default is 5. Must be an odd number.
380 def get_keypoints(self, band, eq=False, mask=True, method='sift', cfac=0.0,bfac=0.0, **kwds): 381 """ 382 Get feature descriptors from the specified band. 383 384 Args: 385 band (int,float,str,tuple): the band index (int) or wavelength (float) to extract features from. Alternatively, a tuple can be passed 386 containing a range of bands (min : max) to average before feature matching. 387 eq (bool): True if the image should be histogram equalized first. Default is False. 388 mask (bool): True if 0 value pixels should be masked. Default is True. 389 method (str): the feature detector to use. Options are 'SIFT' and 'ORB' (faster but less accurate). Default is 'SIFT'. 390 cfac (float): contrast adjustment to apply to hyperspectral bands before matching. Default is 0.0. 391 bfac (float): brightness adjustment to apply to hyperspectral bands before matching. Default is 0.0. 392 **kwds: keyword arguments are passed to the opencv feature detector. For SIFT these are: 393 394 - contrastThreshold: default is 0.01. 395 - edgeThreshold: default is 10. 396 - sigma: default is 1.0 397 398 For ORB these are: 399 400 - nfeatures = the number of features to detect. Default is 5000. 401 402 Returns: 403 Tuple containing 404 405 - k (ndarray): the keypoints detected 406 - d (ndarray): corresponding feature descriptors 407 """ 408 import cv2 # import this here to avoid errors if opencv is not installed properly 409 410 # get image 411 if isinstance(band, int) or isinstance(band, float): #single band 412 image = self.data[:, :, self.get_band_index(band)] 413 elif isinstance(band,tuple): #range of bands (averaged) 414 idx0 = self.get_band_index(band[0]) 415 idx1 = self.get_band_index(band[1]) 416 417 #deal with out of range errors 418 if idx0 is None: 419 idx0 = 0 420 if idx1 is None: 421 idx1 = self.band_count() 422 423 #average bands 424 image = np.nanmean(self.data[:,:,idx0:idx1],axis=2) 425 else: 426 assert False, "Error, unrecognised band %s" % band 427 428 #normalise image to range 0 - 1 429 image -= np.nanmin(image) 430 image = image / np.nanmax(image) 431 432 #apply brightness/contrast adjustment 433 image = (1.0+cfac)*image + bfac 434 image[image > 1.0] = 1.0 435 image[image < 0.0] = 0.0 436 437 #convert image to uint8 for opencv 438 image = np.uint8(255 * image) 439 if eq: 440 image = cv2.equalizeHist(image) 441 442 if mask: 443 mask = np.zeros(image.shape, dtype=np.uint8) 444 mask[image != 0] = 255 # include only non-zero pixels 445 else: 446 mask = None 447 448 if 'sift' in method.lower(): # SIFT 449 450 # setup default keywords 451 kwds["contrastThreshold"] = kwds.get("contrastThreshold", 0.01) 452 kwds["edgeThreshold"] = kwds.get("edgeThreshold", 10) 453 kwds["sigma"] = kwds.get("sigma", 1.0) 454 455 # make feature detector 456 #alg = cv2.xfeatures2d.SIFT_create(**kwds) 457 alg = cv2.SIFT_create() 458 elif 'orb' in method.lower(): # orb 459 kwds['nfeatures'] = kwds.get('nfeatures', 5000) 460 alg = cv2.ORB_create(scoreType=cv2.ORB_FAST_SCORE, **kwds) 461 else: 462 assert False, "Error - %s is not a recognised feature detector." % method 463 464 # detect keypoints 465 kp = alg.detect(image, mask) 466 467 # extract and return feature vectors 468 return alg.compute(image, kp)
Get feature descriptors from the specified band.
Arguments:
- band (int,float,str,tuple): the band index (int) or wavelength (float) to extract features from. Alternatively, a tuple can be passed containing a range of bands (min : max) to average before feature matching.
- eq (bool): True if the image should be histogram equalized first. Default is False.
- mask (bool): True if 0 value pixels should be masked. Default is True.
- method (str): the feature detector to use. Options are 'SIFT' and 'ORB' (faster but less accurate). Default is 'SIFT'.
- cfac (float): contrast adjustment to apply to hyperspectral bands before matching. Default is 0.0.
- bfac (float): brightness adjustment to apply to hyperspectral bands before matching. Default is 0.0.
**kwds: keyword arguments are passed to the opencv feature detector. For SIFT these are:
- contrastThreshold: default is 0.01.
- edgeThreshold: default is 10.
- sigma: default is 1.0
For ORB these are:
- nfeatures = the number of features to detect. Default is 5000.
Returns: Tuple containing
- k (ndarray): the keypoints detected
- d (ndarray): corresponding feature descriptors
470 @classmethod 471 def match_keypoints(cls, kp1, kp2, d1, d2, method='SIFT', dist=0.7, tree = 5, check = 100, min_count=5): 472 """ 473 Compares keypoint feature vectors from two images and returns matching pairs. 474 475 Args: 476 kp1 (ndarray): keypoints from the first image 477 kp2 (ndarray): keypoints from the second image 478 d1 (ndarray): descriptors for the keypoints from the first image 479 d2 (ndarray): descriptors for the keypoints from the second image 480 method (str): the method used to calculate the feature descriptors. Should be 'sift' or 'orb'. Default is 'sift'. 481 dist (float): minimum match distance (0 to 1), default is 0.7 482 tree (int): not sure what this does? Default is 5. See open-cv docs. 483 check (int): ditto. Default is 100. 484 min_count (int): the minimum number of matches to consider a valid matching operation. If fewer matches are found, 485 then the function returns None, None. Default is 5. 486 """ 487 import cv2 # import this here to avoid errors if opencv is not installed properly 488 if 'sift' in method.lower(): 489 algorithm = cv2.NORM_INF 490 elif 'orb' in method.lower(): 491 algorithm = cv2.NORM_HAMMING 492 else: 493 assert False, "Error - unknown matching algorithm %s" % method 494 495 #calculate flann matches 496 index_params = dict(algorithm=algorithm, trees=tree) 497 search_params = dict(checks=check) 498 flann = cv2.FlannBasedMatcher(index_params, search_params) 499 matches = flann.knnMatch(d1, d2, k=2) 500 501 # store all the good matches as per Lowe's ratio test. 502 good = [] 503 for m, n in matches: 504 if m.distance < dist * n.distance: 505 good.append(m) 506 507 if len(good) < min_count: 508 return None, None 509 else: 510 src_pts = np.float32([kp1[m.queryIdx].pt for m in good]).reshape(-1, 1, 2) 511 dst_pts = np.float32([kp2[m.trainIdx].pt for m in good]).reshape(-1, 1, 2) 512 return src_pts, dst_pts
Compares keypoint feature vectors from two images and returns matching pairs.
Arguments:
- kp1 (ndarray): keypoints from the first image
- kp2 (ndarray): keypoints from the second image
- d1 (ndarray): descriptors for the keypoints from the first image
- d2 (ndarray): descriptors for the keypoints from the second image
- method (str): the method used to calculate the feature descriptors. Should be 'sift' or 'orb'. Default is 'sift'.
- dist (float): minimum match distance (0 to 1), default is 0.7
- tree (int): not sure what this does? Default is 5. See open-cv docs.
- check (int): ditto. Default is 100.
- min_count (int): the minimum number of matches to consider a valid matching operation. If fewer matches are found, then the function returns None, None. Default is 5.
517 def quick_plot(self, band=0, ax=None, bfac=0.0, cfac=0.0, samples=False, tscale=False, rot=False, flipX=False, flipY=False, 518 **kwds): 519 """ 520 Plot a band using matplotlib.imshow(...). 521 522 Args: 523 band (str,int,float,tuple): the band name (string), index (integer) or wavelength (float) to plot. Default is 0. If a tuple is passed then 524 each band in the tuple (string or index) will be mapped to rgb. Bands with negative wavelengths or indices will be inverted before plotting. 525 ax: an axis object to plot to. If none, plt.imshow( ... ) is used. 526 bfac (float): a brightness adjustment to apply to RGB mappings (-1 to 1) 527 cfac (float): a contrast adjustment to apply to RGB mappings (-1 to 1) 528 samples (bool): True if sample points (defined in the header file) should be plotted. Default is False. Otherwise, a list of 529 [ (x,y), ... ] points can be passed. 530 tscale (bool): True if each band (for ternary images) should be scaled independently. Default is False. 531 When using scaling, vmin and vmax can be used to set the clipping percentiles (integers) or 532 (constant) values (float). 533 rot (bool): if True, the x and y axis will be flipped (90 degree rotation) before plotting. Default is False. 534 flipX (bool): if True, the x axis will be flipped before plotting (after applying rotations). 535 flipY (bool): if True, the y axis will be flippe before plotting (after applying rotations). 536 **kwds: keywords are passed to matplotlib.imshow( ... ), except for the following: 537 538 - mask = a 2D boolean mask containing true if pixels should be drawn and false otherwise. 539 - path = a file path to save the image too (at matching resolution; use fig.savefig(..) if you want to save the figure). 540 - ticks = True if x- and y- ticks should be plotted. Default is False. 541 - ps, pc = the size and color of sample points to plot. Can be constant or list. 542 - figsize = a figsize for the figure to create (if ax is None). 543 544 Returns: 545 Tuple containing 546 547 - fig: matplotlib figure object 548 - ax: matplotlib axes object. If a colorbar is created, (band is an integer or a float), then this will be stored in ax.cbar. 549 """ 550 551 #create new axes? 552 if ax is None: 553 fig, ax = plt.subplots(figsize=kwds.pop('figsize', (18,18*self.ydim()/self.xdim()) )) 554 555 # deal with ticks 556 if not kwds.pop('ticks', False ): 557 ax.set_xticks([]) 558 ax.set_yticks([]) 559 560 #map individual band using colourmap 561 if isinstance(band, str) or isinstance(band, int) or isinstance(band, float): 562 #get band 563 if isinstance(band, str): 564 data = self.data[:, :, self.get_band_index(band)] 565 else: 566 data = self.data[:, :, self.get_band_index(np.abs(band))] 567 if not isinstance(band, str) and band < 0: 568 data = np.nanmax(data) - data # flip 569 570 # convert integer vmin and vmax values to percentiles 571 if 'vmin' in kwds: 572 if isinstance(kwds['vmin'], int): 573 kwds['vmin'] = np.nanpercentile( data, kwds['vmin'] ) 574 if 'vmax' in kwds: 575 if isinstance(kwds['vmax'], int): 576 kwds['vmax'] = np.nanpercentile( data, kwds['vmax'] ) 577 578 #mask nans (and apply custom mask) 579 mask = np.isnan(data) 580 if not np.isnan(self.header.get_data_ignore_value()): 581 mask = mask + data == self.header.get_data_ignore_value() 582 if 'mask' in kwds: 583 mask = mask + kwds.get('mask') 584 del kwds['mask'] 585 data = np.ma.array(data, mask = mask > 0 ) 586 587 # apply rotations and flipping 588 if rot: 589 data = data.T 590 if flipX: 591 data = data[::-1, :] 592 if flipY: 593 data = data[:, ::-1] 594 595 # save? 596 if 'path' in kwds: 597 path = kwds.pop('path') 598 from matplotlib.pyplot import imsave 599 if not os.path.exists(os.path.dirname(path)): 600 os.makedirs(os.path.dirname(path)) # ensure output directory exists 601 imsave(path, data.T, **kwds) # save the image 602 603 ax.cbar = ax.imshow(data.T, interpolation=kwds.pop('interpolation', 'none'), **kwds) # change default interpolation to None 604 605 #map 3 bands to RGB 606 elif isinstance(band, tuple) or isinstance(band, list): 607 #get band indices and range 608 rgb = [] 609 for b in band: 610 if isinstance(b, str): 611 rgb.append(self.get_band_index(b)) 612 else: 613 rgb.append(self.get_band_index(np.abs(b))) 614 615 #slice image (as copy) and map to 0 - 1 616 img = np.array(self.data[:, :, rgb]).copy() 617 if np.isnan(img).all(): 618 print("Warning - image contains no data.") 619 return ax.get_figure(), ax 620 621 # invert if needed 622 for i,b in enumerate(band): 623 if not isinstance(b, str) and (b < 0): 624 img[..., i] = np.nanmax(img[..., i]) - img[..., i] 625 626 # do scaling 627 if tscale: # scale bands independently 628 for b in range(3): 629 mn = kwds.get("vmin", float(np.nanmin(img))) 630 mx = kwds.get("vmax", float(np.nanmax(img))) 631 if isinstance (mn, int): 632 assert mn >= 0 and mn <= 100, "Error - integer vmin values must be a percentile." 633 mn = float(np.nanpercentile(img[...,b], mn )) 634 if isinstance (mx, int): 635 assert mx >= 0 and mx <= 100, "Error - integer vmax values must be a percentile." 636 mx = float(np.nanpercentile(img[...,b], mx )) 637 img[...,b] = (img[..., b] - mn) / (mx - mn) 638 else: # scale bands together 639 mn = kwds.get("vmin", float(np.nanmin(img))) 640 mx = kwds.get("vmax", float(np.nanmax(img))) 641 if isinstance(mn, int): 642 assert mn >= 0 and mn <= 100, "Error - integer vmin values must be a percentile." 643 mn = float(np.nanpercentile(img, mn)) 644 if isinstance(mx, int): 645 assert mx >= 0 and mx <= 100, "Error - integer vmax values must be a percentile." 646 mx = float(np.nanpercentile(img, mx)) 647 img = (img - mn) / (mx - mn) 648 649 #apply brightness/contrast mapping 650 img = np.clip((1.0 + cfac) * img + bfac, 0, 1.0 ) 651 652 #apply masking so background is white 653 img[np.logical_not( np.isfinite( img ) )] = 1.0 654 if 'mask' in kwds: 655 img[kwds.pop("mask"),:] = 1.0 656 657 # apply rotations and flipping 658 if rot: 659 img = np.transpose( img, (1,0,2) ) 660 if flipX: 661 img = img[::-1, :, :] 662 if flipY: 663 img = img[:, ::-1, :] 664 665 # save? 666 if 'path' in kwds: 667 path = kwds.pop('path') 668 from matplotlib.pyplot import imsave 669 if not os.path.exists(os.path.dirname(path)): 670 os.makedirs(os.path.dirname(path)) # ensure output directory exists 671 imsave(path, np.transpose( np.clip( img*255, 0, 255).astype(np.uint8), (1, 0, 2))) # save the image 672 673 # plot samples? 674 ps = kwds.pop('ps', 5) 675 pc = kwds.pop('pc', 'r') 676 if samples: 677 if isinstance(samples, list) or isinstance(samples, np.ndarray): 678 ax.scatter([s[0] for s in samples], [s[1] for s in samples], s=ps, c=pc) 679 else: 680 for n in self.header.get_class_names(): 681 points = np.array(self.header.get_sample_points(n)) 682 ax.scatter(points[:, 0], points[:, 1], s=ps, c=pc) 683 684 #plot 685 ax.imshow(np.transpose(img, (1,0,2)), interpolation=kwds.pop('interpolation', 'none'), **kwds) 686 ax.cbar = None # no colorbar 687 688 return ax.get_figure(), ax
Plot a band using matplotlib.imshow(...).
Arguments:
- band (str,int,float,tuple): the band name (string), index (integer) or wavelength (float) to plot. Default is 0. If a tuple is passed then each band in the tuple (string or index) will be mapped to rgb. Bands with negative wavelengths or indices will be inverted before plotting.
- ax: an axis object to plot to. If none, plt.imshow( ... ) is used.
- bfac (float): a brightness adjustment to apply to RGB mappings (-1 to 1)
- cfac (float): a contrast adjustment to apply to RGB mappings (-1 to 1)
- samples (bool): True if sample points (defined in the header file) should be plotted. Default is False. Otherwise, a list of [ (x,y), ... ] points can be passed.
- tscale (bool): True if each band (for ternary images) should be scaled independently. Default is False. When using scaling, vmin and vmax can be used to set the clipping percentiles (integers) or (constant) values (float).
- rot (bool): if True, the x and y axis will be flipped (90 degree rotation) before plotting. Default is False.
- flipX (bool): if True, the x axis will be flipped before plotting (after applying rotations).
- flipY (bool): if True, the y axis will be flippe before plotting (after applying rotations).
**kwds: keywords are passed to matplotlib.imshow( ... ), except for the following:
- mask = a 2D boolean mask containing true if pixels should be drawn and false otherwise.
- path = a file path to save the image too (at matching resolution; use fig.savefig(..) if you want to save the figure).
- ticks = True if x- and y- ticks should be plotted. Default is False.
- ps, pc = the size and color of sample points to plot. Can be constant or list.
- figsize = a figsize for the figure to create (if ax is None).
Returns:
Tuple containing
- fig: matplotlib figure object
- ax: matplotlib axes object. If a colorbar is created, (band is an integer or a float), then this will be stored in ax.cbar.
690 def createGIF(self, path, bands=None, figsize=(10,10), fps=10, **kwds): 691 """ 692 Create and save an animated gif that loops through the bands of the image. 693 694 Args: 695 path (str): the path to save the .gif 696 bands (tuple): Tuple containing the range of band indices to draw. Default is the whole range. 697 figsize (tuple): the size of the image to draw. Default is (10,10). 698 fps (int): the framerate (frames per second) of the gif. Default is 10. 699 **kwds: keywords are passed directly to matplotlib.imshow. Use this to specify cmap etc. 700 """ 701 702 frames = [] 703 if bands is None: 704 bands = (0,self.band_count()) 705 else: 706 assert 0 < bands[0] < self.band_count(), "Error - invalid range." 707 assert 0 < bands[1] < self.band_count(), "Error - invalid range." 708 assert bands[1] > bands[0], "Error - invalid range." 709 710 #plot frames 711 for i in range(bands[0],bands[1]): 712 fig, ax = plt.subplots(figsize=figsize) 713 ax.imshow(self.data[:, :, i], **kwds) 714 fig.canvas.draw() 715 frames.append(np.frombuffer(fig.canvas.tostring_rgb(), dtype='uint8')) 716 frames[-1] = np.reshape(frames[-1], (fig.canvas.get_width_height()[1], fig.canvas.get_width_height()[0], 3)) 717 plt.close(fig) 718 719 #save gif 720 imageio.mimsave( os.path.splitext(path)[0] + ".gif", frames, fps=fps)
Create and save an animated gif that loops through the bands of the image.
Arguments:
- path (str): the path to save the .gif
- bands (tuple): Tuple containing the range of band indices to draw. Default is the whole range.
- figsize (tuple): the size of the image to draw. Default is (10,10).
- fps (int): the framerate (frames per second) of the gif. Default is 10.
- **kwds: keywords are passed directly to matplotlib.imshow. Use this to specify cmap etc.
723 def mask(self, mask=None, flag=np.nan, invert=False, crop=False, bands=None): 724 """ 725 Apply a mask to an image, flagging masked pixels with the specified value. Note that this applies the mask to the 726 image in-situ. 727 728 Args: 729 flag (float): the value to use for masked pixels. Default is np.nan 730 mask (ndarray): a numpy array defining the mask polygon of the format [[x1,y1],[x2,y2],...]. If None is passed then 731 pickPolygon( ... ) is used to interactively define a polygon. If a file path is passed then the polygon 732 will be loaded using np.load( ... ). Alternatively if mask.shape == image.shape[0,1] then it is treated as a 733 binary image mask (must be boolean) and True values will be masked across all bands. Default is None. 734 invert (bool): if True, pixels within the polygon will be masked. If False, pixels outside the polygon are masked. Default is False. 735 crop (bool): True if rows/columns containing only zeros should be removed. Default is False. 736 bands (tuple): the bands of the image to plot if no mask is specified. If None, the middle band is used. 737 738 Returns: 739 Tuple containing 740 741 - mask (ndarray): a boolean array with True where pixels are masked and False elsewhere. 742 - poly (ndarray): the mask polygon array in the format described above. Useful if the polygon was interactively defined. 743 """ 744 745 if mask is None: # pick mask interactively 746 if bands is None: 747 bands = int(self.band_count() / 2) 748 749 regions = self.pickPolygons(region_names=["mask"], bands=bands) 750 751 # the user bailed without picking a mask? 752 if len(regions) == 0: 753 print("Warning - no mask picked/applied.") 754 return 755 756 # extract polygon mask 757 mask = regions[0] 758 759 # convert polygon mask to binary mask 760 if mask.shape[1] == 2: 761 762 # build meshgrid with pixel coords 763 xx, yy = np.meshgrid(np.arange(self.xdim()), np.arange(self.ydim())) 764 xx = xx.flatten() 765 yy = yy.flatten() 766 points = np.vstack([xx, yy]).T # coordinates of each pixel 767 768 # calculate per-pixel mask 769 mask = path.Path(mask).contains_points(points) 770 mask = mask.reshape((self.ydim(), self.xdim())).T 771 772 # flip as we want to mask (==True) outside points (unless invert is true) 773 if not invert: 774 mask = np.logical_not(mask) 775 776 # apply binary image mask 777 assert mask.shape[0] == self.data.shape[0] and mask.shape[1] == self.data.shape[1], \ 778 "Error - mask shape %s does not match image shape %s" % (mask.shape, self.data.shape) 779 for b in range(self.band_count()): 780 self.data[:, :, b][mask] = flag 781 782 # crop image 783 if crop: 784 # calculate non-masked pixels 785 valid = np.logical_not(mask) 786 787 # integrate along axes 788 xdata = np.sum(valid, axis=1) > 0.0 789 ydata = np.sum(valid, axis=0) > 0.0 790 791 # calculate domain containing valid pixels 792 xmin = np.argmax(xdata) 793 xmax = xdata.shape[0] - np.argmax(xdata[::-1]) 794 ymin = np.argmax(ydata) 795 ymax = ydata.shape[0] - np.argmax(ydata[::-1]) 796 797 # crop 798 self.data = self.data[xmin:xmax, ymin:ymax, :] 799 800 return mask
Apply a mask to an image, flagging masked pixels with the specified value. Note that this applies the mask to the image in-situ.
Arguments:
- flag (float): the value to use for masked pixels. Default is np.nan
- mask (ndarray): a numpy array defining the mask polygon of the format [[x1,y1],[x2,y2],...]. If None is passed then pickPolygon( ... ) is used to interactively define a polygon. If a file path is passed then the polygon will be loaded using np.load( ... ). Alternatively if mask.shape == image.shape[0,1] then it is treated as a binary image mask (must be boolean) and True values will be masked across all bands. Default is None.
- invert (bool): if True, pixels within the polygon will be masked. If False, pixels outside the polygon are masked. Default is False.
- crop (bool): True if rows/columns containing only zeros should be removed. Default is False.
- bands (tuple): the bands of the image to plot if no mask is specified. If None, the middle band is used.
Returns:
Tuple containing
- mask (ndarray): a boolean array with True where pixels are masked and False elsewhere.
- poly (ndarray): the mask polygon array in the format described above. Useful if the polygon was interactively defined.
802 def crop_to_data(self): 803 """ 804 Remove padding of nan or zero pixels from image. Note that this is performed in place. 805 """ 806 807 valid = np.isfinite(self.data).any(axis=-1) & (self.data != 0).any(axis=-1) 808 ymin, ymax = np.percentile(np.argwhere(np.sum(valid, axis=0) != 0), (0, 100)) 809 xmin, xmax = np.percentile(np.argwhere(np.sum(valid, axis=1) != 0), (0, 100)) 810 self.data = self.data[int(xmin):int(xmax), int(ymin):int(ymax), :] # do clipping
Remove padding of nan or zero pixels from image. Note that this is performed in place.
815 def pickPolygons(self, region_names, bands=0): 816 """ 817 Creates a matplotlib gui for selecting polygon regions in an image. 818 819 Args: 820 names (list, str): a list containing the names of the regions to pick. If a string is passed only one name is used. 821 bands (tuple): the bands of the image to plot. 822 """ 823 824 if isinstance(region_names, str): 825 region_names = [region_names] 826 827 assert isinstance(region_names, list), "Error - names must be a list or a string." 828 829 # set matplotlib backend 830 backend = matplotlib.get_backend() 831 matplotlib.use('Qt5Agg') # need this backend for ROIPoly to work 832 833 # plot image and extract roi's 834 fig, ax = self.quick_plot(bands) 835 roi = MultiRoi(roi_names=region_names) 836 plt.close(fig) # close figure 837 838 # extract regions 839 regions = [] 840 for name, r in roi.rois.items(): 841 # store region 842 x = r.x 843 y = r.y 844 regions.append(np.vstack([x, y]).T) 845 846 # restore matplotlib backend (if possible) 847 try: 848 matplotlib.use(backend) 849 except: 850 print("Warning: could not reset matplotlib backend. Plots will remain interactive...") 851 pass 852 853 return regions
Creates a matplotlib gui for selecting polygon regions in an image.
Arguments:
- names (list, str): a list containing the names of the regions to pick. If a string is passed only one name is used.
- bands (tuple): the bands of the image to plot.
855 def pickPoints(self, n=-1, bands=hylite.RGB, integer=True, title="Pick Points", **kwds): 856 """ 857 Creates a matplotlib gui for picking pixels from an image. 858 859 Args: 860 n (int): the number of pixels to pick, or -1 if the user can select as many as they wish. Default is -1. 861 bands (tuple): the bands of the image to plot. Default is HyImage.RGB 862 integer (bool): True if points coordinates should be cast to integers (for use as indices). Default is True. 863 title (str): The title of the point picking window. 864 **kwds: Keywords are passed to HyImage.quick_plot( ... ). 865 866 Returns: 867 A list containing the picked point coordinates [ (x1,y1), (x2,y2), ... ]. 868 """ 869 870 # set matplotlib backend 871 backend = matplotlib.get_backend() 872 matplotlib.use('Qt5Agg') # need this backend for ROIPoly to work 873 874 # create figure 875 fig, ax = self.quick_plot( bands, **kwds ) 876 ax.set_title(title) 877 878 # get points 879 points = fig.ginput( n ) 880 881 if integer: 882 points = [ (int(p[0]), int(p[1])) for p in points ] 883 884 # restore matplotlib backend (if possible) 885 try: 886 matplotlib.use(backend) 887 except: 888 print("Warning: could not reset matplotlib backend. Plots will remain interactive...") 889 pass 890 891 return points
Creates a matplotlib gui for picking pixels from an image.
Arguments:
- n (int): the number of pixels to pick, or -1 if the user can select as many as they wish. Default is -1.
- bands (tuple): the bands of the image to plot. Default is HyImage.RGB
- integer (bool): True if points coordinates should be cast to integers (for use as indices). Default is True.
- title (str): The title of the point picking window.
- **kwds: Keywords are passed to HyImage.quick_plot( ... ).
Returns:
A list containing the picked point coordinates [ (x1,y1), (x2,y2), ... ].
893 def pickSamples(self, names=None, store=True, **kwds): 894 """ 895 Pick sample probe points and store these in the image header file. 896 897 Args: 898 names (str, list): the name of the sample to pick, or a list of names to pick multiple. 899 store (bool): True if sample should be stored in the image header file (for later access). Default is True. 900 **kwds: Keywords are passed to HyImage.quick_plot( ... ) 901 902 Returns: 903 a list containing a list of points for each sample. 904 """ 905 906 if isinstance(names, str): 907 names = [names] 908 909 # pick points 910 points = [] 911 for s in names: 912 pnts = self.pickPoints(title="%s" % s, **kwds) 913 if store: 914 self.header['sample %s' % s] = pnts # store in header 915 points.append(pnts) 916 # add class to header file 917 if store: 918 cls_names = self.header.get_class_names() 919 if cls_names is None: 920 cls_names = [] 921 self.header['class names'] = cls_names + names 922 923 return points
Pick sample probe points and store these in the image header file.
Arguments:
- names (str, list): the name of the sample to pick, or a list of names to pick multiple.
- store (bool): True if sample should be stored in the image header file (for later access). Default is True.
- **kwds: Keywords are passed to HyImage.quick_plot( ... )
Returns:
a list containing a list of points for each sample.
Inherited Members
- hylite.hydata.HyData
- to_grey
- set_header
- push_to_header
- has_wavelengths
- get_wavelengths
- has_band_names
- get_band_names
- has_fwhm
- get_fwhm
- set_wavelengths
- set_band_names
- set_fwhm
- is_image
- is_point
- is_classification
- band_count
- samples
- lines
- is_int
- is_float
- export_bands
- delete_nan_bands
- set_as_nan
- mask_bands
- mask_water_features
- get_band
- get_band_grey
- get_raveled
- X
- eval
- set_raveled
- get_band_index
- resample
- contiguous_chunks
- smooth_median
- smooth_savgol
- fill_gaps
- plot_spectra
- compress
- decompress
- getQuantized
- fromQuanta
- normalise
- percent_clip
- correct_spectral_shift