用Detr训练自定义数据

前面记录了Detr及其改进Deformable Detr这一篇记录一下用Detr训练自己的数据集。先看下Detr附录中给出的大体源码,整体非常清晰。

接下来记录大体实现过程

一、数据准备

借助labelme对数据进行标注

然后将标注数据转换成COCO格式,得到以下几个文件

其中JPEGImages存放所有图片,Visualization存放可视化结果,annotations.json保存所有图片的标注信息

二、模型训练

2.1 编写DataLoader

在detr/datasets目录下创建一个custom_data.py文件用于处理自己的数据。创建一个类,主要包含__getitem____len__方法。

在新建一个build方法用于detr构建数据。

再到当前目录下的__init__.py文件中添加新的数据类型

def build_dataset(image_set, args):
    if args.dataset_file == 'coco':
        return build_coco(image_set, args)
    if args.dataset_file == 'coco_panoptic':
        # to avoid making panopticapi required for coco
        from .coco_panoptic import build as build_coco_panoptic
        return build_coco_panoptic(image_set, args)
    if args.dataset_file == 'tooth':
        from .custom_data import build as build_tooth  
        return build_tooth(image_set, args)

2.2 训练

修改配置参数
mian.py中新增数据路径参数

修改类别数量,在models/detr.py中修改类别数,类别数要设置为实际类型+1,加1是添加背景类。

num_classes = 2 if args.dataset_file != 'coco' else 91 

加载预训练模型

if args.resume:
    if args.resume.startswith('https'):
        checkpoint = torch.hub.load_state_dict_from_url(args.resume, map_location='cpu', check_hash=True)
    else:
        checkpoint = torch.load(args.resume, map_location='cpu')
        # ==============================================================
        # 这一段是修改了的,去除多余的参数,并将load_state_dict设置为strict=False,这样它便会只加载模型结构相同部分的预训练参数
        del checkpoint["model"]["class_embed.weight"]
        del checkpoint["model"]["class_embed.bias"]
        del checkpoint["model"]["query_embed.weight"]
    model_without_ddp.load_state_dict(checkpoint['model'], strict=False)

开始训练

python -m torch.distributed.launch --nproc_per_node=1 --use_env main.py --tooth_path /home/jinhai_zhou/data/2D_seg/ --dataset_file tooth --output_dir ./output/path/box_model --resume "./models/detr-r50-e632da11.pth" 

我这里检测训练了500次左右开始收敛,分割训练了大概200多次开始接近收敛

如果训练分割模型,建议分两步,先训练检测模型,然后再训练分割头。

三、测试

新增一个predict.py文件,用于测试
里面主要包含检测和画图两部分内容

  • 检测
def detect(im, model, transform, threshold=0.7):
    # mean-std normalize the input image (batch-size: 1)
    img = transform(im).unsqueeze(0)
    print("image.shape:", img.shape)
    # demo model only support by default images with aspect ratio between 0.5 and 2
    # if you want to use images with an aspect ratio outside this range
    # rescale your image so that the maximum size is at most 1333 for best results
    # assert img.shape[-2] <= 1600 and img.shape[-1] <= 1600, 'demo model only supports images up to 1600 pixels on each side'

    # propagate through the model
    outputs = model(img)

    # keep only predictions with 0.7+ confidence
    probas = outputs['pred_logits'].softmax(-1)[0, :, :-1]
    keep = probas.max(-1).values > threshold

    # convert boxes from [0; 1] to image scales
    bboxes_scaled = rescale_bboxes(outputs['pred_boxes'][0, keep], im.size)
    return probas[keep], bboxes_scaled
  • 绘制结果
def plot_results(pil_img, prob, boxes, output):
    CLASSES = [
         'N/A', 'teeth'
    ]

    # colors for visualization
    COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125],
            [0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933]]
    plt.figure(figsize=(16,10))
    plt.imshow(pil_img)
    ax = plt.gca()
    for p, (xmin, ymin, xmax, ymax), c in zip(prob, boxes.tolist(), COLORS * 100):
        ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin,
                                   fill=False, color=c, linewidth=3))
        cl = p.argmax()
        text = f'{CLASSES[cl]}: {p[cl]:0.2f}'
        ax.text(xmin, ymin, text, fontsize=15,
                bbox=dict(facecolor='yellow', alpha=0.5))
    plt.axis('off')
    plt.savefig(output)
    plt.close()
    # plt.show()

测试

    device = torch.device(args.device)

    # fix the seed for reproducibility
    seed = args.seed + utils.get_rank()
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

    model, criterion, postprocessors = build_model(args)
    model.to(device)

    n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print('number of params:', n_parameters)

    output_dir = Path(args.output_dir)
    if args.resume:
        checkpoint = torch.load(args.resume, map_location='cpu') 
        model.load_state_dict(checkpoint['model'], strict=args.strict)
        print("load model {} is success!".format(args.resume))
    else:
        print("Don't load model!")
        return
    
    # standard PyTorch mean-std input image normalization
    transform = T.Compose([
        T.Resize(800),
        T.ToTensor(),
        T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    if args.img_path is not None:
        assert Path(args.img_path).is_file(), "{} not an image path".format(args.img_path)
        im = Image.open(img_path)
        scores, boxes = detect(im, model, transform=transform)
        print("scores: ", scores)
        print("boxes: ", boxes)

    if args.img_dirs is not None:
        assert Path(args.img_dirs).is_dir(), "{} not a dir path".format(args.img_dirs)
        img_paths = Path(args.img_dirs).glob("*.jpg")
        # print("loads {} images".format(len(list(img_paths))))
        for idx, img_path in enumerate(img_paths):
            print(img_path)
            im = Image.open(img_path)
            scores, boxes = detect(im, model, transform=transform)
            print(" scores: ", scores)
            print("boxes: ", boxes)
            out_path = Path(output_dir) / img_path.name
            print("out_path: ", out_path)
            plot_results(im, scores, boxes, out_path) 

posted @   半夜打老虎  阅读(129)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· 阿里最新开源QwQ-32B,效果媲美deepseek-r1满血版,部署成本又又又降低了!
· AI编程工具终极对决:字节Trae VS Cursor,谁才是开发者新宠?
· 开源Multi-agent AI智能体框架aevatar.ai,欢迎大家贡献代码
· Manus重磅发布:全球首款通用AI代理技术深度解析与实战指南
· 被坑几百块钱后,我竟然真的恢复了删除的微信聊天记录!
点击右上角即可分享
微信分享提示