GDAL常用代码

GDAL常用代码

1.导入数据

from osgeo import gdal import numpy as np def LoadData(filename): file = gdal.Open(filename) if file == None: print(filename + " can't be opened!") return nb = file.RasterCount L = [] for i in range(1, nb + 1): band = file.GetRasterBand(i) background = band.GetNoDataValue() data = band.ReadAsArray() data = data.astype(np.float32) index = np.where(data == background) data[index] = 0 L.append(data) data = np.stack(L,0) if nb == 1: data = data[0,:,:] return data

或者

import xarray as xr arr = xr.open_rasterio("路径").data[0,:,:]

2.写出数据

def WriteTiff(im_data, im_width, im_height, im_bands, im_geotrans, im_proj, path): if 'int8' in im_data.dtype.name: datatype = gdal.GDT_Byte elif 'int16' in im_data.dtype.name: datatype = gdal.GDT_UInt16 else: datatype = gdal.GDT_Float32 if len(im_data.shape) == 3: im_bands, im_height, im_width = im_data.shape elif len(im_data.shape) == 2: im_data = np.array([im_data]) else: im_bands, (im_height, im_width) = 1, im_data.shape # 创建文件 driver = gdal.GetDriverByName("GTiff") dataset = driver.Create(path, im_width, im_height, im_bands, datatype) if (dataset != None): dataset.SetGeoTransform(im_geotrans) # 写入仿射变换参数 dataset.SetProjection(im_proj) # 写入投影 for i in range(im_bands): dataset.GetRasterBand(i + 1).WriteArray(im_data[i]) del dataset raster = gdal.Open(path) im_width = raster.RasterXSize #栅格矩阵的列数 im_height = raster.RasterYSize #栅格矩阵的行数 im_bands = raster.RasterCount #波段数 im_geotrans = raster.GetGeoTransform()#获取仿射矩阵信息 im_proj = raster.GetProjection()#获取投影信息 ResultPath = "路径" WriteTiff(arr, im_width, im_height, im_bands, im_geotrans, im_proj, ResultPath)

或者

def WriteTiff(im_data,inputdir, path): raster = gdal.Open(inputdir) im_width = raster.RasterXSize #栅格矩阵的列数 im_height = raster.RasterYSize #栅格矩阵的行数 im_bands = raster.RasterCount #波段数 im_geotrans = raster.GetGeoTransform()#获取仿射矩阵信息 im_proj = raster.GetProjection()#获取投影信息 if im_proj == "": # 没有坐标系默认使用WGS84 osrs = osr.SpatialReference() osrs.SetWellKnownGeogCS('WGS84') osrs.ExportToWkt() im_proj = osrs.ExportToWkt() if 'int8' in im_data.dtype.name: datatype = gdal.GDT_Byte elif 'int16' in im_data.dtype.name: datatype = gdal.GDT_UInt16 else: datatype = gdal.GDT_Float32 if len(im_data.shape) == 3: im_bands, im_height, im_width = im_data.shape elif len(im_data.shape) == 2: im_data = np.array([im_data]) else: im_bands, (im_height, im_width) = 1, im_data.shape # 创建文件 driver = gdal.GetDriverByName("GTiff") dataset = driver.Create(path, im_width, im_height, im_bands, datatype) if (dataset != None): dataset.SetGeoTransform(im_geotrans) # 写入仿射变换参数 dataset.SetProjection(im_proj) # 写入投影 for i in range(im_bands): dataset.GetRasterBand(i + 1).WriteArray(im_data[i]) del dataset WriteTiff(im_data,inputdir, path)

或者

def WriteTiff(im_data,inputdir, path): raster = gdal.Open(inputdir) im_width = raster.RasterXSize #栅格矩阵的列数 im_height = raster.RasterYSize #栅格矩阵的行数 im_bands = raster.RasterCount #波段数 im_geotrans = raster.GetGeoTransform()#获取仿射矩阵信息 im_proj = raster.GetProjection()#获取投影信息 if 'int8' in im_data.dtype.name: datatype = gdal.GDT_Byte elif 'int16' in im_data.dtype.name: datatype = gdal.GDT_UInt16 else: datatype = gdal.GDT_Float32 if len(im_data.shape) == 3: im_bands, im_height, im_width = im_data.shape elif len(im_data.shape) == 2: im_data = np.array([im_data]) else: im_bands, (im_height, im_width) = 1, im_data.shape # 创建文件 driver = gdal.GetDriverByName("GTiff") dataset = driver.Create(path, im_width, im_height, im_bands, datatype) if (dataset != None): dataset.SetGeoTransform(im_geotrans) # 写入仿射变换参数 # 获取地理坐标系统信息,用于选取需要的地理坐标系统 if im_proj == "": # 如果没有坐标系就用WGS-84 sr = osr.SpatialReference() sr.SetWellKnownGeogCS('WGS84') dataset.SetProjection(sr.ExportToWkt()) else: dataset.SetProjection(im_proj) # 写入投影 for i in range(im_bands): dataset.GetRasterBand(i + 1).WriteArray(im_data[i]) del dataset

