如何使用TensorCores优化卷积

如何使用TensorCores优化卷积

本文将演示如何在TVM中使用TensorCores编写高性能的卷积计划。假设卷积的输入有大量数据。首先介绍如何在GPU上优化卷积

TensorCore简介

每个Tensor核心都提供一个4x4x4的矩阵处理阵列,该阵列可以运行 ,其中A,B,C和D是4x4矩阵,如图所示。矩阵乘法输入A和B是FP16矩阵,而累加矩阵C和D可以是FP16或FP32矩阵。D = A * B + C

但是,CUDA程序员只能使用扭曲级原语,在张量核上执行16x16x16半精度矩阵乘法。在调用矩阵乘法之前,程序员必须将内存中的数据显式地加载到寄存器中。NVCC编译器将该原语转换为多个内存加载指令。在运行时runtime,每个线程从矩阵A加载16个元素,从矩阵B加载16个元素。wmma::mma_sync(acc_frag, a_frag, b_frag, acc_frag)wmma::load_matrix_sync

准备和算法

将固定大小用于256通道和14 x 14尺寸的输入张量。批处理大小为256。卷积过滤器包含512个大小为3 x 3的过滤器。对于卷积,使用步幅大小1和填充大小1。在示例中,使用NHWCnc内存布局。以下代码定义了TVM中的卷积算法。

import tvm
from tvm import te
import numpy as np
from tvm.contrib import nvcc
 
# The sizes of inputs and filters
batch_size = 256
height = 14
width = 14
in_channels = 256
out_channels = 512
kernel_h = 3
kernel_w = 3
pad_h = 1
pad_w = 1
stride_h = 1
stride_w = 1
 
# TensorCore shape
block_size = 16
 
assert batch_size % block_size == 0
assert in_channels % block_size == 0
assert out_channels % block_size == 0
 
# Input feature map: (N, H, W, IC, n, ic)
data_shape = (
    batch_size // block_size,
    height,
    width,
    in_channels // block_size,
    block_size,
    block_size,
)
# Kernel: (H, W, IC, OC, ic, oc)
kernel_shape = (
    kernel_h,
    kernel_w,
    in_channels // block_size,
    out_channels // block_size,
    block_size,
    block_size,
)
# Output feature map: (N, H, W, OC, n, oc)
output_shape = (
    batch_size // block_size,
    height,
    width,
    out_channels // block_size,
    block_size,
    block_size,
)
 
