[学习笔记] Hebb 学习规则和Hopfield网络

Hebb 学习规则和Hopfield网络

Hebb学习规则

Hebb学习规则是Donald Hebb在1949年提出的一种学习规则,用来描述神经元的行为是如何影响神经元之间的连接的,通俗的说,就是如果相链接的两个神经元同时被激活,显然我们可以认为这两个神经元之间的关系应该比较近,因此将这两个神经元之间连接的权值增加,而一个被激活一个被抑制,显然两者间的权值应该减小。

此外,Hebb还有一句非常注明的话,我在很多资料中都见到这句话的引用:“neurons that fire together, wire together”,这句话就是对权值如何更新的一个解释,同被激活者连接的权重理应增加。

公式表示为:

(1)Wij(t+1):=Wij(t)+αxixj

即表明神经元xi和xj的连接权重由两者输出决定。

尽管已经给出了生物学上的解释(或者说是启发),但其实仅看这么一个公式是不可能完全理解的,需要一个例子来说明,到底什么网络才需要这样的权值更新。

下面将以离散Hopfield网络为例说明这种权值更新的具体实现。

Hopfield 网络

定义和作用

Hopfield网络是一个有向完全图(仅仅以我看到的资料去定义,并非严谨或者官方定义),是一种递归神经网络。有向完全图可以理解,每两个节点之间都有连接,递归即一个输入经过多次训练最终收敛。本文仅讨论离散Hopfield网络,即节点取值为离散的。(下一篇尝试用连续Hopfield网络解决一下旅行商问题)

离散Hopfield网络的作用是:存储一个或更多的patterns,并且能够根据任意的输入从存储的这些patterns中将对应的原始数据还原。例如,我们的任务是一个手写数字分类,我们将数字1对应的图片作为一种pattern让网络存储起来,如果将数字1遮挡一半,我们希望网络利用存储的记忆将这个数字1恢复,并得到最接近的pattern,也就完成了一个分类任务。

训练

定义输入xi为第i个神经元节点的值,Wij为第i个和第j和节点之间的权值,则每个样本作为节点初始化的权值Wij定义为:

(2)Wij=xixj

则N个样本的权值经过N次更新为:

(3)Wij(N)=n=1Nxi(n)xj(n)

因此训练阶段很简单,仅仅是将所有样本的信息以权值求和的形式存储起来,因此,最终的权值存储的是每个样本的记忆,而测试阶段是需要利用这些权值恢复记忆。
那么这里的权值更新就是利用了Hebb学习规则。

测试

测试阶段先用测试样本初始化节点,利用训练阶段存储的权值,循环随机选择一个节点xi,将节点值根据下式更新:

(4)xi=sgn(j=1NWjixj)

经过若干iter,则所有节点会收敛到一个合适的值。

稳定性分析

当跑完实例后(可以先看下面代码例子),第一个问题就是:为什么Hopfield网络能够收敛,而且这么稳定?而这一切的解释其实是用一个稳定性指标来决定的。

在Hopfield网络中,通过引入李亚普洛夫函数作为能量函数,而能量函数就是稳定性指标,当能量达到最低点不变时,系统达到稳定,也就是说,我们需要证明该能量函数是递减的。

Hopfield网络中的能量函数定义为:

(5)E=(12)ijWijxixj+jθjxj

其中,Wij为第i个节点和第j个节点的链接权重,xi为第i个节点的节点值,θi为第i个节点的阈值(激活函数sgn可以的输入可以通过加减阈值调整激活的位置,即y=sgn(xθ)来调整)。

上式能量E可以化为:

(6)E=j{[(12)iWijxixj]+θjxj}

现在要定义能量E的变化量,我们假定t时刻到t+1时刻只有第i个神经元发生了变化,则能量变化量可以表示为:

(7)ΔE=i,j(12)Wij(xi^xj^xixj)+jθj(xj^xj)=k,j(12)Wkj(xk^xj^xkxj)+jθj(xj^xj)=k(12)Wki(xk^xi^xkxi)+j(12)Wji(xj^xi^xjxi)+jθj(xj^xj)=k(ki)Wki(xk^xi^xkxi)+θi(xi^xi)=(kWkixkθi)(xi^xi)

