循环卷积网络-编码器-解码器架构

 

首先我们回忆一下CNN:
image
在CNN中,输入一张图片,经过多层的卷积层,最后到输出层判别图片中的物体的类别。CNN中使用卷积层做特征提取,使用Softmax回归做预测,从某种意义上来说,特征提取可以看成是编码,Softmax回归可以看成是解码

  • 编码器:将输入编程成中间表达形式(特征),就像上面的卷积层一样。
  • 解码器:将中间表示解码成输出
    然后我们以编码器、解码器的角度来看看RNN:
    image
    对于RNN来讲,输入一个句子,然后对其进行向量输出
  • 如果将RNN最后时刻隐藏层的矩阵当成输出的话,这部分也可以当成是编码器
  • 最后通过全连接层得到最终的输出的话,这部分可以看成是解码器
  • 编码器:将文本表示成向量
  • 解码器:将向量表示成输出

编码器-解码器架构

机器翻译是序列转换模型的⼀个核心问题,其输入和输出都是长度可变的序列。为了处理这种类型的输入和输出,我们可以设计⼀个包含两个主要组件的架构:第⼀个组件是⼀个编码器
(encoder):它接受⼀个长度可变的序列作为输入,并将其转换为具有固定形状的编码状态。第二个组件是解码器(decoder):它将固定形状的编码状态映射到长度可变的序列。这被称为编码器-解码器(encoder-decoder)架构。
image
一个模型被分为两块:

  • 编码器(encoder)处理输入:接受一个长度可变的序列作为输入,并将其转换为具有固定形状的编码状态。编码器在拿到输入之后,将其表示成为中间状态或者中间表示(如隐藏状态、特征图)
  • 解码器(decoder)生成输出:解码器将固定形状的编码状态映射到长度可变的序列。最简单的解码器能够直接将中间状态或者中间表示翻译成输出,解码器也能够结合一些额外的输入得到输出

代码实现

编码器

在编码器接口中,我们只指定长度可变的序列作为编码器的输入X。任何继承这个Encoder基类的模型将完成代码实现。

from torch import nn

class Encoder(nn.Module):
    """编码器-解码器结构的基本编码器接口。"""
    def __init__(self, **kwargs):
        super(Encoder, self).__init__(**kwargs)

    def forward(self, X, *args):
        raise NotImplementedError

解码器

在下面的解码器接口中,我们新增一个init_state函数,用于将编码器的输出(enc_outputs)转换为编码后的状态。注意,此步骤可能需要额外的输入,例如:输入序列的有效长度。为了逐个地生成长度可变的词元序列,解码器在每个时间步都会将输入(例如:在前⼀时间步⽣成的词元)和编码后的状态映射成当前时间步的输出词元。

class EncoderDecoder(nn.Module):
    """编码器-解码器结构的基类。"""
    def __init__(self, encoder, decoder, **kwargs):
        super(EncoderDecoder, self).__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, enc_X, dec_X, *args):
        enc_outputs = self.encoder(enc_X, *args)
        dec_state = self.decoder.init_state(enc_outputs, *args)
		#因为这个decoder的输出有有两个,output和state,而这个self.decoder.init_state返回的就是这个enc_outputs[1]就是state#这里返回的是state,也就是编码器最后一层的state
        return self.decoder(dec_X, dec_state)# 输入有decode的输出

小结
• “编码器-解码器”架构可以将长度可变的序列作为输入和输出,因此适用于机器翻译等序列转换问题。
• 编码器将长度可变的序列作为输入,并将其转换为具有固定形状的编码状态。
• 解码器将具有固定形状的编码状态映射为长度可变的序列。

posted @   lipu123  阅读(190)  评论(0编辑  收藏  举报
(评论功能已被禁用)
相关博文:
阅读排行:
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· 别再用vector<bool>了!Google高级工程师:这可能是STL最大的设计失误
· 单元测试从入门到精通
· 【硬核科普】Trae如何「偷看」你的代码?零基础破解AI编程运行原理
· 上周热点回顾(3.3-3.9)
点击右上角即可分享
微信分享提示