[fastllm]cuda-kernels源码解析
接着前面第一篇架构的分析,这篇文章主要分析fastllm中使用cuda-kernels的写法,在fastllm中主要有以下几种使用频率较高的kernel:gemv_int4, gemv_int8, gemm_int8, RMSNorm, softmax,RotatePosition2D,swiglu等,其中compute-mound的是gemm,其余大都是memory-bound。其主要的提升点在于量化bit的计算比原生的torch转为float计算会更快,另外由于没有加fuse的操作,所以还是有可优化的空间。
gemv_int4 kernels 解析
功能:实现float32*int4 GEMV乘积,其中偏置值为最小值。
template <int THREAD_PER_BLOCK, int PART>
__global__ void FastllmGemvInt4NoZeroKernel2(float *A, uint8_t *B, float *C,
float *bias, float *scales, float *mins,
int m, int k) {
__shared__ float sdata[THREAD_PER_BLOCK];
unsigned int tid = threadIdx.x;
// 1. 计算
int st = blockIdx.x * PART;
int end = st + PART;
for (int p = st; p < end; p++) {
sdata[tid] = 0;
float minv = mins[p] / scales[p];
for (int i = tid; i < m / 2; i += THREAD_PER_BLOCK) {
uint8_t now = B[p * m / 2 + i];
sdata[tid] += (A[i * 2] * (minv + (now >> 4)) + A[i * 2 + 1] * (minv + (now & 15)));
}
__syncthreads();
for (unsigned int s = 1; s < THREAD_PER_BLOCK; s *= 2) {
if ((tid & (2 * s - 1)) == 0) {
sdata[tid] += sdata[tid + s];
}
__syncthreads();
}
if (tid == 0) {
C[p] = sdata[0] * scales[p] + bias[p];
}
__syncthreads();
}
}
对于n*m的矩阵,以及第二个乘数向量m*1, 划分成不同的tile进行计算,不同的tile之间并行,首先通过遍历m/2,找到m列位置上的对应int4的值。
通过保存的mins找到最小值minv,同一个group(两个int4组成的int8)共享同一个minv,实际每次计算的结果为两个float*int4之后的部分和。
然后将结果规约到sdata[0]上,将对应m列位置上的output值进行更新,每次得到一个tile大小的最终结果,最终结果向量为n*1。
template <int THREAD_PER_BLOCK, int PART>
__global__ void FastllmGemvInt4Kernel2(float *A, uint8_t *B, float *C,
float *bias, float *scales, uint8_t *zeros,
int m, int k) {
__shared__ float sdata[THREAD_PER_BLOCK];
unsigned int tid = threadIdx.x;
// 1. 计算
int st = blockIdx.x * PART;
int end = st + PART;
for (int p = st; p < end; p++) {
sdata[tid] = 0;
uint8_t zero = zeros[p];
for (int i = tid; i < m / 2; i += THREAD_PER_BLOCK) {
uint8_t now = B[p * m / 2 + i];
sdata[tid] += (A[i * 2] * ((now >> 4) - zero) + A[i * 2 + 1] * ((now & 15) - zero));
}
__syncthreads();
for (unsigned int s = 1; s < THREAD_PER_BLOCK; s *= 2) {
if ((tid & (2 * s - 1)) == 0) {
sdata[tid] += sdata[tid + s];
}
__syncthreads();
}
if (tid == 0) {
C[p] = sdata[0] * scales[p] + bias[p];
}
__syncthreads();
}
}
与上述的nozero版本基本类似,只不过偏置值由保存的minv变为了zeros。
GEMV int8 kernels
功能: 对于n*m的矩阵和向量m*1,计算其MV的值。
template <int THREAD_PER_BLOCK, int PART>
__global__ void FastllmGemvInt8Kernel2(float *A, uint8_t *B, float *C,
float *bias, float *scales, uint8_t *zeros,
int m, int k) {
__shared__ float sdata[THREAD_PER_BLOCK];
unsigned int tid = threadIdx.x;
// 1. 读入fdata
/*for (int i = tid; i < m; i += THREAD_PER_BLOCK) {
fdata[i] = A[i];
}
__syncthreads();*/
// 2. 计算
int st = blockIdx.x * PART;
int end = st + PART;
for (int p = st; p < end; p++) {
sdata[tid] = 0;
uint8_t zero = zeros[p];
for (int i = tid; i < m; i += THREAD_PER_BLOCK) {
sdata[tid] += A[i] * (B[p * m + i] - zero);
}
__syncthreads();
for (unsigned int s = 1; s < THREAD_PER_BLOCK; s *= 2) {
if ((tid & (2 * s - 1)) == 0) {
sdata[tid] += sdata[tid + s];
}
__syncthreads();
}
if (tid == 0) {
C[p] = sdata[0] * scales[p] + bias[p];
}
__syncthreads();
}
}
将n化为多个tile,多个tile之间并行。
block内保存的是同余线程的部分和,需要注意的是需要将B减去对应的zero,这里B被隐式转化为了float32。
将block内的部分和进行规约得到当前tile内的所有mv的和,更新tile内的所有output,最终形状为n*1。
template <int THREAD_PER_BLOCK, int PART>
__global__ void FastllmGemvFp32Fp16Kernel2(float *A, half *B, float *C, float *bias, int m, int k) {
__shared__ float sdata[THREAD_PER_BLOCK];
unsigned int tid = threadIdx.x;
// 1. 计算
int st = blockIdx.x * PART;
int end = st + PART;
for (int p = st; p < end; p++) {
sdata[tid] = 0;
for (int i = tid; i < m; i += THREAD_PER_BLOCK) {
sdata[tid] += A[i] * (float)B[p * m + i];
}
__syncthreads();
for (unsigned int s = 1; s < THREAD_PER_BLOCK; s *= 2) {
if ((tid & (2 * s - 1)) == 0) {
sdata[tid] += sdata[tid + s];
}
__syncthreads();
}
if (tid == 0) {
C[p] = sdata[0] + bias[p];
}
__syncthreads();
}
}
与int8类似,采用fp16的精度。
GEMM int8 kernels
功能:对于矩阵n*m以及矩阵m*k,计算他们的矩阵乘积。
template <int NBlock, int MBlock, int KBlock>
__global__ void FastllmCudaBaseGemmKernelInt8(float *A, uint8_t *B, float *C,
float *bias, float *scales, uint8_t *zeros,
int n, int m, int k) {
int nStart = blockIdx.x * NBlock, nEnd = nStart + NBlock;
int kStart = blockIdx.y * KBlock, kEnd = kStart + KBlock;
int id = kStart + threadIdx.x;
__shared__ float shareA[NBlock * MBlock];
__shared__ float shareB[KBlock * MBlock];
float localSum[NBlock] = {0.0f};
uint8_t zero = zeros[id];
int idx = threadIdx.x >> 3;
int idy = threadIdx.x & 7;
for (int l = 0; l < m; l += MBlock) {
if (threadIdx.x < MBlock) {
for (int i = nStart; i < nEnd; i++) {
if (i < n && l + threadIdx.x < m) {
shareA[(i - nStart) * MBlock + threadIdx.x] = A[i * m + l + threadIdx.x];
} else {
shareA[(i - nStart) * MBlock + threadIdx.x] = 0.0f;
}
}
}
__syncthreads();
if (threadIdx.x < MBlock) {
for (int i = kStart; i < kEnd; i++) {
if (i < k && l + threadIdx.x < m) {
shareB[(i - kStart) * MBlock + threadIdx.x] = B[i * m + l + threadIdx.x];
} else {
shareB[(i - kStart) * MBlock + threadIdx.x] = 0.0f;
}
}
}
__syncthreads();
for (int mStart = 0; mStart < MBlock; mStart += 4) {
float curA[32] = {0.0f}, curB[32] = {0.0f};
for (int i = 0; i < 8; i++) {
for (int x = l + mStart; x < l + mStart + 4 && x < m; x++) {
curA[i * 4 + (x - l - mStart)] = shareA[(idx * 8 + i) * MBlock + (x - l)];
}
}
for (int j = 0; j < 4; j++) {
zero = zeros[kStart + (idy * 4 + j)];
for (int x = l + mStart; x < l + mStart + 4 && x < m; x++) {
curB[j * 4 + (x - l - mStart)] = shareB[(idy * 4 + j) * MBlock + (x - l)] - zero;
}
}
for (int i = 0; i < 8; i++) {
for (int j = 0; j < 4; j++) {
int cur = i * 4 + j;
localSum[cur] += curA[i * 4 + 0] * curB[j * 4 + 0];
localSum[cur] += curA[i * 4 + 1] * curB[j * 4 + 1];
localSum[cur] += curA[i * 4 + 2] * curB[j * 4 + 2];
localSum[cur] += curA[i * 4 + 3] * curB[j * 4 + 3];
}
}
__syncthreads();
}
__syncthreads();
}
for (int i = 0; i < 8; i++) {
for (int j = 0; j < 4; j++) {
if ((nStart + idx * 8 + i) < n && (kStart + idy * 4 + j) < k) {
C[(nStart + idx * 8 + i) * k + (kStart + idy * 4 + j)] =
localSum[i * 4 + j] * scales[(kStart + idy * 4 + j)] + bias[(kStart + idy * 4 + j)];
}
}
}
}
shareA和shareB都属于block内内存,localSum是global内存,线程又被分为8个为一个wrap。
预取A矩阵的部分数据到shareA,其大小Nblock*MBlock,shareB的大小为KBlock*Mblock,这里将未使用到的置为0,可能会有一定浪费。
然后按32*4大小的size进行计算localsum(即每个位置上的数实际是4个乘加部分和),这里感觉有优化空间,因为不是每个gpu的blcok内存都一样为32个和。
最后以32为单位成块成块地更新output。
layerNorm 实现
功能:实现layernorm,数学计算公式为\(layerNorm=(x-\mu)/\sigma\)
template <int THREAD_PER_BLOCK>
__global__ void FastllmLayerNormKernelInner1(float *input, float *gamma, float *beta, float *output, int outer, int channels) {
int o = blockIdx.x;
input = input + o * channels;
output = output + o * channels;
__shared__ float sdata[THREAD_PER_BLOCK];
__shared__ float sdata2[THREAD_PER_BLOCK];
__shared__ float mean;
__shared__ float var;
// 1. 每个线程计算一部分
unsigned int tid = threadIdx.x;
float sum = 0.0, sum2 = 0.0;
for (int i = tid; i < channels; i += THREAD_PER_BLOCK) {
float x = input[i];
sum += x;
sum2 += x * x;
}
sdata[tid] = sum;
sdata2[tid] = sum2;
__syncthreads();
// 2. 求和
for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) {
if (tid < s) {
sdata[tid] += sdata[tid + s];
sdata2[tid] += sdata2[tid + s];
}
__syncthreads();
}
// 3. 计算参数
if (tid == 0) {
mean = sdata[0] / channels;
var = sdata2[0] + mean * mean * channels - 2 * mean * channels * mean;
var = sqrt(var / channels + 1e-10);
}
__syncthreads();
for (int i = tid; i < channels; i += THREAD_PER_BLOCK) {
output[i] = (input[i] - mean) / var * gamma[i] + beta[i];
}
}
sdata中保存的是用于计算均值的所有和,sdata2中保存的是计算方差的所有平方和的值。
按照同余线程的方法进行计算部分和,然后进行规约得到全局的和,得到mean和var。
最后按照计算公式更新output即可。
template <int THREAD_PER_BLOCK>
__global__ void FastllmLayerNormKernelTop1(float *input, float *output, int channels) {
__shared__ float idData[THREAD_PER_BLOCK];
__shared__ float maxData[THREAD_PER_BLOCK];
float *inputData = input + blockIdx.x * channels;
float *outputData = output + blockIdx.x * 2;
int tid = threadIdx.x;
maxData[tid] = -1e100;
for (int j = tid; j < channels; j += THREAD_PER_BLOCK) {
if (inputData[j] > maxData[tid]) {
maxData[tid] = inputData[j];
idData[tid] = j;
}
}
__syncthreads();
for (unsigned int s = THREAD_PER_BLOCK / 2; s > 0; s >>= 1) {
if (tid < s) {
if (maxData[tid] < maxData[tid + s]) {
maxData[tid] = maxData[tid + s];
idData[tid] = idData[tid + s];
}
}
__syncthreads();
}
if (tid == 0) {
outputData[0] = idData[0];
outputData[1] = maxData[0];
}
}
和layernorm好像没太大关系,找最大值的函数,与求和操作基本一样。
RMS kernels解析
功能: 实现RMSNorm,其数学公式为\(RMSNorm=(x/sqrt(sum(x_i^2)/n+eps))\)
template <int THREAD_PER_BLOCK>
__global__ void FastllmRMSNormKernelInner1(float *input, float *weight, float *output, int outer, int channels, float eps) {
int o = blockIdx.x;
input = input + o * channels;
output = output + o * channels;
__shared__ float sdata2[THREAD_PER_BLOCK];
__shared__ float scale;
// 1. 每个线程计算一部分
unsigned int tid = threadIdx.x;
float sum2 = 0.0;
for (int i = tid; i < channels; i += THREAD_PER_BLOCK) {
float x = input[i];
sum2 += x * x;
}
sdata2[tid] = sum2;
__syncthreads();
// 2. 求和
for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) {
if (tid < s) {
sdata2[tid] += sdata2[tid + s];
}
__syncthreads();
}
// 3. 计算参数
if (tid == 0) {
scale = 1.0 / sqrt(sdata2[0] / channels + eps);
}
__syncthreads();
for (int i = tid; i < channels; i += THREAD_PER_BLOCK) {
output[i] = (input[i] * scale * weight[i]);
}
}
每个block对应位置计算的是同余线程的部分和,sum2即累加了同余线程的部分和。
通过规约当前block内的所有和得到\(sum(x_i^2)\),也即为sdata[0]。
得到所有和之后便可以计算出scale的值,通过scale更新每个output。
softmax kernels 解析
功能: 计算inputs的softmax值,计算公式为\(softmax=exp(x_i-max(x))/sum(exp(x_i-max(x)))\)
template <int THREAD_PER_BLOCK>
__global__ void FastllmSoftmaxKernelInner1(float* input, float *output, int outer, int channels) {
int o = blockIdx.x;
input = input + o * channels;
output = output + o * channels;
__shared__ float sdata[THREAD_PER_BLOCK];
__shared__ float maxV;
// 1. 每个线程计算一部分
unsigned int tid = threadIdx.x;
unsigned int per = (channels / THREAD_PER_BLOCK);
unsigned int id = threadIdx.x * per;
unsigned int len = per;
if (tid == blockDim.x - 1) {
len += (channels - per * THREAD_PER_BLOCK);
}
float maxValue = input[id];
for (int i = 0; i < len; i++) {
maxValue = max(maxValue, input[id + i]);
}
sdata[tid] = maxValue;
__syncthreads();
// 2. 求max
for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) {
if (tid < s) {
sdata[tid] = max(sdata[tid], sdata[tid + s]);
}
__syncthreads();
}
// 3. 记录max
if (tid == 0) {
maxV = sdata[0];
}
__syncthreads();
// 4. 求和
float sum = 0;
for (int i = 0; i < len; i++) {
output[id + i] = exp(input[id + i] - maxV);
sum += output[id + i];
}
sdata[tid] = sum;
__syncthreads();
for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) {
if (tid < s) {
sdata[tid] += sdata[tid + s];
}
__syncthreads();
}
if (tid == 0) {
if (fabs(sdata[0]) < 1e-6) {
sdata[0] = 0.1;
}
}
__syncthreads();
for (int i = 0; i < len; i++) {
output[id + i] /= sdata[0];
}
}
首先每个线程计算一部分找到max值,sdata计算的是block内同余的maxv值,规约得到全局最大值。
然后计算每个input的指数表示,并进行规约求和,与上述结果一致,需要对下溢出的float置较小的数。
最后将结果更新到output中。
RotatePosition2D Kernels解析
功能:旋转位置编码,
__global__ void FastllmLlamaRotatePosition2DKernel(float *data, float *positionIds, float *sin, float *cos,
int len, int bs, int spatial, int n, int m, int partStride, int sinCosStride, int rotateDim) {
int o = (blockIdx.x / n);
int l = o % len;
int b = o / len;
int j = threadIdx.x;
int index = (int) (positionIds[b * partStride + l]);
float curSin = sin[index * sinCosStride + j];
float curCos = cos[index * sinCosStride + j];
float *d = (float *) data + o * spatial + j;
int i = blockIdx.x % n;
float va = d[i * m], vb = d[i * m + m / 2];
d[i * m] = va * curCos - vb * curSin;
d[i * m + m / 2] = va * curSin + vb * curCos;
}
__global__ void FastllmNearlyRotatePosition2DKernel(float *data, float *positionIds, float *sin, float *cos,
int len, int bs, int spatial, int n, int m, int partStride, int sinCosStride, int rotateDim) {
int o = (blockIdx.x / n);
int l = o / bs;
int b = o % bs;
int j = threadIdx.x;
int index = (int) (positionIds[b * 2 * partStride + l]);
float curSin = sin[index * sinCosStride + j];
float curCos = cos[index * sinCosStride + j];
float *d = (float *) data + o * spatial + j * 2;
int i = blockIdx.x % n;
float va = d[i * m], vb = d[i * m + 1];
d[i * m] = va * curCos - vb * curSin;
d[i * m + 1] = va * curSin + vb * curCos;
}
__global__ void FastllmRotatePosition2DKernel(float *data, float *positionIds, float *sin, float *cos,
int len, int bs, int spatial, int n, int m, int partStride, int sinCosStride, int rotateDim) {
int o = (blockIdx.x / n) / 2;
int l = o / bs;
int b = o % bs;
int part = (blockIdx.x / n) % 2;
int j = threadIdx.x;
int index = (int) (positionIds[(b * 2 + part) * partStride + l]);
float curSin = sin[index * sinCosStride + j];
float curCos = cos[index * sinCosStride + j];
float *d = (float *) data + o * spatial + part * m / 2 + j;
int i = blockIdx.x % n;
float va = d[i * m], vb = d[i * m + m / 4];
d[i * m] = va * curCos - vb * curSin;
d[i * m + m / 4] = va * curSin + vb * curCos;
}
线程在这里完全展开了,相当于铺开了3层循环。
这里谈谈三者的区别,LlamaRotatePosition2D是以前半段和后半段进行旋转计算的,NearlyRotatePosition是在两个相近点位置上进行计算的,RotatePosition2D则是先分成两部分,每部分再以m/4段进行计算的,差别在于旋转位置的不同。
AttentionMask Kernels 解析
功能:对对应位置上的按照mask掩码的方式置maskv值。
template <int THREAD_PER_BLOCK>
__global__ void FastllmAttentionMaskKernel(float* a, float *b, float maskValue, int n, int m, int spatial) {
int on = blockIdx.x / m;
int om = blockIdx.x % m;
int o = on * m + om;
int idx = threadIdx.x;
for (int i = idx; i < spatial; i += THREAD_PER_BLOCK) {
if (b[on * spatial + i] > 0.99) {
a[o * spatial + i] = maskValue;
}
}
}
template <int THREAD_PER_BLOCK>
__global__ void FastllmAlibiMaskKernel(float* a, float *b, float maskValue, int n, int m, int spn, int spm, int spatial) {
int on = blockIdx.x / m;
int om = blockIdx.x % m;
int o = on * m + om;
int idx = threadIdx.x;
float now = b[om];
for (int i = idx; i < spatial; i += THREAD_PER_BLOCK) {
int idi = i / spm, idj = i % spm;
if (idj <= spm - spn + idi) {
a[o * spatial + i] += now * idj;
} else {
a[o * spatial + i] = maskValue;
}
}
}
线程按照m*n的方式进行拆分了,普通mask直接置为maskv,Alibimask则是置为了相对位置上的值的和。
这里spatial的值含义还存在一些疑惑?
swiglu kernels解析
功能:几种常见的激活函数
__global__ void FastllmSwigluKernel(float* a, float *b, int len, int spatial, int mid) {
int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx < len) {
int id = idx / mid * spatial + idx % mid;
float x = a[id], y = a[id + mid];
b[idx] = (x / (1.0 + expf(-x))) * y;
}
}
__global__ void FastllmGeluKernel(float* a, float *b, int len) {
int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx < len) {
float x = a[idx];
b[idx] = 0.5f * x * (1.0f + tanhf(0.7978845608028654f * x * (1.0f + 0.044715f * x * x)));
}
}
__global__ void FastllmSiluKernel(float* a, float *b, int len) {
int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx < len) {
float x = a[idx];
b[idx] = x / (1.0 + expf(-x));
}
}
没太多好说的,因为是原地操作,thread够用,按公式手写即可。