使用Auto TensorCore CodeGen优化Matmul

使用Auto TensorCore CodeGen优化Matmul

本文将演示如何使用TVM Auto TensorCore CodeGen在Volta / Turing GPU上编写高性能matmul调度。这是一个生成tensorcore内核的解决方案,其中大多数转换都是通过ir传递完成的。用户还可以编写带有张量的调度,生成TensorCore代码。两种解决方案都使用相同的tensorcore内部函数。有关更多详细信息,请参阅如何使用TensorCores优化卷积

准备和算法

支持两种输入数据类型:float16和int8。对于float16,累加器为float32。对于int8,累加器为int32。对于数据布局,“ N”表示无转置,而“ T”表示转置。

import logging
import sys
 
import numpy as np
import tvm
from tvm import te
 
from tvm import autotvm
from tvm.contrib import nvcc
 
 
def matmul_nn(A, B, L, dtype="float16", layout="NN"):
    k = te.reduce_axis((0, L), name="k")
    if dtype == "float16":
        out_type = "float"
    elif dtype == "int8":
        out_type = "int"
    elif dtype == "int4" or dtype == "int1":
        out_type = "int"
    if layout == "NN":
        return te.compute(
            (N, M), lambda i, j: te.sum(A[i, k].astype(out_type) * B[k, j].astype(out_type), axis=k)
        )
    if layout == "NT":
        return te.compute(
            (N, M), lambda i, j: te.sum(A[k, i].astype(out_type) * B[k, j].astype(out_type), axis=k)
        )
    if layout == "TN":
        return te.compute(
            (N, M), lambda i, j: te.sum(A[i, k].astype(out_type) * B[j, k].astype(out_type), axis=k)
        )
    if layout == "TT":
        return te.compute(
            (N, M), lambda i, j: te.sum(A[k, i].astype(out_type) * B[j, k].astype(out_type), axis=k)
        )

调度计算

此调度与GPU上的非张量内核Matmul调度没有什么不同。请参阅如何优化CPU上的GEMM文档,以了解优化调度的基础。设置“ tensor_core”编译指示时,“为tensorcore重写” ir传递pass,将自动转换tensorcore代码生成调度,否则将生成性能较低但功能相同的普通CUDA代码。

TesnsorCore的要求

请注意,在以下两种情况下,即使设置了“ tensor_core”编译指示,TVM仍将退回到正常的CUDA代码源:(1)输入矩阵的m,n或k并非16的倍数;(2)在CUDA9上,warp块大小不是16x16x16,或者在CUDA版本> = 10.0上不是{16x16x16、32x8x16、8x32x16}之一。

在此调度中,storage_align用于减少共享内存的存储区冲突。请参考此 文档 以了解storage_align原语的用法。需要为某些共享内存缓冲区添加偏移量以减少存储区冲突。根据wmma docload_matrix_sync的步幅必须是16字节的倍数,选择8作为float16的偏移量,选择16作为int8的偏移量。

使用AutoTVM在此时间表中搜索最佳配置。

