TensorFlow中的语义分割套件

TensorFlow中的语义分割套件

描述

该存储库用作语义细分套件。目标是轻松实现,训练和测试新的语义细分模型!完成以下内容:

  • 训练和测试方式
  • 资料扩充
  • 几种最先进的模型。轻松随插即
  • 能够使用任何数据集
  • 评估包括准确性,召回率,f1得分,平均准确性,每类准确性和平均IoU
  • 绘制损失函数和准确性

欢迎提出任何改进此存储库的建议,包括希望看到的任何新细分模型。

也可以签出Transfer Learning Suite

引用

如果发现此存储库有用,请考虑使用回购链接将其引用:)

前端

当前提供以下特征提取模型:

模型

当前提供以下细分模型:

  • 基于SegNet的编解码器。该网络使用VGG样式的编码器/解码器,其中解码器中的升采样是使用转置卷积完成的。
  • 具有基于SegNet的跳过连接的编解码器。该网络使用VGG样式的编码器/解码器,其中解码器中的升采样是使用转置卷积完成的。另外,采用从编码器到解码器的附加跳过连接。
  • 用于语义分割的移动UNet。将MobileNets深度可分离卷积与UNet的思想相结合,以建立一个高速,低参数的语义分割模型。
  • 金字塔场景解析网络。在本文中,通过金字塔池模块以及所提出的金字塔场景解析网络(PSPNet)来应用基于不同区域的上下文聚合的全局上下文信息的功能。请注意,原始的PSPNet使用具有扩展卷积的ResNet,但是一个是此存储库仅具有常规ResNet
  • The One Hundred Layers Tiramisu:用于语义分割的完全卷积DenseNet。使用下采样-上采样样式的编码器-解码器网络。每个阶段(即池化层之间的阶段)都使用密集块。此外,还连接了从编码器到解码器的跳过连接。在代码中,这是FC-DenseNet模型。
  • 对Atrous卷积语义图像分割的再思考。这是DeepLabV3网络。使用Atrous空间金字塔池通过使用多个atrous速率来捕获多尺度上下文。这产生了一个大的接受场。
  • RefineNet:用于高分辨率语义分割的多路径优化网络。一个多路径优化网络,该网络显式地利用降采样过程中的所有可用信息,以实现使用远程残差连接的高分辨率预测。这样,可以使用早期卷积中的细粒度特征直接完善捕获更深层的高级语义特征。
  • 用于街道场景语义分割的全分辨率残差网络。通过使用网络中的两个处理流,将多尺度上下文与像素级精度结合在一起。残留流以全图像分辨率传输信息,从而可以精确地遵守分割边界。池化流经过一系列池化操作以获得可靠的功能以进行识别。两个流使用残差以全图像分辨率耦合。在代码中,这是FRRN模型。
  • 大内核问题-通过全球卷积网络改进语义分割。提出了一个全球卷积网络来解决语义分割的分类和本地化问题。使用较大的可分离内核扩展接收场,并使用边界细化块进一步提高边界附近的定位性能。
  • AdapNet:不利环境条件下的自适应语义分段通过使用具有无规则卷积的多尺度策略执行较低分辨率的处理,来修改ResNet50体系结构。这是使用双线性放大而不是转置卷积的稍微修改的版本,因为发现给出了更好的结果。
  • ICNet用于高分辨率图像的实时语义分割。提出了一种基于压缩PSPNet的图像级联网络(ICNet),该网络在适当的标签指导下合并了多分辨率分支,以应对这一挑战。大多数处理都是在低分辨率下高速完成的,多尺度辅助损耗有助于获得准确的模型。请注意,对于此模型,已经实现了网络,但尚未集成其训练
  • 带有可分解卷积的编解码器用于语义图像分割。这是DeepLabV3 +网络,在常规DeepLabV3模型的顶部添加了解码器模块。
  • DenseASPP在街道场景中的语义分割。使用膨胀卷积结合密集连接的多种不同尺度
  • 用于单通语义分割的密集解码器快捷连接。在细分模型的解码器阶段使用密集连接的密集解码器快捷连接。注意:由于ResNeXt模块的构造,该网络需要花费一些额外的时间来加载
  • BiSeNet:用于实时语义分割的双边分割网络。BiSeNet使用步幅较小的空间路径来保存空间信息并生成高分辨率特征,同时使用具有快速下采样策略的并行上下文路径来获得足够的接收场。
  • 或者自己制作并即插即用!

文件和目录

  • train.py对选择的数据集进行训练。默认为CamVid
  • test.py在选择的数据集上进行测试。默认为CamVid
  • predict.py使用新近训练的模型对单个图像进行预测
  • helper.py快速助手功能,用于数据准备和可视化
  • utils.py用于打印,调试,测试和评估的实用程序
  • models包含所有模型文件的文件夹。使用来构建模型,或使用预先构建的模型
  • CamVid用于语义分割的CamVid数据集作为测试平台。这是32类版本
  • checkpoints训练期间每个时期的检查点文件
  • Test测试结果包括图像,每类准确性,准确性,召回率和f1分数

安装

该项目具有以下依赖性:

  • Numpy的 sudo pip install numpy
  • OpenCV Python sudo apt-get install python-opencv
  • TensorFlow sudo pip install --upgrade tensorflow-gpu

