Pytorch:计算图与动态图机制
计算图
computational graph
表示方法
计算图是用来描述运算的有向无环图
计算图有两个主要元素:结点(node)和边(edge)
结点表示数据,如向量,矩阵,张量
边表示运算,如加减乘除卷积等
计算图不仅使计算显得简洁,更重要的是其表示梯度求导更为方便
用计算图表示y=(x+w)*(w+1): 令 a=x+w b=w+1 则y=a*b
梯度求导结合题例的算式和计算图表示
从上述的计算图表示中,可以看到,除了x,w外的梯度都是可以用x,w或其xw梯度表示
叶子结点
叶子结点:用户创建的结点称为叶子结点,如x,w
torch.tensor.is_leaf:用来查看是否为叶子结点
叶子结点的用处:在反向传播之后,只有叶子结点的梯度会被保留,其他中间结点的梯度数据会被释放,以节省内存
若要保存某特定结点的梯度,可以在运行反向传播函数前,补上一句tensor.retain_grad(),此次反向传播后,该tensor的梯度会被保留
梯度方向
torch.tensor.grad_fn:记录创建该张量时所用的方法(函数)及梯度方向
作用:以便计算机知道求导某张量结点的梯度需用到的特定法则
AddBackward代表a张量为相加形成,此时的梯度在反向传播中
动态图
dynamic graph
动态图与静态图
根据计算图的搭建方式,可以将计算图分为动态图和静态图。
- 动态图采用搭建和运算同时进行:灵活、易调节(debug)
- 静态图为先搭建,后运算:高效但不灵活
pytorch采用动态图机制,tensorflow采用静态图机制
如下左图,tensorflow的搭建机制为先确立好所有的路线流向,再将tensor数据注入跑通整个过程。这也是tensorflow的名字由来
而下右图,pytorch中随着代码的数据定义与构建,同时也确立部分的计算图模块。