@autotvm.template("tutorial/auto_tensorcore/test_gemm")
def test_gemm(N, L, M, dtype, layout):
    if layout == "NN":
        shape_a = (N, L)
        shape_b = (L, M)
    elif layout == "NT":
        shape_a = (L, N)
        shape_b = (L, M)
    elif layout == "TN":
        shape_a = (N, L)
        shape_b = (M, L)
    elif layout == "TT":
        shape_a = (L, N)
        shape_b = (M, L)
    else:
        print("Unsupported layout:", layout)
        sys.exit(1)
    A = te.placeholder(shape_a, name="A", dtype=dtype)
    B = te.placeholder(shape_b, name="B", dtype=dtype)
    C = matmul_nn(A, B, L, dtype, layout)
 
    s = te.create_schedule(C.op)
    y, x = s[C].op.axis
    k = s[C].op.reduce_axis[0]
 
    # storage_align params
    factor = 16
    offset = 8
    if dtype == "int8":
        factor = 32
        offset = 16
    elif dtype == "int4":
        factor = 64
        offset = 32
    elif dtype == "int1":
        factor = 256
        offset = 128
 
    # create cache stages
    AA = s.cache_read(A, "shared", [C])
    if layout == "NN" or layout == "TN":
        s[AA].storage_align(AA.op.axis[0], factor, offset)
    AL = s.cache_read(AA, "local", [C])
    BB = s.cache_read(B, "shared", [C])
    if layout == "TT" or layout == "NT":
        s[BB].storage_align(BB.op.axis[0], factor, offset)
    BL = s.cache_read(BB, "local", [C])
    CL = s.cache_write(C, "local")
 
    # autotvm search space definition
    cfg = autotvm.get_config()
 
    cfg.define_knob("bx", [2, 4, 8])
    cfg.define_knob("by", [8, 16, 32, 64])
    cfg.define_knob("step_k", [1, 2, 4, 8, 16, 32])
    cfg.define_knob("v", [4, 8, 16, 32])
    by = cfg["by"].val
    bx = cfg["bx"].val
    step_k = cfg["step_k"].val
    v = cfg["v"].val
 
    # thread tile
    TX = 8
    TY = 1
    if dtype == "int4" or dtype == "int1":
        TX = 2
    # warp tile
    warp_tile_m = 16  # it could also be 8 or 32 on CUDA version >= 10.0
    warp_tile_k = 16  # it must be 16 for fp16/int8 data type
    if dtype == "int4":
        warp_tile_m = 8
        warp_tile_k = 32
    elif dtype == "int1":
        warp_tile_m = 8
        warp_tile_k = 128
    # block tile
    tile_x = bx * TX
    tile_y = by * TY
 
    yo, ty = s[C].split(y, tile_y)
    ty, yi = s[C].split(ty, TY)
 
    # schedule for C stage
    xo, xi = s[C].split(x, tile_x)
    WX = min(warp_tile_m, tile_x)
    tz, xi = s[C].split(xi, WX)
    tx, xi = s[C].split(xi, TX)
    s[C].reorder(yo, xo, tz, ty, tx, yi, xi)
    s[C].bind(yo, te.thread_axis("blockIdx.y"))
    s[C].bind(xo, te.thread_axis("blockIdx.x"))
    s[C].bind(ty, te.thread_axis("threadIdx.y"))
    s[C].bind(tz, te.thread_axis("threadIdx.z"))
    s[C].bind(tx, te.thread_axis("threadIdx.x"))
 
    # schedule for CL stage
    ko, ki = s[CL].split(k, step_k * warp_tile_k)
    kl, ki = s[CL].split(ki, warp_tile_k)
    s[CL].compute_at(s[C], tx)
    yo, xo = CL.op.axis
    s[CL].reorder(ko, kl, ki, yo, xo)
 
    # schedule for AA stage
    s[AA].compute_at(s[CL], ko)
    xo, xi = s[AA].split(s[AA].op.axis[1], factor=bx * v)
    tz, tx = s[AA].split(xi, factor=(WX // TX) * v)
    tx, vec = s[AA].split(tx, factor=v)
    fused = s[AA].fuse(s[AA].op.axis[0], xo)
    _, ty = s[AA].split(fused, factor=by)
    s[AA].bind(ty, te.thread_axis("threadIdx.y"))
    s[AA].bind(tz, te.thread_axis("threadIdx.z"))
    s[AA].bind(tx, te.thread_axis("threadIdx.x"))
    # vectorization is very important for float16/int8 inputs
    s[AA].vectorize(vec)
 
    # schedule for BB stage
    s[BB].compute_at(s[CL], ko)
    xo, xi = s[BB].split(s[BB].op.axis[1], factor=bx * v)
    tz, tx = s[BB].split(xi, factor=(WX // TX) * v)
    tx, vec = s[BB].split(tx, factor=v)
    fused = s[BB].fuse(s[BB].op.axis[0], xo)
    _, ty = s[BB].split(fused, factor=by)
    s[BB].bind(ty, te.thread_axis("threadIdx.y"))
    s[BB].bind(tz, te.thread_axis("threadIdx.z"))
    s[BB].bind(tx, te.thread_axis("threadIdx.x"))
    s[BB].vectorize(vec)
 
    s[AL].compute_at(s[CL], kl)
    s[BL].compute_at(s[CL], kl)
 
    # set the 'tensor_core' pragma for tensorcore codegen
    s[CL].pragma(ko, "tensor_core")
 
    return s, [A, B, C]

自动调谐和测试AutoTune and Test

使用调谐器来调整调度,使用最佳配置生成代码,运行内核与numpy进行比较,以检查结果是否正确。

# check whether the gpu has tensorcore
if not tvm.gpu(0).exist or not tvm.runtime.enabled("cuda"):
    raise Exception("skip building this tutorial because cuda is not enabled..")
 
ctx = tvm.gpu()
if not nvcc.have_tensorcore(ctx.compute_version):
    raise Exception("the gpu has no tensorcore, skipping...")
 
M, N, L = 512, 32, 512
dtype = "float16"
layout = "NN"
if len(sys.argv) >= 4:
    M, N, L = int(sys.argv[1]), int(sys.argv[2]), int(sys.argv[3])
if len(sys.argv) >= 5:
    dtype = sys.argv[4]
if len(sys.argv) >= 6:
    layout = sys.argv[5]
 
# check whether current gpu arch support support current dtype's wmma codegen
cuda_compute_capability = tvm.runtime._ffi_api.GetDeviceAttr(2, 0, 4)
major, minor = nvcc.parse_compute_version(cuda_compute_capability)
if dtype == "int8":
    assert major == 7 and minor >= 2
elif dtype == "int4" or dtype == "int1":
    # int4/int1 only support layout TN
    assert major == 7 and minor == 5 and layout == "TN"
 
 
def tune_and_evaluate(M, N, L, dtype, layout):
    task = autotvm.task.create(
        "tutorial/auto_tensorcore/test_gemm", args=(N, L, M, dtype, layout), target="cuda"
    )
    print(task.config_space)
 
    logging.getLogger("autotvm").setLevel(logging.DEBUG)
    logging.getLogger("autotvm").addHandler(logging.StreamHandler(sys.stdout))
 
    measure_option = autotvm.measure_option(builder="local", runner=autotvm.LocalRunner(number=5))
 
    tuner = autotvm.tuner.XGBTuner(task)
    tuner.tune(
        n_trial=1000,
        measure_option=measure_option,
        callbacks=[autotvm.callback.log_to_file("matmul.log")],
    )
 
    dispatch_context = autotvm.apply_history_best("matmul.log")
    best_config = dispatch_context.query(task.target, task.workload)
    print("\nBest config:")
    print(best_config)
    with autotvm.apply_history_best("matmul.log"):
        with tvm.target.Target("cuda"):
            s, arg_bufs = test_gemm(N, L, M, dtype, layout)
            print(tvm.lower(s, arg_bufs, simple_mode=True))
            func = tvm.build(s, arg_bufs)
    dev_module = func.imported_modules[0]
    print(dev_module.get_source())
 
    # check correctness
    if layout == "NN":
        shape_a = (N, L)
        shape_b = (L, M)
    elif layout == "NT":
        shape_a = (L, N)
        shape_b = (L, M)
    elif layout == "TN":
        shape_a = (N, L)
        shape_b = (M, L)
    elif layout == "TT":
        shape_a = (L, N)
        shape_b = (M, L)
 
    a_np = None
    b_np = None
    c_np = None
    c_np_type = None
    if dtype == "float16":
        c_np_type = np.float32
        a_np = np.random.uniform(size=shape_a).astype(np.float16)
        b_np = np.random.uniform(size=shape_b).astype(np.float16)
        if layout == "NN":
            c_np = np.dot(a_np, b_np)
        elif layout == "NT":
            c_np = np.dot(a_np.T, b_np)
        elif layout == "TN":
            c_np = np.dot(a_np, b_np.T)
        elif layout == "TT":
            c_np = np.dot(a_np.T, b_np.T)
    elif dtype == "int8":
        c_np_type = np.int32
        a_np = np.random.randint(low=-128, high=127, size=shape_a).astype(np.int8)
        b_np = np.random.randint(low=-128, high=127, size=shape_b).astype(np.int8)
        if layout == "NN":
            c_np = np.dot(a_np.astype(np.int32), b_np.astype(np.int32))
        elif layout == "NT":
            c_np = np.dot(a_np.astype(np.int32).T, b_np.astype(np.int32))
        elif layout == "TN":
            c_np = np.dot(a_np.astype(np.int32), b_np.astype(np.int32).T)
        elif layout == "TT":
            c_np = np.dot(a_np.astype(np.int32).T, b_np.astype(np.int32).T)
    elif dtype == "int4":
        c_np_type = np.int32
        a_np_int = np.random.randint(low=-8, high=7, size=shape_a).astype(np.int32)
        b_np_int = np.random.randint(low=-8, high=7, size=shape_b).astype(np.int32)
        # "TN"
        c_np = np.dot(a_np_int.astype(np.int32), b_np_int.astype(np.int32).T)
        a_np = np.zeros(shape=(N, int(L / 8)), dtype=np.int32)
        b_np = np.zeros(shape=(M, int(L / 8)), dtype=np.int32)
        # a_np --> col_major
        for i in range(N):
            for j in range(int(L / 8)):
                for k in range(8):
                    a_np[i, j] = a_np[i, j] | ((a_np_int[i, j * 8 + k] & 0xF) << ((7 - k) * 4))
 
        # b_np --> row_major
        for i in range(M):
            for j in range(int(L / 8)):
                for k in range(8):
                    b_np[i, j] = b_np[i, j] | ((b_np_int[i, j * 8 + k] & 0xF) << ((7 - k) * 4))
    elif dtype == "int1":
        c_np_type = np.int32
        a_np_int = np.random.randint(low=0, high=1, size=shape_a).astype(np.int32)
        b_np_int = np.random.randint(low=0, high=1, size=shape_b).astype(np.int32)
        # "TN"
        c_np = np.dot(a_np_int.astype(np.int32), b_np_int.astype(np.int32).T)
        a_np = np.zeros(shape=(N, int(L / 32)), dtype=np.int32)
        b_np = np.zeros(shape=(M, int(L / 32)), dtype=np.int32)
        for i in range(N):
            for j in range(int(L / 32)):
                for k in range(32):
                    a_np[i, j] = a_np[i, j] | ((a_np_int[i, j * 32 + k] & 0xF) << (31 - k))
 
        for i in range(M):
            for j in range(int(L / 32)):
                for k in range(32):
                    b_np[i, j] = b_np[i, j] | ((b_np_int[i, j * 32 + k] & 0xF) << (31 - k))
 
    c_tvm = tvm.nd.array(np.zeros(c_np.shape, dtype=c_np_type), ctx=ctx)
    a_tvm = tvm.nd.array(a_np, ctx=ctx)
    b_tvm = tvm.nd.array(b_np, ctx=ctx)
    func(a_tvm, b_tvm, c_tvm)
 
    tvm.testing.assert_allclose(c_np, c_tvm.asnumpy(), rtol=1e-3)
 
    evaluator = func.time_evaluator(func.entry_name, ctx, number=100)
    print("Time cost of this operator: %f" % evaluator(a_tvm, b_tvm, c_tvm).mean)
 
 
# We do not run the tuning in our webpage server since it takes some time.
# Uncomment the following line to run it by yourself.
 
# tune_and_evaluate(M, N, L, dtype, layout)

样本输出Sample Output

Best config:

[('bx', 4), ('by', 32), ('step_k', 16), ('v', 8)],,None,40

Finish loading 162 records

produce compute {

  // attr [iter_var(blockIdx.y, , blockIdx.y)] thread_extent = 1

  // attr [compute.local] storage_scope = "wmma.accumulator"

  allocate compute.local[float32 * 256]

  // attr [A.shared] storage_scope = "shared"

  allocate A.shared[float16 * 8448]

  // attr [B.shared] storage_scope = "shared"

  allocate B.shared[float16 * 8192]

  // attr [A.shared.local] storage_scope = "wmma.matrix_b"

  allocate A.shared.local[float16 * 256]

  // attr [B.shared.local] storage_scope = "wmma.matrix_a"

  allocate B.shared.local[float16 * 256]

  // attr [iter_var(blockIdx.x, , blockIdx.x)] thread_extent = 16

  // attr [iter_var(threadIdx.z, , threadIdx.z)] thread_extent = 2

  // attr [iter_var(threadIdx.y, , threadIdx.y)] thread_extent = 32

  // attr [iter_var(threadIdx.x, , threadIdx.x)] thread_extent = 2

  produce compute.local {

    for (j.c.init, 0, 1) {

      tvm_fill_fragment(compute.local, 16, 16, 16, 0, 0f)

    }

    // attr [iter_var(k.outer, )] pragma_tensor_core = 1

    for (k.outer, 0, 2) {

      produce A.shared {

        for (ax0.ax1.outer.fused.outer, 0, 8) {

          // attr [iter_var(threadIdx.y, , threadIdx.y)] thread_extent = 32

          // attr [iter_var(threadIdx.z, , threadIdx.z)] thread_extent = 2

          // attr [iter_var(threadIdx.x, , threadIdx.x)] thread_extent = 2

          A.shared[ramp((((((ax0.ax1.outer.fused.outer*1056) + (floordiv(threadIdx.y, 8)*264)) + (floormod(threadIdx.y, 8)*32)) + (threadIdx.z*16)) + (threadIdx.x*8)), 1, 8)] = A[ramp(((((((ax0.ax1.outer.fused.outer*2048) + (floordiv(threadIdx.y, 8)*512)) + (k.outer*256)) + (floormod(threadIdx.y, 8)*32)) + (threadIdx.z*16)) + (threadIdx.x*8)), 1, 8)]

        }

      }

      produce B.shared {

        for (ax0.ax1.outer.fused.outer, 0, 8) {

          // attr [iter_var(threadIdx.y, , threadIdx.y)] thread_extent = 32

          // attr [iter_var(threadIdx.z, , threadIdx.z)] thread_extent = 2

          // attr [iter_var(threadIdx.x, , threadIdx.x)] thread_extent = 2

          B.shared[ramp(((((ax0.ax1.outer.fused.outer*1024) + (threadIdx.y*32)) + (threadIdx.z*16)) + (threadIdx.x*8)), 1, 8)] = B[ramp(((((((k.outer*131072) + (ax0.ax1.outer.fused.outer*16384)) + (threadIdx.y*512)) + (blockIdx.x*32)) + (threadIdx.z*16)) + (threadIdx.x*8)), 1, 8)]

        }

      }

      for (k.inner.outer, 0, 16) {

        produce A.shared.local {

          for (ax1, 0, 1) {

            tvm_load_matrix_sync(A.shared.local, 16, 16, 16, 0, &(A.shared[(((threadIdx.y/16)*4224) + (k.inner.outer*16))]), 264, "col_major")

          }

        }

        produce B.shared.local {

          for (ax0, 0, 1) {

            for (ax1, 0, 1) {

              tvm_load_matrix_sync(B.shared.local, 16, 16, 16, 0, &(B.shared[((k.inner.outer*512) + (threadIdx.z*16))]), 32, "col_major")

            }

          }

        }

        for (k.inner.inner, 0, 1) {

          for (j.c, 0, 1) {

            tvm_mma_sync(compute.local, 0, B.shared.local, 0, A.shared.local, 0, compute.local, 0)

          }

        }

      }

    }

  }

  for (j.inner.inner.inner, 0, 1) {

    tvm_store_matrix_sync(compute.local, 16, 16, 16, 0, &(compute[((((threadIdx.y/16)*8192) + (blockIdx.x*32)) + (threadIdx.z*16))]), 512, "col_major")

  }

}

 

#include <cuda_fp16.h>

__device__ half max(const half a, const half b)

{

  return __hgt(__half(a), __half(b)) ? a : b;

}

__device__ half min(const half a, const half b)

{

  return __hlt(__half(a), __half(b)) ? a : b;

}

__device__ half operator+(const volatile __half &a,  const volatile __half &b)

{

  return __hadd(a, b);

}

__device__ half operator<=(const volatile __half &a,  const volatile __half &b)

{

  return __hlt(a, b);

}

__device__ half operator*(const volatile __half &a,  const volatile __half &b)

{

  return __hmul(a, b);

}

#include <mma.h>

extern "C" __global__ void default_function_kernel0( half* __restrict__ A,  half* __restrict__ B,  float* __restrict__ compute) {

  nvcuda::wmma::fragment<nvcuda::wmma::accumulator, 16, 16, 16, float> compute_local[1];

  __shared__ half A_shared[8448];

  __shared__ half B_shared[8192];

  nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, 16, 16, 16, half, nvcuda::wmma::col_major> A_shared_local[1];

  nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, 16, 16, 16, half, nvcuda::wmma::col_major> B_shared_local[1];

  for (int j_c_init = 0; j_c_init < 1; ++j_c_init) {

    (void)nvcuda::wmma::fill_fragment(compute_local[0], 0.000000e+00f);

  }

  for (int k_outer = 0; k_outer < 2; ++k_outer) {

    __syncthreads();

    for (int ax0_ax1_outer_fused_outer = 0; ax0_ax1_outer_fused_outer < 8; ++ax0_ax1_outer_fused_outer) {

      ((__shared__ float4*)(A_shared + (((((ax0_ax1_outer_fused_outer * 1056) + ((((int)threadIdx.y) >> 3) * 264)) + ((((int)threadIdx.y) & 7) * 32)) + (((int)threadIdx.z) * 16)) + (((int)threadIdx.x) * 8))))[0] = (( float4*)(A + ((((((ax0_ax1_outer_fused_outer * 2048) + ((((int)threadIdx.y) >> 3) * 512)) + (k_outer * 256)) + ((((int)threadIdx.y) & 7) * 32)) + (((int)threadIdx.z) * 16)) + (((int)threadIdx.x) * 8))))[0];

    }

    for (int ax0_ax1_outer_fused_outer1 = 0; ax0_ax1_outer_fused_outer1 < 8; ++ax0_ax1_outer_fused_outer1) {

      ((__shared__ float4*)(B_shared + ((((ax0_ax1_outer_fused_outer1 * 1024) + (((int)threadIdx.y) * 32)) + (((int)threadIdx.z) * 16)) + (((int)threadIdx.x) * 8))))[0] = (( float4*)(B + ((((((k_outer * 131072) + (ax0_ax1_outer_fused_outer1 * 16384)) + (((int)threadIdx.y) * 512)) + (((int)blockIdx.x) * 32)) + (((int)threadIdx.z) * 16)) + (((int)threadIdx.x) * 8))))[0];

    }

    __syncthreads();

    for (int k_inner_outer = 0; k_inner_outer < 16; ++k_inner_outer) {

      for (int ax1 = 0; ax1 < 1; ++ax1) {

        (void)nvcuda::wmma::load_matrix_sync(A_shared_local[0], &(A_shared[(((((int)threadIdx.y) / 16) * 4224) + (k_inner_outer * 16))]), 264);

      }

      for (int ax0 = 0; ax0 < 1; ++ax0) {

        for (int ax11 = 0; ax11 < 1; ++ax11) {

          (void)nvcuda::wmma::load_matrix_sync(B_shared_local[0], &(B_shared[((k_inner_outer * 512) + (((int)threadIdx.z) * 16))]), 32);

        }

      }

      for (int k_inner_inner = 0; k_inner_inner < 1; ++k_inner_inner) {

        for (int j_c = 0; j_c < 1; ++j_c) {

          (void)nvcuda::wmma::mma_sync(compute_local[0], B_shared_local[0], A_shared_local[0], compute_local[0]);

        }

      }

    }

  }

  for (int j_inner_inner_inner = 0; j_inner_inner_inner < 1; ++j_inner_inner_inner) {

    (void)nvcuda::wmma::store_matrix_sync(&(compute[((((((int)threadIdx.y) / 16) * 8192) + (((int)blockIdx.x) * 32)) + (((int)threadIdx.z) * 16))]), compute_local[0], 512, nvcuda::wmma::mem_col_major);

  }

}

Time cost of this operator: 0.000008

摘要

本文演示了如何使用TVM的AutoTensorCoreCodeGen生成tensorcore内核。

https://tvm.apache.org/docs/tutorials/optimize/opt_matmul_auto_tensorcore.html#sphx-glr-tutorials-optimize-opt-matmul-auto-tensorcore-py

 

posted @ 2020-12-23 06:19  吴建明wujianming  阅读(263)  评论(0编辑  收藏  举报