LLaMA 2
0 Introduction
What's new
- Rotary Position Embedding (RoPE)
- RMS Norm
- Grouped Query Attention + KV Cache
- SwiGLU
Diagram prospect
1 Model Architecture
1.1 Rotary Position Embedding
Paper: ROFORMER: ENHANCED TRANSFORMER WITH ROTARY POSITION EMBEDDING
\(f(q,m)f(k,n) = g(q,k,m-n)\)
\(
f_q(q,m)f_k(k,n) = \begin{bmatrix}
cosm\theta & -sinm\theta\\
sinm\theta & cosm\theta
\end{bmatrix}q
\begin{bmatrix}
cosn\theta & -sinn\theta\\
sinn\theta & cosn\theta
\end{bmatrix}k
\)
Euler's formula
\(e^{ix} = \cos x + i\sin x\)
\(e^{im\theta} = \cos m\theta + i\sin m\theta\)
\(Q_iR(i\theta) = x_iW_Q^TR(i\theta) = (e_i+p_i)W_Q^TR(i\theta)\)
\(K_jR(j\theta) = x_jW_K^TR(j\theta) = (e_j+p_j)W_K^TR(j\theta)\)
1.2 RMS Norm
1.3 Grouped Query Attention + KV Cache
<1> Grouped Query Attention
GQA is the trade-off between Efficiency and Accuracy.
- Efficiency: MHA < GQA < MQA
- Accuracy: MHA > GQA > MQA
Figures from GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints
<2> KV Cache
1.4 SwiGLU
SwiGLU means Swish(also refers to SiLU) and Gated Linear Unit, which is commonly used in the feed forward network of LLaMA 2, Mixtral 7B, Mixtral 8×7B.
import torch.nn as nn
import torch.nn.functional as F
class MLP(nn.Module):
def __init__(self, config):
self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size)
self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size)
self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size)
def forward(self, x):
hidden_states = self.down_proj(F.silu(self.gate_proj(x), dim = -1) * self.up_proj(x))
return hidden_states
Reference
Video 1: Llama 2 模型结构解析 - CodeLearner | Bilibili
Blog 1: Llama 2详解 - CodeLearner | Zhihu
Blog 2: Understanding Llama2: KV Cache, Grouped Query Attention, Rotary Embedding and More
Video 2: Transformer的位置编码(Position Encoding)进展梳理
Blog 3: 二维旋转矩阵与向量旋转