用法

唯一要做的就是按照以下结构设置文件夹:

├── "dataset_name"                  

|   ├── train

|   ├── train_labels

|   ├── val

|   ├── val_labels

|   ├── test

|   ├── test_labels

将一个文本文件放在数据集目录下,称为“ class_dict.csv”,其中包含类列表以及R,G,B颜色标签,以可视化分割结果。这种字典通常随数据集一起提供。这是CamVid数据集的示例:

name,r,g,b

Animal,64,128,64

Archway,192,0,128

Bicyclist,0,128, 192

Bridge,0, 128, 64

Building,128, 0, 0

Car,64, 0, 128

CartLuggagePram,64, 0, 192

Child,192, 128, 64

Column_Pole,192, 192, 128

Fence,64, 64, 128

LaneMkgsDriv,128, 0, 192

LaneMkgsNonDriv,192, 0, 64

Misc_Text,128, 128, 64

MotorcycleScooter,192, 0, 192

OtherMoving,128, 64, 64

ParkingBlock,64, 192, 128

Pedestrian,64, 64, 0

Road,128, 64, 128

RoadShoulder,128, 128, 192

Sidewalk,0, 0, 192

SignSymbol,192, 128, 128

Sky,128, 128, 128

SUVPickupTruck,64, 128,192

TrafficCone,0, 0, 64

TrafficLight,0, 64, 64

Train,192, 64, 128

Tree,128, 128, 0

Truck_Bus,192, 128, 192

Tunnel,64, 0, 64

VegetationMisc,192, 192, 0

Void,0, 0, 0

Wall,64, 192, 0

注意:如果使用的是依赖于预训练的ResNet的任何网络,则需要使用提供的脚本下载预训练的权重。当前是:PSPNet,RefineNet,DeepLabV3,DeepLabV3 +,GCN。

然后,可以简单地运行train.py!查看可选的命令行参数:

usage: train.py [-h] [--num_epochs NUM_EPOCHS]

                [--checkpoint_step CHECKPOINT_STEP]

                [--validation_step VALIDATION_STEP] [--image IMAGE]

                [--continue_training CONTINUE_TRAINING] [--dataset DATASET]

                [--crop_height CROP_HEIGHT] [--crop_width CROP_WIDTH]

                [--batch_size BATCH_SIZE] [--num_val_images NUM_VAL_IMAGES]

                [--h_flip H_FLIP] [--v_flip V_FLIP] [--brightness BRIGHTNESS]

                [--rotation ROTATION] [--model MODEL] [--frontend FRONTEND]

optional arguments:

  -h, --help            show this help message and exit

  --num_epochs NUM_EPOCHS

                        Number of epochs to train for

  --checkpoint_step CHECKPOINT_STEP

                        How often to save checkpoints (epochs)

  --validation_step VALIDATION_STEP

                        How often to perform validation (epochs)

  --image IMAGE         The image you want to predict on. Only valid in

                        "predict" mode.

  --continue_training CONTINUE_TRAINING

                        Whether to continue training from a checkpoint

  --dataset DATASET     Dataset you are using.

  --crop_height CROP_HEIGHT

                        Height of cropped input image to network

  --crop_width CROP_WIDTH

                        Width of cropped input image to network

  --batch_size BATCH_SIZE

                        Number of images in each batch

  --num_val_images NUM_VAL_IMAGES

                        The number of images to used for validations

  --h_flip H_FLIP       Whether to randomly flip the image horizontally for

                        data augmentation

  --v_flip V_FLIP       Whether to randomly flip the image vertically for data

                        augmentation

  --brightness BRIGHTNESS

                        Whether to randomly change the image brightness for

                        data augmentation. Specifies the max bightness change

                        as a factor between 0.0 and 1.0. For example, 0.1

                        represents a max brightness change of 10% (+-).

  --rotation ROTATION   Whether to randomly rotate the image for data

                        augmentation. Specifies the max rotation angle in

                        degrees.

  --model MODEL         The model you are using. See model_builder.py for

                        supported models

  --frontend FRONTEND   The frontend you are using. See frontend_builder.py

                        for supported models

结果

这些是带有11个类的CamVid数据集的一些示例结果(先前的研究版本)。

在训练中,使用的批处理大小为1,图像大小为352x480。以下结果适用于训练了300个纪元的FC-DenseNet103模型。使用RMSProp的学习速率为0.001,衰减率为0.995。没有像本文中那样使用任何数据增强。也没有使用任何类平衡。这些只是一些快速而肮脏的示例结果。

请注意,检查点文件未上传到此存储库,因为对于GitHub太大(大于100 MB

原始精度

准确性

天空

93.0

94.1

建造

83.0

81.2

37.8

38.3

94.5

97.5

路面

82.2

87.9

77.3

75.5

标志符号

43.9

49.7

围栏

37.1

69.0

汽车

77.3

87.0

行人

59.6

60.3

单车

50.5

75.3

未贴标签

不适用

40.9

全球

91.5

89.6

 

损失与时代

 累积 vs时代

   
posted @ 2020-07-14 15:48  吴建明wujianming  阅读(784)  评论(0编辑  收藏  举报