tvm实现卷积操作
https://blog.csdn.net/sinat_31425585/article/details/103797339
import tvm import numpy as np import mxnet as mx def padding(X, ph, pw): assert len(X.shape) >= 2 nh, nw = X.shape[-2], X.shape[-1] return tvm.te.compute( (*X.shape[0:-2], nh + ph * 2, nw + pw * 2), lambda *i: tvm.te.if_then_else( tvm.te.any(i[-2] < ph, i[-2] >= nh + ph, i[-1] < pw, i[-1] >= nw + pw), 0, X[i[:-2] + (i[-2] - ph, i[-1] - pw)] ), name='PaddedX' ) # 输入size:n # 卷积核size:k # 填充size:p # 步长size:s def conv_out_size(n, k, p, s): return (n - k + 2 * p) // s + 1 def conv(oc, ic, nh, nw, kh, kw, ph=0, pw=0, sh=1, sw=1): # reduction axes ric = tvm.te.reduce_axis((0, ic), name='ric') rkh = tvm.te.reduce_axis((0, kh), name='rkh') rkw = tvm.te.reduce_axis((0, kw), name='rkw') # output height and width oh = conv_out_size(nh, kh, ph, sh) ow = conv_out_size(nw, kw, pw, sw) # pad x and then conpute y X = tvm.te.placeholder((ic, nh, nw), name='x') K = tvm.te.placeholder((oc, ic, kh, kw), name='k') # 对输入填充 PaddedX = padding(X, ph, pw) if ph * pw != 0 else X Y = tvm.te.compute( (oc, oh, ow), lambda c, i, j: tvm.te.sum( PaddedX[ric, i * sh + rkh, j * sw + rkw] * K[c, ric, rkh, rkw], axis=[ric, rkh, rkw] ), name='Y' ) return X, K, Y, PaddedX def get_conv_data(oc, ic, n, k, p=0, s=1, constructor=None): np.random.seed(0) data = np.random.normal(size=(ic, n, n)).astype('float32') weight = np.random.normal(size=(oc, ic, k, k)).astype('float32') on = conv_out_size(n, k, p, s) out = np.empty((oc, on, on), dtype='float32') if constructor: data, weight, out = (constructor(x) for x in [data, weight, out]) return data, weight, out oc, ic, n, k, p, s = 4, 6, 12, 3, 1, 1 X, K, Y, _ = conv(oc, ic, n, n, k, k, p, p, s, s) sch = tvm.te.create_schedule(Y.op) mod = tvm.build(sch, [X, K, Y]) print(tvm.lower(sch, [X, K, Y], simple_mode=True)) data, weight, out = get_conv_data(oc, ic, n, k, p, s, tvm.nd.array) mod(data, weight, out) def get_conv_data_mxnet(oc, ic, n, k, p, s, ctx='cpu'): ctx = getattr(mx, ctx)() data, weight, out = get_conv_data(oc, ic, n, k, p, s, lambda x: mx.nd.array(x, ctx=ctx)) data, out = data.expand_dims(axis=0), out.expand_dims(axis=0) bias = mx.nd.zeros(out.shape[1], ctx=ctx) return data, weight, bias, out def conv_mxnet(data, weight, bias, out, k, p, s): mx.nd.Convolution(data, weight, bias, kernel=(k, k), stride=(s, s), pad=(p, p), num_filter=out.shape[1], out=out) data, weight, bias, out_mx = get_conv_data_mxnet(oc, ic, n, k, p, s) conv_mxnet(data, weight, bias, out_mx, k, p, s) np.testing.assert_allclose(out_mx[0].asnumpy(), out.asnumpy(), atol=1e-5)