行走的蓑衣客

导航

 
import os
import sys
import glob
from osgeo import gdal
import numpy as np
import cv2

def CalHistogram(img):
    img_dtype = img.dtype
    img_hist = img.reshape(-1)
    img_min, img_max = img_hist.min(), img_hist.max()
    n_bins = 2 ** 16
    if (img_dtype == np.uint8):
        n_bins = 256
    if (img_dtype == np.uint16):
        n_bins = 2 ** 16
    elif (img_dtype == np.uint32):
        n_bins = 2 ** 32
    if (img_dtype == np.uint8) or (img_dtype == np.uint16) or (img_dtype == np.uint32):
        hist = np.bincount(img_hist, minlength=n_bins)
        hist[0] = 0
        hist[-1] = 0
        s_values = np.arange(n_bins)
    else:
        hist, s_values = np.histogram(img_hist, bins=n_bins, range=(img_min, img_max))
        hist[0] = 0
        hist[-1] = 0
    img_hist = None
    return hist, s_values


def GetPercentStretchValue(img, left_clip=0.001, right_clip=0.001):
    right_clip = 1.0 - right_clip
    hist, s_values = CalHistogram(img)
    s_quantiles = np.cumsum(hist).astype(np.float64)
    s_quantiles /= (s_quantiles[-1] + 1.0E-5)
    left_clip_index = np.argmin(np.abs(s_quantiles - left_clip))
    right_clip_index = np.argmin(np.abs(s_quantiles - right_clip))
    img_min_clip, img_max_clip = s_values[[left_clip_index, right_clip_index]]
    return img_min_clip, img_max_clip


def percent_stretch_image(input_image_data, left_clip=0.001, right_clip=0.001, left_mask=None,
                          right_mask=None):
    if input_image_data is None:
        return None
    n_dim = input_image_data.ndim
    img_bands = 1 if n_dim == 2 else input_image_data.shape[n_dim - 1]
    xsize = input_image_data.shape[1]
    ysize = input_image_data.shape[0]
    indtype = input_image_data.dtype
    if indtype == np.uint8:
        to_8bit = True
    if img_bands > 1:
        out_8bit_data = np.zeros((ysize, xsize, img_bands), dtype=np.uint8)
    else:
        out_8bit_data = np.zeros((ysize, xsize), dtype=np.uint8)
    for i_band in range(img_bands):
        if img_bands == 1:
            input_image_data_raw = input_image_data  # [:,:,i_band]
        else:
            input_image_data_raw = input_image_data[:, :, i_band]
        img_clip_min, img_clip_max = GetPercentStretchValue(input_image_data_raw, left_clip=left_clip,
                                                            right_clip=right_clip)
        input_image_data_raw = np.clip(input_image_data_raw, img_clip_min, img_clip_max)
        input_image_data_raw = (input_image_data_raw - img_clip_min) / (img_clip_max - img_clip_min) * 255
        input_image_data_raw = input_image_data_raw.astype(np.uint8)
        if img_bands > 1:
            out_8bit_data[:, :, i_band] = input_image_data_raw
        else:
            out_8bit_data = input_image_data_raw
    return out_8bit_data


def read_tif(file_path):
    tif_f = file_path
    ds = gdal.Open(tif_f)
    if ds == None:
        print("Error || Can't open {0} as tif file.".format(tif_f))
        return
    cols = ds.RasterXSize
    rows = ds.RasterYSize
    bands = ds.RasterCount
    geo = ds.GetGeoTransform()
    pro = ds.GetProjection()
    data_set = np.zeros((rows, cols, bands))
    data_type=None
    for i in range(bands):
        band = ds.GetRasterBand(i + 1)
        data_type = gdal.GetDataTypeName(band.DataType).lower()
        data_set[:, :, i] = band.ReadAsArray()
    data_set = np.array(data_set, dtype=data_type)
    del ds
    return data_set,geo, pro


def write_tif(file_path, data, transform=None, projection=None):
    global tra, pro
    out_file = file_path
    w_data = data
    if len(w_data.shape) == 2:
        w_data = w_data.reshape(w_data.shape[0], w_data.shape[1], 1)
    if transform is not None:
        tra = transform
    if projection is not None:
        pro = projection
    if w_data.dtype == np.uint8:  # "int8":
        datatype = gdal.GDT_Byte
    elif w_data.dtype == np.uint16:  # "int16":
        datatype = gdal.GDT_UInt16
    else:
        datatype = gdal.GDT_Float32

    # 创建文件
    driver = gdal.GetDriverByName('GTiff')
    image = driver.Create(out_file , w_data.shape[1], w_data.shape[0], w_data.shape[2], datatype)


    image.SetGeoTransform(geo)
    image.SetProjection(pro)
    if w_data.shape[2] == 1:
        image.GetRasterBand(1).WriteArray(w_data)
    else:
        for i in range(w_data.shape[2]):
            image.GetRasterBand(i + 1).WriteArray(w_data[:, :, i])

    del image  # 删除变量,保留数据

if __name__ == '__main__':
    in_file = r'E:\11111\image\1.tif'
    out_path=r"E:\11111\resule"


    img, geo, pro = read_tif(in_file)
    print(img.shape)
    n_dim = img.ndim
    img_bands = 1 if n_dim == 2 else img.shape[n_dim - 1]
    print(img.min(), img.mean(), img.max())
    img_raw_s = (img - img.min()) / (img.max() - img.min()) * 255
    print('img raw s:', img_raw_s.min(), img_raw_s.mean(), img_raw_s.max())
    img = percent_stretch_image(img)

    out_file = os.path.join(out_path, os.path.basename(in_file))
    write_tif(out_file, img.astype(np.uint8),geo, pro)
    write_tif(out_file + '_raw.tif', img_raw_s.astype(np.uint8),geo, pro)

    if img_bands > 3:
        write_tif(out_file + '_RGB.tif', img[:, :, [2, 1, 0]].astype(np.uint8),geo, pro)
        write_tif(out_file + '_RGB_raw.tif', img_raw_s[:, :, [2, 1, 0]].astype(np.uint8),geo, pro)

 

posted on 2022-11-21 19:28  行走的蓑衣客  阅读(199)  评论(0编辑  收藏  举报