pytorch的一些tricks

转载一篇技术文章,一些pytorch使用技巧

1、计算整体参数量

num_parameters = sum(torch.numel(parameter) for parameter in model.parameters())

2、提取模型中的某一层

modules()会返回模型中所有模块的迭代器,它能够访问到最内层,比如self.layer1.conv1这个模块,还有一个与它们相对应的是name_children()属性以及named_modules(),这两个不仅会返回模块的迭代器,还会返回网络层的名字。

# 取模型中的前两层
new_model = nn.Sequential(*list(model.children())[:2] 
# 如果希望提取出模型中的所有卷积层,可以像下面这样操作:
for layer in model.named_modules():
    if isinstance(layer[1],nn.Conv2d):
         conv_model.add_module(layer[0],layer[1])

3、GPU模型加载到cpu

model.load_state_dict(torch.load('model.pth', map_location='cpu'))

4、导入另一个模型的相同部分到新的模型

# model_new代表新的模型
# model_saved代表其他模型,比如用torch.load导入的已保存的模型
model_new_dict = model_new.state_dict()
model_common_dict = {k:v for k, v in model_saved.items() if k in model_new_dict.keys()}
model_new_dict.update(model_common_dict)
model_new.load_state_dict(model_new_dict)

5、指定GPU编号

# 设置当前使用的GPU设备仅为0号设备,设备名称为 /gpu:0:
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

# 设置当前使用的GPU设备为0, 1号两个设备,名称依次为 /gpu:0、/gpu:1: 
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1" 
#根据顺序表示优先使用0号设备,然后使用1号设备。

6、 查看模型每层输出详情

from torchsummary import summary
summary(your_model, input_size=(channels, H, W))

7、防止验证模型时爆显存

with torch.no_grad():

8、冻结某些层的参数

net = Network()  # 获取自定义网络结构
for name, value in net.named_parameters():
    print('name: {0},\t grad: {1}'.format(name, value.requires_grad))

no_grad = [
    'cnn.VGG_16.convolution1_1.weight',
    'cnn.VGG_16.convolution1_1.bias',
    'cnn.VGG_16.convolution1_2.weight',
    'cnn.VGG_16.convolution1_2.bias'
]
net = Net.CTPN()  # 获取网络结构
for name, value in net.named_parameters():
    if name in no_grad:
        value.requires_grad = False
    else:
        value.requires_grad = True
		
optimizer = optim.Adam(filter(lambda p: p.requires_grad, net.parameters()), lr=0.01)


9、扩展向量维度

import torch
vector = torch.FloatTensor(2,3,4)

vector1 = vector.view(1,*vector.size())
vector2 = vector[1,:,:,:]

vector3 = torch.unsqueeze(vector,dim = 0)
vecotr4 = torch.squeeze(vector1,dim = 0)


10、模型参数

model.state_dict()
model.parameters()
model.buffer()

来源:https://datawhale.club/
作者:z.defying
链接:https://mp.weixin.qq.com/s/yo9AKetSrwZsPooWYcmN1A

posted @   努力生活的叶子吖  阅读(59)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· 【硬核科普】Trae如何「偷看」你的代码?零基础破解AI编程运行原理
· 单元测试从入门到精通
· 上周热点回顾(3.3-3.9)
· winform 绘制太阳,地球,月球 运作规律
点击右上角即可分享
微信分享提示