Python GDAL/OGR常用案例代码总结

安装

推荐使用conda安装python gdal环境,先查询gdal可用版本,再指定版本号,按需安装对应的gdal。

conda search gdal
conda install gdal=version_number

如果运行后面的代码提示"ERROR: proj_create_from_database: Cannot find proj.db",这是说明找不到proj.db,我们需要设置proj的环境变量为proj.db文件所在的目录(可直接在文件目录中搜索proj.db,如果是通过conda安装gdal,则proj.db在conda虚拟目录下的某个Library里):

或者在代码开头定义PROJ_LIB环境变量(更推荐前一种方式):

import os

os.environ["PROJ_LIB"] = "D:\ProgramData\CondaVirtualEnv\py39\Library\share\proj"

案例

import相应模块和常量定义

from osgeo import gdal, ogr, osr, gdalconst

DRIVER_SHAPE = "ESRI Shapefile"
DRIVER_GTIFF = "GTiff"
EPSG_WGS84 = 4326

矢量转栅格

def vector2raster(input_file, ouput_file, template_file, field="", all_touch="False"):
    """
    矢量图层转栅格

    Params:
        input_file: 输入矢量文件
        output_file: 输出栅格文件
        template_file: 模板栅格文件
        field: 输出栅格所用的input_file字段
        all_touch: False仅矢量中心所在的栅格设置像元值,True只要和矢量相交的栅格都设置像元值

    Ref: https://gdal.org/api/gdal_alg.html#_CPPv419GDALRasterizeLayers12GDALDatasetHiPiiP9OGRLayerH19GDALTransformerFuncPvPdPPc16GDALProgressFuncPv
    """
    # 打开模板栅格
    data = gdal.Open(template_file, gdalconst.GA_ReadOnly)
    # 确定栅格大小
    x_size = data.RasterXSize
    y_size = data.RasterYSize
    # 打开矢量vec_layer
    vector = ogr.GetDriverByName(DRIVER_SHAPE).Open(input_file)
    vec_layer = vector.GetLayer()
    feat_count = vec_layer.GetFeatureCount()
    # 创建输出的tiff栅格文件
    target = gdal.GetDriverByName(DRIVER_GTIFF).Create(ouput_file, x_size, y_size, 1, gdal.GDT_Byte)
    # 设置栅格坐标系与投影
    target.SetGeoTransform(data.GetGeoTransform())
    target.SetProjection(data.GetProjection())

    if field:
        gdal.RasterizeLayer(target, [1], vec_layer, 
        options=["ALL_TOUCHED="+all_touch, "ATTRIBUTE="+field])
    else:
        burn_values = [1] # 如果没有指定属性字段,则使用burn_value作为输出的像素值
        gdal.RasterizeLayer(target, [1], vec_layer, burn_values=burn_values, options=["ALL_TOUCHED="+all_touch])

    # 设置NoData
    NoData_value = 0
    target.GetRasterBand(1).SetNoDataValue(NoData_value)
    target.GetRasterBand(1).FlushCache()
    target = None


# 调用
vector2raster(
    input_file="data/input/vec.shp", 
    ouput_file="data/output/vec2raster.tif", 
    template_file="data/input/template.tif", 
    field="Height"
)

栅格转矢量(多边形)

def raster2polygon(input_raster, output_file, layer_name):
    """
    栅格转矢量多边形

    Params:
        input_raster: 输入栅格文件
        output_file: 输出shapefile文件
        layer_name: 矢量图层名称

    Ref: https://gdal.org/api/gdal_alg.html#_CPPv415GDALFPolygonize15GDALRasterBandH15GDALRasterBandH9OGRLayerHiPPc16GDALProgressFuncPv
    """
    data = gdal.Open(input_raster)
    src_band = data.GetRasterBand(1)
    srs = osr.SpatialReference()
    srs.ImportFromWkt(data.GetProjection()) # 矢量的空间参考和栅格保持一致
    target = ogr.GetDriverByName(DRIVER_SHAPE).CreateDataSource(output_file)
    target_layer = target.CreateLayer(layer_name, srs=srs, geom_type=ogr.wkbPolygon)
    # 给目标shp文件添加一个字段,存储原始栅格的pixel value
    field = ogr.FieldDefn('value',ogr.OFTReal)  
    target_layer.CreateField(field)
    gdal.Polygonize(src_band, src_band, target_layer, 0, [])
    target = None


# 调用
raster2polygon(
    input_raster="data/input/raster.tif",
    output_file="data/output/poly.shp",
    layer_name="poly",
)

矢量叠加

def intersect(input_file1, input_file2, output_file, output_layer_name=""):
    """
    矢量叠加

    Params:
        input_file1: 输入矢量文件1
        input_file2: 输入矢量文件2
        output_file: 输出矢量文件路径
        output_layer_name: 输出图层名
    """
    driver = ogr.GetDriverByName(DRIVER_SHAPE)
    shp1 = driver.Open(input_file1, gdalconst.GA_ReadOnly)
    shp2 = driver.Open(input_file2, gdalconst.GA_ReadOnly)
    src_layer1 = shp1.GetLayer()
    src_layer2 = shp2.GetLayer()
    srs1 = src_layer1.GetSpatialRef()
    srs2 = src_layer2.GetSpatialRef()
    if srs1.GetAttrValue('AUTHORITY',1) != srs2.GetAttrValue('AUTHORITY',1):
        print("空间参考不一致!")
        return
    
    target_ds = ogr.GetDriverByName(DRIVER_SHAPE).CreateDataSource(output_file)
    target_layer = target_ds.CreateLayer(output_layer_name, srs1, geom_type=ogr.wkbPolygon, options=["ENCODING=UTF-8"]) # 设置编码为UTF-8,防止中文出现乱码

    for feat1 in src_layer1:
        geom1 = feat1.GetGeometryRef()
        for feat2 in src_layer2:
            geom2 = feat2.GetGeometryRef()
            if not geom1.Intersects(geom2):
                continue
            intersect = geom1.Intersection(geom2)
            feature = ogr.Feature(target_layer.GetLayerDefn())
            feature.SetGeometry(intersect)
            target_layer.CreateFeature(feature)

    # 清理引用
    target_layer = None
    ds = None


