RNN笔记

RNN笔记

模型结构

循环神经网络可以用来处理一些序列问题,其网络结构如下(图片来源于colah's blog

t时刻输入的特征,x, 经过A的处理变换为ht,其中A代表一定的处理过程,不同的RNN结构处理过程也不近相同。

下图为最基本的RNN网络结构,参数$, \rm U, , \mathbf x, \rm W, , {\rm s}_{t-1}, {\rm s}_t, \rm V,, {\rm o}_t$,其对应的变换公式如下:

st=σ(Wst1+Uxt+bS)ot=σ(Vst+bo)

上述的RNN由于梯度消失的原因,不能很好的捕捉长期信息。为了解决该问题,学者们提出了LSTM和简化版的GRU。

LSTM

下图是LSTM的内部结构图(图片来源colah's blog

其除了记录状态ht以外又引入了C,其变换如下:

ft=σ(Ufxt+Wfht1+bf)    作为forget gatert=σ(Urxt+Wrht1+br)    控制引入信息zt=tanh(Uzxt+Wzht1+bz)   为C引入信息ct=ct1ft+rtztot=σ(Uoxt+Woht1+bo)ht=ottanh(ct)

关于各个gate的解读可以参阅colah's bolg

GRU

GRU是LSTM的简化版本,其内部结构如下图所示(图片来源colah's blog):

相比于LSTM,GRU取消了Cell State,其对应的变换如下:

rt=σ(Urxt+Wrht1+br)zt=σ(Uzxt+Wzht1+bz)h¯t=tanh(Uhxt+Wh(rtht1)+bh)ht=ht1(1zt)+h¯tzt

虽然RNN变种各不相同,其大致的思路均为对旧有信息进行选择(LSTM中的ft和GRU中的rt)再根据当前输入xt引入新信息(LSTM中的zt和GRU中的h¯t)。

back propagations

循环神经网络由于涉及了时间,如状态$, s_t, , s_{t-1}$,在进行反向传播时需要考虑时间因素。下面对RNN、LSTM和GRU的back propagation through time进行说明。

Back Propagation: RNN

假设$, \mathbf y, onehoto_t,softmaxs_t, $层的激活函数为sigmoid,交叉熵为损失函数。

对损失函数求微分

dl=d(ytTlogot)=d(ytT(Vst+bolog1TeVst+bo)=ytTd(Vst)+ytTd(log1TeVst+bo)=ytTd(Vst)+ytT1d(1TeVst+bo)1TeVst+bo=ytTd(Vst)+1TeVst+bod(Vst+b)1TeVst+bo=ytTd(Vst)+eVst+boTd(Vst)1TeVst+bo=tr(st(otyt)TdV)+tr((otyt)TVdst)

关于dst

dst=d(σ(Usxt+Wsst1+bs))=(st(1st))d(Usxt+Wsst1+bs)

带入,dl,得到

dl=tr(st(otyt)TdV)+     tr((otyt)TV(st(1st))d(Usxt+Wsst1+bs))=tr(st(otyt)TdV)+    tr((VT(otyt)(st(1st)))Td(Usxt+Wsst1+bs))=tr(st(otyt)TdV)+    tr((VT(otyt)(st(1st)))T(dUxt+dWst1+Wsdst1))

对于dst1项,其形式与$, {\rm d}s_t, , \rm{dU}, , {\rm dW}, t, t-1,$次,但是在通常应用中只会回顾指定次数。

对于参数Vdst1并不包含,为Vl=(otyt)stT

对于参数$, {\rm U}, , {\rm W}\left(V^T(o_t-\mathbf y_t)\circ \left(s_t\circ\left(1-s_t\right)\right)\right), \mathbf x_t^T, {\rm W}^T$

Back Propagation: LSTM

Back Propagation: GRU

back propagation代码

RNN

RNN代码来源于WILDMLgithub地址),用于生成文段。

数据为收集的reddit评论,使用了高频的8000词汇,低频词汇用UNKNOWN代替。

# 数据处理
path = "your file path"

import sys
import os
from datetime import datetime
import numpy as np
import csv
import nltk
import operator
import itertools

import matplotlib.pyplot as plt
nltk.download("book")

vocabulary_size = 8000
unknown_token = "UNKNOWN_TOKEN"
sentence_start_token = "SENTENCE_START"
sentence_end_token = "SENTENCE_END"


try:
    f = open(path, 'rt', encoding='utf-8')
except:
    print("打开文件失败")
    f.close()

# 读取数据
try:
    reader = csv.reader(f, skipinitialspace=True)
    next(reader)
    sentences = itertools.chain(*[nltk.sent_tokenize(x[0].lower()) for x in reader])
    sentences = ["%s, %s, %s " %(sentence_start_token, x, sentence_end_token) for x in sentences]
    print("Parsed %d sentences." % (len(sentences)))
except:
    exit(-1)


tokenized_sentences = [nltk.word_tokenize(sent) for sent in sentences]

word_freq = nltk.FreqDist(itertools.chain(*tokenized_sentences))
print("Found %d unique words tokens." % len(word_freq.items()))

vocab = word_freq.most_common(vocabulary_size-1)
index_to_word = [x[0] for x in vocab]
index_to_word.append(unknown_token)
word_to_index = dict([(w,i) for i,w in enumerate(index_to_word)])

print("Using vocabulary size %d." % vocabulary_size)
print("The least frequent word in our vocabulary is '%s' and appeared %d times." % (vocab[-1][0], vocab[-1][1]))

# Replace all words not in our vocabulary with the unknown token
for i, sent in enumerate(tokenized_sentences):
    tokenized_sentences[i] = [w if w in word_to_index else unknown_token for w in sent]

print("\nExample sentence: '%s'" % sentences[0])
print("\nExample sentence after Pre-processing: '%s'" % tokenized_sentences[0])

输入特征为句子的开始标记至结束标记前一个字符,输出为开始标记的后一个字符至结束标记:

# 生成输入特征和输出数据
X_train = np.asarray([[word_to_index[w] for w in sent[:-1]] for sent in tokenized_sentences])
y_train = np.asarray([[word_to_index[w] for w in sent[-1]] for sent in token

在进行反向传播过程中,一共有三个参数需要训练U,W,V,其中U,W与时间因素有关

import numpy as np

def softmax(x):
    xt = np.exp(x - max(x))
    return xt / xt.sum()

class RNN():
    
    def __init__(self, word_dim, hidden_dim=100, bptt_truncate=4):
        """
        param:
            word_dim: input dimension
            hidden_dim: hidden layers dimension
            bptt_truncate: the number of time steps that back propagation will look back
        """
        self.word_dim = word_dim
        self.hidden_dim = hidden_dim
        self.bptt_truncate = bptt_truncate
        self.W = np.random.uniform(-np.sqrt(1./hidden_dim), np.sqrt(1./hidden_dim), (hidden_dim, hidden_dim))
        self.U = np.random.uniform(-np.sqrt(1./word_dim), np.sqrt(1./word_dim),(hidden_dim, word_dim))
        self.V = np.random.uniform(-np.sqrt(1./hidden_dim), np.sqrt(1./hidden_dim), (word_dim, hidden_dim))
        
    def forward_propagation(self, x):
        """
        param:
            x: array like, dim(x) = 1
        """
        # The total number of time steps
        T = len(x)
        s = np.zeros((T+1, self.hidden_dim))
        s[-1] = np.zeros(self.hidden_dim)
        o = np.zeros((T, self.word_dim))
        
        for t in range(T):
            s[t] = np.tanh(self.U[:, x[t]]+self.W.dot(s[t-1]))
            o[t] = softmax(self.V.dot(s[t]))
        return [o, s]
    
    def predict(self, x):
        o, _ = self.forward_propagation(x)
        return np.argmax(o, axis=1)
    
    def calculate_total_loss(self, x, y):
        """
        param:
            x: dim(x) = 2, words are on axis 0, setences on axis 1
            y: dim(x) = 2
        """
        L = 0
        for i in range(len(y)):
            o, s = self.forward_propagation(x[i])
            correct_word_predictions = o[np.arange(len(y[i])), y[i]]
            L += -1 * np.sum(np.log(correct_word_predictions))
        
        return L
    
    def calculate_loss(self, x, y):
        """
        param:
            x: dim(x) = 2, words are on axis 0, setences on axis 1
            y: same to x
        """
        N = np.sum(len(y_i) for y_i in y)
        return self.calculate_total_loss(x, y) / N
        
    def back_propagation(self, x, y):
        """
        param:
            x: dim(x) = 1
        return:
            dW: derivative of W
            dU: derivative of U
            dV: derivative of V
        """
        dV = np.zeros_like(self.V)
        dW = np.zeros_like(self.W)
        dU = np.zeros_like(self.U)
        
        T = len(x)
        o, s = self.forward_propagation(x)
        delta_o = o
        delta_o[np.arange(T), y] = delta_o[np.arange(T), y] - 1.
        for t in range(T-1, -1, -1):
            dV += np.outer(delta_o[t], s[t])
            delta_t = self.V.T.dot(delta_o[t]) * (1 - s[t]**2)
            for j in range(t, max(t-4, -1), -1):
                dW += np.outer(delta_t, s[j-1])
                dU[:, x[j]] += delta_t
                delta_t = self.W.T.dot(delta_t) * (1 - s[j-1]**2)
        return [dV, dU, dW]
    
    
    def gradient_check(self, x,y, h=0.001, error_threshold=0.01):
        gradient = self.back_propagation(x, y)
        model_parameters = ['V', 'U', 'W']
        for pidx, pname in enumerate(model_parameters):
            parameter = operator.attrgetter(pname)(self)
            print("Performing gradient check for parameter %s with size %d" % 
                  (pname, np.prod(parameter.shape)))
            it = np.nditer(parameter, flags=['multi_index'], op_flags=['readwrite'])
            while not it.finished:
                ix = it.multi_index
                original_value = parameter[ix]

                parameter[ix] = original_value + h
                gradplus = self.calculate_total_loss([x], [y])

                parameter[ix] = original_value - h
                gradminus = self.calculate_total_loss([x], [y])

                estimated_gradient = (gradplus - gradminus) / (2. * h)
                parameter[ix] = original_value

                backprop_gradient = gradient[pidx][ix]
                relative_error = np.abs(backprop_gradient - estimated_gradient)/(np.abs(backprop_gradient) + np.abs(estimated_gradient))

                if relative_error > error_threshold:
                    print("Gradient Check ERROR: parameter=%s ix=%s" % (pname, ix))
                    print("+h Loss: %f" % gradplus)
                    print("-h Loss: %f" % gradminus)
                    print("Estimated_gradient: %f" % estimated_gradient)
                    print("Backpropagation gradient: %f" % backprop_gradient)
                    print("Relative Error: %f" % relative_error)
                    return 
                it.iternext()
            print("Gradient check for parameter %s passed." % (pname))
            
    def sgd_step(self, x, y, learning_rate):
        
        dV, dU, dW = self.back_propagation(x, y)
        self.V -= learning_rate * dV
        self.U -= learning_rate * dU
        self.W -= learning_rate * dW
        
    def train_with_sgd(self, X_train, y_train, learning_rate=0.005, nepoch=10, evaluate_loss_after=5):
        losses = []
        num_examples_seen = 0
        for epoch in range(nepoch):
            if (epoch % evaluate_loss_after == 0):
                loss = self.calculate_loss(X_train, y_train)
                losses.append((epoch, loss))
                time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
                print("%s: Loss after num_examples_seen=%d epoch=%d: %f" %(time, num_examples_seen, epoch, loss))
            
            if (len(losses)>1 and losses[-1][1] > losses[-2][1]):
                learning_rate = 0.5 * learning_rate
            
            for i in range(len(y_train)):
                self.sgd_step(X_train[i], y_train[i], learning_rate)
                num_examples_seen += 1
        return losses

参考资料

[1] bptt(https://ir.hit.edu.cn/~jguo/docs/notes/bptt.pdf)

[2] BPTT Tutorial(https://www.cs.ubc.ca/~minchenl/doc/BPTTTutorial.pdf)

[3] 矩阵求导术—知乎[长躯鬼侠]

[4] WILDML~Recurrent Neural Networks Tutorial

[5] colah's blog~Understanding LSTM Networks

[6] Nico's blog~Simple LSTM

posted @   Neo_DH  阅读(271)  评论(0编辑  收藏  举报
编辑推荐:
· 从 HTTP 原因短语缺失研究 HTTP/2 和 HTTP/3 的设计差异
· AI与.NET技术实操系列:向量存储与相似性搜索在 .NET 中的实现
· 基于Microsoft.Extensions.AI核心库实现RAG应用
· Linux系列:如何用heaptrack跟踪.NET程序的非托管内存泄露
· 开发者必知的日志记录最佳实践
阅读排行:
· TypeScript + Deepseek 打造卜卦网站:技术与玄学的结合
· Manus的开源复刻OpenManus初探
· AI 智能体引爆开源社区「GitHub 热点速览」
· C#/.NET/.NET Core技术前沿周刊 | 第 29 期(2025年3.1-3.9)
· 从HTTP原因短语缺失研究HTTP/2和HTTP/3的设计差异
点击右上角即可分享
微信分享提示