# Reduction axes
kh = te.reduce_axis((0, kernel_h), name="kh")
kw = te.reduce_axis((0, kernel_w), name="kw")
ic = te.reduce_axis((0, in_channels // block_size), name="ic")
ii = te.reduce_axis((0, block_size), name="ii")
 
# Algorithm
A = te.placeholder(data_shape, name="A", dtype="float16")
W = te.placeholder(kernel_shape, name="W", dtype="float16")
Apad = te.compute(
    (
        batch_size // block_size,
        height + 2 * pad_h,
        width + 2 * pad_w,
        in_channels // block_size,
        block_size,
        block_size,
    ),
    lambda n, h, w, i, nn, ii: tvm.tir.if_then_else(
        tvm.tir.all(h >= pad_h, h - pad_h < height, w >= pad_w, w - pad_w < width),
        A[n, h - pad_h, w - pad_w, i, nn, ii],
        tvm.tir.const(0.0, "float16"),
    ),
    name="Apad",
)
Conv = te.compute(
    output_shape,
    lambda n, h, w, o, nn, oo: te.sum(
        Apad[n, h * stride_h + kh, w * stride_w + kw, ic, nn, ii].astype("float32")
        * W[kh, kw, ic, o, ii, oo].astype("float32"),
        axis=[ic, kh, kw, ii],
    ),
    name="Conv",
)
 
s = te.create_schedule(Conv.op)
s[Apad].compute_inline()

存储范围

在传统的GPU计划中,具有全局,共享和本地内存范围。为了支持TensorCores,添加了另外三个特殊的存储范围:wmma.matrix_a, wmma.matrix_b和wmma.accumulator。在硬件上,所有片段作用域存储在片上寄存器级别,与本地存储器位于同一位置。

# Designate the memory hierarchy
AS = s.cache_read(Apad, "shared", [Conv])
WS = s.cache_read(W, "shared", [Conv])
AF = s.cache_read(AS, "wmma.matrix_a", [Conv])
WF = s.cache_read(WS, "wmma.matrix_b", [Conv])
ConvF = s.cache_write(Conv, "wmma.accumulator")

定义张量特征

实际上,TensorCore是一种特殊的硬件操作。因此,可以使用Tensorize用TensorCore指令替换计算单位。首先,需要定义张量特征。

有四种基本的操作TensorCore: ,, 。由于都用于矩阵乘法,因此可以编写以下三个内部函数。fill_fragmentload_matrixmma_syncstore_matrixfill_fragmentmma_sync

def intrin_wmma_load_matrix(scope):
    n = 16
    A = te.placeholder((n, n), name="A", dtype="float16")
    BA = tvm.tir.decl_buffer(A.shape, A.dtype, scope="shared", data_alignment=32, offset_factor=256)
    C = te.compute((n, n), lambda i, j: A[i, j], name="C")
    BC = tvm.tir.decl_buffer(C.shape, C.dtype, scope=scope, data_alignment=32, offset_factor=256)
 
    def intrin_func(ins, outs):
        ib = tvm.tir.ir_builder.create()
 
        BA = ins[0]
        BC = outs[0]
        ib.emit(
            tvm.tir.call_intrin(
                "handle",
                "tir.tvm_load_matrix_sync",
                BC.data,
                n,
                n,
                n,
                BC.elem_offset // 256,
                BA.access_ptr("r"),
                n,
                "row_major",
            )
        )
        return ib.get()
 
    return te.decl_tensor_intrin(C.op, intrin_func, binds={A: BA, C: BC})
 
 
def intrin_wmma_gemm():
    n = 16
    A = te.placeholder((n, n), name="A", dtype="float16")
    B = te.placeholder((n, n), name="B", dtype="float16")
    k = te.reduce_axis((0, n), name="k")
    C = te.compute(
        (n, n),
        lambda ii, jj: te.sum(A[ii, k].astype("float") * B[k, jj].astype("float"), axis=k),
        name="C",
    )
    BA = tvm.tir.decl_buffer(
        A.shape, A.dtype, name="BA", scope="wmma.matrix_a", data_alignment=32, offset_factor=256
    )
    BB = tvm.tir.decl_buffer(
        B.shape, B.dtype, name="BB", scope="wmma.matrix_b", data_alignment=32, offset_factor=256
    )
    BC = tvm.tir.decl_buffer(
        C.shape, C.dtype, name="BC", scope="wmma.accumulator", data_alignment=32, offset_factor=256
    )
 
    def intrin_func(ins, outs):
        BA, BB = ins
        (BC,) = outs
 
        def init():
            ib = tvm.tir.ir_builder.create()
            ib.emit(
                tvm.tir.call_intrin(
                    "handle", "tir.tvm_fill_fragment", BC.data, n, n, n, BC.elem_offset // 256, 0.0
                )
            )
            return ib.get()
 
        def update():
            ib = tvm.tir.ir_builder.create()
            ib.emit(
                tvm.tir.call_intrin(
                    "handle",
                    "tir.tvm_mma_sync",
                    BC.data,
                    BC.elem_offset // 256,
                    BA.data,
                    BA.elem_offset // 256,
                    BB.data,
                    BB.elem_offset // 256,
                    BC.data,
                    BC.elem_offset // 256,
                )
            )
            return ib.get()
 
        return update(), init(), update()
 
    return te.decl_tensor_intrin(C.op, intrin_func, binds={A: BA, B: BB, C: BC})
 
 
def intrin_wmma_store_matrix():
    n = 16
    A = te.placeholder((n, n), name="A", dtype="float32")
    BA = tvm.tir.decl_buffer(
        A.shape, A.dtype, scope="wmma.accumulator", data_alignment=32, offset_factor=256
    )
    C = te.compute((n, n), lambda i, j: A[i, j], name="C")
    BC = tvm.tir.decl_buffer(C.shape, C.dtype, scope="global", data_alignment=32, offset_factor=256)
 
    def intrin_func(ins, outs):
        ib = tvm.tir.ir_builder.create()
        BA = ins[0]
        BC = outs[0]
        ib.emit(
            tvm.tir.call_intrin(
                "handle",
                "tir.tvm_store_matrix_sync",
                BA.data,
                n,
                n,
                n,
                BA.elem_offset // 256,
                BC.access_ptr("w"),
                n,
                "row_major",
            )
        )
        return ib.get()
 
    return te.decl_tensor_intrin(C.op, intrin_func, binds={A: BA, C: BC})

调度计算

要在TVM中使用TensorCores,必须将计算调度到特定的结构中以匹配张量特征。与传统的GPU程序一样,可以使用共享内存来提高速度。如果对阻塞和共享内存有任何疑问,请参阅如何在GPU上优化卷积

在此示例中,每个块包含2x4变形,并且每个变形调用4x2 TensorCore指令。因此,每个warp的输出形状为64x32,每个块输出128x128标题。由于共享内存空间的限制,一次只能加载2个块(2x128x128个图块)。

warp操作

请注意,所有TensorCore指令均为warp级指令,这意味着warp中的所有32个线程应同时执行此指令。使theadIdx.x范围= 32是解决此问题的最简单方法之一。然后可以将threadIdx.x绑定到任何循环,除了那些直接或间接包含TensorCore内部函数的循环。还要注意,这不是唯一的解决方案。唯一要做的是确保warp中的所有线程可以同时调用TensorCore。

# Define tiling sizes
block_row_warps = 4
block_col_warps = 2
warp_row_tiles = 2
warp_col_tiles = 4
warp_size = 32
chunk = 2
 
block_x = te.thread_axis("blockIdx.x")
block_y = te.thread_axis("blockIdx.y")
block_z = te.thread_axis("blockIdx.z")
thread_x = te.thread_axis("threadIdx.x")
thread_y = te.thread_axis("threadIdx.y")
thread_z = te.thread_axis("threadIdx.z")
 
nc, hc, wc, oc, nnc, ooc = Conv.op.axis
block_k = s[Conv].fuse(hc, wc)
s[Conv].bind(block_k, block_z)
nc, nci = s[Conv].split(nc, factor=warp_row_tiles)
block_i, nc = s[Conv].split(nc, factor=block_row_warps)
oc, oci = s[Conv].split(oc, factor=warp_col_tiles)
block_j, oc = s[Conv].split(oc, factor=block_col_warps)
s[Conv].reorder(block_k, block_i, block_j, nc, oc, nci, oci, nnc, ooc)
s[Conv].bind(block_i, block_x)
s[Conv].bind(block_j, block_y)
s[Conv].bind(nc, thread_y)
s[Conv].bind(oc, thread_z)
 
# Schedule local computation
s[ConvF].compute_at(s[Conv], oc)
n, h, w, o, nnf, oof = ConvF.op.axis
ko, ki = s[ConvF].split(ic, factor=chunk)
s[ConvF].reorder(ko, kh, ki, kw, n, o, nnf, oof, ii)
 
# Move intermediate computation into each output compute tile
s[AF].compute_at(s[ConvF], kw)
s[WF].compute_at(s[ConvF], kw)
 
# Schedule for A's share memory
s[AS].compute_at(s[ConvF], kh)
n, h, w, i, nn, ii = AS.op.axis
tx, xo = s[AS].split(n, nparts=block_row_warps)
ty, yo = s[AS].split(xo, nparts=block_col_warps)
t = s[AS].fuse(nn, ii)
to, ti = s[AS].split(t, factor=warp_size)
s[AS].bind(tx, thread_y)
s[AS].bind(ty, thread_z)
s[AS].bind(ti, thread_x)
 
# Schedule for W's share memory
s[WS].compute_at(s[ConvF], kh)
kh, kw, ic, o, ii, oo = WS.op.axis
tx, xo = s[WS].split(o, nparts=block_row_warps)
ty, yo = s[WS].split(xo, nparts=block_col_warps)
t = s[WS].fuse(ii, oo)
to, ti = s[WS].split(t, nparts=warp_size)
s[WS].bind(tx, thread_y)
s[WS].bind(ty, thread_z)
s[WS].bind(to, thread_x)
s[WS].vectorize(ti)
print(tvm.lower(s, [A, W, Conv], simple_mode=True))

输出:

primfn(A_1: handle, W_1: handle, Conv_1: handle) -> ()
  attr = {"global_symbol": "main", "tir.noalias": True}
  buffers = {Conv: Buffer(Conv_2: Pointer(float32), float32, [16, 14, 14, 32, 16, 16], []),
             W: Buffer(W_2: Pointer(float16), float16, [3, 3, 16, 32, 16, 16], []),
             A: Buffer(A_2: Pointer(float16), float16, [16, 14, 14, 16, 16, 16], [])}
  buffer_map = {A_1: A, W_1: W, Conv_1: Conv} {
  attr [IterVar(blockIdx.z: int32, (nullptr), "ThreadIndex", "blockIdx.z")] "thread_extent" = 196;
  attr [Conv.wmma.accumulator: Pointer(float32)] "storage_scope" = "wmma.accumulator";
  allocate(Conv.wmma.accumulator, float32, [2048]);
  attr [Apad.shared: Pointer(float16)] "storage_scope" = "shared";
  allocate(Apad.shared, float16, [12288]);
  attr [W.shared: Pointer(float16)] "storage_scope" = "shared";
  allocate(W.shared, float16, [12288]);
  attr [Apad.shared.wmma.matrix_a: Pointer(float16)] "storage_scope" = "wmma.matrix_a";
  allocate(Apad.shared.wmma.matrix_a, float16, [512]);
  attr [W.shared.wmma.matrix_b: Pointer(float16)] "storage_scope" = "wmma.matrix_b";
  allocate(W.shared.wmma.matrix_b, float16, [1024]);
  attr [IterVar(blockIdx.x: int32, (nullptr), "ThreadIndex", "blockIdx.x")] "thread_extent" = 2;
  attr [IterVar(blockIdx.y: int32, (nullptr), "ThreadIndex", "blockIdx.y")] "thread_extent" = 4;
  attr [IterVar(threadIdx.y: int32, (nullptr), "ThreadIndex", "threadIdx.y")] "thread_extent" = 4;
  attr [IterVar(threadIdx.z: int32, (nullptr), "ThreadIndex", "threadIdx.z")] "thread_extent" = 2 {
    for (n.c.init: int32, 0, 2) {
      for (o.c.init: int32, 0, 4) {
        for (nn.c.init: int32, 0, 16) {
          for (oo.c.init: int32, 0, 16) {
            Conv.wmma.accumulator[((((n.c.init*1024) + (o.c.init*256)) + (nn.c.init*16)) + oo.c.init)] = 0f32
          }
        }
      }
    }
    for (ic.outer: int32, 0, 8) {
      for (kh: int32, 0, 3) {
        for (ax2: int32, 0, 3) {
          for (ax3: int32, 0, 2) {
            for (ax4.ax5.fused.outer: int32, 0, 8) {
              attr [IterVar(threadIdx.x: int32, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 32;
              Apad.shared[((((((threadIdx.y*3072) + (threadIdx.z*1536)) + (ax2*512)) + (ax3*256)) + (ax4.ax5.fused.outer*32)) + threadIdx.x)] = @tir.if_then_else(((((1 <= (floordiv(blockIdx.z, 14) + kh)) && ((floordiv(blockIdx.z, 14) + kh) < 15)) && (1 <= (ax2 + floormod(blockIdx.z, 14)))) && ((ax2 + floormod(blockIdx.z, 14)) < 15)), (float16*)A_2[(((((((((((blockIdx.x*6422528) + (threadIdx.y*1605632)) + (threadIdx.z*802816)) + (kh*57344)) + (blockIdx.z*4096)) + (ax2*4096)) + (ic.outer*512)) + (ax3*256)) + (ax4.ax5.fused.outer*32)) + threadIdx.x) - 61440)], 0f16, dtype=float16)
            }
          }
        }
        for (ax1: int32, 0, 3) {
          for (ax2_1: int32, 0, 2) {
            attr [IterVar(threadIdx.x, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 32;
            W.shared[ramp((((((ax1*4096) + (ax2_1*2048)) + (threadIdx.y*512)) + (threadIdx.z*256)) + (threadIdx.x*8)), 1, 8)] = (float16x8*)W_2[ramp(((((((((kh*393216) + (ax1*131072)) + (ic.outer*16384)) + (ax2_1*8192)) + (blockIdx.y*2048)) + (threadIdx.y*512)) + (threadIdx.z*256)) + (threadIdx.x*8)), 1, 8)]
          }
        }
        for (ic.inner: int32, 0, 2) {
          for (kw: int32, 0, 3) {
            for (ax0: int32, 0, 2) {
              for (ax4: int32, 0, 16) {
                for (ax5: int32, 0, 16) {
                  Apad.shared.wmma.matrix_a[(((ax0*256) + (ax4*16)) + ax5)] = (float16*)Apad.shared[((((((threadIdx.y*3072) + (ax0*1536)) + (kw*512)) + (ic.inner*256)) + (ax4*16)) + ax5)]
                }
              }
            }
            for (ax3_1: int32, 0, 4) {
              for (ax4_1: int32, 0, 16) {
                for (ax5_1: int32, 0, 16) {
                  W.shared.wmma.matrix_b[(((ax3_1*256) + (ax4_1*16)) + ax5_1)] = (float16*)W.shared[((((((kw*4096) + (ic.inner*2048)) + (threadIdx.z*1024)) + (ax3_1*256)) + (ax4_1*16)) + ax5_1)]
                }
              }
            }
            for (n.c: int32, 0, 2) {
              for (o.c: int32, 0, 4) {
                for (nn.c: int32, 0, 16) {
                  for (oo.c: int32, 0, 16) {
                    for (ii: int32, 0, 16) {
                      Conv.wmma.accumulator[((((n.c*1024) + (o.c*256)) + (nn.c*16)) + oo.c)] = ((float32*)Conv.wmma.accumulator[((((n.c*1024) + (o.c*256)) + (nn.c*16)) + oo.c)] + (cast(float32, (float16*)Apad.shared.wmma.matrix_a[(((n.c*256) + (nn.c*16)) + ii)])*cast(float32, (float16*)W.shared.wmma.matrix_b[(((o.c*256) + (ii*16)) + oo.c)])))
                    }
                  }
                }
              }
            }
          }
        }
      }
    }
    for (n.inner: int32, 0, 2) {
      for (o.inner: int32, 0, 4) {
        for (nn: int32, 0, 16) {
          for (oo: int32, 0, 16) {
            Conv_2[(((((((((blockIdx.x*12845056) + (threadIdx.y*3211264)) + (n.inner*1605632)) + (blockIdx.z*8192)) + (blockIdx.y*2048)) + (threadIdx.z*1024)) + (o.inner*256)) + (nn*16)) + oo)] = (float32*)Conv.wmma.accumulator[((((n.inner*1024) + (o.inner*256)) + (nn*16)) + oo)]
          }
        }
      }
    }
  }
}

降低算力

最后一个阶段是通过将2D卷积映射到张量特征,来将计算循环降低到TensorCore硬件特征

s[AF].tensorize(AF.op.axis[-2], intrin_wmma_load_matrix("wmma.matrix_a"))
s[WF].tensorize(WF.op.axis[-2], intrin_wmma_load_matrix("wmma.matrix_b"))
s[Conv].tensorize(nnc, intrin_wmma_store_matrix())
s[ConvF].tensorize(nnf, intrin_wmma_gemm())
print(tvm.lower(s, [A, W, Conv], simple_mode=True))

输出:

primfn(A_1: handle, W_1: handle, Conv_1: handle) -> ()
  attr = {"global_symbol": "main", "tir.noalias": True}
  buffers = {Conv: Buffer(Conv_2: Pointer(float32), float32, [16, 14, 14, 32, 16, 16], []),
             W: Buffer(W_2: Pointer(float16), float16, [3, 3, 16, 32, 16, 16], []),
             A: Buffer(A_2: Pointer(float16), float16, [16, 14, 14, 16, 16, 16], [])}
  buffer_map = {A_1: A, W_1: W, Conv_1: Conv} {
  attr [IterVar(blockIdx.z: int32, (nullptr), "ThreadIndex", "blockIdx.z")] "thread_extent" = 196;
  attr [Conv.wmma.accumulator: Pointer(float32)] "storage_scope" = "wmma.accumulator";
  allocate(Conv.wmma.accumulator, float32, [2048]);
  attr [Apad.shared: Pointer(float16)] "storage_scope" = "shared";
  allocate(Apad.shared, float16, [12288]);
  attr [W.shared: Pointer(float16)] "storage_scope" = "shared";
  allocate(W.shared, float16, [12288]);
  attr [Apad.shared.wmma.matrix_a: Pointer(float16)] "storage_scope" = "wmma.matrix_a";
  allocate(Apad.shared.wmma.matrix_a, float16, [512]);
  attr [W.shared.wmma.matrix_b: Pointer(float16)] "storage_scope" = "wmma.matrix_b";
  allocate(W.shared.wmma.matrix_b, float16, [1024]);
  attr [IterVar(blockIdx.x: int32, (nullptr), "ThreadIndex", "blockIdx.x")] "thread_extent" = 2;
  attr [IterVar(blockIdx.y: int32, (nullptr), "ThreadIndex", "blockIdx.y")] "thread_extent" = 4;
  attr [IterVar(threadIdx.y: int32, (nullptr), "ThreadIndex", "threadIdx.y")] "thread_extent" = 4;
  attr [IterVar(threadIdx.z: int32, (nullptr), "ThreadIndex", "threadIdx.z")] "thread_extent" = 2 {
    for (n.c.init: int32, 0, 2) {
      for (o.c.init: int32, 0, 4) {
        @tir.tvm_fill_fragment(Conv.wmma.accumulator, 16, 16, 16, ((n.c.init*4) + o.c.init), 0f32, dtype=handle)
      }
    }
    for (ic.outer: int32, 0, 8) {
      for (kh: int32, 0, 3) {
        for (ax2: int32, 0, 3) {
          for (ax3: int32, 0, 2) {
            for (ax4.ax5.fused.outer: int32, 0, 8) {
              attr [IterVar(threadIdx.x: int32, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 32;
              Apad.shared[((((((threadIdx.y*3072) + (threadIdx.z*1536)) + (ax2*512)) + (ax3*256)) + (ax4.ax5.fused.outer*32)) + threadIdx.x)] = @tir.if_then_else(((((1 <= (floordiv(blockIdx.z, 14) + kh)) && ((floordiv(blockIdx.z, 14) + kh) < 15)) && (1 <= (ax2 + floormod(blockIdx.z, 14)))) && ((ax2 + floormod(blockIdx.z, 14)) < 15)), (float16*)A_2[(((((((((((blockIdx.x*6422528) + (threadIdx.y*1605632)) + (threadIdx.z*802816)) + (kh*57344)) + (blockIdx.z*4096)) + (ax2*4096)) + (ic.outer*512)) + (ax3*256)) + (ax4.ax5.fused.outer*32)) + threadIdx.x) - 61440)], 0f16, dtype=float16)
            }
          }
        }
        for (ax1: int32, 0, 3) {
          for (ax2_1: int32, 0, 2) {
            attr [IterVar(threadIdx.x, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 32;
            W.shared[ramp((((((ax1*4096) + (ax2_1*2048)) + (threadIdx.y*512)) + (threadIdx.z*256)) + (threadIdx.x*8)), 1, 8)] = (float16x8*)W_2[ramp(((((((((kh*393216) + (ax1*131072)) + (ic.outer*16384)) + (ax2_1*8192)) + (blockIdx.y*2048)) + (threadIdx.y*512)) + (threadIdx.z*256)) + (threadIdx.x*8)), 1, 8)]
          }
        }
        for (ic.inner: int32, 0, 2) {
          for (kw: int32, 0, 3) {
            for (ax0: int32, 0, 2) {
              @tir.tvm_load_matrix_sync(Apad.shared.wmma.matrix_a, 16, 16, 16, ax0, @tir.tvm_access_ptr(@tir.type_annotation(, dtype=float16), Apad.shared, ((((threadIdx.y*3072) + (ax0*1536)) + (kw*512)) + (ic.inner*256)), 256, 1, dtype=handle), 16, "row_major", dtype=handle)
            }
            for (ax3_1: int32, 0, 4) {
              @tir.tvm_load_matrix_sync(W.shared.wmma.matrix_b, 16, 16, 16, ax3_1, @tir.tvm_access_ptr(@tir.type_annotation(, dtype=float16), W.shared, ((((kw*4096) + (ic.inner*2048)) + (threadIdx.z*1024)) + (ax3_1*256)), 256, 1, dtype=handle), 16, "row_major", dtype=handle)
            }
            for (n.c: int32, 0, 2) {
              for (o.c: int32, 0, 4) {
                @tir.tvm_mma_sync(Conv.wmma.accumulator, ((n.c*4) + o.c), Apad.shared.wmma.matrix_a, n.c, W.shared.wmma.matrix_b, o.c, Conv.wmma.accumulator, ((n.c*4) + o.c), dtype=handle)
              }
            }
          }
        }
      }
    }
    for (n.inner: int32, 0, 2) {
      for (o.inner: int32, 0, 4) {
        @tir.tvm_store_matrix_sync(Conv.wmma.accumulator, 16, 16, 16, ((n.inner*4) + o.inner), @tir.tvm_access_ptr(@tir.type_annotation(, dtype=float32), Conv_2, (((((((blockIdx.x*12845056) + (threadIdx.y*3211264)) + (n.inner*1605632)) + (blockIdx.z*8192)) + (blockIdx.y*2048)) + (threadIdx.z*1024)) + (o.inner*256)), 256, 2, dtype=handle), 16, "row_major", dtype=handle)
      }
    }
  }
}

生成CUDA内核

最后,使用TVM生成和编译CUDA内核,并评估卷积的延迟。由于TensorCores仅在具有Compute Capability 7.0或更高版本的NVIDIA GPU中受支持,因此它可能无法在构建服务器上运行

ctx = tvm.gpu(0)
if nvcc.have_tensorcore(ctx.compute_version):
    with tvm.transform.PassContext(config={"tir.UnrollLoop": {"auto_max_step": 16}}):
        func = tvm.build(s, [A, W, Conv], "cuda")
    a_np = np.random.uniform(size=data_shape).astype(A.dtype)
    w_np = np.random.uniform(size=kernel_shape).astype(W.dtype)
    a = tvm.nd.array(a_np, ctx)
    w = tvm.nd.array(w_np, ctx)
    c = tvm.nd.array(np.zeros(output_shape, dtype=Conv.dtype), ctx)
    evaluator = func.time_evaluator(func.entry_name, ctx, number=10)
    print("conv2d with tensor core: %f ms" % (evaluator(a, w, c).mean * 1e3))

输出:

conv2d with tensor core: 8.329637 ms

概要

本文演示了如何使用TVM调度原语在特定GPU上调用TensorCore。

https://tvm.apache.org/docs/tutorials/optimize/opt_conv_tensorcore.html

posted @ 2020-12-22 08:50  吴建明wujianming  阅读(431)  评论(0编辑  收藏  举报