基于Python+MXnet预训练模型的街景图像语义分割代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import os
import mxnet as mx
from mxnet import image, gpu
from PIL import Image
import gluoncv
from gluoncv.data.transforms.presets.segmentation import test_transform
from gluoncv.utils.viz import get_color_pallete,plot_image
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import numpy as np
import pandas as pd
import shutil
 
 
 
#确定处理平台
ctx = mx.gpu(0)#若不是gpu版本需要改为cpu(0)
 
 
#下载在cityscape上的pspnet预训练模型
model = gluoncv.model_zoo.get_model('psp_resnet101_citys', ctx=ctx, pretrained=True)#此处的psp_resnet101_citys可以修改为其他的预训练模型
 
 
#准备文件路径
for n in range(1,12854):#这里对每一个坐标文件夹进行遍历,对文件夹内的图像进行处理
    print('The No.{} coordinate is handling'.format(n))
    general_path='G:/points/from/{}'.format(n)#替换为自己的路径
    try:
        general_file=os.walk(general_path)
    except FileNotFoundError:
        print("No.{} coordinate is missing".format(n))
    img_names=[]
    lonlat_names=[]
    for root, dirs, files in general_file:
        img_names[:] = [f for f in files if f.endswith(".png")]
        lonlat_names[:] = [g for g in files if g.endswith(".txt")]
        for lonlat in lonlat_names:
            save_path='G:/points/target/{}'.format(n)#替换为自己的路径
            os.mkdir(save_path)
            lonlat_path=general_path + '/' + lonlat
#将坐标信息文件复制到目标文件夹中继续储存
            shutil.copy(lonlat_path, save_path)
 
 
#对图像进行处理
        for img_path_ in img_names:
                df = pd.DataFrame(columns=['id','lng','lat','heading','road','sidewalk','building','wall','fence',
                                           'pole','traffic light','traffic sign','vegetation','terrain','sky',
                                           'person','rider','car','truck','bus','train','motorcycle','bicycle'])
                 
                 
                #读取图片并分割,返回的pred后续存入表格
                img_num=img_path_[0]
                img_path='G:/points/from/{}/{}.png'.format(n,img_num)#替换为自己的路径
                img = image.imread(img_path)
                img = test_transform(img,ctx=ctx)
                output = model.predict(img)
                predict = mx.nd.squeeze(mx.nd.argmax(output, 1)).asnumpy()
                col_map = {0:'road', 1:'sidewalk', 2:'building', 3:'wall', 4:'fence', 5:'pole', 6:'traffic light',
                               7:'traffic sign', 8:'vegetation', 9:'terrain', 10:'sky', 11:'person', 12:'rider',
                               13:'car', 14:'truck', 15:'bus', 16:'train', 17:'motorcycle', 18:'bicycle'}
                pred = []
                for i in range(19):
                    pred.append((len(predict[predict==i])/(predict.shape[0]*predict.shape[1])))
                pred = pd.Series(pred).rename(col_map)
                 
                 
                #将结果存入表格
                data_i = pd.Series({'id':img_num,}).append(pred)
                df = pd.concat([df, pd.DataFrame(data_i).T], axis=0, join='outer', ignore_index=True)
                print('---------Segmentation Is Ok--------')
                df.to_csv(save_path + "/img_seg_csv{}.csv".format(img_num))
 
                         
                 
                #将分割结果可视化并储存(可选)
                mask = get_color_pallete(predict, 'citys')
                base = Image.open(img_path)
                plt.figure(figsize=(10,5))
                plt.imshow(base)
                plt.imshow(mask,alpha=0.65)
                plt.axis('off')
                plt.savefig(save_path + "/img_seg_jpg{}.png".format(img_num),dpi=200,bbox_inches='tight')
                plt.close()

  

posted @   Victooor_swd  阅读(149)  评论(0编辑  收藏  举报
(评论功能已被禁用)
相关博文:
阅读排行:
· 全程不用写代码,我用AI程序员写了一个飞机大战
· DeepSeek 开源周回顾「GitHub 热点速览」
· MongoDB 8.0这个新功能碉堡了,比商业数据库还牛
· 记一次.NET内存居高不下排查解决与启示
· 白话解读 Dapr 1.15:你的「微服务管家」又秀新绝活了
点击右上角即可分享
微信分享提示