1.1

OpenCV加载Pytorch模型出现Unsupported Lua type 解决方法

OpenCV加载Pytorch模型出现Unsupported Lua type 解决方法

原因

Torch有两个版本,一个就叫Torch一个专门给Python用的Pytorch,它们训练完之后保存下来的模型是不一样的.
说到这问题就很清楚了.OpenCV的ReadNetFromTorch支持的是前者...

解决方法

那么有没有解决办法呢,答案是有的.
PyTorch支持把模型保存为ONNX格式.而这个格式在opencv是支持的.
操作如下:

import torch
import torch.onnx
from torch.autograd import Variable

# ~~~~~~~~~~~~~~~~初始化与训练模型过程~~~~~~~~~~~~~

# 这是普通的pytorch模型保存方式:
torch.save(net.state_dict(), "torch.pt")

# 这是保存为ONNX的方法:
# 由于PyTorch的模型,是动态调整大小的,这里需要初始化一个指定格式的数据,用来调整模型大小
# 就是和你训练模型的时候用的数据一样的格式就行
dummy_input = Variable(torch.randn(1, 1, 28, 28)).to(device)
# 保存模型
torch.onnx.export(net, dummy_input, "torch.onnx")

注意,这里还有个坑!

虽然模型保存成了ONNX格式,但是OpenCV的ReadTensorFromONNX并不能加载! 需要用ReadNet方法加载!

posted @ 2020-05-08 17:45  asml  阅读(2571)  评论(0编辑  收藏  举报
@.@