LSTM实现

LSTM实现

30、PyTorch LSTM和LSTMP的原理及其手写复现_哔哩哔哩_bilibili

image-20220927201239653

pytorch文档

LSTM — PyTorch 1.12 documentation

image-20220927202044459

it:输入门

ft:遗忘门

gt:细胞

ot:输出门

ct:细胞状态,lstm就是靠ct来不断对历史信息进行筛选和更新

ht:当前时刻的输出

实例化对象的参数

  • input_size :输入x的大小
  • hidden_size :h的大小
  • num_layers:层数
  • bias :决定bi和bh是否可以丢弃
  • batch_first :是否把batch放在第一维
  • dropout :丢弃法
  • bidirectional :是否双向
  • proj_size :LSTM网络的变体(LSTMP),减少LSTM的参数和计算量

输入参数

Inputs: input, (h_0, c_0)

第一个参数:input

第二个参数:(h_0, c_0)

h_0和c_0是要以元祖的形式出现

image-20220928092019272

  • input:输入的特征

batch_first=False

image-20220928090555491

batch_first=True

image-20220928090614254

  • h_0:初始状态

image-20220928091033392

  • c_0:细胞初始状态

image-20220928091101122

输出参数

Outputs: output, (h_n, c_n)

第一个参数:output

第二个参数:(h_n, c_n)

h_n和c_n是要以元祖的形式输出

  • output:整个序列的状态输出

batch_first=False

image-20220928091612174

batch_first=True

image-20220928091629810

  • h_n

如果带有proj_size,那么就会对h_n进行压缩

image-20220928091915416

  • c_n

image-20220928091925180

变量

image-20220929163119805

API实现

# 实现LSTM和LSTMP的源码
# 定义常量
batch_size,T,i_size,h_size = 2,3,4,5

# proj_size
input = torch.randn(batch_size,T,i_size) # 输入序列
c0 = torch.randn(batch_size,h_size)      # 初始值,不需要训练
h0 = torch.randn(batch_size,h_size)
# 调用官方API
lstm_layer = nn.LSTM(i_size,h_size,batch_first=True)
output,(h_final,c_final) = lstm_layer(input,(h0.unsqueeze(0),c0.unsqueeze(0))) # 定义的h0和c0不满足要求,需要对第0维扩维
print(output)
print(h_final)
print(c_final)

image-20220929161804109

查看LSTM的参数

for k,v in lstm_layer.named_parameters():
    print(k,v)

image-20220929162420268

image-20220929162428186

查看参数的shape

for k,v in lstm_layer.named_parameters():
    print(k,v.shape)

image-20220929163245822

自己实现

