DETR训练自己的数据集---使用方法

参考:1、[GitHub](https://github.com/DataXujing/detr_transformer)

           2、[Bilibili视频](https://www.bilibili.com/video/BV1GC4y1h77h)


1、拷贝代码
```
git clone https://github.com/facebookresearch/detr.git
```

2、创建新的虚拟环境(推荐)
```
conda create -n detr python=3.7
conda activate detr
```

3、安装依赖库
安装PyTorch 1.5+ and torchvision 0.6+,安装scipy
```
conda install -c pytorch pytorch torchvision
conda install cython scipy
```

安装pycocotools (for evaluation on COCO):
```
pip install -U 'git+https://github.com/cocodataset/cocoapi.git#subdirectory=PythonAPI'
```

4、数据准备(格式、路径如下所示)
```
├─annotations # 标注的json文件,coco类型的标注
├─instances_train.json
├─instances_val.json
├─train # 训练图像的存放地址
├─xxx.jpg
├─val # 验证图像的存放地址
└─xxxx.jpg
```

5、下载预训练模型,并修改类别数
修改配置文件change.py,将num_class改为“类别+1”
```
import torch

pretrained_weights = torch.load("./detr-r50-e632da11.pth")

num_class = 3 + 1
pretrained_weights["model"]["class_embed.weight"].resize_(num_class+1,256)
pretrained_weights["model"]["class_embed.bias"].resize_(num_class+1)

torch.save(pretrained_weights,'detr_r50_%d.pth'%num_class)
```

```
python change.py # 在项目文件夹下生成detr_r50_{class_num}.pth
```

6、修改./models/detr.py中的build方法
```
def build(args):
num_classes = 3 if args.dataset_file != 'coco' else 3+1 #<--类别数 + 1
if args.dataset_file == "coco_panoptic": # 全景分割
num_classes = 3+1 # <-------------
device = torch.device(args.device)
```

7、按照自己需要修改main.py文件

8、训练
```
python main.py --dataset_file "coco" --coco_path "/myData/coco" --epoch 500 --lr=1e-4 --batch_size=8 --num_workers=4 --output_dir="outputs" --resume="detr_r50_4.pth"

python main.py --dataset_file "coco" --coco_path "/data1/hzy/COCO2007" --epoch 50 --batch_size=4 --num_workers=4 --output_dir="outputs_1" --resume="detr_r50_55.pth"
```

9、测试
```
inference_img.py
```

posted @ 2021-05-17 22:54  零纪年  阅读(6572)  评论(0编辑  收藏  举报