# 调用
intersect(
    input_file1="data/input/vec1.shp",
    input_file2="data/input/vec2.shp",
    output_file="data/output/intersect.shp",
    output_layer_name="intersect",
)

矢量擦除

def erase(erased_file, eraser_file, output_file, output_layer_name=""):
    """
    矢量擦除

    Params:
        erased_file: 被擦除的矢量文件路径
        eraser_file: 擦除矢量文件路径
        output_file: 输出文件路径
        output_layer_name: 输出图层名
    """
    driver = ogr.GetDriverByName(DRIVER_SHAPE)
    shp1 = driver.Open(erased_file, gdalconst.GA_ReadOnly)
    shp2 = driver.Open(eraser_file, gdalconst.GA_ReadOnly)
    src_layer1 = shp1.GetLayer()
    src_layer2 = shp2.GetLayer()
    srs1 = src_layer1.GetSpatialRef()
    srs2 = src_layer2.GetSpatialRef()
    if srs1.GetAttrValue('AUTHORITY',1) != srs2.GetAttrValue('AUTHORITY',1):
        print("空间参考不一致!")
        return
    target_ds = ogr.GetDriverByName(DRIVER_SHAPE).CreateDataSource(output_file)
    target_layer = target_ds.CreateLayer(output_layer_name, srs1, geom_type=ogr.wkbPolygon, options=["ENCODING=UTF-8"])
    ds = src_layer1.Erase(src_layer2, target_layer)
    ds = None


# 调用
erase(
    erased_file="data/input/vec1.shp",
    eraser_file="data/input/vec2.shp", 
    output_file="data/output/erase.shp",
)

缓冲区分析(以点为例)

def point_buffer(point, radius, output_file, layer_name):
    """
    对点集建立缓冲区

    Params:
        point: 输入点坐标
        range: 缓冲区半径
        output_file: 缓冲区输出文件路径
        layer_name: 输出图层名
    """
    # 创建图层
    srs = osr.SpatialReference()
    srs.ImportFromEPSG(EPSG_WGS84) # 选择坐标系,注意地理坐标系和投影坐标系的radius单位不同,前者为度,后者为米
    ds = ogr.GetDriverByName(DRIVER_SHAPE).CreateDataSource(output_file)
    target_layer = ds.CreateLayer(layer_name, srs, geom_type=ogr.wkbPolygon, options=["ENCODING=UTF-8"])
    # 创建geometry
    wkt = "POINT (%f %f)" % (point[0], point[1])
    geom = ogr.CreateGeometryFromWkt(wkt)
    poly = geom.Buffer(radius)
    # 创建feature
    feature = ogr.Feature(target_layer.GetLayerDefn())
    feature.SetGeometry(poly)
    target_layer.CreateFeature(feature)

    target_layer = None
    ds = None


# 调用
point_buffer(
    point=(121.531921, 25.013540),
    range=0.1,
    output_file="data/output/buffer.shp",
    layer_name="buffer",
)

视域分析

def generate_viewshed(input_raster, output_file, location, height):
    """
    视域分析

    Params:
        input_raster: 输入栅格文件
        output_file: 视域分析输出文件
        location: 观察者所在坐标
        height: 观察者所在的高程

    Ref: https://gdal.org/api/gdal_alg.html#_CPPv420GDALViewshedGenerate15GDALRasterBandHPKcPKc12CSLConstListddddddddd16GDALViewshedModed16GDALProgressFuncPv22GDALViewshedOutputType12CSLConstList
    """
    raster = gdal.Open(input_raster)
    band = raster.GetRasterBand(1)
    gdal.ViewshedGenerate(
        srcBand=band, 
        driverName=DRIVER_GTIFF, 
        targetRasterName=output_file,
        creationOptions=[], 
        observerX=location[0], 
        observerY=location[1], 
        observerHeight=height, 
        targetHeight=0, 
        visibleVal=255,
        invisibleVal=0, 
        outOfRangeVal=0, 
        noDataVal=0, 
        dfCurvCoeff=0.85714, 
        mode=2,
        maxDistance=0,
    )


# 调用
generate_viewshed(
    input_raster="data/input/raster.tif",
    output_file="data/output/viewshed.tif",
    location=(120.568740, 32.013540),
    height=2,
)

总结

相比C++,Python GDAL还是很容易上手和应用的,官方提供了比较详细的API使用说明,对于一些参数较多的函数,可直接查相应的C++函数签名,这样对参数的理解会更深。

参考

posted @   g2012  阅读(1102)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· 清华大学推出第四讲使用 DeepSeek + DeepResearch 让科研像聊天一样简单!
· 实操Deepseek接入个人知识库
· 易语言 —— 开山篇
· 一个费力不讨好的项目,让我损失了近一半的绩效!
· 【全网最全教程】使用最强DeepSeekR1+联网的火山引擎,没有生成长度限制,DeepSeek本体
点击右上角即可分享
微信分享提示