Pytorch载入保存的模型更换gpu卡号如何设置map_location

Pytorch在保存模型时,会把训练过程中使用的设备号(GPU:0, CPU等)也一起保存下来,当load保存的模型时,会默认把权重加载到训练时使用的设备上,要修改加载到的卡号可以如下:

torch.load('your_model.pth', map_location=lambda storage, loc : storage.cuda(1))

posted @ 2022-02-19 12:51  sqdtss  阅读(301)  评论(0编辑  收藏  举报