这里先解释一下第二行到第三行的变换,因为只有第i个神经元发生了变化,所以对k和j分四种情况讨论

  1. ki,ji,此时没有任何变化
  2. ki,j=i,此时固定j为i,将i带入得到左式
  3. k=i,ji,同理将k=i带入得到中间的式子
  4. i=j=k,此时无变化

然后解释一下第三行到第四行的变换,因为2. 3. 可通过变量代换结果一致,所以将j代换为k,而最后一项也很简单,对j求和,而当j不等于i的时候最后一项为0,所以直接把求和去掉。

下面讨论下最终结果:

  1. xi^>xi时,说明由负变正,而(kWkixkθi)表示的正式第i个节点的输出(sgn之前,大于0为1,小于0为-1的那个函数),即为正,所以能量变化为负。
  2. xi^<xi,说明由正变负,同理负负得正,能量变化为负。
  3. 由于认为第i个神经元的值变化,所以不讨论相等。

综上,能量变化一直为负,故会朝着能量减小的方向迭代。

所以说明了Hopfield能够稳定。(注意咯,下面的实验中随机挑选神经元用当前状态其他所有神经元作为输入,计算当前神经元的结果可能和上一时刻该神经元的状态相同,如果所有神经元都是如此,那么相当于每次迭代并没有改变任意神经元的值,此时收敛,能量不变,可以由实验的图看出。)

Coding

以一张二值图为输入,将每个像素值定义为-1或+1,来初始化节点,探究遮挡一般图片时利用Hopfied网络恢复图像。本文以矩阵运算代替所有循环。

原图:

masked:

恢复的(iter = 10):

能量变化:

import numpy as np
import cv2 as cv
from matplotlib import pyplot as plt
class Hopfield():
    def __init__(self,size = 64,iter = 10):
        self.iter = iter
        self.size = size
        self.W = np.zeros((size**2,size**2))
        
    def train(self,X): 
        n = self.size**2
        for x in X: # (-1,64*64)
            x = np.reshape(x,(n,1))
            xT = np.reshape(x,(1,n))
            self.W += x*xT/n
        self.W[np.diag_indices_from(self.W)] = 0
    def test_one_frame(self,x):
        n = self.size **2
        x = np.reshape(x,(n,))
        energy = []
        for iter in range(self.iter):
            h = np.zeros((n,))
            for i in range(n):
                i = np.random.randint(n)
                h[i] = self.W[i,:].dot(x)
            x[h>0] = 1
            x[h<0] = -1
            energy.append(self.cal_energy(x))
            
        return np.resize(x,(self.size,self.size)),energy
    def cal_energy(self,x):
        n = self.size **2
        energy = np.sum(self.W.dot(x) * x)
        
        return -0.5 * energy
def show(x):
    img = np.where(x >0,255,0).astype(np.uint8)
    cv.imshow("img",img)
    cv.waitKey(0)

if __name__ =="__main__":
    img = cv.imread("/home/xueaoru/图片/摄像头/handsome_boy.jpg",0)
    
    img = cv.resize(img,(64,64))
    x = np.where(img>255/2.5,1,-1)
    x_masked = x.copy()
    x_masked[64//2:,:] = -1
    #show(x_masked)
    model = Hopfield()
    model.train([x])
    y,energy = model.test_one_frame(x_masked)
    show(y)
    plt.plot(energy, label='energy')
    plt.show()
posted @   aoru45  阅读(8054)  评论(0编辑  收藏  举报
编辑推荐:
· AI与.NET技术实操系列(二):开始使用ML.NET
· 记一次.NET内存居高不下排查解决与启示
· 探究高空视频全景AR技术的实现原理
· 理解Rust引用及其生命周期标识(上)
· 浏览器原生「磁吸」效果!Anchor Positioning 锚点定位神器解析
阅读排行:
· 全程不用写代码,我用AI程序员写了一个飞机大战
· DeepSeek 开源周回顾「GitHub 热点速览」
· 记一次.NET内存居高不下排查解决与启示
· MongoDB 8.0这个新功能碉堡了,比商业数据库还牛
· .NET10 - 预览版1新功能体验(一)
历史上的今天:
2019-03-08 [论文理解] Connectionist Text Proposal Network
点击右上角即可分享
微信分享提示