第四节,线性回归案例
最*在网上也看了不少相关深度学习的视频,大部分都在讲解原理,对代码的实现讲解较少,为此苦苦寻找一本实战的书籍,黄天不负有心人,终于找到一本很好的书籍,<深度学习之TensorFlow入门、原理与进阶实战>,作者是李金洪。在这里就记录一下我的学习之路,也希望对和我一样在学习深度学习路上迷茫的同学有一定的帮助。
一、解决问题
本节内容来源于书中第三章内容,TensorFlow基本开发步骤-以线性回归拟合二维数据为例。
本节主要解决一个什么问题呢?假设我们有一组数据集,数据集是二维的,其中x和y对应的关系*似为y=2x,我们的目的就是从这度数据中求解出y和x之间这样的关系。
我们在解决这样的问题过程中积累了一定的规律。主要遵循以下步骤:
- 准备数据
- 搭建模型
- 迭代训练
- 使用模型进行预测
二、准备数据
数据我们可以利用y=2x的公式来生成带有一定干扰噪声的数据集。
import tensorflow as tf import numpy as np import matplotlib.pyplot as plt ''' 一准备数据 ''' #设定训练集数据长度 n_train = 100 #生成x数据,[-1,1]之间,均分成n_train个数据 train_x = np.linspace(-1,1,n_train) #把x乘以2,加入(0,0.3)的高斯正太分布 train_y = 2*train_x + np.random.normal(loc=0.0,scale=0.3,size=n_train) #绘制x,y波形
plt.figure() plt.plot(train_x,train_y,'ro',label='y=2x') #o使用圆点标记一个点
plt.legend()
plt.show()
我们可以看一看生成的数据点
三、搭建模型
因为只有一个因变量,所以逻辑线性回归方程为 y = w1x+b,也可以看做神经网络中一个神经元,只有两个参数w1和b。我们可以搭建一个这样的模型,代码如下:
''' 二 搭建模型 ''' ''' 前向反馈 ''' #创建占位符 input_x = tf.placeholder(dtype=tf.float32) input_y = tf.placeholder(dtype=tf.float32) #模型参数 w = tf.Variable(tf.truncated_normal(shape=[1],mean=0.0,stddev=1),name='w') #设置正太分布参数 初始化权重 b = tf.Variable(tf.truncated_normal(shape=[1],mean=0.0,stddev=1),name='b') #设置正太分布参数 初始化偏置 #前向结构 pred = tf.multiply(w,input_x) + b ''' 反向传播bp ''' #定义代价函数 选取二次代价函数 cost = tf.reduce_mean(tf.square(input_y - pred)) #设置求解器 采用梯度下降法 学习了设置为0.001 train = tf.train.GradientDescentOptimizer(learning_rate=0.001).minimize(cost)
四、迭代模型
我们可以定义代价函数为二次代价函数,然后利用梯度下降法求解参数,代码如下:
''' 三 迭代模型 ''' #设置迭代次数 training_epochs = 200 display_step = 20 with tf.Session() as sess: #初始化所有张量 sess.run(tf.global_variables_initializer()) #存放批次值和代价值 plotdata = {'batch_size':[],'loss':[]} #开始迭代 for epoch in range(training_epochs): for (x,y) in zip(train_x,train_y): #开始执行图 sess.run(train,feed_dict={input_x:x,input_y:y}) #一轮训练完成后 打印输出信息 if epoch % display_step == 0: #计算代价值 loss = sess.run(cost,feed_dict={input_x:train_x,input_y:train_y}) print('Epoch {0} cost {1} w {2} b{3}'.format(epoch,loss,sess.run(w),sess.run(b))) #保存每display_step轮训练后的代价值以及当前迭代轮数 if not loss == np.nan: plotdata['batch_size'].append(epoch) plotdata['loss'].append(loss) #输出最终结果 print('Finished!') print('cost {0} w {1} b {2}'.format(sess.run(cost,feed_dict={input_x:train_x,input_y:train_y}),sess.run(w),sess.run(b)))
运行程序输出结果如下:
我们可以看到w的值*似为2,b*似为0,这正是我们之前假设的公式参数。
五、模型预测
我们预测输入为2,4,5,7时输出的值:
亲爱的读者和支持者们,自动博客加入了打赏功能,陆陆续续收到了各位老铁的打赏。在此,我想由衷地感谢每一位对我们博客的支持和打赏。你们的慷慨与支持,是我们前行的动力与源泉。
日期 | 姓名 | 金额 |
---|---|---|
2023-09-06 | *源 | 19 |
2023-09-11 | *朝科 | 88 |
2023-09-21 | *号 | 5 |
2023-09-16 | *真 | 60 |
2023-10-26 | *通 | 9.9 |
2023-11-04 | *慎 | 0.66 |
2023-11-24 | *恩 | 0.01 |
2023-12-30 | I*B | 1 |
2024-01-28 | *兴 | 20 |
2024-02-01 | QYing | 20 |
2024-02-11 | *督 | 6 |
2024-02-18 | 一*x | 1 |
2024-02-20 | c*l | 18.88 |
2024-01-01 | *I | 5 |
2024-04-08 | *程 | 150 |
2024-04-18 | *超 | 20 |
2024-04-26 | .*V | 30 |
2024-05-08 | D*W | 5 |
2024-05-29 | *辉 | 20 |
2024-05-30 | *雄 | 10 |
2024-06-08 | *: | 10 |
2024-06-23 | 小狮子 | 666 |
2024-06-28 | *s | 6.66 |
2024-06-29 | *炼 | 1 |
2024-06-30 | *! | 1 |
2024-07-08 | *方 | 20 |
2024-07-18 | A*1 | 6.66 |
2024-07-31 | *北 | 12 |
2024-08-13 | *基 | 1 |
2024-08-23 | n*s | 2 |
2024-09-02 | *源 | 50 |
2024-09-04 | *J | 2 |
2024-09-06 | *强 | 8.8 |
2024-09-09 | *波 | 1 |
2024-09-10 | *口 | 1 |
2024-09-10 | *波 | 1 |
2024-09-12 | *波 | 10 |
2024-09-18 | *明 | 1.68 |
2024-09-26 | B*h | 10 |
2024-09-30 | 岁 | 10 |
2024-10-02 | M*i | 1 |
2024-10-14 | *朋 | 10 |
2024-10-22 | *海 | 10 |
2024-10-23 | *南 | 10 |
2024-10-26 | *节 | 6.66 |
2024-10-27 | *o | 5 |
2024-10-28 | W*F | 6.66 |
2024-10-29 | R*n | 6.66 |
2024-11-02 | *球 | 6 |
2024-11-021 | *鑫 | 6.66 |
2024-11-25 | *沙 | 5 |
2024-11-29 | C*n | 2.88 |

分类:
tensorflow
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 记一次.NET内存居高不下排查解决与启示
· 探究高空视频全景AR技术的实现原理
· 理解Rust引用及其生命周期标识(上)
· 浏览器原生「磁吸」效果!Anchor Positioning 锚点定位神器解析
· 没有源码,如何修改代码逻辑?
· 全程不用写代码,我用AI程序员写了一个飞机大战
· MongoDB 8.0这个新功能碉堡了,比商业数据库还牛
· DeepSeek 开源周回顾「GitHub 热点速览」
· 记一次.NET内存居高不下排查解决与启示
· 白话解读 Dapr 1.15:你的「微服务管家」又秀新绝活了