【极简】Pytorch中的register_buffer()
register buffer
定义模型能用torch.save
保存的、但是不更新参数。
使用:只要是nn.Module
的子类就能直接self.
调用使用:
class A(nn.Module):
#...
self.register_buffer(
'betas', torch.linspace(beta_1, beta_T, T).double())
#...
手动定义参数
上述的参数显然可以直接用一个变量直接定义超参。但是缺点是在用torch.save()
保存的时候不能保存在参数里面,只能用个文本文件保存在外面。不能直接用torch.load加载,不是很方便。
举个例子,假设你有100个超参,难不成要一个一个记录之后,手动造轮子解析保存的txt嘛?当然也行,但是麻烦。
就比如Diffusion Model中的beta和alpha,在每个timestep时候都是不一样的,这时候手动保存会相当麻烦,用register buffer会相当方便。
普通参数
一般来说模型中的可变参数都是nn.Parameter()
类的,这些都是可变的,optimizer会去优化它们。
要是跟register buffer硬凑在一起,把Parameter的require_grad
改成False也能充当。但是何必呢?
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· winform 绘制太阳,地球,月球 运作规律
· 【硬核科普】Trae如何「偷看」你的代码?零基础破解AI编程运行原理
· 上周热点回顾(3.3-3.9)
· 超详细:普通电脑也行Windows部署deepseek R1训练数据并当服务器共享给他人