Fork me on Gitee

pytorch解决鸢尾花分类

半年前用numpy写了个鸢尾花分类200行。。每一步计算都是手写的  python构建bp神经网络_鸢尾花分类

现在用pytorch简单写一遍,pytorch语法解释请看上一篇pytorch搭建简单网络

 1 import pandas as pd
 2 import torch.nn as nn
 3 import torch
 4 
 5 
 6 class MyNet(nn.Module):
 7     def __init__(self):
 8         super(MyNet, self).__init__()
 9         self.fc = nn.Sequential(
10             nn.Linear(4, 3),
11             nn.Sigmoid(),
12             nn.Linear(3, 3),
13             nn.Sigmoid(),
14             nn.Linear(3, 1),
15         )
16         self.mls = nn.MSELoss()
17         self.opt = torch.optim.Adam(params=self.parameters(), lr=0.001)
18 
19     def get_data(self):
20         inputs = []
21         labels = []
22         with open('flower.csv') as file:
23             df = pd.read_csv(file, header=None)
24             x = df.iloc[:, 0:4].values
25             y = df.iloc[:, 4].values
26             for i in range(len(x)):
27                 inputs.append(x[i])
28             for j in range(len(y)):
29                 a = []
30                 a.append(y[j])
31                 labels.append(a)
32 
33         return inputs, labels
34 
35     def forward(self, inputs):
36         out = self.fc(inputs)
37         return out
38 
39     def train(self, x, label):
40         out = self.forward(x)
41         loss = self.mls(out, label)
42         self.opt.zero_grad()
43         loss.backward()
44         self.opt.step()
45 
46     def test(self, x):
47         return self.fc(x)
48 
49 
50 if __name__ == '__main__':
51     net = MyNet()
52     inputs, labels = net.get_data()
53     for i in range(1000):
54         for index, input in enumerate(inputs):
55             # 这里不加.float()会报错,可能是数据格式的问题吧
56             input = torch.from_numpy(input).float()
57             label = torch.Tensor(labels[index])
58             net.train(input, label)
59     # 简单测试一下
60     c = torch.Tensor([[5.6, 2.7, 4.2, 1.3]])
61     print(net.test(c))

运行结果趋近于0.5  正确,单纯练一下pytorch,就没有分训练集,测试集

1 tensor([[0.5392]], grad_fn=<AddmmBackward>)

不用手写反向传播和梯度下降 是多么幸福一件事~

posted @ 2018-12-12 16:30  MARK+  阅读(4339)  评论(4编辑  收藏  举报