Hi, I stumbled over this issue when trying to lower Relay → TIR through the meta_schedule backend. The loop variables are all cast to T.int64()
like this:
@I.ir_module
class Module:
@T.prim_func
def main(p0: T.Buffer((T.int64(16), T.int64(64)), "int8"), res: T.Buffer((T.int64(16), T.int64(16)), "int8")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
fused_constant_1 = T.allocate_const([-5258, -5606, -4890, -5116, -4851, -5553, -4500, -6011, -5436, -5604, -5510, -5580, -4799, -4997, -4977, -5611], "int32", [16])
fused_constant = T.allocate_const([43, -117, -59, -29, -44, 95, -10, 51, 5, 24, -51, -50, 35, 89, -10, 12, 63, -57, -38, -125, -87, -3, -7, 114, 92, 29, 70, 28, 99, -100, -40, -107, -84, -86, -84, 41, 43, -31, 103, 79, -75, -15, -47, 10, 28, -70, -114, -46, -119, -23, 42, 125, -13, -69, -56, 27, 16, -117, 120, 1, -56, 15, 127, 27, -45, -92, 46, -96, 9, -105, -42, 119, 54, -37, -13, 13, -42, -17, 69, 64, 43, 96, -111, 9, 103, 3, -106, 83, -84, 16, -33, 96, -124, 119, -58, -113, -91, -4, -112, -96, 37, 72, -45, -120, -64, -99, -15, 37, -31, 113, 114, -103, 4, 16, 51, -119, 116, -74, -21, -116, 0, -80, 120, -1, -24, 2, -102, -107, -5, 57, 65, -9, 101, 95, 20, -51, 65, -78, 109, -32, -43, -97, -87, 39, -14, 120, -94, 116, 67, 108, -9, -73, 47, 91, -112, -3, 37, 4, 20, 91, 38, -62, 4, 76, -124, 79, 19, -36, -107, 87, -102, 52, -124, 68, 49, -96, -83, -16, 122, -82, -126, -40, 104, 77, -125, 102, 68, 122, 87, -44, -110, -34, 39, 47, -84, -102, -11, -94, 20, -1, 106], "int8", [64, 16])
for x_o, y_o, k_o in T.grid(T.int64(16), T.int64(16), T.int64(64)):
with T.block("res"):
v_x_o, v_y_o, v_k_o = T.axis.remap("SSR", [x_o, y_o, k_o])
fused_constant_2 = T.Buffer((64, 16), "int8", data=fused_constant)
fused_constant_1_1 = T.Buffer((16,), "int32", data=fused_constant_1)
T.reads(p0[v_x_o, v_k_o], fused_constant_2[v_k_o, v_y_o], fused_constant_1_1[v_y_o])
T.writes(res[v_x_o, v_y_o])
with T.init():
res[v_x_o, v_y_o] = T.int8(0)
res[v_x_o, v_y_o] = res[v_x_o, v_y_o] + (p0[v_x_o, v_k_o] * fused_constant_2[v_k_o, v_y_o] + T.Cast("int8", fused_constant_1_1[v_y_o]))
Is there a way to prevent this? This behavior leads to errors when calling sch.blockize
:
InternalError: Check failed: (ret_ex.dtype() == var.dtype()) is false: substituting v0:int32 -> k_o_0 * T.int64(32) + ax0_0 * T.int64(16) + ax0_1:int64