I’m adding a TIR pass after lower_tvm_buildin pass. In the new pass, I need to detect the Buffer declaration for stack_tcode in the following TIR generated from lower_tvm_buildin pass:
def tvmgen_default_fused_nn_pad(args: T.handle, arg_type_ids: T.handle("int32"), num_args: T.int32, out_ret_value: T.handle("void"), out_ret_tcode: T.handle("int32"), resource_handle: T.handle) -> T.int32:
T.func_attr({"calling_conv": 1, "from_legacy_te_schedule": T.bool(True), "global_symbol": "tvmgen_default_fused_nn_pad", "hash": "062660d518bba263", "target": T.target({"keys": ["cpu"], "kind": "llvm", "tag": ""}), "tir.noalias": T.bool(True)})
stack_tcode: T.handle("int32") = T.tvm_stack_alloca("arg_tcode", 6)
stack_value: T.handle = T.tvm_stack_alloca("arg_value", 6)
assert num_args == 3, "tvmgen_default_fused_nn_pad: num_args should be 3"
arg_p0: T.handle = T.tvm_struct_get(args, 0, 12, "handle")
arg_type_ids_1 = T.Buffer((3,), "int32", data=arg_type_ids)
arg_p0_code: T.int32 = arg_type_ids_1[0]
arg_p1: T.handle = T.tvm_struct_get(args, 1, 12, "handle")
arg_p1_code: T.int32 = arg_type_ids_1[1]
arg_T_pad: T.handle = T.tvm_struct_get(args, 2, 12, "handle")
arg_T_pad_code: T.int32 = arg_type_ids_1[2]
p0: T.handle("float16") = T.tvm_struct_get(arg_p0, 0, 1, "handle")
T.attr(p0, "storage_alignment", 64)
tvmgen_default_fused_nn_pad_arg_p0_shape: T.handle("int64") = T.tvm_struct_get(arg_p0, 0, 2, "handle")
tvmgen_default_fused_nn_pad_arg_p0_strides: T.handle("int64") = T.tvm_struct_get(arg_p0, 0, 3, "handle")
dev_id: T.int32 = T.tvm_struct_get(arg_p0, 0, 9, "int32")
p1: T.handle("float16") = T.tvm_struct_get(arg_p1, 0, 1, "handle")
T.attr(p1, "storage_alignment", 64)
tvmgen_default_fused_nn_pad_arg_p1_shape: T.handle("int64") = T.tvm_struct_get(arg_p1, 0, 2, "handle")
tvmgen_default_fused_nn_pad_arg_p1_strides: T.handle("int64") = T.tvm_struct_get(arg_p1, 0, 3, "handle")
T_pad: T.handle("float16") = T.tvm_struct_get(arg_T_pad, 0, 1, "handle")
T.attr(T_pad, "storage_alignment", 64)
tvmgen_default_fused_nn_pad_arg_T_pad_shape: T.handle("int64") = T.tvm_struct_get(arg_T_pad, 0, 2, "handle")
tvmgen_default_fused_nn_pad_arg_T_pad_strides: T.handle("int64") = T.tvm_struct_get(arg_T_pad, 0, 3, "handle")
tvmgen_default_fused_nn_pad_arg_p0_shape_1 = T.Buffer((4,), "int64", data=tvmgen_default_fused_nn_pad_arg_p0_shape)
if not T.isnullptr(tvmgen_default_fused_nn_pad_arg_p0_strides):
tvmgen_default_fused_nn_pad_arg_p0_strides_1 = T.Buffer((0,), "int64", data=tvmgen_default_fused_nn_pad_arg_p0_strides)
T.evaluate(0)
tvmgen_default_fused_nn_pad_arg_T_pad_shape_1 = T.Buffer((4,), "int64", data=tvmgen_default_fused_nn_pad_arg_T_pad_shape)
if not T.isnullptr(tvmgen_default_fused_nn_pad_arg_T_pad_strides):
tvmgen_default_fused_nn_pad_arg_T_pad_strides_1 = T.Buffer((0,), "int64", data=tvmgen_default_fused_nn_pad_arg_T_pad_strides)
T.evaluate(0)
T.tvm_struct_set(stack_value, 0, 12, T.Cast("int64", 7))
stack_tcode_1 = T.Buffer((T.uint64(6),), "int32", data=stack_tcode) //how to detect this Stmt/Expr?
stack_tcode_1[0] = 0
T.tvm_struct_set(stack_value, 1, 12, T.Cast("int64", dev_id))
stack_tcode_1[1] = 0
T.call_packed_lowered("__tvm_set_device", stack_value, stack_tcode, 0, 2)
T.attr(0, "compute_scope", "tvmgen_default_fused_nn_pad_compute_")
T.tvm_struct_set(stack_value, 0, 12, T_pad)
stack_tcode_1[0] = 3
T.tvm_struct_set(stack_value, 1, 12, p0)
stack_tcode_1[1] = 3
T.tvm_struct_set(stack_value, 2, 12, p1)
stack_tcode_1[2] = 3
T.tvm_struct_set(stack_value, 3, 12, T.Cast("int64", 256))
stack_tcode_1[3] = 0
T.tvm_struct_set(stack_value, 4, 12, T.Cast("int64", 256))
stack_tcode_1[4] = 0
T.call_packed_lowered("tvmgen_default_fused_nn_pad_kernel0", stack_value, stack_tcode, 0, 5)
In the above TIR, I try to detect the following stmt:
stack_tcode_1 = T.Buffer((T.uint64(6),), "int32", data=stack_tcode)
Tried the following methods, but none of which detects the buffer declaration in this stmt:
Stmt VisitStmt_(const AllocateNode* op)
Stmt VisitStmt_(const DeclBufferNode* op)
Stmt VisitStmt_(const LetStmtNode* op)
My question is: how to detect the buffer declaration in this “stack_tcode_1 = T.Buffer…” statement? Thanks.