YOLOv5 Transfer Learning 踩坑

预训练模型有5个class,接下来跑的有8个
直接跑报错

size mismatch for model.24.m.0.weight: copying a param with shape torch.Size([30, 192, 1, 1]) fro
m checkpoint, the shape in current model is torch.Size([39, 192, 1, 1]).
size mismatch for model.24.m.0.bias: copying a param with shape torch.Size([30]) from checkpoint,
the shape in current model is torch.Size([39]).
size mismatch for model.24.m.1.weight: copying a param with shape torch.Size([30, 384, 1, 1]) fro
m checkpoint, the shape in current model is torch.Size([39, 384, 1, 1]).
size mismatch for model.24.m.1.bias: copying a param with shape torch.Size([30]) from checkpoint,
the shape in current model is torch.Size([39])
size mismatch for model.24.m.2.weight: copying a param with shape torch.Size([30, 768, 1, 1]) fro
m checkpoint, the shape in current model is torch.Size([39, 768, 1, 1]).
size mismatch for model.24.m.2.bias: copying a param with shape torch. Size([30]) from checkpoint,
the shape in current model is torch. Size([39]).

解决方法,手动修改这几层的shape,不加载,填入rand

# EMA
ema = ModelEMA(model) if RANK in {-1, 0} else None
# Resume
start_epoch, best_fitness = 0, 0.0
if pretrained:
ckpt['new_ema'] = []
for emaa in ckpt['ema'].state_dict():
ckpt['new_ema'].append(emaa)
new_weights = []
for k,v in ckpt['ema'].float().state_dict().items():
if k.startswith('model.24.m.0.weight'):
new_v = torch.rand([39, 192, 1, 1])
new_weights.append(new_v)
elif k.startswith('model.24.m.1.weight'):
new_v = torch.rand([39, 384, 1, 1])
new_weights.append(new_v)
elif k.startswith('model.24.m.2.weight'):
new_v = torch.rand([39, 768, 1, 1])
new_weights.append(new_v)
elif k.startswith('model.24.m'):
new_v = torch.rand([39])
new_weights.append(new_v)
else:
new_weights.append(v)
ckpt['my_weight'] = dict(zip(ckpt['new_ema'], new_weights))
if ema and ckpt.get('ema'):
ema.ema.load_state_dict(ckpt['my_weight'])
ema.updates = ckpt['updates']
posted @   GhostCai  阅读(1549)  评论(5编辑  收藏  举报
编辑推荐:
· 从 HTTP 原因短语缺失研究 HTTP/2 和 HTTP/3 的设计差异
· AI与.NET技术实操系列:向量存储与相似性搜索在 .NET 中的实现
· 基于Microsoft.Extensions.AI核心库实现RAG应用
· Linux系列:如何用heaptrack跟踪.NET程序的非托管内存泄露
· 开发者必知的日志记录最佳实践
阅读排行:
· winform 绘制太阳,地球,月球 运作规律
· TypeScript + Deepseek 打造卜卦网站:技术与玄学的结合
· Manus的开源复刻OpenManus初探
· 写一个简单的SQL生成工具
· AI 智能体引爆开源社区「GitHub 热点速览」
点击右上角即可分享
微信分享提示