[cuda]RMSNorm核函数解析

计算原理

\(RMSNorm = x * (sqrt(1/n * (x_i)^2 + eps)) * g\)

torch实现

class RMSNorm(torch.nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def _norm(self, x):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x):
        output = self._norm(x.float()).type_as(x)
        return output * self.weight

先算出norm的值,然后再计算g*norm, 其中norm为平方和的根。注意这里是先转化为float进行进行norm运算,norm的结果再转为对应type。

cuda实现

__global__ void rms_norm_kernel(
  scalar_t* __restrict__ out,             // [num_tokens, hidden_size]
  const scalar_t* __restrict__ input,     // [num_tokens, hidden_size]
  const scalar_t* __restrict__ weight,    // [hidden_size]
  const float epsilon,
  const int num_tokens,
  const int hidden_size) {
  __shared__ float s_variance;
  float variance = 0.0f;

  for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
    const float x = (float) input[blockIdx.x * hidden_size + idx];
    variance += x * x;
  }
  variance = blockReduceSum<float>(variance);
  if (threadIdx.x == 0) {
    s_variance = rsqrtf(variance / hidden_size + epsilon);
  }
  __syncthreads();

  for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
    float x = (float) input[blockIdx.x * hidden_size + idx];
    out[blockIdx.x * hidden_size + idx] = ((scalar_t) (x * s_variance)) * weight[idx];
  }
}

这里variance计算了不同block间同余位置上的x的平方和,经过blockReduceSum则将部分和进行相加,得到全部的平方和,并在线程0下计算平方根。同时更新同一block下的所有output。

结果对比

torch run time: 0.5862712860107422 ms
torch.Size([200, 2048])
cuda run time: 0.06341934204101562 ms
torch.Size([200, 2048])
tensor([[ 0.2473, -0.4733, -1.5234,  ..., -1.0379,  0.2188, -1.7629],
        [-0.0408, -0.9154,  0.6396,  ...,  0.1713, -1.1047,  0.7188],
        [-1.0582, -0.0282,  0.7803,  ...,  1.4090,  1.4131,  1.7266],
        ...,
        [ 0.4701,  0.2073,  1.7602,  ..., -0.4985, -1.0406, -0.4027],
        [ 0.0527, -1.2559,  0.2172,  ..., -0.2953, -1.3365,  0.2298],
        [ 1.0274,  2.4901, -0.2216,  ...,  0.5723,  1.3783,  0.6167]],
       device='cuda:0', grad_fn=<MulBackward0>)
tensor([[ 0.2473, -0.4733, -1.5234,  ..., -1.0379,  0.2188, -1.7629],
        [-0.0408, -0.9154,  0.6396,  ...,  0.1713, -1.1047,  0.7188],
        [-1.0582, -0.0282,  0.7803,  ...,  1.4090,  1.4131,  1.7266],
        ...,
        [ 0.4701,  0.2073,  1.7602,  ..., -0.4985, -1.0406, -0.4027],
        [ 0.0527, -1.2559,  0.2172,  ..., -0.2953, -1.3365,  0.2298],
        [ 1.0274,  2.4901, -0.2216,  ...,  0.5723,  1.3783,  0.6167]],
       device='cuda:0')
max diff:  tensor(4.7684e-07, device='cuda:0', grad_fn=<MaxBackward1>)

本次测试的大小为[200, 2048], 即token长为200,feature dim长度为2048,可以看到torch的运行时间为0.58ms,cuda的运行时间为0.06ms,效率提升了一个数量级,而误差max diff为1e-7级别,是可接受的范围。

posted @ 2023-08-20 11:12  wildkid1024  阅读(820)  评论(0编辑  收藏  举报