Hello,
in order to better understand how to use tensorization I tried working on a dummy exmaple with a conv2d kernel. I get the following result after tvm.lower
:
@I.ir_module
class Module:
@T.prim_func
def main(input: T.Buffer((2, 32, 32, 32), "int8"), kernel: T.Buffer((3, 3, 32, 64), "int8"), bias: T.Buffer((64,), "int32"), output: T.Buffer((2, 32, 32, 64), "int8")):
T.func_attr({"from_legacy_te_schedule": T.bool(True), "global_symbol": "main", "tir.noalias": T.bool(True)})
res = T.allocate([131136], "int8", "local.accumulator")
input_local_scratchpad = T.allocate([5184], "int8", "local.scratchpad")
kernel_local_scratchpad_weight = T.allocate([2304], "int8", "local.scratchpad_weight")
for b_o_outer, h_o_outer, w_o_outer, c_o_outer in T.grid(2, 2, 2, 4):
res_1 = T.Buffer((131072,), "int8", data=res, scope="local.accumulator", align=4)
for h_o_inner_init, w_o_inner_init, c_o_inner_init in T.grid(16, 16, 16):
res_1[b_o_outer * 65536 + h_o_outer * 32768 + h_o_inner_init * 2048 + w_o_outer * 1024 + w_o_inner_init * 64 + c_o_outer * 16 + c_o_inner_init] = T.int8(0)
for ric_outer in range(2):
input_local_scratchpad_1 = T.Buffer((5184,), "int8", data=input_local_scratchpad, scope="local.scratchpad", align=4)
for ax1 in range(18):
if T.likely(ax1 // 16 + h_o_outer < 2):
for ax2 in range(18):
if T.likely(ax2 // 16 + w_o_outer < 2):
input_1 = T.Buffer((65536,), "int8", data=input.data)
for ax3 in range(16):
input_local_scratchpad_1[ax1 * 288 + ax2 * 16 + ax3] = input_1[b_o_outer * 32768 + h_o_outer * 16384 + ax1 * 1024 + w_o_outer * 512 + ax2 * 32 + ric_outer * 16 + ax3]
kernel_local_scratchpad_weight_1 = T.Buffer((2304,), "int8", data=kernel_local_scratchpad_weight, scope="local.scratchpad_weight", align=4)
for ax0, ax1, ax2, ax3 in T.grid(3, 3, 16, 16):
kernel_1 = T.Buffer((18432,), "int8", data=kernel.data)
kernel_local_scratchpad_weight_1[ax0 * 768 + ax1 * 256 + ax2 * 16 + ax3] = kernel_1[ax0 * 6144 + ax1 * 2048 + ric_outer * 1024 + ax2 * 64 + c_o_outer * 16 + ax3]
res_2 = T.Buffer((16,), "int32", data=res, scope="local.accumulator", align=1)
for ax0 in range(16):
bias_1 = T.Buffer((64,), "int32", data=bias.data)
res_2[ax0 + 32768] = bias_1[c_o_outer * 16 + ax0]
for h_o_inner, w_o_inner, c_o_inner, rkw_inner, rkh_inner, ric_inner in T.grid(16, 16, 16, 3, 3, 16):
cse_var_1: T.int32 = b_o_outer * 65536 + h_o_outer * 32768 + h_o_inner * 2048 + w_o_outer * 1024 + w_o_inner * 64 + c_o_outer * 16 + c_o_inner
res_1[cse_var_1] = res_1[cse_var_1] + (input_local_scratchpad_1[h_o_inner * 288 + rkh_inner * 288 + w_o_inner * 16 + rkw_inner * 16 + ric_inner] * kernel_local_scratchpad_weight_1[rkh_inner * 768 + rkw_inner * 256 + ric_inner * 16 + c_o_inner] + T.Cast("int8", res_2[c_o_inner + 32768]))
I want to fold everything after the ric_outer
loop into an external kernel. The definition of the intrinsic looks like this:
def intrin_conv():
oh = 16
ow = 16
ih = 18
iw = 18
oc = 16
kh = 3
kw = 3
ic = 16
ishape = (1, ih, iw, ic)
kshape = (kh, kw, ic, oc)
bshape = (oc,)
oshape = (1, oh, ow, oc)
ifmap = te.placeholder(ishape, dtype=ENV.inp_dtype, name="ifmap_tile")
kernel = te.placeholder(kshape, dtype=ENV.wgt_dtype, name="kernel_tile")
bias = te.placeholder(bshape, dtype=ENV.inp_dtype, name="bias_tile")
out = te.placeholder(oshape, ENV.inp_dtype, name="out_tile")
ric = te.reduce_axis((0, ic), name="ric")
rkh = te.reduce_axis((0, kh), name="rkh")
rkw = te.reduce_axis((0, kw), name="rkw")
ifmap_buf = tvm.tir.decl_buffer(ifmap.shape, ifmap.dtype, name="Ifmap_buf", offset_factor=1)
kernel_buf = tvm.tir.decl_buffer(kernel.shape, kernel.dtype, name="Kernel_buf", offset_factor=1)
bias_buf = tvm.tir.decl_buffer(bias.shape, bias.dtype, name="Bias_buf", offset_factor=1)
out_buf = tvm.tir.decl_buffer(out.shape, out.dtype, name="Out_buf", offset_factor=1)
res = te.compute(
oshape,
lambda b_o, h_o, w_o, c_o: te.sum(
data[b_o, h_o + rkh, w_o + rkw, ric].astype(ENV.inp_dtype)
* kernel[rkh, rkw, ric, c_o].astype(ENV.inp_dtype)
+ bias[oc].astype(ENV.inp_dtype),
axis=[rkh, rkw, ric],
),
name="res"
)
def intrin_func(ins, outs):
ib = tvm.tir.ir_builder.create()
ifm, ker, bias = ins
res = outs[0]
ib.emit(
tvm.tir.call_extern(
"",
"conv_kernel",
ic,
kh,
kw,
oc,
oh,
ow,
ifm.access_ptr("r"),
ker.access_ptr("r"),
bias.access_ptr("r"),
res.access_ptr("w")
)
)
return ib.get()
return te.decl_tensor_intrin(res.op, intrin_func, binds={ifmap: ifmap_buf, kernel: kernel_buf, bias: bias_buf, out: out_buf})
But I try to tensorize my computation I get the following error:
Check failed: (expr_equal(lhs, rhs)) is false: Failed to match the compute with TensorIntrin tensor_intrin's declaration
I don’t understand where this error is coming from. The buffers should be split into tiles according to the values defined inside intrin_func. I can imagine that the parameters might not always match the defined values. But if that is the case, is there a better way to handle this? Being able to adapt the parameters oh, ow, etc. inside the intrinsic function would also be interesting, but I don’t know how to do that with tensorize
.