Python|遥感影像语义分割:使用训练好的权重文件进行大范围预测
前言
在模型预测过程中,如果将较大的待分类遥感影像直接输入到网络模型中会造成内存溢出,故一般将待分类图像裁剪为一系列较小图像分别输入网络进行预测,然后将预测结果按照裁剪顺序拼接成一张最终结果图像。
原理
如果采用常规的规则格网裁剪然后预测拼接的话效果不好。因为每张图像块的边缘区域的上下文信息较少,所以预测结果精度较低,进而还会导致出现明显的拼接痕迹。采用忽略边缘预测,即有重叠地裁剪影像并在拼接时采取忽略边缘策略。如图1所示,实际裁剪图像预测的结果为A ,进行拼接的结果为 a,a占 A的区域百分比为r ,相邻裁剪图像的重叠比例为。这里借用知乎大佬的图来说明一下
代码实现
我们先把大图像裁剪成一系列与相邻图像块有特定重复区域的图像块,并把它们存在链表里,然后创建生成器,之后进行预测。最后对预测结果只取中间部分进行拼接。代码注释写得相对比较详细,直接看代码:
import math
import numpy as np
import torch.nn.functional as F
import torch
from osgeo import gdal
from unet import UNet
import torchvision
# 读取tif数据集
def readTif(fileName, xoff=0, yoff=0, data_width=0, data_height=0):
dataset = gdal.Open(fileName)
if dataset == None:
print(fileName + "文件无法打开")
# 栅格矩阵的列数
width = dataset.RasterXSize
# 栅格矩阵的行数
height = dataset.RasterYSize
# 获取数据
if (data_width == 0 and data_height == 0):
data_width = width
data_height = height
data = dataset.ReadAsArray(xoff, yoff, data_width, data_height)
return data
# 保存tif文件函数
def writeTiff(fileName, data, im_geotrans=(0, 0, 0, 0, 0, 0), im_proj=""):
if 'int8' in data.dtype.name:
datatype = gdal.GDT_Byte
elif 'int16' in data.dtype.name:
datatype = gdal.GDT_UInt16
else:
datatype = gdal.GDT_Float32
if len(data.shape) == 3:
im_bands, im_height, im_width = data.shape
elif len(data.shape) == 2:
data = np.array([data])
im_bands, im_height, im_width = data.shape
# 创建文件
driver = gdal.GetDriverByName("GTiff")
dataset = driver.Create(fileName, int(im_width), int(im_height), int(im_bands), datatype)
if (dataset != None):
dataset.SetGeoTransform(im_geotrans) # 写入仿射变换参数
dataset.SetProjection(im_proj) # 写入投影
for i in range(im_bands):
dataset.GetRasterBand(i + 1).WriteArray(data[i])
del dataset
# tif裁剪(tif像素数据,裁剪边长)
def TifCroppingArray(img, SideLength):
# 裁剪链表
TifArrayReturn = []
# 列上图像块数目
ColumnNum = int((img.shape[0] - SideLength * 2) / (256 - SideLength * 2))
# 行上图像块数目
RowNum = int((img.shape[1] - SideLength * 2) / (256 - SideLength * 2))
for i in range(ColumnNum):
TifArray = []
for j in range(RowNum):
cropped = img[i * (256 - SideLength * 2): i * (256 - SideLength * 2) + 256,
j * (256 - SideLength * 2): j * (256 - SideLength * 2) + 256]
TifArray.append(cropped)
TifArrayReturn.append(TifArray)
# 考虑到行列会有剩余的情况,向前裁剪一行和一列
# 向前裁剪最后一列
for i in range(ColumnNum):
cropped = img[i * (256 - SideLength * 2): i * (256 - SideLength * 2) + 256,
(img.shape[1] - 256): img.shape[1]]
TifArrayReturn[i].append(cropped)
# 向前裁剪最后一行
TifArray = []
for j in range(RowNum):
cropped = img[(img.shape[0] - 256): img.shape[0],
j * (256 - SideLength * 2): j * (256 - SideLength * 2) + 256]
TifArray.append(cropped)
# 向前裁剪右下角
cropped = img[(img.shape[0] - 256): img.shape[0],
(img.shape[1] - 256): img.shape[1]]
TifArray.append(cropped)
TifArrayReturn.append(TifArray)
# 列上的剩余数
ColumnOver = (img.shape[0] - SideLength * 2) % (256 - SideLength * 2) + SideLength
# 行上的剩余数
RowOver = (img.shape[1] - SideLength * 2) % (256 - SideLength * 2) + SideLength
return TifArrayReturn, RowOver, ColumnOver
# 获得结果矩阵
def Result(shape, TifArray, npyfile, RepetitiveLength, RowOver, ColumnOver):
result = np.zeros(shape, np.uint8)
# j来标记行数
j = 0
for i, img in enumerate(npyfile):
# 最左侧一列特殊考虑,左边的边缘要拼接进去
if (i % len(TifArray[0]) == 0):
# 第一行的要再特殊考虑,上边的边缘要考虑进去
if (j == 0):
result[0: 256 - RepetitiveLength, 0: 256 - RepetitiveLength] = img[0: 256 - RepetitiveLength,
0: 256 - RepetitiveLength]
# 最后一行的要再特殊考虑,下边的边缘要考虑进去
elif (j == len(TifArray) - 1):
# 原来错误的
# result[shape[0] - ColumnOver : shape[0], 0 : 512 - RepetitiveLength] = img[0 : ColumnOver, 0 : 512 - RepetitiveLength]
# 后来修改的
result[shape[0] - ColumnOver - RepetitiveLength: shape[0], 0: 256 - RepetitiveLength] = img[
256 - ColumnOver - RepetitiveLength: 512,
0: 256 - RepetitiveLength]
else:
result[j * (256 - 2 * RepetitiveLength) + RepetitiveLength: (j + 1) * (
256 - 2 * RepetitiveLength) + RepetitiveLength,
0:256 - RepetitiveLength] = img[RepetitiveLength: 256 - RepetitiveLength, 0: 256 - RepetitiveLength]
# 最右侧一列特殊考虑,右边的边缘要拼接进去
elif (i % len(TifArray[0]) == len(TifArray[0]) - 1):
# 第一行的要再特殊考虑,上边的边缘要考虑进去
if (j == 0):
result[0: 256 - RepetitiveLength, shape[1] - RowOver: shape[1]] = img[0: 256 - RepetitiveLength,
256 - RowOver: 256]
# 最后一行的要再特殊考虑,下边的边缘要考虑进去
elif (j == len(TifArray) - 1):
result[shape[0] - ColumnOver: shape[0], shape[1] - RowOver: shape[1]] = img[256 - ColumnOver: 256,
256 - RowOver: 256]
else:
result[j * (256 - 2 * RepetitiveLength) + RepetitiveLength: (j + 1) * (
256 - 2 * RepetitiveLength) + RepetitiveLength,
shape[1] - RowOver: shape[1]] = img[RepetitiveLength: 256 - RepetitiveLength, 256 - RowOver: 256]
# 走完每一行的最右侧,行数+1
j = j + 1
# 不是最左侧也不是最右侧的情况
else:
# 第一行的要特殊考虑,上边的边缘要考虑进去
if (j == 0):
result[0: 256 - RepetitiveLength,
(i - j * len(TifArray[0])) * (256 - 2 * RepetitiveLength) + RepetitiveLength: (i - j * len(
TifArray[0]) + 1) * (256 - 2 * RepetitiveLength) + RepetitiveLength
] = img[0: 512 - RepetitiveLength, RepetitiveLength: 256 - RepetitiveLength]
# 最后一行的要特殊考虑,下边的边缘要考虑进去
if (j == len(TifArray) - 1):
result[shape[0] - ColumnOver: shape[0],
(i - j * len(TifArray[0])) * (256 - 2 * RepetitiveLength) + RepetitiveLength: (i - j * len(
TifArray[0]) + 1) * (256 - 2 * RepetitiveLength) + RepetitiveLength
] = img[256 - ColumnOver: 256, RepetitiveLength: 256 - RepetitiveLength]
else:
result[j * (256 - 2 * RepetitiveLength) + RepetitiveLength: (j + 1) * (
256 - 2 * RepetitiveLength) + RepetitiveLength,
(i - j * len(TifArray[0])) * (256 - 2 * RepetitiveLength) + RepetitiveLength: (i - j * len(
TifArray[0]) + 1) * (256 - 2 * RepetitiveLength) + RepetitiveLength,
] = img[RepetitiveLength: 256 - RepetitiveLength, RepetitiveLength: 256 - RepetitiveLength]
return result
area_perc = 0.5
TifPath = r"343.tif"
model_paths = [
r"MODEl.pth"
]
ResultPath = r"predict_result1.tif"
RepetitiveLength = int((1 - math.sqrt(area_perc)) * 256 / 2)
big_image = readTif(TifPath)
big_image = big_image.swapaxes(1, 0).swapaxes(1, 2)
#big_image = cv2.imread(TifPath, cv2.IMREAD_UNCHANGED)
TifArray, RowOver, ColumnOver = TifCroppingArray(big_image, RepetitiveLength)
# 改成自己的model即可
model = UNet(n_channels=3, n_classes=2, bilinear=False)
# 将模型加载到指定设备DEVICE上
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
predicts = []
for i in range(len(TifArray)):
for j in range(len(TifArray[0])):
image = TifArray[i][j]
img=torchvision.transforms.ToTensor()(image)
img = img.unsqueeze(0)
img = img.to(device=device, dtype=torch.float32)
for model_path in model_paths:
model.load_state_dict(torch.load(model_path))
model.eval()
with torch.no_grad():
output = model(img)
if model.n_classes > 1:
probs = F.softmax(output, dim=1)[0]
else:
probs = torch.sigmoid(output)[0]
tf = torchvision.transforms.Compose([
torchvision.transforms.ToPILImage(),
torchvision.transforms.Resize((image.shape[1], image.shape[0])),
torchvision.transforms.ToTensor()
])
mask = tf(probs.cpu()).squeeze()
if model.n_classes == 1:
mask =(mask > 0.5).numpy()
else:
mask=F.one_hot(mask.argmax(dim=0), model.n_classes).permute(2, 0, 1).numpy()
pred = mask[1]
predicts.append((pred))
# 保存结果predictspredicts
result_shape = (big_image.shape[0], big_image.shape[1])
result_data = Result(result_shape, TifArray, predicts, RepetitiveLength, RowOver, ColumnOver)
writeTiff(ResultPath, result_data)
参考文献
王振庆,周艺,王世新,王福涛,徐知宇.2021.IEU-Net高分辨率遥感影像房屋建筑物提取.遥感学报,25(11): 2245-2254 DOI: 10.11834/jrs.20210042. Wang Z Q,Zhou Y,Wang S X,Wang F T and Xu Z Y. 2021. House building extraction from high-resolution remote sensing images based on IEU-Net. National Remote Sensing Bulletin, 25(11):2245-2254 DOI: 10.11834/jrs.20210042.