transformer的读书笔记
transformer相比较rnn解决了两个重要问题:1.并行化问题。2.解决了长时依赖问题。
以下两篇博文我觉得写的不错:
深入理解Transformer及其源码 - ZingpLiu - 博客园 (cnblogs.com)
Transformer模型详解(图解最完整版) - 知乎 (zhihu.com)
Transformer代码及解析(Pytorch) - 知乎 (zhihu.com)
1.模型结构
其模型结构还是比较简单的,主要就是编码器和解码器两个部分。
WE以及PE是对输入的预处理,考虑到了位置及词义,离散化映射到了空间上,方便计算向量之间的距离。
编码器encoder有前馈神经网络和多头自相关注意力组成。
多头自相关注意力是一个新的网络模块,尽管他的本质还是一个全连接层,但是该网络结合了注意力机制和自相关方法,使得网络可以计算向量与向量之间的相关系数。
虽然几个单纯的全连接层,也可以起到同样的作用。还需要注意,他的norm是layer norm。
解码器decoder中的mask-multi-head-attention 是一个令人疑惑的点,为什么要加上mask呢?原作者认为,这样是为了防止在求取attention score的时候作弊,当前时刻只会与过去时刻求相关性。
我认为有一定道理,但是为什么decoder中另一个multi-head-attention不加上mask呢?很难自圆其说,我认为,这个mask-multi-head-attention其他两个输入来自encoder,这两个输入可以看成k、v。
另一个输入来自decoder自身,看成是q。当你在询问(Q)某一事物的时候,不知道未来发生了什么,所以遮挡,K、V是既定的客观事物,因此不需要遮挡。
2.训练方法
一般使用交叉熵损失函数,adam优化器。与RNN、CNN、resnet的训练方法没什么区别,甚至transformer也可以叠加上千层。
3.并行训练时的掩码问题(转自https://blog.csdn.net/zhaohongfei_358/article/details/125858248)
通常我们在网上看Masked Attention相关的文章时,会说mask的目的是为了防止网络看到不该看到的内容。本节主要来解释一下这句话。
从图上可以看出,Transformer的训练过程和推理过程主要有以下几点异同:
源输入src相同:对于Transformer的inputs部分(src参数)一样,都是要被翻译的句子。
目标输入tgt不同:在Transformer推理时,tgt是从<bos>开始,然后每次加入上一次的输出(第二次输入为<bos> 我)。但在训练时是一次将“完整”的结果给到Transformer,这样其实和一个一个给结果上一致。这里还有一个细节,就是tgt比src少了一位,src是7个token,而tgt是6个token。这是因为我们在最后一次推理时,只会传入前n-1个token。举个例子:假设我们要预测<bos> 我 爱 你 <eos>(这里忽略pad),我们最后一次的输入tgt是<bos> 我 爱 你(没有<eos>),因此我们的输入tgt一定不会出现目标的最后一个token,所以一般tgt处理时会将目标句子删掉最后一个token。
输出数量变多:在训练时,transformer会一次输出多个概率分布。例如上图,我就的等价于是tgt为<bos>时的输出,爱就等价于tgt为<bos> 我时的输出,依次类推。当然在训练时,得到输出概率分布后就可以计算loss了,并不需要将概率分布再转成对应的文字。注意这里也有个细节,我们的输出数量是6,对应到token就是我 爱 你 <eos> <pad> <pad>,这里少的是<bos>,因为<bos>不需要预测。计算loss时,我们也是要和的这几个token进行计算,所以我们的label不包含<bos>。代码中通常命名为tgt_y。
其实总结一下就一句话:Transformer推理时是一个一个词预测,而训练时会把所有的结果一次性给到Transformer,但效果等同于一个一个词给,而之所以可以达到该效果,就是因为对tgt进行了掩码,防止其看到后面的信息,也就是不要让前面的字具备后面字的上下文信息。
可能看了这句总结还是很难理解,所以我们接下来来做个实验,我们的实验内容为:首先模拟Transformer的推理过程,然后再模拟Transformer的训练过程,看看训练时一次性给到所有的tgt和推理时一个一个给的结果是否一致。
这里我们要用到Pytorch中的nn.Transformer,用法可参考这篇文章。
首先我们来定义模型:
1 2 3 4 | # 词典数为10, 词向量维度为8 embedding = nn.Embedding(10, 8) # 定义Transformer,注意一定要改成eval模型,否则每次输出结果不一样 transformer = nn.Transformer(d_model=8, batch_first=True).eval() |
接下来定义我们的src和tgt:
1 2 3 4 | # 词典数为10, 词向量维度为8 embedding = nn.Embedding(10, 8) # 定义Transformer,注意一定要改成eval模型,否则每次输出结果不一样 transformer = nn.Transformer(d_model=8, batch_first=True).eval() |
然后我们将[4]送给Transformer进行预测,模拟推理时的第一步:
1 2 3 4 5 6 | transformer(embedding(src), embedding(tgt[:, :1]), # 这个就是用来生成阶梯式的mask的 tgt_mask=nn.Transformer.generate_square_subsequent_mask(1)) tensor([[[ 1.4053, -0.4680, 0.8110, 0.1218, 0.9668, -1.4539, -1.4427, 0.0598]]], grad_fn=<NativeLayerNormBackward0>) |
然后我们将[4, 3]送给Transformer,模拟推理时的第二步:
1 2 3 4 5 6 7 | transformer(embedding(src), embedding(tgt[:, :2]), tgt_mask=nn.Transformer.generate_square_subsequent_mask(2)) tensor([[[ 1.4053, -0.4680, 0.8110, 0.1218, 0.9668, -1.4539, -1.4427, 0.0598], [ 1.2726, -0.3516, 0.6584, 0.3297, 1.1161, -1.4204, -1.5652, -0.0396]]], grad_fn=<NativeLayerNormBackward0>) |
这个时候你有没有发现,输出的第一个向量和上面那个一模一样。
最后我们再将tgt一次性送给transformer,模拟训练过程:
1 2 3 4 5 6 7 8 9 10 11 12 | transformer(embedding(src), embedding(tgt), tgt_mask=nn.Transformer.generate_square_subsequent_mask(5)) tensor([[[ 1.4053, -0.4680, 0.8110, 0.1218, 0.9668, -1.4539, -1.4427, 0.0598], [ 1.2726, -0.3516, 0.6584, 0.3297, 1.1161, -1.4204, -1.5652, -0.0396], [ 1.4799, -0.3575, 0.8310, 0.1642, 0.8811, -1.3140, -1.5643, -0.1204], [ 1.4359, -0.6524, 0.8377, 0.1742, 1.0521, -1.3222, -1.3799, -0.1454], [ 1.3465, -0.3771, 0.9107, 0.1636, 0.8627, -1.5061, -1.4732, 0.0729]]], grad_fn=<NativeLayerNormBackward0>) |
看到没,前两个tensor和模拟推理时的输出结果一模一样。所以使用mask时,我们可以保证前面的词不会具备后面词的信息,这样就可以保证Transformer的输出不会因为传入词的多少而改变,从而我们就可以做到在训练时一次将tgt全部给到Transformer,却不会出现问题。这也就是人们常说的,防止网络训练时看到不该看到的内容。
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 基于Microsoft.Extensions.AI核心库实现RAG应用
· Linux系列:如何用heaptrack跟踪.NET程序的非托管内存泄露
· 开发者必知的日志记录最佳实践
· SQL Server 2025 AI相关能力初探
· Linux系列:如何用 C#调用 C方法造成内存泄露
· 终于写完轮子一部分:tcp代理 了,记录一下
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· 别再用vector<bool>了!Google高级工程师:这可能是STL最大的设计失误
· 单元测试从入门到精通
· 【硬核科普】Trae如何「偷看」你的代码?零基础破解AI编程运行原理