pt与pth

.pt和.pth文件都是用于保存PyTorch模型的文件格式,习惯上:

.pt是完整的模型(full model):包括模型的架构和权重。

.pth只有模型的权重(weights)

.pt的保存与加载

import torch

# 假设有一个训练好的模型 model
torch.save(model, 'full_model.pt')

model = torch.load('full_model.pt')
model.eval()

.pth的保存与加载

import torch

# 假设有一个训练好的模型 model
torch.save(model.state_dict(), 'model_weights.pth')

# 假设有一个模型类 ModelClass
model = ModelClass()
model.load_state_dict(torch.load('model_weights.pth'))

pth的好处:训练时滑窗的步长设置为5,使用时我想用步长为3的,不需要再次训练模型。权重偏置等无需再重新训练得到,只是调整下步长。

初始化模型结构时将参数指定为3,然后载入pth就可以。之后保存成pt给其他人(转engine部署时,pt先转onnx,然后在部署机器上onnx转engine)

官方推荐保存和加载.pth,直接加载模型的权重数据可以减少内存使用、加快加载速度。模型结构和优化器状态等其他元数据通过其他方式单独加载。

【其他】

【Pytorch】model.train() 和 model.eval() 原理与用法-腾讯云开发者社区-腾讯云

posted @ 2022-08-11 15:47  夕西行  阅读(335)  评论(0编辑  收藏  举报