机器学习线性模型-自定义各个模块阶段
例:尝试一个线性模型
让我们来使用目前为止学到的概念—Tensor,Variable,和 GradientTape—来创建和训练一个简单的模型。一般需要下面这些步骤:
定义模型
定义损失函数
获取训练数据
通过训练数据运行模型,使用 “optimizer” 来调整变量以满足数据
在这个教程中,我们使用一个简单线性模型作为示例:f(x) = x * W + b,有2个变量- W 和 b。另外,我们会生成数据让训练好的模型满足 W = 3.0 和 b = 2.0。
定义模型
定义一个简单的类封装变量和计算
class Model(object):
def __init__(self):
# 初始化变量值为(5.0, 0.0)
# 实际上,这些变量应该初始化为随机值
self.W = tf.Variable(5.0)
self.b = tf.Variable(0.0)
def __call__(self, x):
return self.W * x + self.b
model = Model()
assert model(3.0).numpy() == 15.0
定义损失函数
损失函数用来衡量在给定输入的情况下,模型的预测输出与实际输出的偏差。我们这里使用标准 L2 损失函数。
def loss(predicted_y, desired_y):
return tf.reduce_mean(tf.square(predicted_y - desired_y))
获取训练数据
我们来生成带噪声的训练数据。
TRUE_W = 3.0
TRUE_b = 2.0
NUM_EXAMPLES = 1000
inputs = tf.random.normal(shape=[NUM_EXAMPLES])
noise = tf.random.normal(shape=[NUM_EXAMPLES])
outputs = inputs * TRUE_W + TRUE_b + noise
在训练模型之前,我们来看看当前的模型表现。我们绘制模型的预测结果和训练数据,预测结果用红色表示,训练数据用蓝色表示。
import matplotlib.pyplot as plt
plt.scatter(inputs, outputs, c='b')
plt.scatter(inputs, model(inputs), c='r')
plt.show()
print('Current loss: '),
print(loss(model(inputs), outputs).numpy())
定义训练循环
我们已经定义了网络模型,并且获得了训练数据。现在对模型进行训练,采用梯度下降的方式,通过训练数据更新模型的变量(W 和 b)使得损失量变小。梯度下降中有很多参数,通过 tf.train.Optimizer 实现。我们强烈建议使用这些实现方式,但基于通过基本规则创建模型的精神,在这个特别示例中,我们自己实现基本的数学运算。
def train(model, inputs, outputs, learning_rate):
with tf.GradientTape() as t:
current_loss = loss(model(inputs), outputs)
dW, db = t.gradient(current_loss, [model.W, model.b])
model.W.assign_sub(learning_rate * dW)
model.b.assign_sub(learning_rate * db)
训练
最后,我们对训练数据重复地训练,观察 W 和 b 是怎么变化的。
model = Model()
# 收集 W 和 b 的历史数值,用于显示
Ws, bs = [], []
epochs = range(10)
for epoch in epochs:
Ws.append(model.W.numpy())
bs.append(model.b.numpy())
current_loss = loss(model(inputs), outputs)
train(model, inputs, outputs, learning_rate=0.1)
print('Epoch %2d: W=%1.2f b=%1.2f, loss=%2.5f' %
(epoch, Ws[-1], bs[-1], current_loss))
# 显示所有
plt.plot(epochs, Ws, 'r',
epochs, bs, 'b')
plt.plot([TRUE_W] * len(epochs), 'r--',
[TRUE_b] * len(epochs), 'b--')
plt.legend(['W', 'b', 'true W', 'true_b'])
plt.show()
posted on 2019-11-12 17:55 MrCharles在cnblogs 阅读(202) 评论(0) 编辑 收藏 举报