# 自己写一个LSTM模型
def lstm_forward(input,initial_states,w_ih,w_hh,b_ih,b_hh):
    h0,c0 = initial_states # 初始状态
    batch_size,T,i_size = input.shape
    # 因为有4个门,这里有四个隐藏层
    h_size = w_ih.shape[0] // 4
    
    prev_h = h0
    prev_c = c0
    
    output_size = h_size
    output = torch.zeros(batch_size,T,output_size) # 输出序列
    
    # w不带batch维度的话,无法与x进行带有batch的乘法
    # w_ih shape为(4*hidden_size,input_size)
    batch_w_ih = w_ih.unsqueeze(0).tile(batch_size,1,1) # (batch_size,4*hidden_size,input_size)
    batch_w_hh = w_hh.unsqueeze(0).tile(batch_size,1,1) # (batch_size,4*hidden_size,input_size) 
    
    
    for t in range(T):
        x = input[:,t,:] # 当前时刻的输入向量,batch,当前是2维,需要扩维
        # x.shape (batch_size,input_size)
        # x.unsqueeze(-1)后x.shape (batch_size,input_size,1)
        
        w_times_x = torch.bmm(batch_w_ih,x.unsqueeze(-1)) # (batch_size,4*hidden_size,1)
        
        # 后面的那个1维不需要
        w_times_x = w_times_x.squeeze(-1) # (batch_size,4*hidden_size)
        
        w_times_h_prev = torch.bmm(batch_w_hh,prev_h.unsqueeze(-1))
        w_times_h_prev = w_times_h_prev.squeeze(-1)
        
        # 分别计算,输入门(i),遗忘门(f),cell门(g),输出门(o),
        # w_time_x是一个四个门的权重拼起来的矩阵,每个门用的话,只需要用四分之一就可以了

        i_t = torch.sigmoid(w_times_x[:,:h_size] + w_times_h_prev[:,:h_size] + b_ih[:h_size] + b_hh[:h_size])
        
        f_t = torch.sigmoid(w_times_x[:,h_size:2*h_size] + w_times_h_prev[:,h_size:2*h_size] + b_ih[h_size:2*h_size] + b_hh[h_size:2*h_size])
        
        g_t = torch.tanh(w_times_x[:,2*h_size:3*h_size] + w_times_h_prev[:,2*h_size:3*h_size] + b_ih[2*h_size:3*h_size] + b_hh[2*h_size:3*h_size])
        
        o_t = torch.sigmoid(w_times_x[:,3*h_size:4*h_size] + w_times_h_prev[:,3*h_size:4*h_size] + b_ih[3*h_size:4*h_size] + b_hh[3*h_size:4*h_size])
        
        prev_c = f_t*prev_c + i_t*g_t
        prev_h = o_t * torch.tanh(prev_c)
        
        output[:,t,:] = prev_h
        
    return output,(prev_h,prev_c)
  • 为什么要对w使用unsqueeze

for循环中x的维度是batch_size*input_size,w要和x要做乘法的话,也就是带有batch的乘法,需要使用到torch.bmm()函数,而这里的w的shape是(4*hidden_size,input_size)没有带batch_size,做带有batch_size的乘法,必须两个张量第一维是batch,所以这里对w使用unsqueeze,只对对一维扩维,后面的不变,使用tile函数,tile(batch_size,1,1)

output,(h_final,c_final) =  lstm_forward(input,(h0,c0),lstm_layer.weight_ih_l0,lstm_layer.weight_hh_l0,lstm_layer.bias_ih_l0,lstm_layer.bias_hh_l0)
print(output)
print(h_final)
print(c_final)

image-20220929175252150

API实现LSTMP

batch_size,T,i_size,h_size = 2,3,4,5
proj_size = 3
input = torch.randn(batch_size,T,i_size) # 输入序列
c0 = torch.randn(batch_size,h_size)      # 初始值,不需要训练

# 不是用proj_size
h0 = torch.randn(batch_size,h_size)
# 使用proj_size
h0_proj_size = torch.randn(batch_size,proj_size)
lstm_layer1 = nn.LSTM(i_size,h_size,batch_first=True)
output1,(h_final1,c_final1) = lstm_layer1(input,(h0.unsqueeze(0),c0.unsqueeze(0))) 

lstm_layer2 = nn.LSTM(i_size,h_size,batch_first=True,proj_size=proj_size)
output2,(h_final2,c_final2) = lstm_layer2(input,(h0_proj_size.unsqueeze(0),c0.unsqueeze(0)))
print('-'*10,'没有使用proj_size','-'*10)
print(output1.shape,h_final1.shape,c_final1.shape)
print('-'*10,'使用proj_size','-'*10)
print(output2.shape,h_final2.shape,c_final2.shape)

image-20220929180533849

LSTMP只会对输出进行压缩,不会对细胞状态进行压缩


查看LSTMP的参数

for k,v in lstm_layer2.named_parameters():
    print(k,v.shape)

image-20220929181230612

不用API实现LSTMP

