Python GDAL/OGR常用案例代码总结
安装
推荐使用conda安装python gdal环境,先查询gdal可用版本,再指定版本号,按需安装对应的gdal。
conda search gdal
conda install gdal=version_number
如果运行后面的代码提示"ERROR: proj_create_from_database: Cannot find proj.db",这是说明找不到proj.db,我们需要设置proj的环境变量为proj.db文件所在的目录(可直接在文件目录中搜索proj.db,如果是通过conda安装gdal,则proj.db在conda虚拟目录下的某个Library里):
或者在代码开头定义PROJ_LIB环境变量(更推荐前一种方式):
import os
os.environ["PROJ_LIB"] = "D:\ProgramData\CondaVirtualEnv\py39\Library\share\proj"
案例
import相应模块和常量定义
from osgeo import gdal, ogr, osr, gdalconst
DRIVER_SHAPE = "ESRI Shapefile"
DRIVER_GTIFF = "GTiff"
EPSG_WGS84 = 4326
矢量转栅格
def vector2raster(input_file, ouput_file, template_file, field="", all_touch="False"):
"""
矢量图层转栅格
Params:
input_file: 输入矢量文件
output_file: 输出栅格文件
template_file: 模板栅格文件
field: 输出栅格所用的input_file字段
all_touch: False仅矢量中心所在的栅格设置像元值,True只要和矢量相交的栅格都设置像元值
Ref: https://gdal.org/api/gdal_alg.html#_CPPv419GDALRasterizeLayers12GDALDatasetHiPiiP9OGRLayerH19GDALTransformerFuncPvPdPPc16GDALProgressFuncPv
"""
# 打开模板栅格
data = gdal.Open(template_file, gdalconst.GA_ReadOnly)
# 确定栅格大小
x_size = data.RasterXSize
y_size = data.RasterYSize
# 打开矢量vec_layer
vector = ogr.GetDriverByName(DRIVER_SHAPE).Open(input_file)
vec_layer = vector.GetLayer()
feat_count = vec_layer.GetFeatureCount()
# 创建输出的tiff栅格文件
target = gdal.GetDriverByName(DRIVER_GTIFF).Create(ouput_file, x_size, y_size, 1, gdal.GDT_Byte)
# 设置栅格坐标系与投影
target.SetGeoTransform(data.GetGeoTransform())
target.SetProjection(data.GetProjection())
if field:
gdal.RasterizeLayer(target, [1], vec_layer,
options=["ALL_TOUCHED="+all_touch, "ATTRIBUTE="+field])
else:
burn_values = [1] # 如果没有指定属性字段,则使用burn_value作为输出的像素值
gdal.RasterizeLayer(target, [1], vec_layer, burn_values=burn_values, options=["ALL_TOUCHED="+all_touch])
# 设置NoData
NoData_value = 0
target.GetRasterBand(1).SetNoDataValue(NoData_value)
target.GetRasterBand(1).FlushCache()
target = None
# 调用
vector2raster(
input_file="data/input/vec.shp",
ouput_file="data/output/vec2raster.tif",
template_file="data/input/template.tif",
field="Height"
)
栅格转矢量(多边形)
def raster2polygon(input_raster, output_file, layer_name):
"""
栅格转矢量多边形
Params:
input_raster: 输入栅格文件
output_file: 输出shapefile文件
layer_name: 矢量图层名称
Ref: https://gdal.org/api/gdal_alg.html#_CPPv415GDALFPolygonize15GDALRasterBandH15GDALRasterBandH9OGRLayerHiPPc16GDALProgressFuncPv
"""
data = gdal.Open(input_raster)
src_band = data.GetRasterBand(1)
srs = osr.SpatialReference()
srs.ImportFromWkt(data.GetProjection()) # 矢量的空间参考和栅格保持一致
target = ogr.GetDriverByName(DRIVER_SHAPE).CreateDataSource(output_file)
target_layer = target.CreateLayer(layer_name, srs=srs, geom_type=ogr.wkbPolygon)
# 给目标shp文件添加一个字段,存储原始栅格的pixel value
field = ogr.FieldDefn('value',ogr.OFTReal)
target_layer.CreateField(field)
gdal.Polygonize(src_band, src_band, target_layer, 0, [])
target = None
# 调用
raster2polygon(
input_raster="data/input/raster.tif",
output_file="data/output/poly.shp",
layer_name="poly",
)
矢量叠加
def intersect(input_file1, input_file2, output_file, output_layer_name=""):
"""
矢量叠加
Params:
input_file1: 输入矢量文件1
input_file2: 输入矢量文件2
output_file: 输出矢量文件路径
output_layer_name: 输出图层名
"""
driver = ogr.GetDriverByName(DRIVER_SHAPE)
shp1 = driver.Open(input_file1, gdalconst.GA_ReadOnly)
shp2 = driver.Open(input_file2, gdalconst.GA_ReadOnly)
src_layer1 = shp1.GetLayer()
src_layer2 = shp2.GetLayer()
srs1 = src_layer1.GetSpatialRef()
srs2 = src_layer2.GetSpatialRef()
if srs1.GetAttrValue('AUTHORITY',1) != srs2.GetAttrValue('AUTHORITY',1):
print("空间参考不一致!")
return
target_ds = ogr.GetDriverByName(DRIVER_SHAPE).CreateDataSource(output_file)
target_layer = target_ds.CreateLayer(output_layer_name, srs1, geom_type=ogr.wkbPolygon, options=["ENCODING=UTF-8"]) # 设置编码为UTF-8,防止中文出现乱码
for feat1 in src_layer1:
geom1 = feat1.GetGeometryRef()
for feat2 in src_layer2:
geom2 = feat2.GetGeometryRef()
if not geom1.Intersects(geom2):
continue
intersect = geom1.Intersection(geom2)
feature = ogr.Feature(target_layer.GetLayerDefn())
feature.SetGeometry(intersect)
target_layer.CreateFeature(feature)
# 清理引用
target_layer = None
ds = None
# 调用
intersect(
input_file1="data/input/vec1.shp",
input_file2="data/input/vec2.shp",
output_file="data/output/intersect.shp",
output_layer_name="intersect",
)
矢量擦除
def erase(erased_file, eraser_file, output_file, output_layer_name=""):
"""
矢量擦除
Params:
erased_file: 被擦除的矢量文件路径
eraser_file: 擦除矢量文件路径
output_file: 输出文件路径
output_layer_name: 输出图层名
"""
driver = ogr.GetDriverByName(DRIVER_SHAPE)
shp1 = driver.Open(erased_file, gdalconst.GA_ReadOnly)
shp2 = driver.Open(eraser_file, gdalconst.GA_ReadOnly)
src_layer1 = shp1.GetLayer()
src_layer2 = shp2.GetLayer()
srs1 = src_layer1.GetSpatialRef()
srs2 = src_layer2.GetSpatialRef()
if srs1.GetAttrValue('AUTHORITY',1) != srs2.GetAttrValue('AUTHORITY',1):
print("空间参考不一致!")
return
target_ds = ogr.GetDriverByName(DRIVER_SHAPE).CreateDataSource(output_file)
target_layer = target_ds.CreateLayer(output_layer_name, srs1, geom_type=ogr.wkbPolygon, options=["ENCODING=UTF-8"])
ds = src_layer1.Erase(src_layer2, target_layer)
ds = None
# 调用
erase(
erased_file="data/input/vec1.shp",
eraser_file="data/input/vec2.shp",
output_file="data/output/erase.shp",
)
缓冲区分析(以点为例)
def point_buffer(point, radius, output_file, layer_name):
"""
对点集建立缓冲区
Params:
point: 输入点坐标
range: 缓冲区半径
output_file: 缓冲区输出文件路径
layer_name: 输出图层名
"""
# 创建图层
srs = osr.SpatialReference()
srs.ImportFromEPSG(EPSG_WGS84) # 选择坐标系,注意地理坐标系和投影坐标系的radius单位不同,前者为度,后者为米
ds = ogr.GetDriverByName(DRIVER_SHAPE).CreateDataSource(output_file)
target_layer = ds.CreateLayer(layer_name, srs, geom_type=ogr.wkbPolygon, options=["ENCODING=UTF-8"])
# 创建geometry
wkt = "POINT (%f %f)" % (point[0], point[1])
geom = ogr.CreateGeometryFromWkt(wkt)
poly = geom.Buffer(radius)
# 创建feature
feature = ogr.Feature(target_layer.GetLayerDefn())
feature.SetGeometry(poly)
target_layer.CreateFeature(feature)
target_layer = None
ds = None
# 调用
point_buffer(
point=(121.531921, 25.013540),
range=0.1,
output_file="data/output/buffer.shp",
layer_name="buffer",
)
视域分析
def generate_viewshed(input_raster, output_file, location, height):
"""
视域分析
Params:
input_raster: 输入栅格文件
output_file: 视域分析输出文件
location: 观察者所在坐标
height: 观察者所在的高程
Ref: https://gdal.org/api/gdal_alg.html#_CPPv420GDALViewshedGenerate15GDALRasterBandHPKcPKc12CSLConstListddddddddd16GDALViewshedModed16GDALProgressFuncPv22GDALViewshedOutputType12CSLConstList
"""
raster = gdal.Open(input_raster)
band = raster.GetRasterBand(1)
gdal.ViewshedGenerate(
srcBand=band,
driverName=DRIVER_GTIFF,
targetRasterName=output_file,
creationOptions=[],
observerX=location[0],
observerY=location[1],
observerHeight=height,
targetHeight=0,
visibleVal=255,
invisibleVal=0,
outOfRangeVal=0,
noDataVal=0,
dfCurvCoeff=0.85714,
mode=2,
maxDistance=0,
)
# 调用
generate_viewshed(
input_raster="data/input/raster.tif",
output_file="data/output/viewshed.tif",
location=(120.568740, 32.013540),
height=2,
)
总结
相比C++,Python GDAL还是很容易上手和应用的,官方提供了比较详细的API使用说明,对于一些参数较多的函数,可直接查相应的C++函数签名,这样对参数的理解会更深。
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】凌霞软件回馈社区,博客园 & 1Panel & Halo 联合会员上线
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】博客园社区专享云产品让利特惠,阿里云新客6.5折上折
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 清华大学推出第四讲使用 DeepSeek + DeepResearch 让科研像聊天一样简单!
· 实操Deepseek接入个人知识库
· 易语言 —— 开山篇
· 一个费力不讨好的项目,让我损失了近一半的绩效!
· 【全网最全教程】使用最强DeepSeekR1+联网的火山引擎,没有生成长度限制,DeepSeek本体