机器学习笔记(十八)——BP神经网络(异或问题)

本博客仅用于个人学习,不用于传播教学,主要是记自己能够看得懂的笔记(

学习知识、资源和数据来自:机器学习算法基础-覃秉丰_哔哩哔哩_bilibili

这次B站上课的老师写的代码有个地方错了,我代码中会指出来。

首先,BP神经网络的示意图如下:

BP神经网络有一点点难以理解,不过也就是梯度下降法的极致应用。通过几层感知器的共同运算,得到最后结果,最后再对结果求偏导,得出各个参数应该如何变化。

具体推导过程可以看下面的参考博客(或者我现在给出也行:详解BP神经网络_fanxin_i的博客-CSDN博客_bp神经网络),下面直接给出推到结果:

不过呢,我自己写的神经网络,在隐藏层还有个偏置值b,所以是下面这种结构:

所以在上面提到的推导博客里,作者给出的结论是:

倒数第二层到最后一层(共n层,真实值t,激活函数f(),隐藏层y):

Δw(n-1)=lr*(t-z1)*f '(z1)*y(n-1)

Δb(n-1)=lr*(t-z1)*f '(z1)

δ(n-1)=-(t-z1)*f '(z1)

第i层到第i+1层:

Δw(i)=-lr*δ(i+1)*y(i)

Δb(i)=-lr*δ(i+1)

δ(i)=δ(i+1)*w(i+1).T*f '(y(i+1))

写得比较混乱,或许到代码中才看得清一些。

import numpy as np

def sig(x): #sigmoid函数
    return 1/(1+np.exp(-x))

def dsig(x): #sigmoid求导后函数,准确的表述应该是这样:f'(x)=f(x)[1-f(x)]。视频中老师给的代码写错了。
    return sig(x)*(1-sig(x))

x_data=np.array([[1,0,0],[1,1,0],[1,0,1],[1,1,1]]) #还是异或问题
y_data=np.array([[0],[1],[1],[0]])
v=np.random.random([3,4])*2-1 #输入层到隐藏层的权值,因为x_data包含了1,所以包含了截距
w=np.random.random([4,1])*2-1 #隐藏层到输出层的权值
b=np.random.random([4,1])*2-1 #隐藏层到输出层的偏置,相当于截距
print(v)
print(w)
lr=0.11

for i in range(20001):
    xa=sig(np.dot(x_data,v)) #隐藏层的中间值
    y=sig(np.dot(xa,w)+b) #这一次的预测值

    w_d=(y_data-y)*dsig(y) #隐藏层到输出层的δ值(delta)
    v_d=np.dot(w_d,w.T)*dsig(xa) #输入层到隐藏层的δ值
    w_=lr*np.dot(xa.T,w_d) #w的增加值
    v_=lr*np.dot(x_data.T,v_d) #v的增加值
    b_=lr*w_d #b的增加值
    #print(v_.shape,v_d.shape)
    w+=w_
    v+=v_
    b+=b_

    if i%500==0:
        xa=sig(np.dot(x_data,v))
        y=sig(np.dot(xa,w)+b)
        print("Error:",np.mean(np.abs(y_data-y))) #输出偏差

xa=sig(np.dot(x_data,v))
y=sig(np.dot(xa,w)+b)
print(y) #输出预测结果

得到结果:

[[ 0.23207948 -0.70037517 -0.49481911 -0.23703071]
[ 0.67476117 0.60464094 -0.04328313 0.3111562 ]
[ 0.51628408 0.03375034 0.8703467 0.36401558]]
[[ 0.49344442]
[-0.80654527]
[ 0.03139705]
[ 0.77230891]]
Error: 0.37876493377804066
Error: 0.07558647763636259
Error: 0.04023716754300484
Error: 0.027270108988748765
Error: 0.02058821808269875
Error: 0.016523621565233015
Error: 0.013793650073407254
Error: 0.011834955942627942
Error: 0.01036173671209702
Error: 0.00921370631086146
Error: 0.00829406537233357
Error: 0.007540925616132301
Error: 0.006912884796831388
Error: 0.0063812040668992515
Error: 0.005925313585748153
Error: 0.005530104412823155
Error: 0.005184230200208028
Error: 0.004879005269621099
Error: 0.004607668341604377
Error: 0.004364877829424915
Error: 0.004146357991551615
Error: 0.003948645849170563
Error: 0.003768906919157519
Error: 0.0036047988855851558
Error: 0.0034543692686396462
Error: 0.0033159775972754018
Error: 0.0031882355047766968
Error: 0.003069960111147376
Error: 0.0029601373776167537
Error: 0.002857893030882395
Error: 0.0027624692939987464
Error: 0.002673206114919778
Error: 0.002589525910352445
Error: 0.0025109210803193832
Error: 0.0024369437237478954
Error: 0.0023671971154225293
Error: 0.002301328602207482
Error: 0.0022390236503129083
Error: 0.002180000831780275
Error: 0.0021240075817627597
Error: 0.0020708165918279257
[[0.00182558]
[0.99768826]
[0.99768902]
[0.00183496]]

参考博客:

详解BP神经网络_fanxin_i的博客-CSDN博客_bp神经网络

posted @ 2021-08-03 17:57  Lcy的瞎bb  阅读(465)  评论(0编辑  收藏  举报