RoPE
概
原作者的博客已经讲得非常到位了: [here] and [there].
RoPE
-
RoPE 是一种相对位置编码, 特点是它可以像绝对位置编码一样, 在 embedding 上操作后再进行 attention 的运算, 而不限定于在 score 矩阵上操作.
-
具体的, 假设 为位置 上的两个 embedding, 令:
则
就是吸收了相对位置信息 的 score.
-
其中 是旋转矩阵, 它作用在向量是相当于对两个两个维度地进行旋转. 和最普通的 Sinusoidal 编码保持一致.
-
有一种更加高效的方式:
-
下面是 LLaMA 中的实现方式:
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device) # type: ignore
freqs = torch.outer(t, freqs).float() # type: ignore
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
return freqs_cis
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
ndim = x.ndim
assert 0 <= 1 < ndim
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis.view(*shape)
def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
# x: B, S, H, D
# freqs_cis: S, D // 2
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # (B, S, H, D // 2)
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) # (B, S, H, D // 2)
freqs_cis = reshape_for_broadcast(freqs_cis, xq_) # (1, S, 1, 1, D // 2)
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· Manus重磅发布:全球首款通用AI代理技术深度解析与实战指南
· 被坑几百块钱后,我竟然真的恢复了删除的微信聊天记录!
· 没有Manus邀请码?试试免邀请码的MGX或者开源的OpenManus吧
· 园子的第一款AI主题卫衣上架——"HELLO! HOW CAN I ASSIST YOU TODAY
· 【自荐】一款简洁、开源的在线白板工具 Drawnix
2022-07-24 Commute Time Distance
2022-07-24 HOP-Rec: High-Order Proximity for Implicit Recommendation