import os import sys import glob from osgeo import gdal import numpy as np import cv2 def CalHistogram(img): img_dtype = img.dtype img_hist = img.reshape(-1) img_min, img_max = img_hist.min(), img_hist.max() n_bins = 2 ** 16 if (img_dtype == np.uint8): n_bins = 256 if (img_dtype == np.uint16): n_bins = 2 ** 16 elif (img_dtype == np.uint32): n_bins = 2 ** 32 if (img_dtype == np.uint8) or (img_dtype == np.uint16) or (img_dtype == np.uint32): hist = np.bincount(img_hist, minlength=n_bins) hist[0] = 0 hist[-1] = 0 s_values = np.arange(n_bins) else: hist, s_values = np.histogram(img_hist, bins=n_bins, range=(img_min, img_max)) hist[0] = 0 hist[-1] = 0 img_hist = None return hist, s_values def GetPercentStretchValue(img, left_clip=0.001, right_clip=0.001): right_clip = 1.0 - right_clip hist, s_values = CalHistogram(img) s_quantiles = np.cumsum(hist).astype(np.float64) s_quantiles /= (s_quantiles[-1] + 1.0E-5) left_clip_index = np.argmin(np.abs(s_quantiles - left_clip)) right_clip_index = np.argmin(np.abs(s_quantiles - right_clip)) img_min_clip, img_max_clip = s_values[[left_clip_index, right_clip_index]] return img_min_clip, img_max_clip def percent_stretch_image(input_image_data, left_clip=0.001, right_clip=0.001, left_mask=None, right_mask=None): if input_image_data is None: return None n_dim = input_image_data.ndim img_bands = 1 if n_dim == 2 else input_image_data.shape[n_dim - 1] xsize = input_image_data.shape[1] ysize = input_image_data.shape[0] indtype = input_image_data.dtype if indtype == np.uint8: to_8bit = True if img_bands > 1: out_8bit_data = np.zeros((ysize, xsize, img_bands), dtype=np.uint8) else: out_8bit_data = np.zeros((ysize, xsize), dtype=np.uint8) for i_band in range(img_bands): if img_bands == 1: input_image_data_raw = input_image_data # [:,:,i_band] else: input_image_data_raw = input_image_data[:, :, i_band] img_clip_min, img_clip_max = GetPercentStretchValue(input_image_data_raw, left_clip=left_clip, right_clip=right_clip) input_image_data_raw = np.clip(input_image_data_raw, img_clip_min, img_clip_max) input_image_data_raw = (input_image_data_raw - img_clip_min) / (img_clip_max - img_clip_min) * 255 input_image_data_raw = input_image_data_raw.astype(np.uint8) if img_bands > 1: out_8bit_data[:, :, i_band] = input_image_data_raw else: out_8bit_data = input_image_data_raw return out_8bit_data def read_tif(file_path): tif_f = file_path ds = gdal.Open(tif_f) if ds == None: print("Error || Can't open {0} as tif file.".format(tif_f)) return cols = ds.RasterXSize rows = ds.RasterYSize bands = ds.RasterCount geo = ds.GetGeoTransform() pro = ds.GetProjection() data_set = np.zeros((rows, cols, bands)) data_type=None for i in range(bands): band = ds.GetRasterBand(i + 1) data_type = gdal.GetDataTypeName(band.DataType).lower() data_set[:, :, i] = band.ReadAsArray() data_set = np.array(data_set, dtype=data_type) del ds return data_set,geo, pro def write_tif(file_path, data, transform=None, projection=None): global tra, pro out_file = file_path w_data = data if len(w_data.shape) == 2: w_data = w_data.reshape(w_data.shape[0], w_data.shape[1], 1) if transform is not None: tra = transform if projection is not None: pro = projection if w_data.dtype == np.uint8: # "int8": datatype = gdal.GDT_Byte elif w_data.dtype == np.uint16: # "int16": datatype = gdal.GDT_UInt16 else: datatype = gdal.GDT_Float32 # 创建文件 driver = gdal.GetDriverByName('GTiff') image = driver.Create(out_file , w_data.shape[1], w_data.shape[0], w_data.shape[2], datatype) image.SetGeoTransform(geo) image.SetProjection(pro) if w_data.shape[2] == 1: image.GetRasterBand(1).WriteArray(w_data) else: for i in range(w_data.shape[2]): image.GetRasterBand(i + 1).WriteArray(w_data[:, :, i]) del image # 删除变量,保留数据 if __name__ == '__main__': in_file = r'E:\11111\image\1.tif' out_path=r"E:\11111\resule" img, geo, pro = read_tif(in_file) print(img.shape) n_dim = img.ndim img_bands = 1 if n_dim == 2 else img.shape[n_dim - 1] print(img.min(), img.mean(), img.max()) img_raw_s = (img - img.min()) / (img.max() - img.min()) * 255 print('img raw s:', img_raw_s.min(), img_raw_s.mean(), img_raw_s.max()) img = percent_stretch_image(img) out_file = os.path.join(out_path, os.path.basename(in_file)) write_tif(out_file, img.astype(np.uint8),geo, pro) write_tif(out_file + '_raw.tif', img_raw_s.astype(np.uint8),geo, pro) if img_bands > 3: write_tif(out_file + '_RGB.tif', img[:, :, [2, 1, 0]].astype(np.uint8),geo, pro) write_tif(out_file + '_RGB_raw.tif', img_raw_s[:, :, [2, 1, 0]].astype(np.uint8),geo, pro)