# w_hr判断是不是LSTMP
def lstm_forward(input,initial_states,w_ih,w_hh,b_ih,b_hh,w_hr=None):
    h0,c0 = initial_states # 初始状态
    batch_size,T,i_size = input.shape
    # 因为有4个门,这里有四个隐藏层
    h_size = w_ih.shape[0] // 4
    
    prev_h = h0
    prev_c = c0
    
    # 引入proj_size所做的改变 *************更改的地方**********
    if w_hr is not None:
        proj_size = w_hr.shape[0]
        output_size = proj_size
        # 对w_hr引入batch_size
        batch_w_hr = w_hr.unsqueeze(0).tile(batch_size,1,1) # [batch_size,proj_size,h_size]
    else:
        output_size = h_size
    # *************更改的地方**********
    
    output = torch.zeros(batch_size,T,output_size) # 输出序列
    
    # w不带batch维度的话,无法与x进行带有batch的乘法
    # w_ih shape为(4*hidden_size,input_size)
    batch_w_ih = w_ih.unsqueeze(0).tile(batch_size,1,1) # (batch_size,4*hidden_size,input_size)
    batch_w_hh = w_hh.unsqueeze(0).tile(batch_size,1,1) # (batch_size,4*hidden_size,input_size) 
    
    
    for t in range(T):
        x = input[:,t,:] # 当前时刻的输入向量,batch,当前是2维,需要扩维
        # x.shape (batch_size,input_size)
        # x.unsqueeze(-1)后x.shape (batch_size,input_size,1)
        
        w_times_x = torch.bmm(batch_w_ih,x.unsqueeze(-1)) # (batch_size,4*hidden_size,1)
        
        # 后面的那个1维不需要
        w_times_x = w_times_x.squeeze(-1) # (batch_size,4*hidden_size)
        
        w_times_h_prev = torch.bmm(batch_w_hh,prev_h.unsqueeze(-1))
        w_times_h_prev = w_times_h_prev.squeeze(-1)
        
        # 分别计算,输入门(i),遗忘门(f),cell门(g),输出门(o),
        # w_time_x是一个四个门的权重拼起来的矩阵,每个门用的话,只需要用四分之一就可以了

        i_t = torch.sigmoid(w_times_x[:,:h_size] + w_times_h_prev[:,:h_size] + b_ih[:h_size] + b_hh[:h_size])
        
        f_t = torch.sigmoid(w_times_x[:,h_size:2*h_size] + w_times_h_prev[:,h_size:2*h_size] + b_ih[h_size:2*h_size] + b_hh[h_size:2*h_size])
        
        g_t = torch.tanh(w_times_x[:,2*h_size:3*h_size] + w_times_h_prev[:,2*h_size:3*h_size] + b_ih[2*h_size:3*h_size] + b_hh[2*h_size:3*h_size])
        
        o_t = torch.sigmoid(w_times_x[:,3*h_size:4*h_size] + w_times_h_prev[:,3*h_size:4*h_size] + b_ih[3*h_size:4*h_size] + b_hh[3*h_size:4*h_size])
        
        prev_c = f_t*prev_c + i_t*g_t
        prev_h = o_t * torch.tanh(prev_c)
        
        # 进行压缩  *************更改的地方**********
        if w_hr is not None:
            prev_h = torch.bmm(batch_w_hr,prev_h.unsqueeze(-1)) # (batch_size,proj_size,1)
            prev_h = prev_h.squeeze(-1) # (batch_size,proj_size)
        # *************更改的地方**********
        
        output[:,t,:] = prev_h
        
    return output,(prev_h,prev_c)
lstm_layer = nn.LSTM(i_size,h_size,batch_first=True,proj_size=proj_size)
output,(h_final,c_final) = lstm_layer(input,(h0_proj_size.unsqueeze(0),c0.unsqueeze(0)))
print('----api实现-----')
print(output)
print(h_final)
print(c_final)     
print('----不用api实现-----')
output,(h_final,c_final) =  lstm_forward(input,(h0_proj_size,c0),lstm_layer.weight_ih_l0,lstm_layer.weight_hh_l0,lstm_layer.bias_ih_l0,lstm_layer.bias_hh_l0,lstm_layer.weight_hr_l0)
print(output)
print(h_final)
print(c_final)

image-20220929183618607

posted @ 2022-10-19 12:06  放学别跑啊  阅读(179)  评论(0编辑  收藏  举报