GDAL常用代码
GDAL常用代码
1.导入数据
from osgeo import gdal
import numpy as np
def LoadData(filename):
file = gdal.Open(filename)
if file == None:
print(filename + " can't be opened!")
return
nb = file.RasterCount
L = []
for i in range(1, nb + 1):
band = file.GetRasterBand(i)
background = band.GetNoDataValue()
data = band.ReadAsArray()
data = data.astype(np.float32)
index = np.where(data == background)
data[index] = 0
L.append(data)
data = np.stack(L,0)
if nb == 1:
data = data[0,:,:]
return data
或者
import xarray as xr
arr = xr.open_rasterio("路径").data[0,:,:]
2.写出数据
def WriteTiff(im_data, im_width, im_height, im_bands, im_geotrans, im_proj, path):
if 'int8' in im_data.dtype.name:
datatype = gdal.GDT_Byte
elif 'int16' in im_data.dtype.name:
datatype = gdal.GDT_UInt16
else:
datatype = gdal.GDT_Float32
if len(im_data.shape) == 3:
im_bands, im_height, im_width = im_data.shape
elif len(im_data.shape) == 2:
im_data = np.array([im_data])
else:
im_bands, (im_height, im_width) = 1, im_data.shape
# 创建文件
driver = gdal.GetDriverByName("GTiff")
dataset = driver.Create(path, im_width, im_height, im_bands, datatype)
if (dataset != None):
dataset.SetGeoTransform(im_geotrans) # 写入仿射变换参数
dataset.SetProjection(im_proj) # 写入投影
for i in range(im_bands):
dataset.GetRasterBand(i + 1).WriteArray(im_data[i])
del dataset
raster = gdal.Open(path)
im_width = raster.RasterXSize #栅格矩阵的列数
im_height = raster.RasterYSize #栅格矩阵的行数
im_bands = raster.RasterCount #波段数
im_geotrans = raster.GetGeoTransform()#获取仿射矩阵信息
im_proj = raster.GetProjection()#获取投影信息
ResultPath = "路径"
WriteTiff(arr, im_width, im_height, im_bands, im_geotrans, im_proj, ResultPath)
或者
def WriteTiff(im_data,inputdir, path):
raster = gdal.Open(inputdir)
im_width = raster.RasterXSize #栅格矩阵的列数
im_height = raster.RasterYSize #栅格矩阵的行数
im_bands = raster.RasterCount #波段数
im_geotrans = raster.GetGeoTransform()#获取仿射矩阵信息
im_proj = raster.GetProjection()#获取投影信息
if im_proj == "": # 没有坐标系默认使用WGS84
osrs = osr.SpatialReference()
osrs.SetWellKnownGeogCS('WGS84')
osrs.ExportToWkt()
im_proj = osrs.ExportToWkt()
if 'int8' in im_data.dtype.name:
datatype = gdal.GDT_Byte
elif 'int16' in im_data.dtype.name:
datatype = gdal.GDT_UInt16
else:
datatype = gdal.GDT_Float32
if len(im_data.shape) == 3:
im_bands, im_height, im_width = im_data.shape
elif len(im_data.shape) == 2:
im_data = np.array([im_data])
else:
im_bands, (im_height, im_width) = 1, im_data.shape
# 创建文件
driver = gdal.GetDriverByName("GTiff")
dataset = driver.Create(path, im_width, im_height, im_bands, datatype)
if (dataset != None):
dataset.SetGeoTransform(im_geotrans) # 写入仿射变换参数
dataset.SetProjection(im_proj) # 写入投影
for i in range(im_bands):
dataset.GetRasterBand(i + 1).WriteArray(im_data[i])
del dataset
WriteTiff(im_data,inputdir, path)
或者
def WriteTiff(im_data,inputdir, path):
raster = gdal.Open(inputdir)
im_width = raster.RasterXSize #栅格矩阵的列数
im_height = raster.RasterYSize #栅格矩阵的行数
im_bands = raster.RasterCount #波段数
im_geotrans = raster.GetGeoTransform()#获取仿射矩阵信息
im_proj = raster.GetProjection()#获取投影信息
if 'int8' in im_data.dtype.name:
datatype = gdal.GDT_Byte
elif 'int16' in im_data.dtype.name:
datatype = gdal.GDT_UInt16
else:
datatype = gdal.GDT_Float32
if len(im_data.shape) == 3:
im_bands, im_height, im_width = im_data.shape
elif len(im_data.shape) == 2:
im_data = np.array([im_data])
else:
im_bands, (im_height, im_width) = 1, im_data.shape
# 创建文件
driver = gdal.GetDriverByName("GTiff")
dataset = driver.Create(path, im_width, im_height, im_bands, datatype)
if (dataset != None):
dataset.SetGeoTransform(im_geotrans) # 写入仿射变换参数
# 获取地理坐标系统信息,用于选取需要的地理坐标系统
if im_proj == "":
# 如果没有坐标系就用WGS-84
sr = osr.SpatialReference()
sr.SetWellKnownGeogCS('WGS84')
dataset.SetProjection(sr.ExportToWkt())
else:
dataset.SetProjection(im_proj) # 写入投影
for i in range(im_bands):
dataset.GetRasterBand(i + 1).WriteArray(im_data[i])
del dataset
3.影像拼接
把要合并的多个tif放在path路径下的文件夹file1、 file2...filen,每个file下文件数量名字都相同
from osgeo import gdal
import math
def GetExtent(in_fn):
ds=gdal.Open(in_fn)
geotrans=list(ds.GetGeoTransform())
xsize=ds.RasterXSize
ysize=ds.RasterYSize
min_x=geotrans[0]
max_y=geotrans[3]
max_x=geotrans[0]+xsize*geotrans[1]
min_y=geotrans[3]+ysize*geotrans[5]
ds=None
return min_x,max_y,max_x,min_y
def mosaic(in_files,output_name,arr_files):
os.chdir(in_files)
in_files = os.listdir(in_files)
in_fn=in_files[0]
#获取待镶嵌栅格的最大最小的坐标值
min_x,max_y,max_x,min_y=GetExtent(in_fn)
for in_fn in in_files[1:]:
minx,maxy,maxx,miny=GetExtent(in_fn)
min_x=min(min_x,minx)
min_y=min(min_y,miny)
max_x=max(max_x,maxx)
max_y=max(max_y,maxy)
#计算镶嵌后影像的行列号
in_ds=gdal.Open(in_files[0])
geotrans=list(in_ds.GetGeoTransform())
width=geotrans[1]
height=geotrans[5]
columns=math.ceil((max_x-min_x)/width)
rows=math.ceil((max_y-min_y)/(-height))
in_band=in_ds.GetRasterBand(1)
driver=gdal.GetDriverByName('GTiff')
out_ds=driver.Create(output_name,columns,rows,1,in_band.DataType)
out_ds.SetProjection(in_ds.GetProjection())
geotrans[0]=min_x
geotrans[3]=max_y
out_ds.SetGeoTransform(geotrans)
out_band=out_ds.GetRasterBand(1)
#定义仿射逆变换
inv_geotrans=gdal.InvGeoTransform(geotrans)
#开始逐渐写入
for in_fn in in_files:
in_ds=gdal.Open(in_fn)
in_gt=in_ds.GetGeoTransform()
#仿射逆变换
offset=gdal.ApplyGeoTransform(inv_geotrans,in_gt[0],in_gt[3])
x,y=map(int,offset)
# print(x,y)
trans=gdal.Transformer(in_ds,out_ds,[])#in_ds是源栅格,out_ds是目标栅格
success,xyz=trans.TransformPoint(False,0,0)#计算in_ds中左上角像元对应out_ds中的行列号
x,y,z=map(int,xyz)
# print(x,y,z)
data=in_ds.GetRasterBand(1).ReadAsArray()
out_band.WriteArray(data,x,y)#x,y是开始写入时左上角像元行列号
del in_ds,out_band,out_ds
in_files = 要合并的tif存放的路径
output_name = 输出的tif名称
4.开闭运算去除小斑块
# -*- coding: utf-8 -*-
"""
Created on Tue Oct 5 12:53:34 2021
@author: Xhpan
"""
from osgeo import gdal
import xarray as xr
from skimage import morphology as sm
import numpy as np
import os
def WriteTiff(im_data, im_width, im_height, im_bands, im_geotrans, im_proj, path):
if 'int8' in im_data.dtype.name:
datatype = gdal.GDT_Byte
elif 'int16' in im_data.dtype.name:
datatype = gdal.GDT_UInt16
else:
datatype = gdal.GDT_Float32
if len(im_data.shape) == 3:
im_bands, im_height, im_width = im_data.shape
elif len(im_data.shape) == 2:
im_data = np.array([im_data])
else:
im_bands, (im_height, im_width) = 1, im_data.shape
# 创建文件
driver = gdal.GetDriverByName("GTiff")
dataset = driver.Create(path, im_width, im_height, im_bands, datatype)
if (dataset != None):
dataset.SetGeoTransform(im_geotrans) # 写入仿射变换参数
dataset.SetProjection(im_proj) # 写入投影
for i in range(im_bands):
dataset.GetRasterBand(i + 1).WriteArray(im_data[i])
del dataset
def getBoundary(filename,urbanID,kernel,ResultPath):
raster = xr.open_rasterio(filename).data[0,:,:]
index1 = np.where(raster != urbanID)
index2 = np.where(raster == urbanID)
raster[index1] = False
raster[index2] = True
img_close = sm.closing(raster, kernel)
img_open = sm.opening(img_close, kernel)
raster = gdal.Open(filename)
im_width = raster.RasterXSize #栅格矩阵的列数
im_height = raster.RasterYSize #栅格矩阵的行数
im_bands = raster.RasterCount #波段数
im_geotrans = raster.GetGeoTransform()#获取仿射矩阵信息
im_proj = raster.GetProjection()#获取投影信息
WriteTiff(img_open, im_width, im_height, im_bands, im_geotrans, im_proj, ResultPath)
# 获取某目录下所有tif文件
def getTiffFileName(filepath, suffix):
L1 = []
L2 = []
for root, dirs, files in os.walk(filepath): # 遍历该文件夹
for file in files: # 遍历刚获得的文件名files
(filename, extension) = os.path.splitext(file) # 将文件名拆分为文件名与后缀
if (extension == suffix): # 判断该后缀是否为.c文件
L1.append(filepath + "/" + file)
L2.append(filename)
return L1, L2
urbanID = 1
filepath = r"D:\Work\doing\CNN_RNN\data\beijing\origindata\urban\1985_2017"
kernel = sm.octagon(2, 1)
inputPathFiles, inputNames = getTiffFileName(filepath, ".tif")
for name in inputNames:
filename = filepath + "/" + name + ".tif"
ResultPath = r"D:\Work\doing\CNN_RNN\data\beijing\origindata\urban\1985_2017_1" + "/" + name + ".tif"
getBoundary(filename,urbanID,kernel,ResultPath)
print(filename)
5.时间序列订正
# -*- coding: utf-8 -*-
"""
Created on Tue Sep 28 13:17:24 2021
@author: Xhpan
"""
import numpy as np
from osgeo import gdal
import xarray as xr
import os
def WriteTiff(im_data, im_width, im_height, im_bands, im_geotrans, im_proj, path):
if 'int8' in im_data.dtype.name:
datatype = gdal.GDT_Byte
elif 'int16' in im_data.dtype.name:
datatype = gdal.GDT_UInt16
else:
datatype = gdal.GDT_Float32
if len(im_data.shape) == 3:
im_bands, im_height, im_width = im_data.shape
elif len(im_data.shape) == 2:
im_data = np.array([im_data])
else:
im_bands, (im_height, im_width) = 1, im_data.shape
# 创建文件
driver = gdal.GetDriverByName("GTiff")
dataset = driver.Create(path, im_width, im_height, im_bands, datatype)
if (dataset != None):
dataset.SetGeoTransform(im_geotrans) # 写入仿射变换参数
dataset.SetProjection(im_proj) # 写入投影
for i in range(im_bands):
dataset.GetRasterBand(i + 1).WriteArray(im_data[i])
del dataset
# 获取某目录下所有tif文件
def getTiffFileName(filepath, suffix):
L1 = []
L2 = []
for root, dirs, files in os.walk(filepath): # 遍历该文件夹
for file in files: # 遍历刚获得的文件名files
(filename, extension) = os.path.splitext(file) # 将文件名拆分为文件名与后缀
if (extension == suffix): # 判断该后缀是否为.c文件
L1.append(filepath + "/" + file)
L2.append(filename)
return L1, L2
def timeSeriesCorrection(filepath,outputpath):
if not os.path.exists(outputpath):
os.makedirs(outputpath)
inputPathFiles, inputNames = getTiffFileName(filepath, ".tif")
raster = gdal.Open(filepath + "/" + str(inputNames[0]) + ".tif")
im_width = raster.RasterXSize #栅格矩阵的列数
im_height = raster.RasterYSize #栅格矩阵的行数
im_bands = raster.RasterCount #波段数
im_geotrans = raster.GetGeoTransform()#获取仿射矩阵信息
im_proj = raster.GetProjection()#获取投影信息
for i in range(len(inputNames)-1):
if i == 0:
arr1 = xr.open_rasterio(filepath + "/" + str(inputNames[i]) + ".tif").data[0,:,:]
arr2 = xr.open_rasterio(filepath + "/" + str(inputNames[i + 1]) + ".tif").data[0,:,:]
arr3 = arr1 + arr2
arr3[np.where(arr3 == 2)] = 1
else:
arr1 = arr3
arr2 = xr.open_rasterio(filepath + "/" + str(inputNames[i + 1]) + ".tif").data[0,:,:]
arr3 = arr1 + arr2
arr3[np.where(arr3 == 2)] = 1
ResultPath = outputpath + "/" + str(inputNames[i + 1]) + ".tif"
WriteTiff(arr3, im_width, im_height, im_bands, im_geotrans, im_proj, ResultPath)
print(ResultPath)
filepath = r"D:\Work\doing\CNN_RNN\data\beijing\origindata\urban\1985_2017_1"
outputpath = r"D:\Work\doing\CNN_RNN\data\beijing\origindata\urban\1985_2017_2"
timeSeriesCorrection(filepath,outputpath)
6.统一图像的行列号
# -*- coding: utf-8 -*-
"""
Created on Wed Oct 6 14:26:52 2021
@author: Xhpan
"""
from osgeo import gdal
import math
def GetExtent(in_fn):
ds=gdal.Open(in_fn)
geotrans=list(ds.GetGeoTransform())
xsize=ds.RasterXSize
ysize=ds.RasterYSize
min_x=geotrans[0]
max_y=geotrans[3]
max_x=geotrans[0]+xsize*geotrans[1]
min_y=geotrans[3]+ysize*geotrans[5]
ds=None
return min_x,max_y,max_x,min_y
def UnifiedLineNumber(in_fn,criterion_fn,output_name):
in_ds=gdal.Open(criterion_fn)
geotrans=list(in_ds.GetGeoTransform())
width=geotrans[1]
height=geotrans[5]
# 计算输出图像的行列号
min_x,max_y,max_x,min_y = GetExtent(criterion_fn)
columns=math.ceil((max_x-min_x)/width)
rows=math.ceil((max_y-min_y)/(-height))
in_band=in_ds.GetRasterBand(1)
driver=gdal.GetDriverByName('GTiff')
out_ds=driver.Create(output_name,columns,rows,1,in_band.DataType)
out_ds.SetProjection(in_ds.GetProjection())
# 计算原图像在新图像位置
min_x1,max_y1,max_x1,min_y1 = GetExtent(in_fn)
geotrans[0]=min_x1
geotrans[3]=max_y1
out_ds.SetGeoTransform(geotrans)
out_band=out_ds.GetRasterBand(1)
#定义仿射逆变换
inv_geotrans=gdal.InvGeoTransform(geotrans)
in_ds=gdal.Open(in_fn)
in_gt=in_ds.GetGeoTransform()
#仿射逆变换
offset=gdal.ApplyGeoTransform(inv_geotrans,in_gt[0],in_gt[3])
x,y=map(int,offset)
# print(x,y)
trans=gdal.Transformer(in_ds,out_ds,[])#in_ds是源栅格,out_ds是目标栅格
success,xyz=trans.TransformPoint(False,0,0)#计算in_ds中左上角像元对应out_ds中的行列号
x,y,z=map(int,xyz)
# print(x,y,z)
data=in_ds.GetRasterBand(1).ReadAsArray()
out_band.WriteArray(data,x,y)#x,y是开始写入时左上角像元行列号
del in_ds,out_band,out_ds
in_fn = "需要处理的tif路径"
criterion_fn = "标准的tif路径"
output_name = "输出tif路径"
UnifiedLineNumber(in_fn,criterion_fn,output_name)
7.矢量按位置(相交和相邻)/属性选择
import os
from osgeo import ogr
from tqdm import trange
def create_shp_by_layer(shp, layer): # 保存结果shp:文件名。layer:原输入shp
outputfile = shp
if os.access(outputfile, os.F_OK):
driver.DeleteDataSource(outputfile)
newds = driver.CreateDataSource(outputfile)
pt_layer = newds.CopyLayer(layer,'')
# print(shp)
def totxt(resultname,L):
f=open(resultname,"w")
for name in L:
f.write(str(name)+'\n')
f.close()
filename = 'Export_Output.shp'
resultpath = "result"
if not os.path.exists(resultpath):
os.makedirs(resultpath)
resultpath1 = "resultshp"
if not os.path.exists(resultpath1):
os.makedirs(resultpath1)
driver = ogr.GetDriverByName("ESRI Shapefile")
# 读入被选择数据(本身)
target_shp = filename
target_ds = ogr.Open(target_shp)
target_layer = target_ds.GetLayer(0) # 得到第一个layer
source_shp = filename
source_ds = ogr.Open(source_shp)
source_layer = source_ds.GetLayer(0) # 得到第一个layer
# 遍历每个数据
for i in trange(source_layer.GetFeatureCount()):
source_feats = source_layer.GetFeature(i)
source_id = source_feats.GetField('cyid') # 获取每个面cyid字段值
poly = source_feats.GetGeometryRef() # 获取该面的范围
target_layer.SetSpatialFilter(poly) # 选择该面和它相邻的全部要素
shp = resultpath1 + "/" + str(source_id) + ".shp"
create_shp_by_layer(shp, target_layer)
# 读取输出shp的cyid
filter_names = []
result_ds = ogr.Open(shp)
result_layer = result_ds.GetLayer(0) # 得到第一个layer
for j,fea in enumerate(result_layer):
result_feats = result_layer.GetFeature(j)
result_id = result_feats.GetField('cyid')
if result_id != source_id:
filter_names.append(result_id)
resultname = resultpath + "/"+ str(source_id) +".txt"
totxt(resultname,filter_names)
source_layer.SetSpatialFilter(None)
source_layer.ResetReading()
target_layer.SetSpatialFilter(None)
target_layer.ResetReading()
result_layer.SetSpatialFilter(None)
result_layer.ResetReading()
8.根据矢量裁剪栅格
from osgeo import gdal
def clip_raster(in_raster, out_raster, mask_shp):
"""
:param in_raster: 输入栅格
:param out_raster: 输出栅格
:param mask_shp: 裁剪矢量
:param wkid: wkid
:return:
"""
gdal.Warp(out_raster,
in_raster,
format='GTiff',
dstSRS='EPSG:4326',
cutlineDSName=mask_shp,
cropToCutline=True, # 按掩膜图层范围裁剪
dstNodata=-9999,
outputType=gdal.GDT_Float64)
in_raster = "栅格路径"
out_raster = r"test.tif"
mask_shp = "矢量路径"
clip_raster(in_raster, out_raster, mask_shp)
9. ArcGIS将shp按照属性字段进行分割为多个polygon矢量
from osgeo import ogr
import os
shpfile = r"输入shp"
resultpath = r"输出文件"
if not os.path.exists(resultpath):
os.makedirs(resultpath)
driver = ogr.GetDriverByName("ESRI Shapefile")
ds = ogr.Open(shpfile)
layer = ds.GetLayer(0)
for i in range(layer.GetFeatureCount()):
source_feats = layer.GetFeature(i)
source_id = source_feats.GetField('ID') # 以ID字段命名结果
layer.SetAttributeFilter("ID = {}".format(source_id))
extfile = resultpath + "/" + str(source_id).zfill(2) + ".shp"
newds = driver.CreateDataSource(extfile)
lyrn = newds.CreateLayer('rect', None, ogr.wkbPolygon)
feat = layer.GetNextFeature()
while feat is not None:
lyrn.CreateFeature(feat)
feat = layer.GetNextFeature()
newds.Destroy()
print(i)
10.分区统计(多tif批量)
import time
import geopandas as gpd
import rasterio
from rasterstats import zonal_stats
import pandas as pd
from osgeo import gdal
import numpy as np
def LoadData(filename):
file = gdal.Open(filename)
if file == None:
print(filename + " can't be opened!")
return
nb = file.RasterCount
L = []
for i in range(1, nb + 1):
band = file.GetRasterBand(i)
background = band.GetNoDataValue()
data = band.ReadAsArray()
data = data.astype(np.float32)
index = np.where(data == background)
data[index] = 0
L.append(data)
data = np.stack(L,0)
if nb == 1:
data = data[0,:,:]
return data
start = time.time()
shp_path = '../data/urbandata/shp/basin.shp'
stats = ['mean'] # ['min', 'max', 'mean', 'median', 'majority']
shp_driver = gpd.read_file(shp_path)
df = shp_driver['ID'].to_frame()
names = ["Chen","He","Zhou","Hyde","LUH","Gao"]
for name in names:
ras_path = "../data/urbandata/" + name + ".tif"
ras_driver = rasterio.open(ras_path)
array = LoadData(ras_path)
array[np.where(array == array[0][0])] = 0
affine = ras_driver.transform
zs = zonal_stats(shp_path, array, affine=affine, stats = stats)
values = []
for i in range(0,len(zs)):
values.append(zs[i][stats[0]])
df['{}'.format(name)] = values
print(name)
df.to_excel("../urba_mean.xlsx")
__EOF__