3.影像拼接

把要合并的多个tif放在path路径下的文件夹file1、 file2...filen,每个file下文件数量名字都相同

from osgeo import gdal import math def GetExtent(in_fn): ds=gdal.Open(in_fn) geotrans=list(ds.GetGeoTransform()) xsize=ds.RasterXSize ysize=ds.RasterYSize min_x=geotrans[0] max_y=geotrans[3] max_x=geotrans[0]+xsize*geotrans[1] min_y=geotrans[3]+ysize*geotrans[5] ds=None return min_x,max_y,max_x,min_y def mosaic(in_files,output_name,arr_files): os.chdir(in_files) in_files = os.listdir(in_files) in_fn=in_files[0] #获取待镶嵌栅格的最大最小的坐标值 min_x,max_y,max_x,min_y=GetExtent(in_fn) for in_fn in in_files[1:]: minx,maxy,maxx,miny=GetExtent(in_fn) min_x=min(min_x,minx) min_y=min(min_y,miny) max_x=max(max_x,maxx) max_y=max(max_y,maxy) #计算镶嵌后影像的行列号 in_ds=gdal.Open(in_files[0]) geotrans=list(in_ds.GetGeoTransform()) width=geotrans[1] height=geotrans[5] columns=math.ceil((max_x-min_x)/width) rows=math.ceil((max_y-min_y)/(-height)) in_band=in_ds.GetRasterBand(1) driver=gdal.GetDriverByName('GTiff') out_ds=driver.Create(output_name,columns,rows,1,in_band.DataType) out_ds.SetProjection(in_ds.GetProjection()) geotrans[0]=min_x geotrans[3]=max_y out_ds.SetGeoTransform(geotrans) out_band=out_ds.GetRasterBand(1) #定义仿射逆变换 inv_geotrans=gdal.InvGeoTransform(geotrans) #开始逐渐写入 for in_fn in in_files: in_ds=gdal.Open(in_fn) in_gt=in_ds.GetGeoTransform() #仿射逆变换 offset=gdal.ApplyGeoTransform(inv_geotrans,in_gt[0],in_gt[3]) x,y=map(int,offset) # print(x,y) trans=gdal.Transformer(in_ds,out_ds,[])#in_ds是源栅格,out_ds是目标栅格 success,xyz=trans.TransformPoint(False,0,0)#计算in_ds中左上角像元对应out_ds中的行列号 x,y,z=map(int,xyz) # print(x,y,z) data=in_ds.GetRasterBand(1).ReadAsArray() out_band.WriteArray(data,x,y)#x,y是开始写入时左上角像元行列号 del in_ds,out_band,out_ds in_files = 要合并的tif存放的路径 output_name = 输出的tif名称

4.开闭运算去除小斑块

