行走的蓑衣客

导航

 

1.去除NoData

import cv2
import gdal
import scipy.interpolate
import numpy as np
 
def read_img(filename):
    dataset=gdal.Open(filename)
 
    im_width = dataset.RasterXSize
    im_height = dataset.RasterYSize
 
    im_geotrans = dataset.GetGeoTransform()
    im_proj = dataset.GetProjection()
    im_data = dataset.ReadAsArray(0,0,im_width,im_height)
 
    del dataset 
    return im_proj,im_geotrans,im_width, im_height,im_data
 
 
def write_img(filename, im_proj, im_geotrans, im_data):
    if 'int8' in im_data.dtype.name:
        datatype = gdal.GDT_Byte
    elif 'int16' in im_data.dtype.name:
        datatype = gdal.GDT_UInt16
    else:
        datatype = gdal.GDT_Float32
 
    if len(im_data.shape) == 3:
        im_bands, im_height, im_width = im_data.shape
    else:
        im_bands, (im_height, im_width) = 1,im_data.shape 
 
    driver = gdal.GetDriverByName("GTiff")
    dataset = driver.Create(filename, im_width, im_height, im_bands, datatype)
 
    dataset.SetGeoTransform(im_geotrans)
    dataset.SetProjection(im_proj)
 
    if im_bands == 1:
        dataset.GetRasterBand(1).WriteArray(im_data)
    else:
        for i in range(im_bands):
            dataset.GetRasterBand(i+1).WriteArray(im_data[i])
 
    del dataset
 
def NoData_kill(in_path, out_path):
    # data = cv2.imread("cq_test.tif")
    im_proj,im_geotrans,im_width, im_height,im_data = read_img(in_path)
    mask = np.isnan(im_data)
    c, w, h = mask.shape
    mask_list = []
    for i in range(c):
        if mask[i].__contains__(True):
            mask_list.append(mask[i])
 
    for m in mask_list:
        m = m + 0
        m = np.uint8(m)
        inpainted_img = cv2.inpaint(im_data, m, inpaintRadius=3, flags=cv2.INPAINT_TELEA)
        im_data = inpainted_img
    # cv2.imwrite('./fixed2.tif', data)
    write_img(out_path, im_proj, im_geotrans, im_data)
 
#too slow
def NoData_kill2(in_path, out_path):
    im_proj,im_geotrans,im_width, im_height, data = read_img(in_path)
    data = data.transpose(2,1,0)
 
    # a boolean array of (width, height) which False where there are missing values and True where there are valid (non-missing) values
    mask = ~( (data[:,:,0] == 255) & (data[:,:,1] == 255) & (data[:,:,2] == 255) )
 
    # array of (number of points, 2) containing the x,y coordinates of the valid values only
    xx, yy = np.meshgrid(np.arange(data.shape[1]), np.arange(data.shape[0]))
    xym = np.vstack( (np.ravel(xx[mask]), np.ravel(yy[mask])) ).T
 
    # the valid values in the first, second, third color channel,  as 1D arrays (in the same order as their coordinates in xym)
    data0 = np.ravel( data[:,:,0][mask] )
    data1 = np.ravel( data[:,:,1][mask] )
    data2 = np.ravel( data[:,:,2][mask] )
 
    # three separate interpolators for the separate color channels
    interp0 = scipy.interpolate.NearestNDInterpolator( xym, data0 )
    interp1 = scipy.interpolate.NearestNDInterpolator( xym, data1 )
    interp2 = scipy.interpolate.NearestNDInterpolator( xym, data2 )
 
    # interpolate the whole image, one color channel at a time    
    result0 = interp0(np.ravel(xx), np.ravel(yy)).reshape( xx.shape )
    result1 = interp1(np.ravel(xx), np.ravel(yy)).reshape( xx.shape )
    result2 = interp2(np.ravel(xx), np.ravel(yy)).reshape( xx.shape )
 
    # combine them into an output image
    result = np.dstack( (result0, result1, result2) )
    result = result.transpose(2,1,0)
    write_img(out_path, im_proj, im_geotrans, result)
 
