pytorch 模型加载和保存

模型加载

torch.load(f, map_location=None, pickle_module=<module 'pickle' from '/opt/conda/lib/python3.6/pickle.py'>, **pickle_load_args)

 

map_location适用于修改模型能在gpu上运行还是cpu上运行。

一般情况下,加载模型,主要用于预测新来的一组样本。预测的主要流程包括:输入数据——预处理——加载模型——预测得返回值(类别或者是属于某一类别的概率)

def predict(test_data, model_path, config):
    '''
    input:
           test_data:测试数据
           model_path:模型的保存路径 model_path = './save/20201104_204451.ckpt'
    output:
           score:模型输出属于某一类别的概率
    '''
    data = process_data_for_predict(test_data)#预处理数据,使得数据格式符合模型输入形式
    model = torch.load(model_path)#加载模型
    score = model(data)#模型预测
    return score #返回得分

 Pytorch模型 .pt, .pth, .pkl的区别

后缀名为.pt, .pth, .pkl的pytorch模型文件,在格式上其实没有区别,只是后缀不同而已

模型的保存和加载有两种方式:

(1) 仅仅保存和加载模型参数

# 保存
torch.save(the_model.state_dict(), PATH='mymodel.pth')        #只保存模型权重参数,不保存模型结构

# 调用
the_model = TheModelClass(*args, **kwargs)                    #这里需要重新模型结构,TheModelClass
the_model.load_state_dict(torch.load('mymodel.pth'))        #这里根据模型结构,调用存储的模型参数

(2) 保存和加载整个模型

# 保存
torch.save(the_model, PATH)                                    #保存整个model的状态

# 调用
the_model = torch.load(PATH)                                #这里已经不需要重构模型结构了,直接load就可以

第一种方式需要自己定义网络,并且其中的参数名称与结构要与保存的模型中的一致(可以是部分网络,比如只使用VGG的前几层),相对灵活,便于对网络进行修改。第二种方式则无需自定义网络,保存时已把网络结构保存,比较死板,不能调整网络结构。

posted @ 2024-08-06 10:08  Arxu  阅读(165)  评论(0)    收藏  举报