Chainer的初步学习
人们都说Chainer是一块非常灵活you要用的框架,今天接着项目里面的应用,初步接触一下,涨涨姿势,直接上源码吧,看着好理解。其实跟Tensorflow等其他框架都是一个套路,个人感觉更简洁了。
1 """ 2 测试使用 3 """ 4 import pickle 5 import time 6 import numpy as np 7 import matplotlib.pyplot as plt 8 from chainer import Chain, Variable, optimizers, serializers 9 import chainer.functions as F 10 import chainer.links as L 11 12 # 创建Chainer Variables变]量 13 a = Variable(np.array([3], dtype=np.float32)) 14 b = Variable(np.array([4], dtype=np.float32)) 15 c = a**2 +b**2 16 17 # 5通过data属性检查之前定义的变量 18 print('a.data:{0}, b.data{1}, c.data{2}'.format(a.data, b.data, c.data)) 19 20 # 使用backward()方法,对变量c进行反向传播.对c进行求导 21 c.backward() 22 # 通过在变量中存储的grad属性,检查其导数 23 print('dc/da = {0}, dc/db={1}, dc/dc={2}'.format(a.grad, b.grad, c.grad)) 24 25 # 在chainer中做线性回归 26 x = 30*np.random.rand(1000).astype(np.float32) 27 y = 7*x + 10 28 y += 10*np.random.randn(1000).astype(np.float32) 29 30 plt.scatter(x, y) 31 plt.xlabel('x') 32 plt.ylabel('y') 33 plt.show() 34 35 36 # 使用chainer做线性回归 37 38 # 从一个变量到另一个变量建立一个线性连接 39 linear_function = L.Linear(1, 1) 40 # 设置x和y作为chainer变量,以确保能够变形到特定形态 41 x_var = Variable(x.reshape(1000, -1)) 42 y_var = Variable(y.reshape(1000, -1)) 43 # 建立优化器 44 optimizer = optimizers.MomentumSGD(lr=0.001) 45 optimizer.setup(linear_function) 46 47 48 # 定义一个前向传播函数,数据作为输入,线性函数作为输出 49 def linear_forward(data): 50 return linear_function(data) 51 52 53 # 定义一个训练函数,给定输入数据,目标数据,迭代数 54 def linear_train(train_data, train_traget, n_epochs=200): 55 for _ in range(n_epochs): 56 # 得到前向传播结果 57 output = linear_forward(train_data) 58 # 计算训练目标数据和实际标数据的损失 59 loss = F.mean_squared_error(train_traget, output) 60 # 在更新之前将梯度取零,线性函数和梯度有非常密切的关系 61 # linear_function.zerograds() 62 linear_function.cleargrads() 63 # 计算并更新所有梯度 64 loss.backward() 65 # 优化器更新 66 optimizer.update() 67 68 69 # 绘制训练结果 70 plt.scatter(x, y, alpha=0.5) 71 for i in range(150): 72 # 训练 73 linear_train(x_var, y_var, n_epochs=5) 74 # 预测值 75 y_pred = linear_forward(x_var).data 76 plt.plot(x, y_pred, color=plt.cm.cool(i / 150.), alpha=0.4, lw=3) 77 78 slope = linear_function.W.data[0, 0] # linear_function是之前定义的连接,线性连接有两个参数W和b,此种形式可以获取训练后参数的值,slope是斜率的意思 79 intercept = linear_function.b.data[0] # intercept是截距的意思 80 plt.title("Final Line: {0:.3}x + {1:.3}".format(slope, intercept)) 81 plt.xlabel('x') 82 plt.ylabel('y') 83 plt.show()
时刻记着自己要成为什么样的人!