pt与pth
.pt和.pth文件都是用于保存PyTorch模型的文件格式,习惯上:
.pt是完整的模型(full model):包括模型的架构和权重。
.pth只有模型的权重(weights)
.pt的保存与加载
import torch # 假设有一个训练好的模型 model torch.save(model, 'full_model.pt') model = torch.load('full_model.pt') model.eval()
.pth的保存与加载
import torch # 假设有一个训练好的模型 model torch.save(model.state_dict(), 'model_weights.pth') # 假设有一个模型类 ModelClass model = ModelClass() model.load_state_dict(torch.load('model_weights.pth'))
pth的好处:训练时滑窗的步长设置为5,使用时我想用步长为3的,不需要再次训练模型。权重偏置等无需再重新训练得到,只是调整下步长。
初始化模型结构时将参数指定为3,然后载入pth就可以。之后保存成pt给其他人(转engine部署时,pt先转onnx,然后在部署机器上onnx转engine)
官方推荐保存和加载.pth,直接加载模型的权重数据可以减少内存使用、加快加载速度。模型结构和优化器状态等其他元数据通过其他方式单独加载。
【其他】