机器学习——编码器和解码器架构
正如我们在 9.5节中所讨论的, 机器翻译是序列转换模型的一个核心问题, 其输入和输出都是长度可变的序列。 为了处理这种类型的输入和输出, 我们可以设计一个包含两个主要组件的架构: 第一个组件是一个编码器(encoder): 它接受一个长度可变的序列作为输入, 并将其转换为具有固定形状的编码状态。 第二个组件是解码器(decoder): 它将固定形状的编码状态映射到长度可变的序列。 这被称为编码器-解码器(encoder-decoder)架构, 如 图9.6.1 所示。
我们以英语到法语的机器翻译为例: 给定一个英文的输入序列:“They”“are”“watching”“.”。 首先,这种“编码器-解码器”架构将长度可变的输入序列编码成一个“状态”, 然后对该状态进行解码, 一个词元接着一个词元地生成翻译后的序列作为输出: “Ils”“regordent”“.”。 由于“编码器-解码器”架构是形成后续章节中不同序列转换模型的基础, 因此本节将把这个架构转换为接口方便后面的代码实现。
编码器
在编码器接口中,我们只指定长度可变的序列作为编码器的输入X
。 任何继承这个Encoder
基类的模型将完成代码实现。
1 2 3 4 5 6 7 8 9 10 11 | from torch import nn #@save class Encoder(nn.Module): """编码器-解码器架构的基本编码器接口""" def __init__( self , * * kwargs): super (Encoder, self ).__init__( * * kwargs) def forward( self , X, * args): raise NotImplementedError |
解码器
在下面的解码器接口中,我们新增一个init_state
函数, 用于将编码器的输出(enc_outputs
)转换为编码后的状态。 注意,此步骤可能需要额外的输入,例如:输入序列的有效长度, 这在 9.5.4节中进行了解释。 为了逐个地生成长度可变的词元序列, 解码器在每个时间步都会将输入 (例如:在前一时间步生成的词元)和编码后的状态 映射成当前时间步的输出词元。
1 2 3 4 5 6 7 8 9 10 11 | #@save class Decoder(nn.Module): """编码器-解码器架构的基本解码器接口""" def __init__( self , * * kwargs): super (Decoder, self ).__init__( * * kwargs) def init_state( self , enc_outputs, * args): raise NotImplementedError def forward( self , X, state): raise NotImplementedError |
合并编码器和解码器
1 2 3 4 5 6 7 8 9 10 11 12 | #@save class EncoderDecoder(nn.Module): """编码器-解码器架构的基类""" def __init__( self , encoder, decoder, * * kwargs): super (EncoderDecoder, self ).__init__( * * kwargs) self .encoder = encoder self .decoder = decoder def forward( self , enc_X, dec_X, * args): enc_outputs = self .encoder(enc_X, * args) dec_state = self .decoder.init_state(enc_outputs, * args) return self .decoder(dec_X, dec_state) |
总结
-
“编码器-解码器”架构可以将长度可变的序列作为输入和输出,因此适用于机器翻译等序列转换问题。
-
编码器将长度可变的序列作为输入,并将其转换为具有固定形状的编码状态。
-
解码器将具有固定形状的编码状态映射为长度可变的序列。
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· 别再用vector<bool>了!Google高级工程师:这可能是STL最大的设计失误
· 单元测试从入门到精通
· 【硬核科普】Trae如何「偷看」你的代码?零基础破解AI编程运行原理
· 上周热点回顾(3.3-3.9)