tvm实现卷积操作

 

https://blog.csdn.net/sinat_31425585/article/details/103797339

import tvm
import numpy as np
import mxnet as mx


def padding(X, ph, pw):
    assert len(X.shape) >= 2
    nh, nw = X.shape[-2], X.shape[-1]
    return tvm.te.compute(
        (*X.shape[0:-2], nh + ph * 2, nw + pw * 2),
        lambda *i: tvm.te.if_then_else(
            tvm.te.any(i[-2] < ph, i[-2] >= nh + ph, i[-1] < pw, i[-1] >= nw + pw),
            0, X[i[:-2] + (i[-2] - ph, i[-1] - pw)]
        ), name='PaddedX'
    )


# 输入size:n
# 卷积核size:k
# 填充size:p
# 步长size:s
def conv_out_size(n, k, p, s):
    return (n - k + 2 * p) // s + 1


def conv(oc, ic, nh, nw, kh, kw, ph=0, pw=0, sh=1, sw=1):
    # reduction axes
    ric = tvm.te.reduce_axis((0, ic), name='ric')
    rkh = tvm.te.reduce_axis((0, kh), name='rkh')
    rkw = tvm.te.reduce_axis((0, kw), name='rkw')

    # output height and width
    oh = conv_out_size(nh, kh, ph, sh)
    ow = conv_out_size(nw, kw, pw, sw)

    # pad x and then conpute y
    X = tvm.te.placeholder((ic, nh, nw), name='x')
    K = tvm.te.placeholder((oc, ic, kh, kw), name='k')
    # 对输入填充
    PaddedX = padding(X, ph, pw) if ph * pw != 0 else X
    Y = tvm.te.compute(
        (oc, oh, ow),
        lambda c, i, j: tvm.te.sum(
            PaddedX[ric, i * sh + rkh, j * sw + rkw] * K[c, ric, rkh, rkw],
            axis=[ric, rkh, rkw]
        ), name='Y'
    )

    return X, K, Y, PaddedX


def get_conv_data(oc, ic, n, k, p=0, s=1, constructor=None):
    np.random.seed(0)
    data = np.random.normal(size=(ic, n, n)).astype('float32')
    weight = np.random.normal(size=(oc, ic, k, k)).astype('float32')
    on = conv_out_size(n, k, p, s)
    out = np.empty((oc, on, on), dtype='float32')
    if constructor:
        data, weight, out = (constructor(x) for x in [data, weight, out])

    return data, weight, out


oc, ic, n, k, p, s = 4, 6, 12, 3, 1, 1
X, K, Y, _ = conv(oc, ic, n, n, k, k, p, p, s, s)
sch = tvm.te.create_schedule(Y.op)
mod = tvm.build(sch, [X, K, Y])
print(tvm.lower(sch, [X, K, Y], simple_mode=True))

data, weight, out = get_conv_data(oc, ic, n, k, p, s, tvm.nd.array)
mod(data, weight, out)


def get_conv_data_mxnet(oc, ic, n, k, p, s, ctx='cpu'):
    ctx = getattr(mx, ctx)()
    data, weight, out = get_conv_data(oc, ic, n, k, p, s,
                                      lambda x: mx.nd.array(x, ctx=ctx))
    data, out = data.expand_dims(axis=0), out.expand_dims(axis=0)
    bias = mx.nd.zeros(out.shape[1], ctx=ctx)
    return data, weight, bias, out


def conv_mxnet(data, weight, bias, out, k, p, s):
    mx.nd.Convolution(data, weight, bias, kernel=(k, k), stride=(s, s),
                      pad=(p, p), num_filter=out.shape[1], out=out)


data, weight, bias, out_mx = get_conv_data_mxnet(oc, ic, n, k, p, s)
conv_mxnet(data, weight, bias, out_mx, k, p, s)
np.testing.assert_allclose(out_mx[0].asnumpy(), out.asnumpy(), atol=1e-5)

 

posted @ 2024-05-26 22:36  小丑_jk  阅读(7)  评论(0编辑  收藏  举报