register_buffer vs register_parameter
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 | 先来看下nn.Module的成员: def __init__( self ): """ Initializes internal Module state, shared by both nn.Module and ScriptModule. """ torch._C._log_api_usage_once( "python.nn_module" ) self .training = True self ._parameters = OrderedDict() self ._buffers = OrderedDict() self ._backward_hooks = OrderedDict() self ._forward_hooks = OrderedDict() self ._forward_pre_hooks = OrderedDict() self ._state_dict_hooks = OrderedDict() self ._load_state_dict_pre_hooks = OrderedDict() self ._modules = OrderedDict() |
register_buffer
和register_parameter
只涉及到_buffer
和_parameters
,调用这两个函数分别会向两个成员写入数据。
_buffer
和_parameter
都会被state_dict
返回,且可以通过.cpu()
和.cuda()
在设备间进行转换。
_buffer
中的元素不会被优化器更新,如果在模型中需要需要一些参数,并且要通过state_dict
返回,且不需要被优化器训练,那么这些参数可以注册在_buffer
中。
例如在maskrcnn_benchmark中的anchor_generator生成中就用到了register_buffer
,以及detectron2中的BatchNorm2d。
如果定义self.param1=torch.randn(2,2)
,那么param1
是不会被state_dict
返回的,且不会被.cpu()
和.cuda()
在设备间进行转换。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 | import torch from torch import nn class MyModel(nn.Module): def __init__( self ): super (MyModel, self ).__init__() print ( 'before register buffer:\n' , self ._buffers, end = '\n\n' ) self .register_buffer( 'mybuffer1' , torch.randn( 2 , 2 )) print ( 'after register buffer:\n' , self ._buffers, end = '\n\n' ) print ( 'before register parameter:\n' , self ._parameters, end = '\n\n' ) self .register_parameter( 'my_param1' , nn.Parameter(torch.randn( 3 , 3 ))) print ( 'after register parameter:\n' , self ._parameters, end = '\n\n' ) self .param1 = torch.randn( 3 , 3 ) def forward( self , x): return x mymodel = MyModel() mymodel.cuda() print ( list (mymodel.parameters())) print ( list (mymodel.buffers())) print (mymodel.param1) |
返回如下
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 | before register buffer : OrderedDict() after register buffer : OrderedDict([( 'mybuffer1' , tensor([[ - 0.4997 , - 1.0214 ], [ 0.5604 , - 2.3252 ]]))]) before register parameter: OrderedDict() after register parameter: OrderedDict([( 'my_param1' , Parameter containing: tensor([[ 0.1465 , 1.1252 , - 0.2854 ], [ 2.2109 , - 0.3919 , 0.0385 ], [ 0.3347 , 0.1597 , 0.7505 ]], requires_grad = True ))]) [Parameter containing: tensor([[ 0.1465 , 1.1252 , - 0.2854 ], [ 2.2109 , - 0.3919 , 0.0385 ], [ 0.3347 , 0.1597 , 0.7505 ]], device = 'cuda:0' , requires_grad = True )] [tensor([[ - 0.4997 , - 1.0214 ], [ 0.5604 , - 2.3252 ]], device = 'cuda:0' )] tensor([[ 0.6994 , - 2.6078 , 2.0409 ], [ - 0.1210 , 1.0048 , - 1.3913 ], [ - 1.3752 , - 1.3748 , - 2.4478 ]]) |
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 开发者必知的日志记录最佳实践
· SQL Server 2025 AI相关能力初探
· Linux系列:如何用 C#调用 C方法造成内存泄露
· AI与.NET技术实操系列(二):开始使用ML.NET
· 记一次.NET内存居高不下排查解决与启示
· 阿里最新开源QwQ-32B,效果媲美deepseek-r1满血版,部署成本又又又降低了!
· 开源Multi-agent AI智能体框架aevatar.ai,欢迎大家贡献代码
· Manus重磅发布:全球首款通用AI代理技术深度解析与实战指南
· 被坑几百块钱后,我竟然真的恢复了删除的微信聊天记录!
· 没有Manus邀请码?试试免邀请码的MGX或者开源的OpenManus吧