深度学习武器库-timm-非常好用的pytorch CV模型库 - 常用模型操作

简要介绍

timm库,全称pytorch-image-models,是最前沿的PyTorch图像模型、预训练权重和实用脚本的开源集合库,其中的模型可用于训练、推理和验证。

github源码链接
https://github.com/huggingface/pytorch-image-models

文档教程
文档:https://huggingface.co/docs/hub/timm
上手教程:https://towardsdatascience.com/getting-started-with-pytorch-image-models-timm-a-practitioners-guide-4e77b4bf9055


优点

1、方便使用。在python环境中安装timm库,即可用几行代码创建网络模型,并可选择导入在imagenet等数据集上得到的预训练权重;无需再去扒每个模型的源代码,这对于跑模型对比实验是非常方便的,可以节省大量的时间;

2、灵活性高。导入模型的原始做法是,直接用.pth等权重文件导入,但这通常受到保存模型方法的限制,可能出现权重键名称不匹配、网络中间张量操作丢失(只能导入模型权重 却无法导入网络层之间的中间张量操作)等问题;而timm中的模型是对于整个模型的封装,在创建模型后,可以对相应网络层进行改动调整,非常灵活;

3、模型收录全面且前沿。在CV这个模型日新月异的领域,timm及时更新收录了最新的模型,比如,2024.8.8 添加了ECCV2024上的新模型RDNet(工作链接:https://github.com/naver-ai/rdnet)。
image

缺点

目前使用中感到一点不方便的是,使用函数请求下载这个库中的权重文件时,链接地址大部分是huggingface上的,而huggingface得用US的节点才能有较好的网速 。。。 我的PC可以访问,但服务器访问不了。。。 但这个缺点可以通过一些操作进行规避。

常用模型操作

  • 查看目前收录模型
    使用代码:
    timm.list_models('*')
    运行效果:
    image
    可以看到目前收录共有946个模型,查看目前已收录的模型,从这个列表中确定要导入的目标模型名称

    还可以通过正则表达式匹配目标模型名称,并通过指定pretrained=True筛选有预训练权重的模型
    如下匹配预训练resnet:timm.list_models('resnet*', pretrained=True)
    image

  • 创建模型
    使用代码(以resnet50为例):
    timm.create_model('resnet50', pretrained=True, in_chans=3, num_classes=6)
    这里的主要参数有四个:
    第一个是模型名称model_name
    第二个是是否预训练pretrained,
    第三个是输入图像的通道数in_chans
    第四个是分类类别数num_classes,指最后输出FC层的维度。

    注意:在创建模型这步中包含了从网络下载模型权重的操作,
    此时就会出现我在“缺点”部分讲到的问题:因为网络无法连接huggingface网站,而导致权重下载请求失败的情况 (多在服务器端出现)。
    image

    下面是解决方法
    先在能够连接huggingface网站的PC上,手动下载权重配置文件,使用代码如下:

    backbone_name = 'resnet50'
    
    pretrained_cfg = timm.create_model(backbone_name).default_cfg
    print(pretrained_cfg)
    

    运行后输出配置信息:

    {'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet50_a1_0-14fe96d1.pth', 'hf_hub_id': 'timm/resnet50.a1_in1k', 'architecture': 'resnet50', 'tag': 'a1_in1k', 'custom_load': False, 'input_size': (3, 224, 224), 'test_input_size': (3, 288, 288), 'fixed_input_size': False, 'interpolation': 'bicubic', 'crop_pct': 0.95, 'test_crop_pct': 1.0, 'crop_mode': 'center', 'mean': (0.485, 0.456, 0.406), 'std': (0.229, 0.224, 0.225), 'num_classes': 1000, 'pool_size': (7, 7), 'first_conv': 'conv1', 'classifier': 'fc', 'origin_url': 'https://github.com/huggingface/pytorch-image-models', 'paper_ids': 'arXiv:2110.00476'}
    

    其中,url对应了模型的下载请求地址,直接将这个url复制粘贴到浏览器中,手动下载权重文件。
    获得权重文件后,再使用timm.create_model方法,通过将pretrained_cfg_overlay参数指定为权重文件,来创建模型,这样就是本地创建了:

    backbone_name = 'resnet50'
    ckpt_path = './ckpt/resnet50_a1_0-14fe96d1.pth'
    
    model = timm.create_model(backbone_name,
                                       pretrained=True,
                                       pretrained_cfg_overlay=dict(file=ckpt_path))
    
  • 手动调整模型
    在cv中,最常见的操作是将某个网络的主干层,用于特征提取
    timm中有专门的方法可以实现这个目的:
    feature_ouput = model.forward_features(image)
    feature_ouput即为网络在最后的head层之前输出的特征向量。
    但这个操作还无法完美解决问题,因为通常认为提取特征就是排除网络的最后一层,但有的网络最后一层中不仅包括全连接(FC)层,在FC层之前还包含池化层。这就需要更灵活的操作,来调整、构建我们想要的网络。
    下面是我的代码(以手动添加池化层为例):

    feature_extract_model = nn.Sequential(*list(model.children())[:-1],
                                               nn.AdaptiveAvgPool2d(1))
    

    另外,也可以自己定制最后的head层(但个人感觉这个用途不多),例如:

    model.fc = nn.Sequential(
        nn.BatchNorm1d(num_in_features),
        nn.Linear(in_features=num_in_features, out_features=512, bias=False),
        nn.ReLU(),
        nn.BatchNorm1d(512),
        nn.Dropout(0.4),
        nn.Linear(in_features=512, out_features=10, bias=False))
    

本期总结

timm是一个非常全面且便捷的CV图像模型库,能够大大提升我们跑实验的效率。我们同样也能运用其中的部分模块类,用到自己的编码中,也可以在它的源码中学习模型的代码写法。本期笔记只是介绍了本人近期跑对比实验,使用后感觉最常用的一些方法和操作,timm还有很多功能和用法需要去探索,比如还有数据增强、数据集和优化器等等功能。
posted on 2024-08-11 21:35  零度的python武器库  阅读(350)  评论(0编辑  收藏  举报