# -*- coding: utf-8 -*- """ Created on Tue Oct 5 12:53:34 2021 @author: Xhpan """ from osgeo import gdal import xarray as xr from skimage import morphology as sm import numpy as np import os def WriteTiff(im_data, im_width, im_height, im_bands, im_geotrans, im_proj, path): if 'int8' in im_data.dtype.name: datatype = gdal.GDT_Byte elif 'int16' in im_data.dtype.name: datatype = gdal.GDT_UInt16 else: datatype = gdal.GDT_Float32 if len(im_data.shape) == 3: im_bands, im_height, im_width = im_data.shape elif len(im_data.shape) == 2: im_data = np.array([im_data]) else: im_bands, (im_height, im_width) = 1, im_data.shape # 创建文件 driver = gdal.GetDriverByName("GTiff") dataset = driver.Create(path, im_width, im_height, im_bands, datatype) if (dataset != None): dataset.SetGeoTransform(im_geotrans) # 写入仿射变换参数 dataset.SetProjection(im_proj) # 写入投影 for i in range(im_bands): dataset.GetRasterBand(i + 1).WriteArray(im_data[i]) del dataset def getBoundary(filename,urbanID,kernel,ResultPath): raster = xr.open_rasterio(filename).data[0,:,:] index1 = np.where(raster != urbanID) index2 = np.where(raster == urbanID) raster[index1] = False raster[index2] = True img_close = sm.closing(raster, kernel) img_open = sm.opening(img_close, kernel) raster = gdal.Open(filename) im_width = raster.RasterXSize #栅格矩阵的列数 im_height = raster.RasterYSize #栅格矩阵的行数 im_bands = raster.RasterCount #波段数 im_geotrans = raster.GetGeoTransform()#获取仿射矩阵信息 im_proj = raster.GetProjection()#获取投影信息 WriteTiff(img_open, im_width, im_height, im_bands, im_geotrans, im_proj, ResultPath) # 获取某目录下所有tif文件 def getTiffFileName(filepath, suffix): L1 = [] L2 = [] for root, dirs, files in os.walk(filepath): # 遍历该文件夹 for file in files: # 遍历刚获得的文件名files (filename, extension) = os.path.splitext(file) # 将文件名拆分为文件名与后缀 if (extension == suffix): # 判断该后缀是否为.c文件 L1.append(filepath + "/" + file) L2.append(filename) return L1, L2 urbanID = 1 filepath = r"D:\Work\doing\CNN_RNN\data\beijing\origindata\urban\1985_2017" kernel = sm.octagon(2, 1) inputPathFiles, inputNames = getTiffFileName(filepath, ".tif") for name in inputNames: filename = filepath + "/" + name + ".tif" ResultPath = r"D:\Work\doing\CNN_RNN\data\beijing\origindata\urban\1985_2017_1" + "/" + name + ".tif" getBoundary(filename,urbanID,kernel,ResultPath) print(filename)

5.时间序列订正

# -*- coding: utf-8 -*- """ Created on Tue Sep 28 13:17:24 2021 @author: Xhpan """ import numpy as np from osgeo import gdal import xarray as xr import os def WriteTiff(im_data, im_width, im_height, im_bands, im_geotrans, im_proj, path): if 'int8' in im_data.dtype.name: datatype = gdal.GDT_Byte elif 'int16' in im_data.dtype.name: datatype = gdal.GDT_UInt16 else: datatype = gdal.GDT_Float32 if len(im_data.shape) == 3: im_bands, im_height, im_width = im_data.shape elif len(im_data.shape) == 2: im_data = np.array([im_data]) else: im_bands, (im_height, im_width) = 1, im_data.shape # 创建文件 driver = gdal.GetDriverByName("GTiff") dataset = driver.Create(path, im_width, im_height, im_bands, datatype) if (dataset != None): dataset.SetGeoTransform(im_geotrans) # 写入仿射变换参数 dataset.SetProjection(im_proj) # 写入投影 for i in range(im_bands): dataset.GetRasterBand(i + 1).WriteArray(im_data[i]) del dataset # 获取某目录下所有tif文件 def getTiffFileName(filepath, suffix): L1 = [] L2 = [] for root, dirs, files in os.walk(filepath): # 遍历该文件夹 for file in files: # 遍历刚获得的文件名files (filename, extension) = os.path.splitext(file) # 将文件名拆分为文件名与后缀 if (extension == suffix): # 判断该后缀是否为.c文件 L1.append(filepath + "/" + file) L2.append(filename) return L1, L2 def timeSeriesCorrection(filepath,outputpath): if not os.path.exists(outputpath): os.makedirs(outputpath) inputPathFiles, inputNames = getTiffFileName(filepath, ".tif") raster = gdal.Open(filepath + "/" + str(inputNames[0]) + ".tif") im_width = raster.RasterXSize #栅格矩阵的列数 im_height = raster.RasterYSize #栅格矩阵的行数 im_bands = raster.RasterCount #波段数 im_geotrans = raster.GetGeoTransform()#获取仿射矩阵信息 im_proj = raster.GetProjection()#获取投影信息 for i in range(len(inputNames)-1): if i == 0: arr1 = xr.open_rasterio(filepath + "/" + str(inputNames[i]) + ".tif").data[0,:,:] arr2 = xr.open_rasterio(filepath + "/" + str(inputNames[i + 1]) + ".tif").data[0,:,:] arr3 = arr1 + arr2 arr3[np.where(arr3 == 2)] = 1 else: arr1 = arr3 arr2 = xr.open_rasterio(filepath + "/" + str(inputNames[i + 1]) + ".tif").data[0,:,:] arr3 = arr1 + arr2 arr3[np.where(arr3 == 2)] = 1 ResultPath = outputpath + "/" + str(inputNames[i + 1]) + ".tif" WriteTiff(arr3, im_width, im_height, im_bands, im_geotrans, im_proj, ResultPath) print(ResultPath) filepath = r"D:\Work\doing\CNN_RNN\data\beijing\origindata\urban\1985_2017_1" outputpath = r"D:\Work\doing\CNN_RNN\data\beijing\origindata\urban\1985_2017_2" timeSeriesCorrection(filepath,outputpath)

