Pytorch02_GPU加速

GPU加速

 

1. 定义GPU设备

import torch
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

 2. 将模型、张量等放在GPU设备上

# 
loss.to(device)
tensor.to(device)
model.to(device)

 3. 将数据等放回CPu

predict = model(data)
predict = predict.cpu().detach().numpy()  
# detach() 和 data效果相似,但detach是深拷贝,data是浅拷贝

 

 

--------------------------------

随有随更 2021.6.9

--------------------------------

posted @ 2021-06-09 13:10  Haozi_D17  阅读(50)  评论(0编辑  收藏  举报