How to detect BufferNode in a TIR pass?

I’m adding a TIR pass after lower_tvm_buildin pass. In the new pass, I need to detect the Buffer declaration for stack_tcode in the following TIR generated from lower_tvm_buildin pass:

def tvmgen_default_fused_nn_pad(args: T.handle, arg_type_ids: T.handle("int32"), num_args: T.int32, out_ret_value: T.handle("void"), out_ret_tcode: T.handle("int32"), resource_handle: T.handle) -> T.int32:
    T.func_attr({"calling_conv": 1, "from_legacy_te_schedule": T.bool(True), "global_symbol": "tvmgen_default_fused_nn_pad", "hash": "062660d518bba263", "target": T.target({"keys": ["cpu"], "kind": "llvm", "tag": ""}), "tir.noalias": T.bool(True)})
    stack_tcode: T.handle("int32") = T.tvm_stack_alloca("arg_tcode", 6)
    stack_value: T.handle = T.tvm_stack_alloca("arg_value", 6)
    assert num_args == 3, "tvmgen_default_fused_nn_pad: num_args should be 3"
    arg_p0: T.handle = T.tvm_struct_get(args, 0, 12, "handle")
    arg_type_ids_1 = T.Buffer((3,), "int32", data=arg_type_ids)
    arg_p0_code: T.int32 = arg_type_ids_1[0]
    arg_p1: T.handle = T.tvm_struct_get(args, 1, 12, "handle")
    arg_p1_code: T.int32 = arg_type_ids_1[1]
    arg_T_pad: T.handle = T.tvm_struct_get(args, 2, 12, "handle")
    arg_T_pad_code: T.int32 = arg_type_ids_1[2]
    p0: T.handle("float16") = T.tvm_struct_get(arg_p0, 0, 1, "handle")
    T.attr(p0, "storage_alignment", 64)
    tvmgen_default_fused_nn_pad_arg_p0_shape: T.handle("int64") = T.tvm_struct_get(arg_p0, 0, 2, "handle")
    tvmgen_default_fused_nn_pad_arg_p0_strides: T.handle("int64") = T.tvm_struct_get(arg_p0, 0, 3, "handle")
    dev_id: T.int32 = T.tvm_struct_get(arg_p0, 0, 9, "int32")
    p1: T.handle("float16") = T.tvm_struct_get(arg_p1, 0, 1, "handle")
    T.attr(p1, "storage_alignment", 64)
    tvmgen_default_fused_nn_pad_arg_p1_shape: T.handle("int64") = T.tvm_struct_get(arg_p1, 0, 2, "handle")
    tvmgen_default_fused_nn_pad_arg_p1_strides: T.handle("int64") = T.tvm_struct_get(arg_p1, 0, 3, "handle")
    T_pad: T.handle("float16") = T.tvm_struct_get(arg_T_pad, 0, 1, "handle")
    T.attr(T_pad, "storage_alignment", 64)
    tvmgen_default_fused_nn_pad_arg_T_pad_shape: T.handle("int64") = T.tvm_struct_get(arg_T_pad, 0, 2, "handle")
    tvmgen_default_fused_nn_pad_arg_T_pad_strides: T.handle("int64") = T.tvm_struct_get(arg_T_pad, 0, 3, "handle")
    tvmgen_default_fused_nn_pad_arg_p0_shape_1 = T.Buffer((4,), "int64", data=tvmgen_default_fused_nn_pad_arg_p0_shape)
    if not T.isnullptr(tvmgen_default_fused_nn_pad_arg_p0_strides):
        tvmgen_default_fused_nn_pad_arg_p0_strides_1 = T.Buffer((0,), "int64", data=tvmgen_default_fused_nn_pad_arg_p0_strides)
        T.evaluate(0)
    tvmgen_default_fused_nn_pad_arg_T_pad_shape_1 = T.Buffer((4,), "int64", data=tvmgen_default_fused_nn_pad_arg_T_pad_shape)
    if not T.isnullptr(tvmgen_default_fused_nn_pad_arg_T_pad_strides):
        tvmgen_default_fused_nn_pad_arg_T_pad_strides_1 = T.Buffer((0,), "int64", data=tvmgen_default_fused_nn_pad_arg_T_pad_strides)
        T.evaluate(0)
    T.tvm_struct_set(stack_value, 0, 12, T.Cast("int64", 7))
    stack_tcode_1 = T.Buffer((T.uint64(6),), "int32", data=stack_tcode) //how to detect this Stmt/Expr?
    stack_tcode_1[0] = 0
    T.tvm_struct_set(stack_value, 1, 12, T.Cast("int64", dev_id))
    stack_tcode_1[1] = 0
    T.call_packed_lowered("__tvm_set_device", stack_value, stack_tcode, 0, 2)
    T.attr(0, "compute_scope", "tvmgen_default_fused_nn_pad_compute_")
    T.tvm_struct_set(stack_value, 0, 12, T_pad)
    stack_tcode_1[0] = 3
    T.tvm_struct_set(stack_value, 1, 12, p0)
    stack_tcode_1[1] = 3
    T.tvm_struct_set(stack_value, 2, 12, p1)
    stack_tcode_1[2] = 3
    T.tvm_struct_set(stack_value, 3, 12, T.Cast("int64", 256))
    stack_tcode_1[3] = 0
    T.tvm_struct_set(stack_value, 4, 12, T.Cast("int64", 256))
    stack_tcode_1[4] = 0
    T.call_packed_lowered("tvmgen_default_fused_nn_pad_kernel0", stack_value, stack_tcode, 0, 5)

In the above TIR, I try to detect the following stmt:

    stack_tcode_1 = T.Buffer((T.uint64(6),), "int32", data=stack_tcode) 

Tried the following methods, but none of which detects the buffer declaration in this stmt:

Stmt VisitStmt_(const AllocateNode* op)

Stmt VisitStmt_(const DeclBufferNode* op)

Stmt VisitStmt_(const LetStmtNode* op)

My question is: how to detect the buffer declaration in this “stack_tcode_1 = T.Buffer…” statement? Thanks.

Hi @samwyi. Looks like the buffer stack_tcode_1 doesn’t have its declaration in the IR. Instead, it is just used as a BufferNode when it’s referenced.

The code you mentioned

stack_tcode_1 = T.Buffer((T.uint64(6),), "int32", data=stack_tcode) 

is not a declaration in the TIR. This line just prints the internal attributes of stack_tcode_1. It means all the references of stack_tcode_1 refers to such a BufferNode.

In terms of the detection, I think you can detect such a BufferNode in the IR. There is no declaration in the TIR for detection.

1 Like