行走的蓑衣客

导航

 

 

"""
将16位遥感图像压缩至8位,并保持色彩一致
"""
from osgeo import gdal
import os
import glob
import numpy as np
import matplotlib.pyplot as plt # plt 用于显示图片
import matplotlib.image as mpimg # mpimg 用于读取图片


def read_tiff(input_file):
    """
    读取影像
    :param input_file:输入影像
    :return:波段数据,仿射变换参数,投影信息、行数、列数、波段数
    """
    print("**********",input_file)


    dataset = gdal.Open(input_file)
    rows = dataset.RasterYSize
    cols = dataset.RasterXSize

    geo = dataset.GetGeoTransform()
    proj = dataset.GetProjection()

    couts = dataset.RasterCount

    array_data = np.zeros((couts, rows, cols))

    for i in range(couts):
        band = dataset.GetRasterBand(i + 1)
        array_data[i, :, :] = band.ReadAsArray()

    return array_data, geo, proj, rows, cols, 1


def compress(origin_16, output_8):
    array_data, geo, proj, rows, cols, couts = read_tiff(origin_16)

    compress_data = np.zeros((couts, rows, cols))

    for i in range(couts):
        band_max = np.max(array_data[i, :, :])
        band_min = np.min(array_data[i, :, :])

        cutmin, cutmax = cumulativehistogram(array_data[i, :, :], rows, cols, band_min, band_max)

        compress_scale = (cutmax - cutmin) / 255

        for j in range(rows):
            for k in range(cols):
                if (array_data[i, j, k] < cutmin):
                    array_data[i, j, k] = cutmin

                if (array_data[i, j, k] > cutmax):
                    array_data[i, j, k] = cutmax

                compress_data[i, j, k] = (array_data[i, j, k] - cutmin) / compress_scale

    write_tiff(output_8, compress_data, rows, cols, couts, geo, proj)


def write_tiff(output_file, array_data, rows, cols, counts, geo, proj):
    Driver = gdal.GetDriverByName("Gtiff")
    dataset = Driver.Create(output_file, cols, rows, counts, gdal.GDT_Byte)

    dataset.SetGeoTransform(geo)
    dataset.SetProjection(proj)

    for i in range(counts):
        band = dataset.GetRasterBand(i + 1)
        band.WriteArray(array_data[i, :, :])


def cumulativehistogram(array_data, rows, cols, band_min, band_max):
    """
    累计直方图统计
    """

    # 逐波段统计最值

    gray_level = int(band_max - band_min + 1)
    gray_array = np.zeros(gray_level)

    counts = 0
    for row in range(rows):
        for col in range(cols):
            gray_array[int(array_data[row, col] - band_min)] += 1
            counts += 1

    count_percent2 = counts * 0.02
    count_percent98 = counts * 0.98

    cutmax = 0
    cutmin = 0

    for i in range(1, gray_level):
        gray_array[i] += gray_array[i - 1]
        if (gray_array[i] >= count_percent2 and gray_array[i - 1] <= count_percent2):
            cutmin = i + band_min

        if (gray_array[i] >= count_percent98 and gray_array[i - 1] <= count_percent98):
            cutmax = i + band_min

    return cutmin, cutmax


if __name__ == '__main__':
    input_path = r"D:\test"

    classs = os.listdir(input_path)
    i=0
    for folder in classs:
        folderall = os.path.join(input_path, folder)
        # print(folderall)
        # lena = mpimg.imread(folderall)
        # lena.shape
        # plt.imshow(lena)  # 显示图片
        # plt.axis('off')  # 不显示坐标轴
        # plt.show()
        # break
        output_8 = r'D:\test\out{}.tif'.format(i)
        compress(folderall, output_8)
        i=i+1

 

posted on 2022-11-21 19:22  行走的蓑衣客  阅读(226)  评论(0编辑  收藏  举报