本文作者:skypanxh
本文链接:https://www.cnblogs.com/skypanxh/p/15214516.html
关于博主:评论和私信会在第一时间回复。或者直接私信我。
版权声明:本博客所有文章除特别声明外,均采用 BY-NC-SA 许可协议。转载请注明出处!
声援博主:如果您觉得文章对您有帮助,可以点击文章右下角【推荐】一下。您的鼓励是博主的最大动力!
本文链接:https://www.cnblogs.com/skypanxh/p/15214516.html
关于博主:评论和私信会在第一时间回复。或者直接私信我。
版权声明:本博客所有文章除特别声明外,均采用 BY-NC-SA 许可协议。转载请注明出处!
声援博主:如果您觉得文章对您有帮助,可以点击文章右下角【推荐】一下。您的鼓励是博主的最大动力!
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 基于Microsoft.Extensions.AI核心库实现RAG应用
· Linux系列:如何用heaptrack跟踪.NET程序的非托管内存泄露
· 开发者必知的日志记录最佳实践
· SQL Server 2025 AI相关能力初探
· Linux系列:如何用 C#调用 C方法造成内存泄露
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· 别再用vector<bool>了!Google高级工程师:这可能是STL最大的设计失误
· 单元测试从入门到精通
· 【硬核科普】Trae如何「偷看」你的代码?零基础破解AI编程运行原理
· 上周热点回顾(3.3-3.9)