pytorch的一些tricks
转载一篇技术文章,一些pytorch使用技巧
1、计算整体参数量
num_parameters = sum(torch.numel(parameter) for parameter in model.parameters())
2、提取模型中的某一层
modules()会返回模型中所有模块的迭代器,它能够访问到最内层,比如self.layer1.conv1这个模块,还有一个与它们相对应的是name_children()属性以及named_modules(),这两个不仅会返回模块的迭代器,还会返回网络层的名字。
# 取模型中的前两层
new_model = nn.Sequential(*list(model.children())[:2]
# 如果希望提取出模型中的所有卷积层,可以像下面这样操作:
for layer in model.named_modules():
if isinstance(layer[1],nn.Conv2d):
conv_model.add_module(layer[0],layer[1])
3、GPU模型加载到cpu
model.load_state_dict(torch.load('model.pth', map_location='cpu'))
4、导入另一个模型的相同部分到新的模型
# model_new代表新的模型
# model_saved代表其他模型,比如用torch.load导入的已保存的模型
model_new_dict = model_new.state_dict()
model_common_dict = {k:v for k, v in model_saved.items() if k in model_new_dict.keys()}
model_new_dict.update(model_common_dict)
model_new.load_state_dict(model_new_dict)
5、指定GPU编号
# 设置当前使用的GPU设备仅为0号设备,设备名称为 /gpu:0:
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
# 设置当前使用的GPU设备为0, 1号两个设备,名称依次为 /gpu:0、/gpu:1:
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
#根据顺序表示优先使用0号设备,然后使用1号设备。
6、 查看模型每层输出详情
from torchsummary import summary
summary(your_model, input_size=(channels, H, W))
7、防止验证模型时爆显存
with torch.no_grad():
8、冻结某些层的参数
net = Network() # 获取自定义网络结构
for name, value in net.named_parameters():
print('name: {0},\t grad: {1}'.format(name, value.requires_grad))
no_grad = [
'cnn.VGG_16.convolution1_1.weight',
'cnn.VGG_16.convolution1_1.bias',
'cnn.VGG_16.convolution1_2.weight',
'cnn.VGG_16.convolution1_2.bias'
]
net = Net.CTPN() # 获取网络结构
for name, value in net.named_parameters():
if name in no_grad:
value.requires_grad = False
else:
value.requires_grad = True
optimizer = optim.Adam(filter(lambda p: p.requires_grad, net.parameters()), lr=0.01)
9、扩展向量维度
import torch
vector = torch.FloatTensor(2,3,4)
vector1 = vector.view(1,*vector.size())
vector2 = vector[1,:,:,:]
vector3 = torch.unsqueeze(vector,dim = 0)
vecotr4 = torch.squeeze(vector1,dim = 0)
10、模型参数
model.state_dict()
model.parameters()
model.buffer()
来源:https://datawhale.club/
作者:z.defying
链接:https://mp.weixin.qq.com/s/yo9AKetSrwZsPooWYcmN1A
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· 【硬核科普】Trae如何「偷看」你的代码?零基础破解AI编程运行原理
· 单元测试从入门到精通
· 上周热点回顾(3.3-3.9)
· winform 绘制太阳,地球,月球 运作规律