if __name__ == "__main__":
    in_path = 'cq_test.tif'
    out_path = './fixed.tif'
    NoData_kill(in_path, out_path)

 

 

2.遥感图像增强

import cv2
import gdal
import numpy as np
from PIL import Image, ImageEnhance 
import matplotlib.pyplot as plt
 
def read_img(filename):
    dataset=gdal.Open(filename)
 
    im_width = dataset.RasterXSize
    im_height = dataset.RasterYSize
 
    im_geotrans = dataset.GetGeoTransform()
    im_proj = dataset.GetProjection()
    im_data = dataset.ReadAsArray(0,0,im_width,im_height)
 
    del dataset 
    return im_proj,im_geotrans,im_width, im_height,im_data
 
 
def write_img(filename, im_proj, im_geotrans, im_data):
    if 'int8' in im_data.dtype.name:
        datatype = gdal.GDT_Byte
    elif 'int16' in im_data.dtype.name:
        datatype = gdal.GDT_UInt16
    else:
        datatype = gdal.GDT_Float32
 
    if len(im_data.shape) == 3:
        im_bands, im_height, im_width = im_data.shape
    else:
        im_bands, (im_height, im_width) = 1,im_data.shape 
 
    driver = gdal.GetDriverByName("GTiff")
    dataset = driver.Create(filename, im_width, im_height, im_bands, datatype)
 
    dataset.SetGeoTransform(im_geotrans)
    dataset.SetProjection(im_proj)
 
    if im_bands == 1:
        dataset.GetRasterBand(1).WriteArray(im_data)
    else:
        for i in range(im_bands):
            dataset.GetRasterBand(i+1).WriteArray(im_data[i])
 
    del dataset
 
def deaw_gray_hist(gray_img):
    '''
    :param  gray_img大小为[h, w]灰度图像
    绘制直方图
    '''
    # 获取图像大小
    h, w = gray_img.shape
    gray_hist = np.zeros([256])
    for i in range(h):
        for j in range(w):
            gray_hist[gray_img[i][j]] += 1
    x = np.arange(256)
    # 绘制灰度直方图
    plt.bar(x, gray_hist)
    plt.xlabel("gray Label")
    plt.ylabel("number of pixels")
    plt.show()
 
 
def im_enhance(in_path,out_path):
    '''
    图像锐化
    '''
    im_proj,im_geotrans,im_width, im_height,im_data = read_img(in_path)
 
    im_data = im_data.transpose(2,1,0)
    im_data = Image.fromarray(im_data)
    enhancer = ImageEnhance.Sharpness(im_data)
    enhanced_im = enhancer.enhance(100.0)
    enhanced_im = np.array(enhanced_im)
    enhanced_im = enhanced_im.transpose(2,1,0)
    write_img(out_path, im_proj, im_geotrans, enhanced_im)
 
def linear_transform(in_path, a, b, out_path):
    '''
    # 对图像进行 线性变换
    :param img: [h, w, 3] 彩色图像
    :param a:  float  这里需要是浮点数,把图片uint8类型的数据强制转成float64
    :param b:  float
    :return: out = a * img + b
    '''
    im_proj,im_geotrans,im_width, im_height,im_data = read_img(in_path)
    out = a * im_data + b
    out[out > 255] = 255
    out = np.around(out)
    out = out.astype(np.uint8)
    write_img(out_path, im_proj, im_geotrans, out)
 
 
