Debug Log - ModuleNotFoundError: No module named 'timm.models.layers.patch_embed'

运行代码:

import timm
import torch

model = timm.create_model(
    'deit_small_patch16_224',
    pretrained=True, 
    num_classes=6,
    pretrained_cfg_overlay = dict(file='/home/lingdu/zyt/works/pretrained_models/deit_small_patch16_224-cd65a155.pth'))

torch.save(model, 'timm_models/deit_small.pth')

目的是想通过本地的权重文件,通过timm库来创建一个deit_small_patch16_224模型。

报错信息:

File "/home/lingdu/zyt/works/PD_6/get_model.py", line 10, in <module>
    model = timm.create_model(
  File "/home/lingdu/.conda/envs/codiff/lib/python3.8/site-packages/timm/models/_factory.py", line 117, in create_model
    model = create_fn(
  File "/home/lingdu/.conda/envs/codiff/lib/python3.8/site-packages/timm/models/deit.py", line 258, in deit_small_patch16_224
    model = _create_deit('deit_small_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
  File "/home/lingdu/.conda/envs/codiff/lib/python3.8/site-packages/timm/models/deit.py", line 123, in _create_deit
    model = build_model_with_cfg(
  File "/home/lingdu/.conda/envs/codiff/lib/python3.8/site-packages/timm/models/_builder.py", line 418, in build_model_with_cfg
    load_pretrained(
  File "/home/lingdu/.conda/envs/codiff/lib/python3.8/site-packages/timm/models/_builder.py", line 168, in load_pretrained
    state_dict = load_state_dict(pretrained_loc)
  File "/home/lingdu/.conda/envs/codiff/lib/python3.8/site-packages/timm/models/_helpers.py", line 54, in load_state_dict
    checkpoint = torch.load(checkpoint_path, map_location=device)
  File "/home/lingdu/.conda/envs/codiff/lib/python3.8/site-packages/torch/serialization.py", line 1025, in load
    return _load(opened_zipfile,
  File "/home/lingdu/.conda/envs/codiff/lib/python3.8/site-packages/torch/serialization.py", line 1446, in _load
    result = unpickler.load()
  File "/home/lingdu/.conda/envs/codiff/lib/python3.8/site-packages/torch/serialization.py", line 1439, in find_class
    return super().find_class(mod_name, name)
ModuleNotFoundError: No module named 'timm.models.layers.patch_embed'

Deit作为一个Transformer系列的模型,毫无疑问会使用到patch_embed这个模块,这里的bug主要是由于路径错误
查看github中timm的源码可以看到,在新版本的timm中,patch_embed位于timm.layers.patch_embed路径下。
这里报错是因为环境中的timm是旧版本的,但下载的模型与最新版本的timm适配,造成了路径的不匹配。

image image

解决方法:
卸载旧版本timm,
pip uninstall timm

安装最新版。
pip install timm==1.0.8

最新版本号我是在pypi上查看的。
https://pypi.org/project/timm/
image

posted on 2024-08-11 10:57  零度的python武器库  阅读(34)  评论(0编辑  收藏  举报