A DIRECT CONV2D vs IMPLICIT GEMM impl

#include <cstdlib>
#include <cstring>
#include <iostream>

using namespace std;


size_t N = 1, IC = 3, IH = 2, IW = 2;
size_t OC = 2, FH = 2, FW = 2;
size_t stride = 1, pad = 1;

size_t OH = (IH + 2 * pad - FH) / stride + 1;
size_t OW = OH;


void assert_eq(float *a, float *b, size_t N) {
    float eps = 1e-5;
    for (int i = 0; i < N; i++) {
        if (*(a + i) - *(b + i) > eps) {
            printf("[FAILED]result eq check\n");
            exit(-1);
        }
    }
    printf("[PASSED]result eq check\n");
}

void assign(float *ptr, int a, int b, int c, int d) {
    int v = 1;
    for (int i = 0; i < a; i++) {
        for (int j = 0; j < b; j++) {
            for (int p = 0; p < c; p++) {
                for (int q = 0; q < d; q++) {
                    *(ptr + i * (b * c * d) + j * (c * d) + p * (d) + q) = (v) % 4;// 方便check
                    v++;
                }
            }
        }
    }
}

void printout(float *ptr, int a, int b, int c, int d) {
    for (int i = 0; i < a; i++) {
        for (int j = 0; j < b; j++) {
            for (int p = 0; p < c; p++) {
                for (int q = 0; q < d; q++) {
                    printf("%.2f ", *(ptr + i * b * c * d + j * c * d + p * d + q));
                }
                printf("\n");
            }
        }
    }
    printf("\n");
}

// NxICxIHxIW, OCxICxFHxFW, NxOCxOHxOW
void direct_conv(float *feat, float *weight, float *out) {
    for (int n = 0; n < N; n++) {
        for (int oc = 0; oc < OC; oc++) {
            for (int oh = 0; oh < OH; oh++) {
                for (int ow = 0; ow < OW; ow++) {
                    float value = 0;
                    // feature map中左上角坐标
                    auto ih_start = oh * stride - pad;
                    auto iw_start = ow * stride - pad;

                    for (int ic = 0; ic < IC; ic++) {
                        for (int fh = 0; fh < FH; fh++) {
                            for (int fw = 0; fw < FW; fw++) {

                                float f = 0;
                                // 边界检查,没有到pad处
                                if (ih_start + fh >= 0 && iw_start + fw >= 0 && ih_start + fh < IH &&
                                    iw_start + fw < IW) {
                                    f = *(feat + n * (IC * (IH) * (IW)) + ic * ((IH) * (IW)) + (ih_start + fh) * (IW) +
                                          iw_start + fw);
                                }
                                auto w = *(weight + oc * (IC * FH * FW) + ic * (FH * FW) + fh * FW + fw);
                                value += f * w;
                                printf("feat@{%d,%d,%d,%d}, val:%f\n", n, ic, ih_start + fh, iw_start + fw, f);
                                printf("weight@{%d,%d,%d,%d}, val:%f\n", oc, ic, fh, fw, w);
                            }
                        }
                    }
                    *(out + n * (OC * OH * OW) + oc * (OH * OW) + oh * OW + ow) = value;
                }
            }
        }
    }
}

