A DIRECT CONV2D vs IMPLICIT GEMM impl
#include <cstdlib>
#include <cstring>
#include <iostream>
using namespace std;
size_t N = 1, IC = 3, IH = 2, IW = 2;
size_t OC = 2, FH = 2, FW = 2;
size_t stride = 1, pad = 1;
size_t OH = (IH + 2 * pad - FH) / stride + 1;
size_t OW = OH;
void assert_eq(float *a, float *b, size_t N) {
float eps = 1e-5;
for (int i = 0; i < N; i++) {
if (*(a + i) - *(b + i) > eps) {
printf("[FAILED]result eq check\n");
exit(-1);
}
}
printf("[PASSED]result eq check\n");
}
void assign(float *ptr, int a, int b, int c, int d) {
int v = 1;
for (int i = 0; i < a; i++) {
for (int j = 0; j < b; j++) {
for (int p = 0; p < c; p++) {
for (int q = 0; q < d; q++) {
*(ptr + i * (b * c * d) + j * (c * d) + p * (d) + q) = (v) % 4;// 方便check
v++;
}
}
}
}
}
void printout(float *ptr, int a, int b, int c, int d) {
for (int i = 0; i < a; i++) {
for (int j = 0; j < b; j++) {
for (int p = 0; p < c; p++) {
for (int q = 0; q < d; q++) {
printf("%.2f ", *(ptr + i * b * c * d + j * c * d + p * d + q));
}
printf("\n");
}
}
}
printf("\n");
}
// NxICxIHxIW, OCxICxFHxFW, NxOCxOHxOW
void direct_conv(float *feat, float *weight, float *out) {
for (int n = 0; n < N; n++) {
for (int oc = 0; oc < OC; oc++) {
for (int oh = 0; oh < OH; oh++) {
for (int ow = 0; ow < OW; ow++) {
float value = 0;
// feature map中左上角坐标
auto ih_start = oh * stride - pad;
auto iw_start = ow * stride - pad;
for (int ic = 0; ic < IC; ic++) {
for (int fh = 0; fh < FH; fh++) {
for (int fw = 0; fw < FW; fw++) {
float f = 0;
// 边界检查,没有到pad处
if (ih_start + fh >= 0 && iw_start + fw >= 0 && ih_start + fh < IH &&
iw_start + fw < IW) {
f = *(feat + n * (IC * (IH) * (IW)) + ic * ((IH) * (IW)) + (ih_start + fh) * (IW) +
iw_start + fw);
}
auto w = *(weight + oc * (IC * FH * FW) + ic * (FH * FW) + fh * FW + fw);
value += f * w;
printf("feat@{%d,%d,%d,%d}, val:%f\n", n, ic, ih_start + fh, iw_start + fw, f);
printf("weight@{%d,%d,%d,%d}, val:%f\n", oc, ic, fh, fw, w);
}
}
}
*(out + n * (OC * OH * OW) + oc * (OH * OW) + oh * OW + ow) = value;
}
}
}
}
}
void implicit_gemm_cutlass_offcical(float *feat, float *weight, float *out) {
/*
* 我理解这里只要定好MNK的维度分别表示conv的维度是什么,然后就按照NCHW写就行了*/
int GEMM_M = N * OH * OW;
int GEMM_N = OC;
int GEMM_K = IC * FH * FW;
// 在N,OC,OH,OW上iter
for (auto i = 0; i < GEMM_M; i++) {
for (auto j = 0; j < GEMM_N; j++) {
float value = 0;
// 偏移计算都根据NCHW这个layout来的, 和物理没关系
int n = i / (OH * OW);
int oh_ow_res = i % (OH * OW);
int oh = oh_ow_res / OW;
int ow = oh_ow_res % OW;
int oc = j;
// 在filter上iter
for (auto k = 0; k < GEMM_K; k++) {
// 偏移计算都根据NCHW这个layout来的, 和物理没关系
int ic = k / (FH * FW);
int fh_fw_res = k % (FH * FW);
int fh = fh_fw_res / (FW);
int fw = fh_fw_res % (FW);
int ih = oh * stride - pad + fh;
int iw = ow * stride - pad + fw;
float _w = 0, _f = 0;
// 只有下面操作内存时,才和物理相关。
_w = *(weight + oc * (IC * FH * FW) + ic * (FH * FW) + fh * FW + fw); // acc to NCHW
if (ih >= 0 && ih < IH && iw >= 0 && iw < FW) {
_f = *(feat + n * (IC * IH * IW) + ic * (IH * IW) + ih * IW + iw); // acc to NCHW
}
value += _f * _w;
}
*(out + n * (OC * OH * OW) + oc * (OH * OW) + oh * OW + ow) = value; // acc to NCHW
}
}
}
void implicit_gemm(float *feat, float *weight, float *out) {
int GEMM_M = OC;
int GEMM_N = N * OH * OW;
int GEMM_K = IC * FH * FW;
for (int i = 0; i < GEMM_M; i++) {
int oc = i;
for (int j = 0; j < GEMM_N; j++) {
int value = 0;
int n = j / (OH * OW);
int j_res = j % (OH * OW);
int oh = j_res / OW;
int ow = j_res % OW;
for (int k = 0; k < GEMM_K; k++) {
int ic = k / (FH * FW);
int k_res = k % (FH * FW);
int fh = k_res / FW;
int fw = k_res % FW;
int ih = oh * stride - pad + fh;
int iw = ow * stride - pad + fw;
float _f = 0, _w = 0;
_w = *(weight + oc * IC * FH * FW + ic * FH * FW + fh * FW + fw);
if (ih >= 0 && iw >= 0 && ih < IH && iw < IW) {
_f = *(feat + n * IC * IH * IW + ic * IH * IW + ih * IW + iw);
}
value += _f * _w;
}
*(out + n * OC * OH * OW + oc * OH * OW + oh * OW + ow) = value;
}
}
}
int main() {
float *feat = (float *) malloc(sizeof(float) * N * IC * (IH) * (IW));
assign(feat, N, IC, IH, IW);
printf("feat\n");
printout(feat, N, IC, IH, IW);
printf("\n");
printf("weight\n");
float *weight = (float *) malloc(sizeof(float) * OC * IC * FH * FW);
assign(weight, OC, IC, FH, FW);
printout(weight, OC, IC, FH, FW);
printf("\n");
float *out = (float *) malloc(sizeof(float) * N * OC * OH * OW);
memset(out, 0, N * OC * OH * OW);
direct_conv(feat, weight, out);
printf("out\n");
printout(out, N, OC, OH, OW);
printf("\n");
float *implicit_gemm_out = (float *) malloc(sizeof(float) * N * OC * OH * OW);
memset(implicit_gemm_out, 0, N * OC * OH * OW);
implicit_gemm(feat, weight, implicit_gemm_out);
printf("implicit gemm out\n");
printout(implicit_gemm_out, N, OC, OH, OW);
assert_eq(out, implicit_gemm_out, N * OC * OH * OW);
printf("\n");
float *implicit_gemm_out_cutlass = (float *) malloc(sizeof(float) * N * OC * OH * OW);
memset(implicit_gemm_out_cutlass, 0, N * OC * OH * OW);
implicit_gemm_cutlass_offcical(feat, weight, implicit_gemm_out_cutlass);
printf("implicit gemm cutlass_impl out\n");
printout(implicit_gemm_out_cutlass, N, OC, OH, OW);
assert_eq(implicit_gemm_out_cutlass, implicit_gemm_out, N * OC * OH * OW);
printf("\n");
return 0;
}
本文来自博客园,作者:ijpq,转载请注明原文链接:https://www.cnblogs.com/ijpq/p/17296058.html