def normalize_transform(in_path, out_path):
    '''
    直方图正规化
    cv2.normalize(src,dst,alpha,beta,normType,dtype,mask)
    参数:
    src: 图像对象矩阵
    dst:输出图像矩阵(和src的shape一样)
    alpha:正规化的值,如果是范围值,为范围的下限 (alpha – norm value to normalize to or the lower range boundary in case of the range normalization.)
    beta:如果是范围值,为范围的上限;正规化中不用到 ( upper range boundary in case of the range normalization; it is not used for the norm normalization.)
    norm_type:normalize的类型
                cv2.NORM_L1:将像素矩阵的1-范数做为最大值(矩阵中值的绝对值的和)
                cv2.NORM_L2:将像素矩阵的2-范数做为最大值(矩阵中值的平方和的开方)
                cv2.NORM_MINMAX:将像素矩阵的∞-范数做为最大值 (矩阵中值的绝对值的最大值)
                
    dtype: 输出图像矩阵的数据类型,默认为-1,即和src一样
    mask:掩模矩阵,只对感兴趣的地方归一化
    '''
    im_proj,im_geotrans,im_width, im_height,im_data = read_img(in_path)
    c, h, w = im_data.shape
    arr_list = []
    for i in range(c):
        Imin, Imax = cv2.minMaxLoc(im_data[i])[:2]
        Omin, Omax = 0, 255
        a = float(Omax - Omin) / (Imax - Imin)
        b = Omin - a * Imin
        out = a * im_data[i] + b
        out = out.astype(np.uint8)
        arr_list.append(np.expand_dims(out,axis=0))
 
    temp = np.zeros_like(im_data)
    for i in range(c-1):
        temp[i] = arr_list[i]
 
    nor_out = temp
    write_img(out_path, im_proj, im_geotrans, nor_out)
 
def equalize_transfrom(in_path, out_path):
    '''
    全局直方图均衡化
    cv2.equalizeHist(src,dst)
        src: 图像对象矩阵,必须为单通道的uint8类型的矩阵数据
        dst:输出图像矩阵(和src的shape一样)
    '''
    im_proj,im_geotrans,im_width,im_height,im_data = read_img(in_path)
    c, h, w = im_data.shape
    temp = np.zeros_like(im_data)
    for i in range(c):
        temp[i] = cv2.equalizeHist(im_data[i])
    write_img(out_path, im_proj, im_geotrans, temp)
 
def restrict_hist(in_path, out_path):
    '''
    clahe=cv2.createCLAHE(clipLimit,tileGridSize)
        clipLimit:限制对比度的阈值,默认为40,直方图中像素值出现次数大于该阈值,多余的次数会被重新分配
        tileGridSize:图像会被划分的size, 如tileGridSize=(8,8),默认为(8,8)
    calhe.apply(img) #对img进行限制对比度自适应直方图均衡化
    '''
    im_proj,im_geotrans,im_width,im_height,im_data = read_img(in_path)
    c, h, w = im_data.shape
    temp = np.zeros_like(im_data)
    for i in range(c):
        clahe = cv2.createCLAHE(3,(8,8))
        temp[i] = clahe.apply(im_data[i])
    write_img(out_path, im_proj, im_geotrans, temp)
 
def gamma_trans(in_path, out_path):
    '''
    伽马变换
    '''
    im_proj,im_geotrans,im_width,im_height,im_data = read_img(in_path)
    img_norm = im_data/255.0  #注意255.0得采用浮点数
    img_gamma = np.power(img_norm,0.4)*255.0
    img_gamma = img_gamma.astype(np.uint8)
    write_img(out_path, im_proj, im_geotrans, img_gamma)
 
if __name__ == "__main__":
    in_path = './fixed2.tif'
 
    # out_path = './enhance2.tif'
    # linear_transform(in_path, 2.0, 10.0, out_path)
 
    # out_path = './enhance3.tif'
    # normalize_transform(in_path, out_path)
 
    # out_path = './enhance4.tif'
    # equalize_transfrom(in_path, out_path)
 
    # out_path = './enhance5.tif'
    # restrict_hist(in_path, out_path)
 
    out_path = './enhance6.tif'
    gamma_trans(in_path, out_path)

 

posted on 2022-05-17 20:27  行走的蓑衣客  阅读(445)  评论(0编辑  收藏  举报