pytorch requires_grad = True的意思
计算图通常包含两种元素,一个是 tensor,另一个是 Function。张量 tensor 不必多说,但是大家可能对 Function 比较陌生。这里 Function 指的是在计算图中某个节点(node)所进行的运算,比如加减乘除卷积等等之类的,Function 内部有 forward() 和 backward() 两个方法,分别应用于正向、反向传播。
当我们创建一个张量 (tensor) 的时候,如果没有特殊指定的话,那么这个张量是默认是不需要求导的。们在训练一个网络的时候,我们从 DataLoader
中读取出来的一个 mini-batch 的数据,这些输入默认是不需要求导的,其次,网络的输出我们没有特意指明需要求导吧,Ground Truth 我们也没有特意设置需要求导吧。这么一想,哇,那我之前的那些 loss 咋还能自动求导呢?其实原因就是上边那条规则,虽然输入的训练数据是默认不求导的,但是,我们的 model 中的所有参数,它默认是求导的,这么一来,其中只要有一个需要求导,那么输出的网络结果必定也会需要求的。来看个实例:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 | input = torch.randn( 8 , 3 , 50 , 100 ) print ( input .requires_grad) # False net = nn.Sequential(nn.Conv2d( 3 , 16 , 3 , 1 ), nn.Conv2d( 16 , 32 , 3 , 1 )) for param in net.named_parameters(): print (param[ 0 ], param[ 1 ].requires_grad) # 0.weight True # 0.bias True # 1.weight True # 1.bias True output = net( input ) print (output.requires_grad) # True |
在写代码的过程中,不要把网络的输入和 Ground Truth 的 requires_grad
设置为 True。虽然这样设置不会影响反向传播,但是需要额外计算网络的输入和 Ground Truth 的导数,增大了计算量和内存占用不说,这些计算出来的导数结果也没啥用。因为我们只需要神经网络中的参数的导数,用来更新网络,其余的导数都不需要。
原文链接:https://zhuanlan.zhihu.com/p/67184419
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 25岁的心里话
· 闲置电脑爆改个人服务器(超详细) #公网映射 #Vmware虚拟网络编辑器
· 基于 Docker 搭建 FRP 内网穿透开源项目(很简单哒)
· 零经验选手,Compose 一天开发一款小游戏!
· 一起来玩mcp_server_sqlite,让AI帮你做增删改查!!
2019-03-18 《构建之法》阅读笔记02
2019-03-18 二维数组