6.统一图像的行列号

# -*- coding: utf-8 -*- """ Created on Wed Oct 6 14:26:52 2021 @author: Xhpan """ from osgeo import gdal import math def GetExtent(in_fn): ds=gdal.Open(in_fn) geotrans=list(ds.GetGeoTransform()) xsize=ds.RasterXSize ysize=ds.RasterYSize min_x=geotrans[0] max_y=geotrans[3] max_x=geotrans[0]+xsize*geotrans[1] min_y=geotrans[3]+ysize*geotrans[5] ds=None return min_x,max_y,max_x,min_y def UnifiedLineNumber(in_fn,criterion_fn,output_name): in_ds=gdal.Open(criterion_fn) geotrans=list(in_ds.GetGeoTransform()) width=geotrans[1] height=geotrans[5] # 计算输出图像的行列号 min_x,max_y,max_x,min_y = GetExtent(criterion_fn) columns=math.ceil((max_x-min_x)/width) rows=math.ceil((max_y-min_y)/(-height)) in_band=in_ds.GetRasterBand(1) driver=gdal.GetDriverByName('GTiff') out_ds=driver.Create(output_name,columns,rows,1,in_band.DataType) out_ds.SetProjection(in_ds.GetProjection()) # 计算原图像在新图像位置 min_x1,max_y1,max_x1,min_y1 = GetExtent(in_fn) geotrans[0]=min_x1 geotrans[3]=max_y1 out_ds.SetGeoTransform(geotrans) out_band=out_ds.GetRasterBand(1) #定义仿射逆变换 inv_geotrans=gdal.InvGeoTransform(geotrans) in_ds=gdal.Open(in_fn) in_gt=in_ds.GetGeoTransform() #仿射逆变换 offset=gdal.ApplyGeoTransform(inv_geotrans,in_gt[0],in_gt[3]) x,y=map(int,offset) # print(x,y) trans=gdal.Transformer(in_ds,out_ds,[])#in_ds是源栅格,out_ds是目标栅格 success,xyz=trans.TransformPoint(False,0,0)#计算in_ds中左上角像元对应out_ds中的行列号 x,y,z=map(int,xyz) # print(x,y,z) data=in_ds.GetRasterBand(1).ReadAsArray() out_band.WriteArray(data,x,y)#x,y是开始写入时左上角像元行列号 del in_ds,out_band,out_ds in_fn = "需要处理的tif路径" criterion_fn = "标准的tif路径" output_name = "输出tif路径" UnifiedLineNumber(in_fn,criterion_fn,output_name)

7.矢量按位置(相交和相邻)/属性选择

import os from osgeo import ogr from tqdm import trange def create_shp_by_layer(shp, layer): # 保存结果shp:文件名。layer:原输入shp outputfile = shp if os.access(outputfile, os.F_OK): driver.DeleteDataSource(outputfile) newds = driver.CreateDataSource(outputfile) pt_layer = newds.CopyLayer(layer,'') # print(shp) def totxt(resultname,L): f=open(resultname,"w") for name in L: f.write(str(name)+'\n') f.close() filename = 'Export_Output.shp' resultpath = "result" if not os.path.exists(resultpath): os.makedirs(resultpath) resultpath1 = "resultshp" if not os.path.exists(resultpath1): os.makedirs(resultpath1) driver = ogr.GetDriverByName("ESRI Shapefile") # 读入被选择数据(本身) target_shp = filename target_ds = ogr.Open(target_shp) target_layer = target_ds.GetLayer(0) # 得到第一个layer source_shp = filename source_ds = ogr.Open(source_shp) source_layer = source_ds.GetLayer(0) # 得到第一个layer # 遍历每个数据 for i in trange(source_layer.GetFeatureCount()): source_feats = source_layer.GetFeature(i) source_id = source_feats.GetField('cyid') # 获取每个面cyid字段值 poly = source_feats.GetGeometryRef() # 获取该面的范围 target_layer.SetSpatialFilter(poly) # 选择该面和它相邻的全部要素 shp = resultpath1 + "/" + str(source_id) + ".shp" create_shp_by_layer(shp, target_layer) # 读取输出shp的cyid filter_names = [] result_ds = ogr.Open(shp) result_layer = result_ds.GetLayer(0) # 得到第一个layer for j,fea in enumerate(result_layer): result_feats = result_layer.GetFeature(j) result_id = result_feats.GetField('cyid') if result_id != source_id: filter_names.append(result_id) resultname = resultpath + "/"+ str(source_id) +".txt" totxt(resultname,filter_names) source_layer.SetSpatialFilter(None) source_layer.ResetReading() target_layer.SetSpatialFilter(None) target_layer.ResetReading() result_layer.SetSpatialFilter(None) result_layer.ResetReading()

