点此进入CSDN

点此添加QQ好友 加载失败时会显示




Pytorch手写线性回归

pytorch手写线性回归

 

import torch
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation

LEARN_RATE = 0.1
#1.准备数据
x = torch.randn([500,1])
y_true = x*0.8+3

#2.计算预测值 t_tred = x*w + b

w = torch.rand([],requires_grad=True)
b = torch.tensor(0.,requires_grad=True)

plt.figure()
plt.grid(True)

#开启交互模式
plt.ion()
for i in range(50):

    plt.cla()

    for j in [w,b]:
        if j.grad is not None:
            j.grad.zero_()
    y_predict = x*w+b

    #3.计算损失,把参数的梯度置为0,进行反向传播

    loss = (y_predict-y_true).pow(2).mean()

    loss.backward()

    #4.更新参数,grad表示导数

    w.data = w.data - LEARN_RATE*w.grad
    b.data = b.data - LEARN_RATE*b.grad


    plt.scatter(x.numpy(),y_true.numpy())
    plt.plot(x.numpy(),y_predict.detach().numpy(),color="g")

    plt.pause(0.1)


    if i %50 ==0:
        print( "第{}次,损失{},权重w={},偏执b={}".format(i,loss.data,w.data,b.data))

#关闭交互模式
plt.ioff()
plt.show()

  

posted @ 2019-08-19 00:26  高颜值的殺生丸  阅读(317)  评论(0编辑  收藏  举报

作者信息

昵称:

刘新宇

园龄:4年6个月


粉丝:1209


QQ:522414928