pytorch 踩坑笔记之w.grad.data.zero_()

  在使用pytorch实现多项线性回归中,在grad更新时,每一次运算后都需要将上一次的梯度记录清空,运用如下方法:

     w.grad.data.zero_()
     b.grad.data.zero_() 

   但是,运行程序就会报如下错误:

  报错,grad没有data这个属性,

  原因是,在系统将w的grad值初始化为none,第一次求梯度计算是在none值上进行报错,自然会没有data属性

  修改方法:添加一个判断语句,从第二次循环开始执行求导运算

for i in range(100):
    y_pred = multi_linear(x_train)
    loss = getloss(y_pred,y_train)
    if i != 0:
        w.grad.data.zero_()
        b.grad.data.zero_()
    loss.backward()
    w.data = w.data - 0.001 * w.grad.data
    b.data = b.data - 0.001 * b.grad.data

 

posted @ 2019-07-22 17:30  去冰七分糖  阅读(7500)  评论(0编辑  收藏  举报