[cuda]RMSNorm核函数解析
计算原理
\(RMSNorm = x * (sqrt(1/n * (x_i)^2 + eps)) * g\)
torch实现
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
output = self._norm(x.float()).type_as(x)
return output * self.weight
先算出norm的值,然后再计算g*norm, 其中norm为平方和的根。注意这里是先转化为float进行进行norm运算,norm的结果再转为对应type。
cuda实现
__global__ void rms_norm_kernel(
scalar_t* __restrict__ out, // [num_tokens, hidden_size]
const scalar_t* __restrict__ input, // [num_tokens, hidden_size]
const scalar_t* __restrict__ weight, // [hidden_size]
const float epsilon,
const int num_tokens,
const int hidden_size) {
__shared__ float s_variance;
float variance = 0.0f;
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
const float x = (float) input[blockIdx.x * hidden_size + idx];
variance += x * x;
}
variance = blockReduceSum<float>(variance);
if (threadIdx.x == 0) {
s_variance = rsqrtf(variance / hidden_size + epsilon);
}
__syncthreads();
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
float x = (float) input[blockIdx.x * hidden_size + idx];
out[blockIdx.x * hidden_size + idx] = ((scalar_t) (x * s_variance)) * weight[idx];
}
}
这里variance计算了不同block间同余位置上的x的平方和,经过blockReduceSum则将部分和进行相加,得到全部的平方和,并在线程0下计算平方根。同时更新同一block下的所有output。
结果对比
torch run time: 0.5862712860107422 ms
torch.Size([200, 2048])
cuda run time: 0.06341934204101562 ms
torch.Size([200, 2048])
tensor([[ 0.2473, -0.4733, -1.5234, ..., -1.0379, 0.2188, -1.7629],
[-0.0408, -0.9154, 0.6396, ..., 0.1713, -1.1047, 0.7188],
[-1.0582, -0.0282, 0.7803, ..., 1.4090, 1.4131, 1.7266],
...,
[ 0.4701, 0.2073, 1.7602, ..., -0.4985, -1.0406, -0.4027],
[ 0.0527, -1.2559, 0.2172, ..., -0.2953, -1.3365, 0.2298],
[ 1.0274, 2.4901, -0.2216, ..., 0.5723, 1.3783, 0.6167]],
device='cuda:0', grad_fn=<MulBackward0>)
tensor([[ 0.2473, -0.4733, -1.5234, ..., -1.0379, 0.2188, -1.7629],
[-0.0408, -0.9154, 0.6396, ..., 0.1713, -1.1047, 0.7188],
[-1.0582, -0.0282, 0.7803, ..., 1.4090, 1.4131, 1.7266],
...,
[ 0.4701, 0.2073, 1.7602, ..., -0.4985, -1.0406, -0.4027],
[ 0.0527, -1.2559, 0.2172, ..., -0.2953, -1.3365, 0.2298],
[ 1.0274, 2.4901, -0.2216, ..., 0.5723, 1.3783, 0.6167]],
device='cuda:0')
max diff: tensor(4.7684e-07, device='cuda:0', grad_fn=<MaxBackward1>)
本次测试的大小为[200, 2048], 即token长为200,feature dim长度为2048,可以看到torch的运行时间为0.58ms,cuda的运行时间为0.06ms,效率提升了一个数量级,而误差max diff为1e-7级别,是可接受的范围。