pytorch处理模型过拟合
演示代码如下
1 import torch 2 from torch.autograd import Variable 3 import torch.nn.functional as F 4 import matplotlib.pyplot as plt 5 # make fake data 6 n_data = torch.ones(100, 2) 7 x0 = torch.normal(2*n_data, 1) #每个元素(x,y)是从 均值=2*n_data中对应位置的取值,标准差为1的正态分布中随机生成的 8 y0 = torch.zeros(100) # 给每个元素一个0标签 9 x1 = torch.normal(-2*n_data, 1) # 每个元素(x,y)是从 均值=-2*n_data中对应位置的取值,标准差为1的正态分布中随机生成的 10 y1 = torch.ones(100) # 给每个元素一个1标签 11 x = torch.cat((x0, x1), 0).type(torch.FloatTensor) # shape (200, 2) FloatTensor = 32-bit floating 12 y = torch.cat((y0, y1), ).type(torch.LongTensor) # shape (200,) LongTensor = 64-bit integer 13 # torch can only train on Variable, so convert them to Variable 14 x, y = Variable(x), Variable(y) 15 16 # draw the data 17 plt.scatter(x.data.numpy()[:, 0], x.data.numpy()[:, 1], c=y.data.numpy())#c是一个颜色序列 18 19 20 #plt.show() 21 #神经网络模块 22 net2 = torch.nn.Sequential( 23 torch.nn.Linear(2,10), 24 torch.nn.Dropout(0.2),#处理过拟合,当然这个模型本身很简单,不需要处理过拟合,这个只是一个演示 25 torch.nn.ReLU(), 26 torch.nn.Linear(10,2) 27 ) 28 29 plt.ion()#在Plt.ion和plt.ioff之间的代码,交互绘图 30 plt.show() 31 #神经网络优化器,主要是为了优化我们的神经网络,使他在我们的训练过程中快起来,节省社交网络训练的时间。 32 optimizer = torch.optim.SGD(net2.parameters(),lr = 0.01)#其实就是神经网络的反向传播,第一个参数是更新权重等参数,第二个对应的是学习率 33 loss_func = torch.nn.CrossEntropyLoss()#标签误差代价函数 34 35 for t in range(50): 36 out = net2(x) 37 loss = loss_func(out,y)#计算损失 38 optimizer.zero_grad()#梯度置零 39 loss.backward()#反向传播 40 optimizer.step()#计算结点梯度并优化, 41 if t % 2 == 0: 42 net2.eval()#模型做预测的时候不需要dropout,切换为eval()模式 43 plt.cla()# Clear axis即清除当前图形中的之前的轨迹 44 prediction = torch.max(F.softmax(out), 1)[1]#转换为概率,后面的一是最大值索引,如果为0则返回最大值 45 pred_y = prediction.data.numpy().squeeze() 46 target_y = y.data.numpy() 47 plt.scatter(x.data.numpy()[:, 0], x.data.numpy()[:, 1], c=pred_y, s=100, lw=0, cmap='RdYlGn') 48 accuracy = sum(pred_y == target_y) / 200.#求准确率 49 plt.text(1.5, -4, 'Accuracy=%.2f' % accuracy, fontdict={'size': 20, 'color': 'red'}) 50 plt.pause(0.1) 51 net2.train()#切花为训练模式 52 53 plt.ioff() 54 plt.show()
注意model.eval和model.train的使用
作者:你的雷哥
本文版权归作者和博客园共有,欢迎转载,但未经作者同意必须在文章页面给出原文连接,否则保留追究法律责任的权利。
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· AI与.NET技术实操系列:基于图像分类模型对图像进行分类
· go语言实现终端里的倒计时
· 如何编写易于单元测试的代码
· 10年+ .NET Coder 心语,封装的思维:从隐藏、稳定开始理解其本质意义
· .NET Core 中如何实现缓存的预热?
· 25岁的心里话
· 基于 Docker 搭建 FRP 内网穿透开源项目(很简单哒)
· 闲置电脑爆改个人服务器(超详细) #公网映射 #Vmware虚拟网络编辑器
· 一起来玩mcp_server_sqlite,让AI帮你做增删改查!!
· 零经验选手,Compose 一天开发一款小游戏!