代码改变世界

利用TensorFlow实现多元线性回归

  猎手家园  阅读(3334)  评论(0编辑  收藏  举报

利用TensorFlow实现多元线性回归,代码如下:

复制代码
# -*- coding:utf-8 -*-
import tensorflow as tf
import numpy as np
from sklearn import linear_model
from sklearn import preprocessing

# Read x and y
x_data = np.loadtxt("ex3x.dat").astype(np.float32)
y_data = np.loadtxt("ex3y.dat").astype(np.float32)

# We evaluate the x and y by sklearn to get a sense of the coefficients.
reg = linear_model.LinearRegression()
reg.fit(x_data, y_data)
print ("Coefficients of sklearn: K=%s, b=%f" % (reg.coef_, reg.intercept_))

# Now we use tensorflow to get similar results.
# Before we put the x_data into tensorflow, we need to standardize it
# in order to achieve better performance in gradient descent;
# If not standardized, the convergency speed could not be tolearated.
# Reason:  If a feature has a variance that is orders of magnitude larger than others,
# it might dominate the objective function
# and make the estimator unable to learn from other features correctly as expected.
scaler = preprocessing.StandardScaler().fit(x_data)
print (scaler.mean_, scaler.scale_)
x_data_standard = scaler.transform(x_data)

W = tf.Variable(tf.zeros([2, 1]))
b = tf.Variable(tf.zeros([1, 1]))
y = tf.matmul(x_data_standard, W) + b

loss = tf.reduce_mean(tf.square(y - y_data.reshape(-1, 1)))/2
optimizer = tf.train.GradientDescentOptimizer(0.3)
train = optimizer.minimize(loss)

init = tf.initialize_all_variables()

sess = tf.Session()
sess.run(init)
for step in range(100):
    sess.run(train)
    if step % 10 == 0:
        print (step, sess.run(W).flatten(), sess.run(b).flatten())

print ("Coefficients of tensorflow (input should be standardized): K=%s, b=%s" % (sess.run(W).flatten(), sess.run(b).flatten()))
print ("Coefficients of tensorflow (raw input): K=%s, b=%s" % (sess.run(W).flatten() / scaler.scale_, sess.run(b).flatten() - np.dot(scaler.mean_ / scaler.scale_, sess.run(W))))
复制代码

 

数据集下载:下载地址

编辑推荐:
· Linux系列:如何用 C#调用 C方法造成内存泄露
· AI与.NET技术实操系列(二):开始使用ML.NET
· 记一次.NET内存居高不下排查解决与启示
· 探究高空视频全景AR技术的实现原理
· 理解Rust引用及其生命周期标识(上)
阅读排行:
· 阿里最新开源QwQ-32B,效果媲美deepseek-r1满血版,部署成本又又又降低了!
· 单线程的Redis速度为什么快?
· 展开说说关于C#中ORM框架的用法!
· SQL Server 2025 AI相关能力初探
· Pantheons:用 TypeScript 打造主流大模型对话的一站式集成库
历史上的今天:
2016-05-12 如何使用Hive&R从Hadoop集群中提取数据进行分析
2016-05-12 CentOS6.5下实现R绘图
点击右上角即可分享
微信分享提示