网络frozen
参考网页
https://blog.csdn.net/Orientliu96/article/details/104705912
device = torch.device("cuda" )
#Try to load models
model = DGCNN(args)
print(str(model))
model = model.to(device)
save_model = torch.load('model.t7')
model_dict = model.state_dict()
更新模型的参数,因为自己的网络比pretrain的模型更复杂
state_dict = {k:v for k,v in save_model.items() if k in model_dict.keys()}
print(state_dict.keys()) # dict_keys(['w', 'conv1.weight', 'conv1.bias', 'conv2.weight', 'conv2.bias'])
model_dict.update(state_dict)
model.load_state_dict(model_dict)
for name,p in model.named_parameters():
if name.startswith('conv1'): p.requires_grad = False
if name.startswith('conv2'): p.requires_grad = False
if name.startswith('conv3'): p.requires_grad = False
if name.startswith('conv4'): p.requires_grad = False
if name.startswith('bn1'): p.requires_grad = False
if name.startswith('bn2'): p.requires_grad = False
if name.startswith('bn3'): p.requires_grad = False
if name.startswith('bn4'): p.requires_grad = False
opt = optim.SGD(filter(lambda x: x.requires_grad is not False ,model.parameters()), lr=args.lr*100, momentum=args.momentum, weight_decay=1e-4)
解读
这段代码使用了PyTorch框架,并展示了如何在一个卷积神经网络(DGCNN)上执行模型参数的加载、更新、部分冻结参数,以及设置优化器的过程。下面是每部分代码的具体解读:
-
设置设备:
device = torch.device("cuda")
这行代码设置了使用CUDA设备,也就是GPU来加速模型的计算。
-
初始化和打印模型:
model = DGCNN(args) print(str(model)) model = model.to(device)
这里创建了一个DGCNN模型的实例,其中
args
可能包含了模型的一些配置参数(如层数、激活函数等)。打印模型的结构后,将模型迁移到之前指定的CUDA设备上。 -
加载预训练模型并更新当前模型的参数:
save_model = torch.load('model.t7') model_dict = model.state_dict() state_dict = {k:v for k,v in save_model.items() if k in model_dict.keys()} print(state_dict.keys()) model_dict.update(state_dict) model.load_state_dict(model_dict)
首先从
model.t7
文件中加载一个预训练模型的状态字典。然后获取当前模型的状态字典,并创建一个新的状态字典,只包含预训练模型中存在于当前模型中的参数。更新当前模型的状态字典,并将其加载回模型中。 -
冻结特定层的参数:
for name, p in model.named_parameters(): if name.startswith(('conv1', 'conv2', 'conv3', 'conv4', 'bn1', 'bn2', 'bn3', 'bn4')): p.requires_grad = False
这部分遍历模型的所有参数,并根据参数名字决定是否冻结参数。这里冻结了以
conv1
,conv2
,conv3
,conv4
,bn1
,bn2
,bn3
,bn4
开头的所有层的参数。requires_grad = False
意味着在训练过程中这些参数不会被更新。 -
设置优化器:
opt = optim.SGD(filter(lambda x: x.requires_grad is not False, model.parameters()), lr=args.lr*100, momentum=args.momentum, weight_decay=1e-4)
最后,设置了一个随机梯度下降(SGD)优化器,仅针对那些未被冻结(
requires_grad
不是False
)的参数。这里还对学习率进行了调整(可能是为了适应不同的训练需求),并设置了动量和权重衰减,这些都有助于优化训练过程。
整体上,这段代码展示了如何利用预训练模型来初始化一个更复杂的网络模型,并通过冻结部分层的参数来加速训练过程,同时专注于训练模型的特定部分。
对与预训练中重合的key做交集。再将做交集的内容网络值返回到现有的模型当中
-
torch.load('model.t7')
:这行代码从model.t7
文件加载模型权重。在PyTorch中,模型的状态通常保存在一个名为.t7
或.pth
的文件中,这个文件包含了模型参数的字典。 -
model.state_dict()
:每个PyTorch模型都有一个state_dict
方法,它返回模型当前的状态,即其所有参数的字典。 -
state_dict = {k:v for k,v in save_model.items() if k in model_dict.keys()}
:这段代码遍历save_model
字典中的每一项(键值对)。它创建一个新的字典state_dict
,这个字典只包括那些键同时存在于save_model
和当前模型状态字典model_dict
中的项。这意味着,只有当预训练模型的某个权重与当前模型中的一个权重具有相同的键名时,这个权重才会被添加到新的字典中。 -
print(state_dict.keys())
:打印出state_dict
中所有键的列表,这通常用于验证哪些参数被载入了。 -
model_dict.update(state_dict)
:这行代码将state_dict
中的项更新到model_dict
中。如果model_dict
中已经有相同的键,则这些键对应的值会被state_dict
中的值覆盖。 -
model.load_state_dict(model_dict)
:最后,更新后的model_dict
被加载回模型中,这样模型就具有了部分预训练的权重。
这个过程允许模型只加载那些已知的、匹配的权重,这对于模型微调是很有用的,尤其是当你有一个预训练的模型,想要将它适配到一个新的、稍有不同的任务或网络架构时。这样做可以确保只有对应的权重被加载和更新,从而保持模型中其他自定义部分的参数不变。
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· TypeScript + Deepseek 打造卜卦网站:技术与玄学的结合
· Manus的开源复刻OpenManus初探
· AI 智能体引爆开源社区「GitHub 热点速览」
· 三行代码完成国际化适配,妙~啊~
· .NET Core 中如何实现缓存的预热?