《PyTorch深度学习》笔记(3)

过拟合,在训练集上训练的很好,但是在测试集上结果不好

判别法:通过观察验证机和训练集的曲线图(用visdom进行实时的图像输出,用numpy,matplotlib画最终的结果图)

模型要持久化,存盘

np.meshgrid()画三维图很有用

pytorch优点:动态图

一个简单的训练线性函数的pytorch代码
训练y=w*x中的w的值,数据为

x=[1.0,2.0,3.0]
y=[2.0,4.0,6.0]

#用pytorch进行简单的线性函数的参数计算
import torch

x_data=[1.0,2.0,3.0]
y_data=[2.0,4.0,6.0]

#1.0表示w这个权重的初始值
w = torch.tensor([1.0])
#是否计算梯度
w.requires_grad=True


#前馈计算
#其中w是tensor了
#其中*被重载了
def forward(x):
    return x * w

#损失函数
#(y_predict-y) ** 2表示平方
def loss(x,y):
    y_predict = forward(x)
    return (y_predict-y) ** 2

#计算之前,表示4*w的值,为1*4=4
print("predict (before trainning)",4,forward(4).item())

#开始计算
for epoch in range(100):
    for x,y in zip(x_data,y_data):
        l=loss(x,y)
        #求梯度,并且自动存在变量里,这里梯度存在w里
        #每进行一次反向传播都会生成反向传播计算图,每次计算完都会释放

        l.backward()
        #item()表示取出w张量元素的值,使w变成一个标量
        # w.grad.item()类型为:<class 'float'>
        print('\t',"x为"+str(x),"y为"+str(y),"梯度为"+str(w.grad.item()),"w的值为"+str(w.item()))
        # 用tensor变量进行计算就是在生成计算图,所以要转换成非tensor的
        #用法就是调用.data
        #w.grad.data的类型还是tensor,但是不会建立计算图
        w.data = w.data - 0.01*w.grad.data
        #权重里面的梯度的数据全部清零,不然的话会还是保存那里,这个操作是必须的
        w.grad.data.zero_()
    print("Progress:",epoch,l.item())

#计算之后,表示4*w的值,为2*4=8
print("predict (after trainning)",4,forward(4).item())

输出结果

predict (before trainning) 4 4.0
	 x为1.0 y为2.0 梯度为-2.0 w的值为1.0
	 x为2.0 y为4.0 梯度为-7.840000152587891 w的值为1.0199999809265137
	 x为3.0 y为6.0 梯度为-16.228801727294922 w的值为1.0983999967575073
Progress: 0 7.315943717956543
	 x为1.0 y为2.0 梯度为-1.478623867034912 w的值为1.260688066482544
	 x为2.0 y为4.0 梯度为-5.796205520629883 w的值为1.2754743099212646
	 x为3.0 y为6.0 梯度为-11.998146057128906 w的值为1.333436369895935
Progress: 1 3.9987640380859375
	 x为1.0 y为2.0 梯度为-1.0931644439697266 w的值为1.4534177780151367
	 x为2.0 y为4.0 梯度为-4.285204887390137 w的值为1.464349389076233
	 x为3.0 y为6.0 梯度为-8.870372772216797 w的值为1.5072014331817627
Progress: 2 2.1856532096862793
	 x为1.0 y为2.0 梯度为-0.8081896305084229 w的值为1.5959051847457886
	 x为2.0 y为4.0 梯度为-3.1681032180786133 w的值为1.6039870977401733
	 x为3.0 y为6.0 梯度为-6.557973861694336 w的值为1.635668158531189
Progress: 3 1.1946394443511963
	 x为1.0 y为2.0 梯度为-0.5975041389465332 w的值为1.7012479305267334
	 x为2.0 y为4.0 梯度为-2.3422164916992188 w的值为1.7072229385375977
	 x为3.0 y为6.0 梯度为-4.848389625549316 w的值为1.7306450605392456
Progress: 4 0.6529689431190491
	 x为1.0 y为2.0 梯度为-0.4417421817779541 w的值为1.779128909111023
	 x为2.0 y为4.0 梯度为-1.7316293716430664 w的值为1.7835463285446167
	 x为3.0 y为6.0 梯度为-3.58447265625 w的值为1.8008626699447632
Progress: 5 0.35690122842788696
	 x为1.0 y为2.0 梯度为-0.3265852928161621 w的值为1.836707353591919
	 x为2.0 y为4.0 梯度为-1.2802143096923828 w的值为1.8399732112884521
	 x为3.0 y为6.0 梯度为-2.650045394897461 w的值为1.8527753353118896
