GRU原理及其实现

GRU原理及其实现

https://www.bilibili.com/video/BV1jm4y1Q7uh/?spm_id_from=333.788&vd_source=91219057315288b0881021e879825aa3

同等情况下GRU的参数是LSTM的0.75倍

公式

image-20221003093500929

1-zt保留当前候选者,zt保留上一时刻的部分,公式中的*表示按位置相乘

查看网络模型的参数数目

GRU约为LSTM的0.75倍

image-20221003095230316

初始化参数

image-20221027201709376

输入参数

都是3维的

image-20221003110636805

输出参数

image-20221003110652168

API实现

batch_size,T,i_size,h_size = 2,3,4,5
input = torch.randn(batch_size,T,i_size) # 输入序列
h0 = torch.randn(batch_size,h_size)
# 用pytorch的api实现
gru_layer = nn.GRU(i_size,h_size,batch_first=True)
output,h_final = gru_layer(input,h0.unsqueeze(0))
print(output)

自定义

def gru_forward(input,initial_states,w_ih,w_hh,b_ih,b_hh):
    prev_h = initial_states
    bs,T,i_size = input.shape
    h_size = w_ih.shape[0] // 3 # 只有只有r,z,n门有w,而且这些w是堆叠在一起的
    
    # w是二维张量,而input和initial_states都是带有batch的三维张量
    # 所以需要两个w进行扩维
    batch_w_ih = w_ih.unsqueeze(0).tile(bs,1,1)
    batch_w_hh = w_hh.unsqueeze(0).tile(bs,1,1)
    
    output = torch.zeros(bs,T,h_size) # GRU网络的输出
    
    for t in range(T):
        x = input[:,t,:] # t时刻的GRU cell的输入特征向量 [bs,i_size]
        w_times_x = torch.bmm(batch_w_ih,x.unsqueeze(-1)) # [bs,3*h_size,1]
        w_times_x = w_times_x.squeeze(-1) # [bs,3*h_size]
        
        w_times_h_prev = torch.bmm(batch_w_hh,prev_h.unsqueeze(-1)) # [bs,3*h_size,1]
        w_times_h_prev = w_times_h_prev.squeeze(-1) # [bs,3*h_size]
        
        # 重置门
        r_t = torch.sigmoid(w_times_x[:,:h_size]+w_times_h_prev[:,:h_size]+b_ih[:h_size]+b_hh[:h_size])
        # 更新门
        z_t = torch.sigmoid(w_times_x[:,h_size:2*h_size]+w_times_h_prev[:,h_size:2*h_size]+b_ih[h_size:2*h_size]+b_hh[h_size:2*h_size])
        # 候选门
        n_t = torch.tanh(w_times_x[:,2*h_size:3*h_size] +b_ih[2*h_size:3*h_size]+r_t*(w_times_h_prev[:,2*h_size:3*h_size] + b_hh[2*h_size:3*h_size]))
        # 增量更新,含有隐藏状态的
        prev_h = (1-z_t)*n_t + z_t*prev_h
        
        output[:,t,:] = prev_h
        
        return output,prev_h
# 调用自定义
output_custom,h_final_custom = gru_forward(input,h0,gru_layer.weight_ih_l0,gru_layer.weight_hh_l0,gru_layer.bias_ih_l0,gru_layer.bias_hh_l0)
print(output_custom)

查看两个是否一致

torch.allclose(output,output_custom)
posted @ 2022-10-19 12:04  放学别跑啊  阅读(313)  评论(0编辑  收藏  举报