线性模型

复制代码
import tensorflow as tf
from numpy.random import RandomState

#定义训练数据batch的大小
batch_size=8

#定义神经网络的参数
w1=tf.Variable(tf.random_normal([2,3], stddev=1, seed=1))
w2=tf.Variable(tf.random_normal([3,1], stddev=1, seed=1))

#在shape的一个维度上使用None可以方便使用不大的batch大小,在训练时需要把数据分为
#较小的batch,但是在测试时,可以一次性使用全部数据。
x=tf.placeholder(tf.float32,shape=(None, 2), name="x-input")
y_=tf.placeholder(tf.float32, shape=(None,1), name="y-input")

#定义神经网络的前向传播过程
a=tf.matmul(x, w1)
y=tf.matmul(a, w2)

#定义损失函数
cross_entropy=-tf.reduce_mean(y_*tf.log(tf.clip_by_value(y,1e-10,1.0)))
#定义反向传播算法
train_step=tf.train.AdamOptimizer(0.001).minimize(cross_entropy)

#通过随机数产生一个模拟数据集
rdm=RandomState(1)
dataset_size=128
X=rdm.rand(dataset_size, 2)

#定义规则来给出样本的标签。
#定义x1+x2<1的样例被认为是正样本。其他情况为负样本
Y=[[int(x1+x2<1)]for(x1,x2) in X]

#创建一个会话来运行tensorflow程序
with tf.Session() as sess:
    #初始化——start——
    init_op=tf.initialize_all_variables()
    sess.run(init_op)
    #初始化——end——

    #输出原始参数值
    print("输出初始参数:")
    print(sess.run(w1))
    print(sess.run(w2))

    #训练——start——
    STEPS=5000
    for i in range(STEPS):
        '''每次选出batch_size个样本进行训练'''
        start=(i*batch_size)%dataset_size
        end=min(start+batch_size, dataset_size)

        #通过选出的样本训练神经网络并更新参数
        sess.run(train_step,
                 feed_dict={x:X[start:end],y_:Y[start:end]})
        if i%1000==0 :
            '''每隔一段时间计算在所有数据上的交叉熵并输出'''
            total_cross_entropy=sess.run(cross_entropy,feed_dict={x:X,y_:Y})
            print("-----------------------")
            print("在%d训练步后,所有数据的交叉熵是%g"%(i,total_cross_entropy))
            print("-----------------------")
        print("第%d轮->"%(i))
        print("参数w1:")
        print(sess.run(w1))
        print("参数w2:")
        print(sess.run(w2))
复制代码

 

作者:ALINGMAOMAO

出处:https://www.cnblogs.com/ALINGMAOMAO/articles/11561770.html

版权:本作品采用「署名-非商业性使用-相同方式共享 4.0 国际」许可协议进行许可。

posted @   青山新雨  阅读(31)  评论(0编辑  收藏  举报
编辑推荐:
· Linux系列:如何用 C#调用 C方法造成内存泄露
· AI与.NET技术实操系列(二):开始使用ML.NET
· 记一次.NET内存居高不下排查解决与启示
· 探究高空视频全景AR技术的实现原理
· 理解Rust引用及其生命周期标识(上)
阅读排行:
· 单线程的Redis速度为什么快?
· 展开说说关于C#中ORM框架的用法!
· 阿里最新开源QwQ-32B,效果媲美deepseek-r1满血版,部署成本又又又降低了!
· Pantheons:用 TypeScript 打造主流大模型对话的一站式集成库
· SQL Server 2025 AI相关能力初探
历史上的今天:
2018-09-21 Coprime (单色三角形+莫比乌斯反演(数论容斥))
2018-09-21 莫比乌斯函数 51nod-1240(合数分解试除法)
more_horiz
keyboard_arrow_up light_mode palette
选择主题
点击右上角即可分享
微信分享提示