Progress: 6 0.195076122879982
	 x为1.0 y为2.0 梯度为-0.24144840240478516 w的值为1.8792757987976074
	 x为2.0 y为4.0 梯度为-0.9464778900146484 w的值为1.881690263748169
	 x为3.0 y为6.0 梯度为-1.9592113494873047 w的值为1.8911550045013428
Progress: 7 0.10662525147199631
	 x为1.0 y为2.0 梯度为-0.17850565910339355 w的值为1.9107471704483032
	 x为2.0 y为4.0 梯度为-0.699742317199707 w的值为1.9125322103500366
	 x为3.0 y为6.0 梯度为-1.4484672546386719 w的值为1.919529676437378
Progress: 8 0.0582793727517128
	 x为1.0 y为2.0 梯度为-0.1319713592529297 w的值为1.9340143203735352
	 x为2.0 y为4.0 梯度为-0.5173273086547852 w的值为1.9353340864181519
	 x为3.0 y为6.0 梯度为-1.070866584777832 w的值为1.940507411956787
Progress: 9 0.03185431286692619
	 x为1.0 y为2.0 梯度为-0.09756779670715332 w的值为1.9512161016464233
	 x为2.0 y为4.0 梯度为-0.3824653625488281 w的值为1.9521918296813965
	 x为3.0 y为6.0 梯度为-0.7917022705078125 w的值为1.9560165405273438
Progress: 10 0.017410902306437492
	 x为1.0 y为2.0 梯度为-0.07213282585144043 w的值为1.9639335870742798
	 x为2.0 y为4.0 梯度为-0.2827606201171875 w的值为1.9646549224853516
	 x为3.0 y为6.0 梯度为-0.5853137969970703 w的值为1.967482566833496
Progress: 11 0.009516451507806778
	 x为1.0 y为2.0 梯度为-0.053328514099121094 w的值为1.9733357429504395
	 x为2.0 y为4.0 梯度为-0.2090473175048828 w的值为1.9738690853118896
	 x为3.0 y为6.0 梯度为-0.43272972106933594 w的值为1.9759595394134521
Progress: 12 0.005201528314501047
	 x为1.0 y为2.0 梯度为-0.039426326751708984 w的值为1.9802868366241455
	 x为2.0 y为4.0 梯度为-0.15455150604248047 w的值为1.98068106174469
	 x为3.0 y为6.0 梯度为-0.3199195861816406 w的值为1.9822266101837158
Progress: 13 0.0028430151287466288
	 x为1.0 y为2.0 梯度为-0.029148340225219727 w的值为1.9854258298873901
	 x为2.0 y为4.0 梯度为-0.11426162719726562 w的值为1.9857172966003418
	 x为3.0 y为6.0 梯度为-0.23652076721191406 w的值为1.986859917640686
Progress: 14 0.0015539465239271522
	 x为1.0 y为2.0 梯度为-0.021549701690673828 w的值为1.989225149154663
	 x为2.0 y为4.0 梯度为-0.08447456359863281 w的值为1.989440679550171
	 x为3.0 y为6.0 梯度为-0.17486286163330078 w的值为1.9902853965759277
Progress: 15 0.0008493617060594261
	 x为1.0 y为2.0 梯度为-0.01593184471130371 w的值为1.9920340776443481
	 x为2.0 y为4.0 梯度为-0.062453269958496094 w的值为1.992193341255188
	 x为3.0 y为6.0 梯度为-0.12927818298339844 w的值为1.9928178787231445
Progress: 16 0.00046424579340964556
	 x为1.0 y为2.0 梯度为-0.011778593063354492 w的值为1.9941107034683228
	 x为2.0 y为4.0 梯度为-0.046172142028808594 w的值为1.994228482246399
	 x为3.0 y为6.0 梯度为-0.09557533264160156 w的值为1.994690179824829
Progress: 17 0.0002537401160225272
	 x为1.0 y为2.0 梯度为-0.00870823860168457 w的值为1.9956458806991577
	 x为2.0 y为4.0 梯度为-0.03413581848144531 w的值为1.9957330226898193
	 x为3.0 y为6.0 梯度为-0.07066154479980469 w的值为1.9960744380950928
Progress: 18 0.00013869594840798527
	 x为1.0 y为2.0 梯度为-0.006437778472900391 w的值为1.9967811107635498
	 x为2.0 y为4.0 梯度为-0.025236129760742188 w的值为1.9968454837799072
	 x为3.0 y为6.0 梯度为-0.052239418029785156 w的值为1.9970978498458862
	 ...

可以看出w的值逐渐趋近于2

posted @ 2021-08-02 15:18  猪猪猪猪侠  阅读(31)  评论(0编辑  收藏  举报