代码改变世界

nlp RNN 梯度消失和爆炸

2022-04-05 15:01  jym蒟蒻  阅读(193)  评论(0编辑  收藏  举报

之前的RNN,无法很好地学习到时序数据的长期依赖关系。因为BPTT会发生梯度消失和梯度爆炸的问题。

RNN梯度消失和爆炸

对于RNN来说,输入时序数据xt时,RNN 层输出ht。这个ht称为RNN 层的隐藏状态,它记录过去的信息。

语言模型的任务是根据已经出现的单词预测下一个将要出现的单词。

学习正确解标签过程中,RNN层通过向过去传递有意义的梯度,能够学习时间方向上的依赖关系。如果这个梯度在中途变弱(甚至没有包含任何信息),权重参数将不会被更新,也就是所谓的RNN层无法学习长期的依赖关系。梯度的流动如下图绿色箭头。

在这里插入图片描述

随着时间增加,RNN会产生梯度变小(梯度消失)或梯度变大(梯度爆炸)。

RNN 层在时间方向上的梯度传播,如下图。

在这里插入图片描述

反向传播的梯度流经tanh、+、MatMul(矩阵乘积)运算。

+的反向传播,将上游传来的梯度原样传给下游,梯度值不变。

tanh的计算图如下。它将上游传来的梯度乘以tanh的导数传给下游。

在这里插入图片描述

y=tanh(x)的值及其导数的值如下图。导数值小于1,x越远离0,值越小。反向传播梯度经过tanh节点要乘上tanh的导数,这就导致梯度越来越小。

如果RNN层的激活函数使用ReLU,可以抑制梯度消失,当ReLU输入为x时,输出是max(0,x)。x大于0时,反向传播将上游的梯度原样传递到下游,梯度不会退化。

在这里插入图片描述

对于MatMul(矩阵乘积)节点。仅关注RNN层MatMul节点时的梯度反向传播如下图。每一次矩阵乘积计算都使用相同的权重Wh。

在这里插入图片描述

N = 2  # mini-batch的大小
H = 3  # 隐藏状态向量的维数
T = 20  # 时序数据的长度

dh = np.ones((N, H))#初始化为所有元素均为 1 的矩阵,dh是梯度

np.random.seed(3)

Wh = np.random.randn(H, H)#梯度的大小随时间步长呈指数级增加,发生梯度爆炸
#Wh = np.random.randn(H, H) * 0.5
#梯度的大小随时间步长呈指数级减小,发生梯度消失,权重梯度不能被更新,模型无法学习长期的依赖关系
norm_list = []
for t in range(T):
    dh = np.dot(dh, Wh.T)#根据反向传播的 MatMul 节点的数量更新 dh 相应次数
    norm = np.sqrt(np.sum(dh**2)) / N#mini-batch(N)中的平均L2 范数,L2 范数对所有元素的平方和求平方根.
    norm_list.append(norm)#将各步的 dh 的大小(范数)添加到 norm_list 中

print(norm_list)

# 绘制图形
plt.plot(np.arange(len(norm_list)), norm_list)
plt.xticks([0, 4, 9, 14, 19], [1, 5, 10, 15, 20])
plt.xlabel('time step')
plt.ylabel('norm')
plt.show()

如果Wh是标量,由于Wh被反复乘了T次,当Wh大于1时,梯度呈指数级增加;当 Wh 小于1时,梯度呈指数级减小。

如果wh是矩阵,矩阵的奇异值表示数据的离散程度,根据奇异值(多个奇异值中的最大值)是否大于1,可以预测梯度大小的变化。奇异值比1大是梯度爆炸的必要非充分条件。

在这里插入图片描述

在这里插入图片描述

梯度裁剪gradients clipping

梯度裁剪(gradients clipping)是解决解决梯度爆炸的一个方法。

将神经网络用到的所有参数的梯度整合成一个,用g表示,将阈值设置为threshold,如果梯度g的L2范数大于等于该阈值,就按如下方式修正梯度。

在这里插入图片描述

dW1 = np.random.rand(3, 3) * 10
dW2 = np.random.rand(3, 3) * 10
grads = [dW1, dW2]
max_norm = 5.0#阈值


def clip_grads(grads, max_norm):
    total_norm = 0
    for grad in grads:
        total_norm += np.sum(grad ** 2)
    total_norm = np.sqrt(total_norm)#L2 范数对所有元素的平方和求平方根

    rate = max_norm / (total_norm + 1e-6)
    if rate < 1:#如果梯度的L2范数total_norm大于等于阈值max_norm,rate是小于1的,此时就需要修正梯度
        for grad in grads:
            grad *= rate


print('before:', dW1.flatten())
clip_grads(grads, max_norm)
print('after:', dW1.flatten())

before: [7.14418135 3.58857143 7.82910303 8.04057218 8.8617387  1.89963886
 3.0606848  8.14163088 5.25490409]
after: [1.43122195 0.71891263 1.56843501 1.61079946 1.77530697 0.38056213
 0.61315903 1.63104494 1.05273561]

解决梯度消失

为了解决梯度消失,需要从根本上改变 RNN 层的结构。

LSTM 和GRU中增加了一种门结构,可以学习到时序数据的长期依赖关系。