8.根据矢量裁剪栅格

from osgeo import gdal def clip_raster(in_raster, out_raster, mask_shp): """ :param in_raster: 输入栅格 :param out_raster: 输出栅格 :param mask_shp: 裁剪矢量 :param wkid: wkid :return: """ gdal.Warp(out_raster, in_raster, format='GTiff', dstSRS='EPSG:4326', cutlineDSName=mask_shp, cropToCutline=True, # 按掩膜图层范围裁剪 dstNodata=-9999, outputType=gdal.GDT_Float64) in_raster = "栅格路径" out_raster = r"test.tif" mask_shp = "矢量路径" clip_raster(in_raster, out_raster, mask_shp)

9. ArcGIS将shp按照属性字段进行分割为多个polygon矢量

from osgeo import ogr import os shpfile = r"输入shp" resultpath = r"输出文件" if not os.path.exists(resultpath): os.makedirs(resultpath) driver = ogr.GetDriverByName("ESRI Shapefile") ds = ogr.Open(shpfile) layer = ds.GetLayer(0) for i in range(layer.GetFeatureCount()): source_feats = layer.GetFeature(i) source_id = source_feats.GetField('ID') # 以ID字段命名结果 layer.SetAttributeFilter("ID = {}".format(source_id)) extfile = resultpath + "/" + str(source_id).zfill(2) + ".shp" newds = driver.CreateDataSource(extfile) lyrn = newds.CreateLayer('rect', None, ogr.wkbPolygon) feat = layer.GetNextFeature() while feat is not None: lyrn.CreateFeature(feat) feat = layer.GetNextFeature() newds.Destroy() print(i)

10.分区统计(多tif批量)

import time import geopandas as gpd import rasterio from rasterstats import zonal_stats import pandas as pd from osgeo import gdal import numpy as np def LoadData(filename): file = gdal.Open(filename) if file == None: print(filename + " can't be opened!") return nb = file.RasterCount L = [] for i in range(1, nb + 1): band = file.GetRasterBand(i) background = band.GetNoDataValue() data = band.ReadAsArray() data = data.astype(np.float32) index = np.where(data == background) data[index] = 0 L.append(data) data = np.stack(L,0) if nb == 1: data = data[0,:,:] return data start = time.time() shp_path = '../data/urbandata/shp/basin.shp' stats = ['mean'] # ['min', 'max', 'mean', 'median', 'majority'] shp_driver = gpd.read_file(shp_path) df = shp_driver['ID'].to_frame() names = ["Chen","He","Zhou","Hyde","LUH","Gao"] for name in names: ras_path = "../data/urbandata/" + name + ".tif" ras_driver = rasterio.open(ras_path) array = LoadData(ras_path) array[np.where(array == array[0][0])] = 0 affine = ras_driver.transform zs = zonal_stats(shp_path, array, affine=affine, stats = stats) values = [] for i in range(0,len(zs)): values.append(zs[i][stats[0]]) df['{}'.format(name)] = values print(name) df.to_excel("../urba_mean.xlsx")

__EOF__

本文作者skypanxh
本文链接https://www.cnblogs.com/skypanxh/p/15214516.html
关于博主:评论和私信会在第一时间回复。或者直接私信我。
版权声明:本博客所有文章除特别声明外,均采用 BY-NC-SA 许可协议。转载请注明出处!
声援博主:如果您觉得文章对您有帮助,可以点击文章右下角推荐一下。您的鼓励是博主的最大动力!
posted @   skypanxh  阅读(653)  评论(0编辑  收藏  举报
编辑推荐:
· 基于Microsoft.Extensions.AI核心库实现RAG应用
· Linux系列:如何用heaptrack跟踪.NET程序的非托管内存泄露
· 开发者必知的日志记录最佳实践
· SQL Server 2025 AI相关能力初探
· Linux系列:如何用 C#调用 C方法造成内存泄露
阅读排行:
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· 别再用vector<bool>了!Google高级工程师:这可能是STL最大的设计失误
· 单元测试从入门到精通
· 【硬核科普】Trae如何「偷看」你的代码?零基础破解AI编程运行原理
· 上周热点回顾(3.3-3.9)
点击右上角即可分享
微信分享提示