tensorflow实现线性回归总结
1、知识点
""" 模拟一个y = 0.7x+0.8的案例 报警: 1、initialize_all_variables (from tensorflow.python.ops.variables) is deprecated and will be removed after 2017-03-02 解决方法:由于使用了tf.initialize_all_variables() 初始化变量,该方法已过时,使用tf.global_variables_initializer()就不会了 tensorboard查看数据: 1、收集变量信息 tf.summary.scalar() tf.summary.histogram() merge = tf.summary.merge_all() 2、创建事件机制 fileWriter = tf.summary.FileWriter(logdir='',graph=sess.graph) 3、在sess中运行并合并merge summary = sess.run(merge) 4、在循环训练中将变量添加到事件中 fileWriter.add_summary(summary,i) #i为训练次数 保存并加载训练模型: 1、创建保存模型saver对象 saver = tf.train.Saver() 2、保存模型 saver.save(sess,'./ckpt/model') 3、利用保存的模型加载模型,变量初始值从保存模型读取 if os.path.exists('./ckpt/checkpoint'): saver.restore(sess,'./ckpt/model') 创建变量域: with tf.variable_scope("data"): """
2、代码
# coding = utf-8 import tensorflow as tf import os os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' def myLinear(): """ 自实现线性回归 :return: """ with tf.variable_scope("data"): #1、准备数据 x = tf.random_normal((100,1),mean=0.5,stddev=1,name='x') y_true = tf.matmul(x,[[0.7]])+0.8 #矩阵相乘至少为2维 with tf.variable_scope("model"): #2、初始化权重和偏置 weight = tf.Variable(tf.random_normal((1,1)),name='w') bias = tf.Variable(0.0,name='b') y_predict = tf.matmul(x,weight)+bias with tf.variable_scope("loss"): #3、计算损失值 loss = tf.reduce_mean(tf.square(y_true-y_predict)) with tf.variable_scope("train"): #4、梯度下降优化loss train_op = tf.train.GradientDescentOptimizer(learning_rate=0.01).minimize(loss) #初始化变量 init_op = tf.global_variables_initializer() ############收集变量信息存到tensorboard查看############### #收集变量 tf.summary.scalar('losses',loss)#1维 tf.summary.histogram('weight',weight) #高维 tf.summary.histogram('bias', bias) # 高维 merged = tf.summary.merge_all() #将变量合并 ######################################################### #####################保存并加载模型############### saver = tf.train.Saver() ################################################# #5、循环训练 with tf.Session() as sess: sess.run(init_op) #运行是初始化变量 if os.path.exists('./ckpt/checkpoint'): saver.restore(sess,'./ckpt/model') #建立事件机制 fileWriter = tf.summary.FileWriter(logdir='./tmp',graph=sess.graph) print("初始化权重为:%f,偏置为:%f" %(weight.eval(),bias.eval())) for i in range(501): summary = sess.run(merged) # 运行并合并 fileWriter.add_summary(summary,i) sess.run(train_op) if i%10==0 : print("第%d次训练权重为:%f,偏置为:%f" % (i,weight.eval(), bias.eval())) saver.save(sess,'./ckpt/model') return None if __name__ == '__main__': myLinear()
3、代码
import tensorflow as tf import csv import numpy as np import matplotlib.pyplot as plt # 设置学习率 learning_rate = 0.01 # 设置训练次数 train_steps = 1000 with open('D:/Machine Learning/Data_wrangling/鲍鱼数据集.csv') as file: reader = csv.reader(file) a, b = [], [] for item in reader: b.append(item[8]) del(item[8]) a.append(item) file.close() x_data = np.array(a) y_data = np.array(b) for i in range(len(x_data)): y_data[i] = float(y_data[i]) for j in range(len(x_data[i])): x_data[i][j] = float(x_data[i][j]) # 定义各影响因子的权重 weights = tf.Variable(np.ones([8,1]),dtype = tf.float32) x_data_ = tf.placeholder(tf.float32, [None, 8]) y_data_ = tf.placeholder(tf.float32, [None, 1]) bias = tf.Variable(1.0, dtype = tf.float32)#定义偏差值 # 构建模型为:y_model = w1X1 + w2X2 + w3X3 + w4X4 + w5X5 + w6X6 + w7X7 + w8X8 + bias y_model = tf.add(tf.matmul(x_data_ , weights), bias) # 定义损失函数 loss = tf.reduce_mean(tf.pow((y_model - y_data_), 2)) #训练目标为损失值最小,学习率为0.01 train_op = tf.train.GradientDescentOptimizer(0.01).minimize(loss) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) print("Start training!") lo = [] sample = np.arange(train_steps) for i in range(train_steps): for (x,y) in zip(x_data, y_data): z1 = x.reshape(1,8) z2 = y.reshape(1,1) sess.run(train_op, feed_dict = {x_data_ : z1, y_data_ : z2}) l = sess.run(loss, feed_dict = {x_data_ : z1, y_data_ : z2}) lo.append(l) print(weights.eval(sess)) print(bias.eval(sess)) # 绘制训练损失变化图 plt.plot(sample, lo, marker="*", linewidth=1, linestyle="--", color="red") plt.title("The variation of the loss") plt.xlabel("Sampling Point") plt.ylabel("Loss") plt.grid(True) plt.show()
本文来自博客园,作者:小白啊小白,Fighting,转载请注明原文链接:https://www.cnblogs.com/ywjfx/p/10911610.html