The IR before codegen is:
@tvm.script.ir_module
class Module:
@tir.prim_func
def main(args: tir.handle, arg_type_ids: tir.handle, num_args: tir.int32, out_ret_value: tir.handle, out_ret_tcode: tir.handle, resource_handle: tir.handle) -> tir.int32:
# function attr dict
tir.func_attr({"target": , "tir.noalias": True, "global_symbol": "main", "tir.is_entry_func": True, "calling_conv": 1})
# body
assert num_args == 4, "main: num_args should be 4"
arg0: tir.handle = tir.tvm_struct_get(args, 0, 12, dtype="handle")
arg0_code: tir.int32 = tir.load("int32", arg_type_ids, 0)
arg1: tir.handle = tir.tvm_struct_get(args, 1, 12, dtype="handle")
arg1_code: tir.int32 = tir.load("int32", arg_type_ids, 1)
arg2: tir.handle = tir.tvm_struct_get(args, 2, 12, dtype="handle")
arg2_code: tir.int32 = tir.load("int32", arg_type_ids, 2)
arg3: tir.handle = tir.tvm_struct_get(args, 3, 12, dtype="handle")
arg3_code: tir.int32 = tir.load("int32", arg_type_ids, 3)
b1: tir.Ptr[global tir.int32] = tir.tvm_struct_get(arg0, 0, 1, dtype="handle")
tir.attr(b1, "storage_alignment", 128)
arg0_shape: tir.handle = tir.tvm_struct_get(arg0, 0, 2, dtype="handle")
N: tir.uint32 = tir.cast(tir.load("int64", arg0_shape, 0), "uint32")
arg0_strides: tir.handle = tir.tvm_struct_get(arg0, 0, 3, dtype="handle")
dev_id: tir.int32 = tir.tvm_struct_get(arg0, 0, 9, dtype="int32")
b2: tir.Ptr[global tir.int32] = tir.tvm_struct_get(arg1, 0, 1, dtype="handle")
tir.attr(b2, "storage_alignment", 128)
arg1_shape: tir.handle = tir.tvm_struct_get(arg1, 0, 2, dtype="handle")
arg1_strides: tir.handle = tir.tvm_struct_get(arg1, 0, 3, dtype="handle")
b3: tir.Ptr[global tir.int32] = tir.tvm_struct_get(arg2, 0, 1, dtype="handle")
tir.attr(b3, "storage_alignment", 128)
arg2_shape: tir.handle = tir.tvm_struct_get(arg2, 0, 2, dtype="handle")
arg2_strides: tir.handle = tir.tvm_struct_get(arg2, 0, 3, dtype="handle")
b4: tir.Ptr[global tir.int32] = tir.tvm_struct_get(arg3, 0, 1, dtype="handle")
tir.attr(b4, "storage_alignment", 128)
arg3_shape: tir.handle = tir.tvm_struct_get(arg3, 0, 2, dtype="handle")
arg3_strides: tir.handle = tir.tvm_struct_get(arg3, 0, 3, dtype="handle")
assert arg0_code == 3 or arg0_code == 13 or arg0_code == 7 or arg0_code == 4, "main: Expect arg[0] to be pointer"
assert arg1_code == 3 or arg1_code == 13 or arg1_code == 7 or arg1_code == 4, "main: Expect arg[1] to be pointer"
assert arg2_code == 3 or arg2_code == 13 or arg2_code == 7 or arg2_code == 4, "main: Expect arg[2] to be pointer"
assert arg3_code == 3 or arg3_code == 13 or arg3_code == 7 or arg3_code == 4, "main: Expect arg[3] to be pointer"
assert 3 == tir.tvm_struct_get(arg0, 0, 4, dtype="int32"), "arg0.ndim is expected to equal 3"
assert 3 == tir.tvm_struct_get(arg0, 0, 4, dtype="int32"), "arg0.ndim is expected to equal 3"
assert tir.tvm_struct_get(arg0, 0, 5, dtype="uint8") == tir.uint8(0) and tir.tvm_struct_get(arg0, 0, 6, dtype="uint8") == tir.uint8(32) and tir.tvm_struct_get(arg0, 0, 7, dtype="uint16") == tir.uint16(1), "arg0.dtype is expected to be int32"
assert N == tir.cast(tir.load("int64", arg0_shape, 1), "uint32"), "Argument arg0.shape[1] has an unsatisfied constraint: (N == uint32(arg0.shape[1]))"
assert N == tir.cast(tir.load("int64", arg0_shape, 2), "uint32"), "Argument arg0.shape[2] has an unsatisfied constraint: (N == uint32(arg0.shape[2]))"
if not(tir.isnullptr(arg0_strides, dtype="bool")):
assert tir.uint32(1) == tir.cast(tir.load("int64", arg0_strides, 2), "uint32") and N == tir.cast(tir.load("int64", arg0_strides, 1), "uint32") and N * N == tir.cast(tir.load("int64", arg0_strides, 0), "uint32"), "arg0.strides: expected to be compact array"
tir.evaluate(0)
assert tir.uint64(0) == tir.tvm_struct_get(arg0, 0, 8, dtype="uint64"), "Argument arg0.byte_offset has an unsatisfied constraint: ((uint64)0 == tir.tvm_struct_get(arg0, 0, 8))"
assert 1 == tir.tvm_struct_get(arg0, 0, 10, dtype="int32"), "Argument arg0.device_type has an unsatisfied constraint: (1 == tir.tvm_struct_get(arg0, 0, 10))"
assert 2 == tir.tvm_struct_get(arg1, 0, 4, dtype="int32"), "arg1.ndim is expected to equal 2"
assert 2 == tir.tvm_struct_get(arg1, 0, 4, dtype="int32"), "arg1.ndim is expected to equal 2"
assert tir.tvm_struct_get(arg1, 0, 5, dtype="uint8") == tir.uint8(0) and tir.tvm_struct_get(arg1, 0, 6, dtype="uint8") == tir.uint8(32) and tir.tvm_struct_get(arg1, 0, 7, dtype="uint16") == tir.uint16(1), "arg1.dtype is expected to be int32"
assert N == tir.cast(tir.load("int64", arg1_shape, 0), "uint32"), "Argument arg1.shape[0] has an unsatisfied constraint: (N == uint32(arg1.shape[0]))"
assert N == tir.cast(tir.load("int64", arg1_shape, 1), "uint32"), "Argument arg1.shape[1] has an unsatisfied constraint: (N == uint32(arg1.shape[1]))"
if not(tir.isnullptr(arg1_strides, dtype="bool")):
assert tir.uint32(1) == tir.cast(tir.load("int64", arg1_strides, 1), "uint32") and N == tir.cast(tir.load("int64", arg1_strides, 0), "uint32"), "arg1.strides: expected to be compact array"
tir.evaluate(0)
assert tir.uint64(0) == tir.tvm_struct_get(arg1, 0, 8, dtype="uint64"), "Argument arg1.byte_offset has an unsatisfied constraint: ((uint64)0 == tir.tvm_struct_get(arg1, 0, 8))"
assert 1 == tir.tvm_struct_get(arg1, 0, 10, dtype="int32"), "Argument arg1.device_type has an unsatisfied constraint: (1 == tir.tvm_struct_get(arg1, 0, 10))"
assert dev_id == tir.tvm_struct_get(arg1, 0, 9, dtype="int32"), "Argument arg1.device_id has an unsatisfied constraint: (dev_id == tir.tvm_struct_get(arg1, 0, 9))"
assert 1 == tir.tvm_struct_get(arg2, 0, 4, dtype="int32"), "arg2.ndim is expected to equal 1"
assert 1 == tir.tvm_struct_get(arg2, 0, 4, dtype="int32"), "arg2.ndim is expected to equal 1"
assert tir.tvm_struct_get(arg2, 0, 5, dtype="uint8") == tir.uint8(0) and tir.tvm_struct_get(arg2, 0, 6, dtype="uint8") == tir.uint8(32) and tir.tvm_struct_get(arg2, 0, 7, dtype="uint16") == tir.uint16(1), "arg2.dtype is expected to be int32"
assert N == tir.cast(tir.load("int64", arg2_shape, 0), "uint32"), "Argument arg2.shape[0] has an unsatisfied constraint: (N == uint32(arg2.shape[0]))"
if not(tir.isnullptr(arg2_strides, dtype="bool")):
assert tir.uint32(1) == tir.cast(tir.load("int64", arg2_strides, 0), "uint32"), "arg2.strides: expected to be compact array"
tir.evaluate(0)
assert tir.uint64(0) == tir.tvm_struct_get(arg2, 0, 8, dtype="uint64"), "Argument arg2.byte_offset has an unsatisfied constraint: ((uint64)0 == tir.tvm_struct_get(arg2, 0, 8))"
assert 1 == tir.tvm_struct_get(arg2, 0, 10, dtype="int32"), "Argument arg2.device_type has an unsatisfied constraint: (1 == tir.tvm_struct_get(arg2, 0, 10))"
assert dev_id == tir.tvm_struct_get(arg2, 0, 9, dtype="int32"), "Argument arg2.device_id has an unsatisfied constraint: (dev_id == tir.tvm_struct_get(arg2, 0, 9))"
assert 1 == tir.tvm_struct_get(arg3, 0, 4, dtype="int32"), "arg3.ndim is expected to equal 1"
assert 1 == tir.tvm_struct_get(arg3, 0, 4, dtype="int32"), "arg3.ndim is expected to equal 1"
assert tir.tvm_struct_get(arg3, 0, 5, dtype="uint8") == tir.uint8(0) and tir.tvm_struct_get(arg3, 0, 6, dtype="uint8") == tir.uint8(32) and tir.tvm_struct_get(arg3, 0, 7, dtype="uint16") == tir.uint16(1), "arg3.dtype is expected to be int32"
assert N == tir.cast(tir.load("int64", arg3_shape, 0), "uint32"), "Argument arg3.shape[0] has an unsatisfied constraint: (N == uint32(arg3.shape[0]))"
if not(tir.isnullptr(arg3_strides, dtype="bool")):
assert tir.uint32(1) == tir.cast(tir.load("int64", arg3_strides, 0), "uint32"), "arg3.strides: expected to be compact array"
tir.evaluate(0)
assert tir.uint64(0) == tir.tvm_struct_get(arg3, 0, 8, dtype="uint64"), "Argument arg3.byte_offset has an unsatisfied constraint: ((uint64)0 == tir.tvm_struct_get(arg3, 0, 8))"
assert 1 == tir.tvm_struct_get(arg3, 0, 10, dtype="int32"), "Argument arg3.device_type has an unsatisfied constraint: (1 == tir.tvm_struct_get(arg3, 0, 10))"
assert dev_id == tir.tvm_struct_get(arg3, 0, 9, dtype="int32"), "Argument arg3.device_id has an unsatisfied constraint: (dev_id == tir.tvm_struct_get(arg3, 0, 9))"
tir.attr(0, "compute_scope", "main_compute_")
for v in tir.parallel(1, 4):
while True:
tir.evaluate(0)