void implicit_gemm_cutlass_offcical(float *feat, float *weight, float *out) {
    /*
     * 我理解这里只要定好MNK的维度分别表示conv的维度是什么,然后就按照NCHW写就行了*/
    int GEMM_M = N * OH * OW;
    int GEMM_N = OC;
    int GEMM_K = IC * FH * FW;
    // 在N,OC,OH,OW上iter
    for (auto i = 0; i < GEMM_M; i++) {
        for (auto j = 0; j < GEMM_N; j++) {
            float value = 0;
            // 偏移计算都根据NCHW这个layout来的, 和物理没关系
            int n = i / (OH * OW);
            int oh_ow_res = i % (OH * OW);
            int oh = oh_ow_res / OW;
            int ow = oh_ow_res % OW;
            int oc = j;
            // 在filter上iter
            for (auto k = 0; k < GEMM_K; k++) {
                // 偏移计算都根据NCHW这个layout来的, 和物理没关系
                int ic = k / (FH * FW);
                int fh_fw_res = k % (FH * FW);
                int fh = fh_fw_res / (FW);
                int fw = fh_fw_res % (FW);
                int ih = oh * stride - pad + fh;
                int iw = ow * stride - pad + fw;
                float _w = 0, _f = 0;
                // 只有下面操作内存时,才和物理相关。
                _w = *(weight + oc * (IC * FH * FW) + ic * (FH * FW) + fh * FW + fw); // acc to NCHW
                if (ih >= 0 && ih < IH && iw >= 0 && iw < FW) {
                    _f = *(feat + n * (IC * IH * IW) + ic * (IH * IW) + ih * IW + iw); // acc to NCHW
                }
                value += _f * _w;
            }
            *(out + n * (OC * OH * OW) + oc * (OH * OW) + oh * OW + ow) = value; // acc to NCHW
        }
    }
}

void implicit_gemm(float *feat, float *weight, float *out) {
    int GEMM_M = OC;
    int GEMM_N = N * OH * OW;
    int GEMM_K = IC * FH * FW;
    for (int i = 0; i < GEMM_M; i++) {
        int oc = i;
        for (int j = 0; j < GEMM_N; j++) {
            int value = 0;
            int n = j / (OH * OW);
            int j_res = j % (OH * OW);
            int oh = j_res / OW;
            int ow = j_res % OW;
            for (int k = 0; k < GEMM_K; k++) {
                int ic = k / (FH * FW);
                int k_res = k % (FH * FW);
                int fh = k_res / FW;
                int fw = k_res % FW;
                int ih = oh * stride - pad + fh;
                int iw = ow * stride - pad + fw;
                float _f = 0, _w = 0;
                _w = *(weight + oc * IC * FH * FW + ic * FH * FW + fh * FW + fw);
                if (ih >= 0 && iw >= 0 && ih < IH && iw < IW) {
                    _f = *(feat + n * IC * IH * IW + ic * IH * IW + ih * IW + iw);
                }
                value += _f * _w;
            }
            *(out + n * OC * OH * OW + oc * OH * OW + oh * OW + ow) = value;
        }
    }
}

int main() {

    float *feat = (float *) malloc(sizeof(float) * N * IC * (IH) * (IW));
    assign(feat, N, IC, IH, IW);
    printf("feat\n");
    printout(feat, N, IC, IH, IW);
    printf("\n");

    printf("weight\n");
    float *weight = (float *) malloc(sizeof(float) * OC * IC * FH * FW);
    assign(weight, OC, IC, FH, FW);
    printout(weight, OC, IC, FH, FW);
    printf("\n");

    float *out = (float *) malloc(sizeof(float) * N * OC * OH * OW);
    memset(out, 0, N * OC * OH * OW);
    direct_conv(feat, weight, out);
    printf("out\n");
    printout(out, N, OC, OH, OW);
    printf("\n");

    float *implicit_gemm_out = (float *) malloc(sizeof(float) * N * OC * OH * OW);
    memset(implicit_gemm_out, 0, N * OC * OH * OW);
    implicit_gemm(feat, weight, implicit_gemm_out);
    printf("implicit gemm out\n");
    printout(implicit_gemm_out, N, OC, OH, OW);
    assert_eq(out, implicit_gemm_out, N * OC * OH * OW);
    printf("\n");

    float *implicit_gemm_out_cutlass = (float *) malloc(sizeof(float) * N * OC * OH * OW);
    memset(implicit_gemm_out_cutlass, 0, N * OC * OH * OW);
    implicit_gemm_cutlass_offcical(feat, weight, implicit_gemm_out_cutlass);
    printf("implicit gemm cutlass_impl out\n");
    printout(implicit_gemm_out_cutlass, N, OC, OH, OW);
    assert_eq(implicit_gemm_out_cutlass, implicit_gemm_out, N * OC * OH * OW);
    printf("\n");
    return 0;
}
posted @ 2023-04-07 14:29  ijpq  阅读(57)  评论(0编辑  收藏  举报