程序项目代做,有需求私信(小程序、网站、爬虫、电路板设计、驱动、应用程序开发、毕设疑难问题处理等)

第十节,利用隐藏层解决非线性问题-异或问题

多层神经网络非常好理解,就是在输入和输出中间多加一些层,每一层可以加多个神经元。下面的例子是通过加入一个隐藏层后对异或数据进行分类。

一 异或数据集

所谓的"异或数据"是来源于异或操作,可以绘制在直角坐标系中,如下图所示:

我们可以看到这张图很难通过一个直线把这两类数据风格开来,但是我们可以通过类似支持向量机中核函数一样的函数把数据映射到高维空间,然后通过线性分类的方法把数据分类。在输入层和输出层之间加一层隐藏层,就是起到核函数的作用。

生成数据集的代码如下:

复制代码
'''
生成模拟数据集
'''
train_x = np.array([[0,0],[0,1],[1,0],[1,1]],dtype=np.float32)
#非one_hot编码
#train_y = np.array([[0],[1],[1],[0]],dtype = np.float32)
#输出层节点个数
#n_label = 1


#one_hot编码
train_y = np.array([[1, 0], [0, 1], [0, 1], [1, 0]],dtype = np.float32)
#输出层节点个数
n_label = 2
复制代码

二 定义参数

复制代码
'''
定义变量
'''
#学习率
learning_rate = 1e-4
#输入层节点个数
n_input = 2
#隐藏层节点个数
n_hidden = 2


input_x = tf.placeholder(tf.float32,[None,n_input])
input_y = tf.placeholder(tf.float32,[None,n_label])

'''
定义学习参数

h1 代表隐藏层
h2 代表输出层
'''
weights = {
        'h1':tf.Variable(tf.truncated_normal(shape=[n_input,n_hidden],stddev = 0.01)),     #方差0.01
        'h2':tf.Variable(tf.truncated_normal(shape=[n_hidden,n_label],stddev=0.01))
        }


biases = {
        'h1':tf.Variable(tf.zeros([n_hidden])),    
        'h2':tf.Variable(tf.zeros([n_label]))
        }
复制代码

三 定义网络结构

复制代码
'''
定义网络模型
'''
#隐藏层
layer_1 = tf.nn.relu(tf.add(tf.matmul(input_x,weights['h1']),biases['h1']))


#1 softmax 方法
y_pred = tf.nn.sigmoid(tf.add(tf.matmul(layer_1, weights['h2']),biases['h2'])) 
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits( labels=input_y,logits=y_pred))


#2 tanh方法+平方差#输出层
#y_pred = tf.nn.tanh(tf.add(tf.matmul(layer_1,weights['h2']),biases['h2']))
#定义代价函数  二次代价函数
#loss = tf.reduce_mean((y_pred - input_y)**2)



train = tf.train.AdamOptimizer(learning_rate).minimize(loss)
复制代码

四 开始训练

复制代码
'''
开始训练
'''
training_epochs = 100000
sess = tf.InteractiveSession()

#初始化
sess.run(tf.global_variables_initializer())

for epoch in range(training_epochs):
    _,lo = sess.run([train,loss],feed_dict={input_x:train_x,input_y:train_y})
    if epoch % 10000 == 0:
        print(lo)
    
#计算预测值
print(sess.run(y_pred,feed_dict={input_x:train_x}))


#查看隐藏层的输出
print(sess.run(layer_1,feed_dict={input_x:train_x}))
复制代码

运行结果如下:

亲爱的读者和支持者们,自动博客加入了打赏功能,陆陆续续收到了各位老铁的打赏。在此,我想由衷地感谢每一位对我们博客的支持和打赏。你们的慷慨与支持,是我们前行的动力与源泉。

日期姓名金额
2023-09-06*源19
2023-09-11*朝科88
2023-09-21*号5
2023-09-16*真60
2023-10-26*通9.9
2023-11-04*慎0.66
2023-11-24*恩0.01
2023-12-30I*B1
2024-01-28*兴20
2024-02-01QYing20
2024-02-11*督6
2024-02-18一*x1
2024-02-20c*l18.88
2024-01-01*I5
2024-04-08*程150
2024-04-18*超20
2024-04-26.*V30
2024-05-08D*W5
2024-05-29*辉20
2024-05-30*雄10
2024-06-08*:10
2024-06-23小狮子666
2024-06-28*s6.66
2024-06-29*炼1
2024-06-30*!1
2024-07-08*方20
2024-07-18A*16.66
2024-07-31*北12
2024-08-13*基1
2024-08-23n*s2
2024-09-02*源50
2024-09-04*J2
2024-09-06*强8.8
2024-09-09*波1
2024-09-10*口1
2024-09-10*波1
2024-09-12*波10
2024-09-18*明1.68
2024-09-26B*h10
2024-09-3010
2024-10-02M*i1
2024-10-14*朋10
2024-10-22*海10
2024-10-23*南10
2024-10-26*节6.66
2024-10-27*o5
2024-10-28W*F6.66
2024-10-29R*n6.66
2024-11-02*球6
2024-11-021*鑫6.66
2024-11-25*沙5
2024-11-29C*n2.88
posted @   大奥特曼打小怪兽  阅读(1855)  评论(0编辑  收藏  举报
编辑推荐:
· 记一次.NET内存居高不下排查解决与启示
· 探究高空视频全景AR技术的实现原理
· 理解Rust引用及其生命周期标识(上)
· 浏览器原生「磁吸」效果!Anchor Positioning 锚点定位神器解析
· 没有源码,如何修改代码逻辑?
阅读排行:
· 全程不用写代码,我用AI程序员写了一个飞机大战
· MongoDB 8.0这个新功能碉堡了,比商业数据库还牛
· DeepSeek 开源周回顾「GitHub 热点速览」
· 记一次.NET内存居高不下排查解决与启示
· 白话解读 Dapr 1.15:你的「微服务管家」又秀新绝活了
如果有任何技术小问题,欢迎大家交流沟通,共同进步

公告 & 打赏

>>

欢迎打赏支持我 ^_^

最新公告

程序项目代做,有需求私信(小程序、网站、爬虫、电路板设计、驱动、应用程序开发、毕设疑难问题处理等)。

了解更多

点击右上角即可分享
微信分享提示