register_buffer vs register_parameter

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
先来看下nn.Module的成员:
 
    def __init__(self):
        """
        Initializes internal Module state, shared by both nn.Module and ScriptModule.
        """
        torch._C._log_api_usage_once("python.nn_module")
 
        self.training = True
        self._parameters = OrderedDict()
        self._buffers = OrderedDict()
        self._backward_hooks = OrderedDict()
        self._forward_hooks = OrderedDict()
        self._forward_pre_hooks = OrderedDict()
        self._state_dict_hooks = OrderedDict()
        self._load_state_dict_pre_hooks = OrderedDict()
        self._modules = OrderedDict()

 

register_bufferregister_parameter只涉及到_buffer_parameters,调用这两个函数分别会向两个成员写入数据。

_buffer_parameter都会被state_dict返回,且可以通过.cpu().cuda()在设备间进行转换。
_buffer中的元素不会被优化器更新,如果在模型中需要需要一些参数,并且要通过state_dict返回,且不需要被优化器训练,那么这些参数可以注册在_buffer中。
例如在maskrcnn_benchmark中的anchor_generator生成中就用到了register_buffer,以及detectron2中的BatchNorm2d

如果定义self.param1=torch.randn(2,2),那么param1是不会被state_dict返回的,且不会被.cpu().cuda()在设备间进行转换。

 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import torch
from torch import nn
 
class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        print('before register buffer:\n', self._buffers, end='\n\n')
        self.register_buffer('mybuffer1', torch.randn(2, 2))
        print('after register buffer:\n', self._buffers, end='\n\n')
 
        print('before register parameter:\n', self._parameters, end='\n\n')
        self.register_parameter('my_param1', nn.Parameter(torch.randn(3, 3)))
        print('after register parameter:\n', self._parameters, end='\n\n')
        self.param1 = torch.randn(3, 3)
 
    def forward(self, x):
        return x
 
mymodel = MyModel()
mymodel.cuda()
print(list(mymodel.parameters()))
print(list(mymodel.buffers()))
print(mymodel.param1)

返回如下 

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
before register buffer:
 OrderedDict()
 
after register buffer:
 OrderedDict([('mybuffer1', tensor([[-0.4997, -1.0214],
        [ 0.5604, -2.3252]]))])
 
before register parameter:
 OrderedDict()
 
after register parameter:
 OrderedDict([('my_param1', Parameter containing:
tensor([[ 0.14651.1252, -0.2854],
        [ 2.2109, -0.39190.0385],
        [ 0.33470.15970.7505]], requires_grad=True))])
 
[Parameter containing:
tensor([[ 0.14651.1252, -0.2854],
        [ 2.2109, -0.39190.0385],
        [ 0.33470.15970.7505]], device='cuda:0', requires_grad=True)]
[tensor([[-0.4997, -1.0214],
        [ 0.5604, -2.3252]], device='cuda:0')]
tensor([[ 0.6994, -2.60782.0409],
        [-0.12101.0048, -1.3913],
        [-1.3752, -1.3748, -2.4478]])

 

posted @   水木清扬  阅读(334)  评论(0编辑  收藏  举报
编辑推荐:
· 开发者必知的日志记录最佳实践
· SQL Server 2025 AI相关能力初探
· Linux系列:如何用 C#调用 C方法造成内存泄露
· AI与.NET技术实操系列(二):开始使用ML.NET
· 记一次.NET内存居高不下排查解决与启示
阅读排行:
· 阿里最新开源QwQ-32B,效果媲美deepseek-r1满血版,部署成本又又又降低了!
· 开源Multi-agent AI智能体框架aevatar.ai,欢迎大家贡献代码
· Manus重磅发布:全球首款通用AI代理技术深度解析与实战指南
· 被坑几百块钱后,我竟然真的恢复了删除的微信聊天记录!
· 没有Manus邀请码?试试免邀请码的MGX或者开源的OpenManus吧
点击右上角即可分享
微信分享提示