Pre-trained Model Summary
🥥 Table of Content
- 00 - Overview
- 01 - Tokenization
- 02 - Position Encoding
- 03 - Word Embedding
- 04 - Transformer Block(Attention + FFN)
- 05 - Model Head for Downstream Tasks
- 06 - BERT Family
- 07 - Decoder-only Model
🥑 Get Started!
00 - Overview
1st Paradigm: Full Supervised Learning(Non-Neural Network)
2nd Paradigm: Full Supervised Learning(Neural Network)
3rd Paradigm: Pre-train(self-supervised) + Fine-tune(supervised)
4th Paradigm: Pre-train(self-supervised) + Prompt/Instruct + Predict
What is pre-training?
Self-supervised learning on the large set of unlabeled data.
NLP vs. NLU vs. NLG: the differences between three natural language processing concepts
预训练语言模型:GLM
Pre-trained Model Architecture | Pre-training task | Task Type | Example |
---|---|---|---|
Encoder-only (Auto Encoder) | Masked Language Model | NLU | BERT Family |
Decoder-only (Auto Regression) | Causal Language Model or Prefix Language Model | NLG | GPT, Llama, Bloom |
Encoder-Decoder (Seq2Seq) | Sequence to Sequence Model | Conditional-NLG | T5, BART |
Resource 1: Self-supervised Learning: Generative or Contrastive
Resource 2: Generative Self-supervised Learning in LLM Pre-Training task
Resource 3: 一文读懂GPT家族和BERT的底层区别——自回归和自编码语言模型详解
Resource 3: The Transformer model family | Hugging Face
Generative Pre-Training task
- Auto-Encoder(AE) Models: BERT(MLM&NSP)
- Auto-Regressive(AR) Models: GPT
- Encoder-Decoder: T5
Downstream task
- Auto-Encoder(AE) Models: Text Understanding(Text Classification, Token Classification, Question Answering, Text Summarization)
- Auto-Regressive(AR) Models: Text Generation
- Encoder-Decoder: Text Translation
01 - Tokenization
Resource 1: Summary of tokenizers | Hugging Face
Resource 2: Do you need to put EOS and BOS tokens in autoencoder transformers? | StackOverflow
Text --> Tokens --> input_ids
output_ids --> Tokens --> Text
- Character-based Tokenization
- Word-based Tokenization
- Subword-based Tokenization
- WordPiece: BERT, DistilBERT
- Byte-Pair Encoding(BPE): GPT-2, RoBERTa
- Unigram: XLNet, ALBERT
- SentencePiece: ChatGLM、BLOOM、PaLM
02 - Position Encoding
<1> Absolute Position Encoding
\(\begin{align*} Q_iK_j^T & = (X_iW_Q^T)(X_jW_K^T)^T\\ & = (e_i+p_i)W_Q^TW_K(e_j+p_j)^T\\ & = e_iW_Q^TW_Ke_j^T + e_iW_Q^TW_Kp_j^T + p_iW_Q^TW_Ke_j^T + p_iW_Q^TW_K p_j^T \end{align*}\)
<2> Relative Position Encoding
\((e_i+u)W_q^TW_k(e_j+r_{i-j})^T\)
<3> Rotary Position Embedding (RoPE)
Article: ROFORMER: ENHANCED TRANSFORMER WITH ROTARY
POSITION EMBEDDING
\(f(q,m)f(k,n) = g(q,k,m-n)\)
\(
f_q(q,m)f_k(k,n) = \begin{bmatrix}
cosm\theta & -sinm\theta\\
sinm\theta & cosm\theta
\end{bmatrix}q
\begin{bmatrix}
cosn\theta & -sinn\theta\\
sinn\theta & cosn\theta
\end{bmatrix}k
\)
Euler's formula
\(e^{ix} = \cos x + i\sin x\)
\(e^{im\theta} = \cos m\theta + i\sin m\theta\)
\(Q_iR(i\theta) = x_iW_Q^TR(i\theta) = (e_i+p_i)W_Q^TR(i\theta)\)
\(K_jR(j\theta) = x_jW_K^TR(j\theta) = (e_j+p_j)W_K^TR(j\theta)\)
\(\begin{align*} f_q(x_i,i)f_k(x_j,j) & = [Q_iR(i\theta)][K_jR(j\theta)]^T\\ & = [x_iW_Q^TR(i\theta)][x_jW_K^TR(j\theta)]^T\\ & = (e_i+p_i)W_Q^TR(i\theta)R(j\theta)^TW_K(e_j+p_j)^T\\ & = (e_i+p_i)W_Q^TR(i\theta)R(-j\theta)W_K(e_j+p_j)^T\\ & = (e_i+p_i)W_Q^TR[(i-j)\theta]W_K(e_j+p_j)^T\\ & = g(x_i, x_j, i-j) \end{align*}\)
03 - Word Embedding
Source 1: Glossary of Deep Learning: Word Embedding | Medium
Source 2: Word2Vec | Bilibili
- BOW(One-Hot, TF-IDF, TextRank)
- Word2Vec(CBOW, Skip-gram)
- Glove
- FastText
- nn.Embedding()
04 - Transformer Block
Multi-head Attention + Feed Forward Network(Linear + Activition + Linear) + Residual Addition + Layer Normalization
<1> Attention
- Bidirectional attention: Encoder
- unidirectional or one-way attention: Decoder
注意力得分矩阵是下三角矩阵
Attention mask
Flash Attention
<2> Normalization
Group by the type
- Batch Normalization
- Layer Normalization
- RMS Normalization
Transformer中的Normalization层一般都是采用LayerNorm来对Tensor进行归一化,LayerNorm的公式如下:
\(Layer Norm:\)
而RMSNorm就是LayerNorm的变体,RMSNorm省去了求均值的过程,也没有了偏置\(\beta\),即
\(RMSNorm:\)
其中 \(\gamma\) 和 \(\beta\) 为可学习的参数
Group by the position
- Post-Norm
- Pre-Norm
- Sandwich-Norm
Post LN:
位置:layer norm 在残差链接之后
缺点:Post LN 在深层的梯度范式逐渐增大,导致使用post-LN 的深层transformer 容易出现训练不稳定的问题
Pre-LN:
位置:layer norm 在残差链接中
优点:相比于 Post-LN,Pre LN 在深层的梯度范式近似相等,所以使用 Pre-LN 的深层
transformer 训练更稳定,可以缓解训练不稳定问题
缺点:相比于Post-LN,Pre-LN 的模型效果略差
Sandwich-LN:
位置:在pre-LN 的基础上,额外插入了一个layer norm
优点:Cogview 用来避免值爆炸的问题
缺点:训练不稳定,可能会导致训练崩溃。
Model | Nomalization |
---|---|
GPT3 | Pre Layer Norm |
Llama | Pre RMS Norm |
baichuan | Pre RMS Norm |
ChatGLM-6B | Post Deep Norm |
ChatGLM2-6B | Post RMS Norm |
Bloom | Pre Layer Norm |
05 - Model Head for Downstream Tasks
To see the hidden states, click here.
self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
hidden states: contextual understanding of the input sentences by the model, the shape is [batch_size, token_len, hidden_size].(hidden_size is the last linear size in the model block)
Tokenizer(Text) --> input
AutoModel(input_ids) --> return hidden states
Model Head(hidden states) --> return the logits based on the given num_labels
softmax(logits) --> the prob distribution based on the given num_labels
argmax(prob distribution) --> 0 or 1 (or 2 ...)
id2label(1 or 0) --> The final output
AutoModel
AutoModelForCausalLM # Text Generation
AutoModelForMaskedLM
AutoModelForSeq2SeqLM # Text Translation
AutoModelForMultipleChoice
AutoModelForQuestionAnswering
AutoModelForSequenceClassification # Text Classification
AutoModelForTokenClassification # Text Tagging (NER or POS)
06 - BERT Family
DeBERTa
- disentangled attention
content-to-content, content-to-position, position-to-content matrix - enhanced mask decoder
H(hidden states), I(any necessary information for decoding(e.g. hidden states, absolute position embeddings or output from the previous EMD layer))
07 - Decoder-only Model
Resource 1: LLaMA
Resource 2: 论文精读:Mixtral + MoE - 王几行 | 知乎
Article 1: Mixtral of Experts
Article 2: Mixtral 7B
Article 3: GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints
LLaMA 2
- Pre RMS Norm
- Rotary Positional Embeddings (RoPE)
- Multi-Head Attention(7B), Grouped-Query Attention(34B, 70B)
- KV Cache
- SwiGLU
Mistral 7B
- Pre RMS Norm
- Rotary Positional Embeddings (RoPE)
- Grouped Query Attention + Sliding Window Attention
- Rotating Buffer Cache
- SwiGLU
Mistral 8x7B(use only 12.9B)
- Mistral 7B + Mixture of Experts (MoE)
ChatGLM