摘要: ### 计算原理 $RMSNorm = x * (sqrt(1/n * (x_i)^2 + eps)) * g$ ### torch实现 ```python class RMSNorm(torch.nn.Module): def __init__(self, dim: int, eps: float 阅读全文
posted @ 2023-08-20 11:12 wildkid1024 阅读(827) 评论(0) 推荐(0) 编辑