Pytorch | Pytorch格式 .pt .pth .bin .onnx 详解
Pytorch是深度学习领域中非常流行的框架之一,支持的模型保存格式包括.pt和.pth .bin .onnx。这几种格式的文件都可以保存Pytorch训练出的模型,但是它们的区别是什么呢?
模型的保存与加载到底在做什么?
我们在使用pytorch构建模型并且训练完成后,下一步要做的就是把这个模型放到实际场景中应用,或者是分享给其他人学习、研究、使用。因此,我们开始思考一个问题,提供哪些模型信息,能够让对方能够完全复现我们的模型?
-
模型代码:
-
- (1)包含了我们如何定义模型的结构,包括模型有多少层/每层有多少神经元等等信息;
- (2)包含了我们如何定义的训练过程,包括epoch batch_size等参数;
- (3)包含了我们如何加载数据和使用;
- (4)包含了我们如何测试评估模型。
-
模型参数:提供了模型代码之后,对方确实能够复现模型,但是运行的参数需要重新训练才能得到,而没有办法在我们的模型参数基础上继续训练,因此对方还希望我们能够把模型的参数也保存下来给对方。
-
- (1)包含model.state_dict(),这是模型每一层可学习的节点的参数,比如weight/bias;
- (2)包含optimizer.state_dict(),这是模型的优化器中的参数;
- (3)包含我们其他参数信息,如epoch/batch_size/loss等。
-
数据集:
-
- (1)包含了我们训练模型使用的所有数据;
- (2)可以提示对方如何去准备同样格式的数据来训练模型。
-
使用文档:
-
- (1)根据使用文档的步骤,每个人都可以重现模型;
- (2)包含了模型的使用细节和我们相关参数的设置依据等信息。
可以看到,根据我们提供的模型代码/模型参数/数据集/使用文档,我们就可以有理由相信对方是有手就会了,那么目的就达到了。
现在我们反转一下思路,我们希望别人给我们提供模型的时候也能够提供这些信息,那么我们就可以拿捏住别人的模型了。
为什么要约定格式?
根据上一段的思路,我们知道模型重现的关键是模型结构/模型参数/数据集,那么我们提供或者希望别人提供这些信息,需要一个交流的规范,这样才不会1000个人给出1000种格式,而 .pt .pth .bin 以及 .onnx 就是约定的格式。
torch.save: Saves a serialized object to disk. This function uses Python’s pickle utility for serialization. Models, tensors, and dictionaries of all kinds of objects can be saved using this function.
不同的后缀只是用于提示我们文件可能包含的内容,但是具体的内容需要看模型提供者编写的README.md才知道。而在使用torch.load()方法加载模型信息的时候,并不是根据文件的后缀进行的读取,而是根据文件的实际内容自动识别的,因此对于torch.load()方法而言,不管你把后缀改成是什么,只要文件是对的都可以读取。
torch.load: Uses pickle’s unpickling facilities to deserialize pickled object files to memory. This function also facilitates the device to load the data into (see Saving & Loading Model Across Devices).
顺便提一下,“一切皆文件”的思维才是正确打开计算机世界的思维方式,文件后缀只作为提示作用,在Windows系统中也会用于提示系统默认如何打开或执行文件,除此之外,文件后缀不应该成为我们认识和了解文件阻碍。
格式汇总
下面是一个整理了 .pt
、.pth
、.bin
、ONNX 和 TorchScript 等 PyTorch 模型文件格式的表格:
格式 | 解释 | 适用场景 | 可对应的后缀 |
---|---|---|---|
.pt 或 .pth | PyTorch 的默认模型文件格式,用于保存和加载完整的 PyTorch 模型,包含模型的结构和参数等信息。 | 需要保存和加载完整的 PyTorch 模型的场景,例如在训练中保存最佳的模型或在部署中加载训练好的模型。 | .pt 或 .pth |
.bin | 一种通用的二进制格式,可以用于保存和加载各种类型的模型和数据。 | 需要将 PyTorch 模型转换为通用的二进制格式的场景。 | .bin |
ONNX | 一种通用的模型交换格式,可以用于将模型从一个深度学习框架转换到另一个深度学习框架或硬件平台。在 PyTorch 中,可以使用 torch.onnx.export 函数将 PyTorch 模型转换为 ONNX 格式。 | 需要将 PyTorch 模型转换为其他深度学习框架或硬件平台可用的格式的场景。 | .onnx |
TorchScript | PyTorch 提供的一种序列化和优化模型的方法,可以将 PyTorch 模型转换为一个序列化的程序,并使用 JIT 编译器对模型进行优化。在 PyTorch 中,可以使用 torch.jit.trace 或 torch.jit.script 函数将 PyTorch 模型转换为 TorchScript 格式。 | 需要将 PyTorch 模型序列化和优化,并在没有 Python 环境的情况下运行模型的场景。 | .pt 或 .pth |
.pt .pth格式
一个完整的Pytorch模型文件,包含了如下参数:
- model_state_dict:模型参数
- optimizer_state_dict:优化器的状态
- epoch:当前的训练轮数
- loss:当前的损失值
下面是一个.pt文件的保存和加载示例(注意,后缀也可以是 .pth ):
- .state_dict():包含所有的参数和持久化缓存的字典,model和optimizer都有这个方法
- torch.save():将所有的组件保存到文件中
模型保存
import torch
import torch.nn as nn
# 定义一个简单的模型
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc1 = nn.Linear(10, 5)
self.fc2 = nn.Linear(5, 1)
def forward(self, x):
x = self.fc1(x)
x = self.fc2(x)
return x
model = Net()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
loss = 0.55
# 保存模型
state = {
'epoch': 10,
"model": model.state_dict(),
"optimizer": optimizer.state_dict(),
'loss': loss
}
checkpoint="./PATH/model-1.pt"
torch.save(state,checkpoint)
模型加载
import torch
import torch.nn as nn
# 定义同样的模型结构
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc1 = nn.Linear(10, 5)
self.fc2 = nn.Linear(5, 1)
def forward(self, x):
x = self.fc1(x)
x = self.fc2(x)
return x
# 加载模型
model = Net()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
model.eval()
.bin格式
.bin文件是一个二进制文件,可以保存Pytorch模型的参数和持久化缓存。.bin
文件的大小较小,加载速度较快,因此在生产环境中使用较多。
下面是一个.bin文件的保存和加载示例(注意:也可以使用 .pt .pth 后缀):
保存模型
import torch
import torch.nn as nn
# 定义一个简单的模型
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc1 = nn.Linear(10, 5)
self.fc2 = nn.Linear(5, 1)
def forward(self, x):
x = self.fc1(x)
x = self.fc2(x)
return x
model = Net()
# 保存参数到.bin文件
checkpoint="./PATH/model.bin"
torch.save(model.state_dict(), checkpoint)
加载模型
import torch
import torch.nn as nn
# 定义相同的模型结构
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc1 = nn.Linear(10, 5)
self.fc2 = nn.Linear(5, 1)
def forward(self, x):
x = self.fc1(x)
x = self.fc2(x)
return x
# 加载.bin文件
model = Net()
checkpoint="./PATH/model.bin"
model.load_state_dict(torch.load(checkpoint))
model.eval()
.onnx格式
上述保存的文件可以通过PyTorch提供的torch.onnx.export
函数转化为ONNX格式,这样可以在其他深度学习框架中使用PyTorch训练的模型。转化方法如下:
import torch
import torch.onnx
# 将模型保存为.bin文件
model = torch.nn.Linear(3, 1)
torch.save(model.state_dict(), "model.bin")
# torch.save(model.state_dict(), "model.pt")
# torch.save(model.state_dict(), "model.pth")
# 将.bin文件转化为ONNX格式
model = torch.nn.Linear(3, 1)
model.load_state_dict(torch.load("model.bin"))
# model.load_state_dict(torch.load("model.pt"))
# model.load_state_dict(torch.load("model.pth"))
example_input = torch.randn(1, 3)
torch.onnx.export(model, example_input, "model.onnx", input_names=["input"], output_names=["output"])
加载ONNX格式的代码可以参考以下示例代码:
import onnx
import onnxruntime
# 加载ONNX文件
onnx_model = onnx.load("model.onnx")
# 将ONNX文件转化为ORT格式
ort_session = onnxruntime.InferenceSession("model.onnx")
# 输入数据
input_data = np.random.random(size=(1, 3)).astype(np.float32)
# 运行模型
outputs = ort_session.run(None, {"input": input_data})
# 输出结果
print(outputs)
注意,需要安装onnx
和onnxruntime
两个Python包。此外,还需要使用numpy
等其他常用的科学计算库。
直接保存完整模型
可以看出来,我们在之前的报错方式中,都是保存了.state_dict(),但是没有保存模型的结构,在其他地方使用的时候,必须先重新定义相同结构的模型(或兼容模型),才能够加载模型参数进行使用,如果我们想直接把整个模型都保存下来,避免重新定义模型,可以按如下操作:
# 保存模型
PATH = "entire_model.pt"
# PATH = "entire_model.pth"
# PATH = "entire_model.bin"
torch.save(model, PATH)
加载模型
# 加载模型
model = torch.load("entire_model.pt")
model.eval()
结语
本文介绍了pytorch可以导出的模型的几种后缀格式,但是模型导出的关键并不是后缀,而是到处时候提供的信息到底是什么,只要知道了模型的model.state_dict()
和optimizer.state_dict()
,以及相应的epoch batch_size loss等信息,我们就能够重建出模型,至于要导出哪些信息,就取决于你了,务必在readme.md中写清楚,你导出了哪些信息。
保存场景 | 保存方法 | 文件后缀 |
---|---|---|
整个模型 | model = Net() torch.save(model, PATH) | .pt .pth .bin |
仅模型参数 | model = Net() torch.save(model.state_dict(), PATH) | .pt .pth .bin |
checkpoints使用 | model = Net() torch.save({ 'epoch': 10, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss': loss, }, PATH) | .pt .pth .bin |
ONNX通用保存 | model = Net() model.load_state_dict(torch.load("model.bin")) example_input = torch.randn(1, 3) torch.onnx.export(model, example_input, "model.onnx", input_names=["input"], output_names=["output"]) | .onnx |
TorchScript无python环境使用 | model = Net() model_scripted = torch.jit.script(model) # Export to TorchScript model_scripted.save('model_scripted.pt') 使用时: model = torch.jit.load('model_scripted.pt') model.eval() | .pt .pth |