Chainer的初步学习

  人们都说Chainer是一块非常灵活you要用的框架,今天接着项目里面的应用,初步接触一下,涨涨姿势,直接上源码吧,看着好理解。其实跟Tensorflow等其他框架都是一个套路,个人感觉更简洁了。

 

 1 """
 2     测试使用
 3 """
 4 import pickle
 5 import time
 6 import numpy as np
 7 import matplotlib.pyplot as plt
 8 from chainer import Chain, Variable, optimizers, serializers
 9 import chainer.functions as F
10 import chainer.links as L
11 
12 # 创建Chainer Variables变]量
13 a = Variable(np.array([3], dtype=np.float32))
14 b = Variable(np.array([4], dtype=np.float32))
15 c = a**2 +b**2
16 
17 # 5通过data属性检查之前定义的变量
18 print('a.data:{0}, b.data{1}, c.data{2}'.format(a.data, b.data, c.data))
19 
20 # 使用backward()方法,对变量c进行反向传播.对c进行求导
21 c.backward()
22 # 通过在变量中存储的grad属性,检查其导数
23 print('dc/da = {0}, dc/db={1}, dc/dc={2}'.format(a.grad, b.grad, c.grad))
24 
25 # 在chainer中做线性回归
26 x = 30*np.random.rand(1000).astype(np.float32)
27 y = 7*x + 10
28 y += 10*np.random.randn(1000).astype(np.float32)
29 
30 plt.scatter(x, y)
31 plt.xlabel('x')
32 plt.ylabel('y')
33 plt.show()
34 
35 
36 # 使用chainer做线性回归
37 
38 # 从一个变量到另一个变量建立一个线性连接
39 linear_function = L.Linear(1, 1)
40 # 设置x和y作为chainer变量,以确保能够变形到特定形态
41 x_var = Variable(x.reshape(1000, -1))
42 y_var = Variable(y.reshape(1000, -1))
43 # 建立优化器
44 optimizer = optimizers.MomentumSGD(lr=0.001)
45 optimizer.setup(linear_function)
46 
47 
48 # 定义一个前向传播函数,数据作为输入,线性函数作为输出
49 def linear_forward(data):
50     return linear_function(data)
51 
52 
53 # 定义一个训练函数,给定输入数据,目标数据,迭代数
54 def linear_train(train_data, train_traget, n_epochs=200):
55     for _ in range(n_epochs):
56         # 得到前向传播结果
57         output = linear_forward(train_data)
58         # 计算训练目标数据和实际标数据的损失
59         loss = F.mean_squared_error(train_traget, output)
60         # 在更新之前将梯度取零,线性函数和梯度有非常密切的关系
61         # linear_function.zerograds()
62         linear_function.cleargrads()
63         # 计算并更新所有梯度
64         loss.backward()
65         # 优化器更新
66         optimizer.update()
67 
68 
69 # 绘制训练结果
70 plt.scatter(x, y, alpha=0.5)
71 for i in range(150):
72     # 训练
73     linear_train(x_var, y_var, n_epochs=5)
74     # 预测值
75     y_pred = linear_forward(x_var).data
76     plt.plot(x, y_pred, color=plt.cm.cool(i / 150.), alpha=0.4, lw=3)
77 
78 slope = linear_function.W.data[0, 0]        # linear_function是之前定义的连接,线性连接有两个参数W和b,此种形式可以获取训练后参数的值,slope是斜率的意思
79 intercept = linear_function.b.data[0]       # intercept是截距的意思
80 plt.title("Final Line: {0:.3}x + {1:.3}".format(slope, intercept))
81 plt.xlabel('x')
82 plt.ylabel('y')
83 plt.show()

 

posted @ 2018-09-27 15:44  今夜无风  阅读(4027)  评论(0编辑  收藏  举报