Problem of freeing unused buffers

Hi all, I find an interesting problem when I write some examples of TIR.

Firstly, I try to build the example below:

import tvm

var = tvm.tir.buffer.decl_buffer((32,), dtype='float32', name='var')
body = tvm.tir.Store(buffer_var=var.data, value=tvm.tir.const(1), index=tvm.tir.const(1), predicate=tvm.tir.const(2))
new_func = tvm.tir.PrimFunc(params=[var.data], body=body)
tvm.build(new_func)

It tells me that Check failed: (is_one(op->predicate)) is false: 2, which means that everything works well at present. However, after I introduce unused variables and buffers in the function, the program will crash:

Here is the code example to trigger the crash:

import tvm

var = tvm.tir.buffer.decl_buffer((32,), dtype='float32', name='var')
# h and buf is not used in the body
h1 = tvm.tir.Var('h1', dtype='handle')
buf1 = tvm.tir.buffer.decl_buffer((1, 1000), dtype='float32', name='buf1')
h2 = tvm.tir.Var('h2', dtype='handle')
buf2 = tvm.tir.buffer.decl_buffer((1, 1000), dtype='float32', name='buf2')
h3 = tvm.tir.Var('h3', dtype='handle')
buf3 = tvm.tir.buffer.decl_buffer((1, 1000), dtype='float32', name='buf3')

body = tvm.tir.Store(buffer_var=var.data, value=tvm.tir.const(1), index=tvm.tir.const(1), predicate=tvm.tir.const(2))
new_func = tvm.tir.PrimFunc(params=[var.data, h1, h2, h3], body=body, buffer_map={h1:buf1, h2:buf2, h3:buf3})
tvm.build(new_func)

I’m curious about the reason why the program will crash. I use GDB to trace the execution, and I find that something wrong happens when we try to free these unused buffers

Interestingly, I found another crash when I tried to add unused buffers into the function. Here is the code example to trigger this crash:

import tvm
from tvm.script import tir as T

@tvm.script.ir_module
class MyModule:
    @T.prim_func
    def main(h1: T.handle, h2: T.handle, h3: T.handle, h4: T.handle) -> None:
        N = T.var("uint32")
        b1 = T.match_buffer(h1, [N, N, N], dtype="int32")
        b2 = T.match_buffer(h2, [N, N], dtype="int32")
        b3 = T.match_buffer(h3, [N], dtype="int32")
        b4 = T.match_buffer(h4, [N], dtype="int32")
        # body
        for v in T.parallel(1, 4):
            while T.cast(True, "bool"):
                T.evaluate(0)

ir_module = MyModule
tvm.build(ir_module['main'])    
1 Like

The IR before codegen is:

@tvm.script.ir_module
class Module:
    @tir.prim_func
    def main(args: tir.handle, arg_type_ids: tir.handle, num_args: tir.int32, out_ret_value: tir.handle, out_ret_tcode: tir.handle, resource_handle: tir.handle) -> tir.int32:
        # function attr dict
        tir.func_attr({"target": , "tir.noalias": True, "global_symbol": "main", "tir.is_entry_func": True, "calling_conv": 1})
        # body
        assert num_args == 4, "main: num_args should be 4"
        arg0: tir.handle = tir.tvm_struct_get(args, 0, 12, dtype="handle")
        arg0_code: tir.int32 = tir.load("int32", arg_type_ids, 0)
        arg1: tir.handle = tir.tvm_struct_get(args, 1, 12, dtype="handle")
        arg1_code: tir.int32 = tir.load("int32", arg_type_ids, 1)
        arg2: tir.handle = tir.tvm_struct_get(args, 2, 12, dtype="handle")
        arg2_code: tir.int32 = tir.load("int32", arg_type_ids, 2)
        arg3: tir.handle = tir.tvm_struct_get(args, 3, 12, dtype="handle")
        arg3_code: tir.int32 = tir.load("int32", arg_type_ids, 3)
        b1: tir.Ptr[global tir.int32] = tir.tvm_struct_get(arg0, 0, 1, dtype="handle")
        tir.attr(b1, "storage_alignment", 128)
        arg0_shape: tir.handle = tir.tvm_struct_get(arg0, 0, 2, dtype="handle")
        N: tir.uint32 = tir.cast(tir.load("int64", arg0_shape, 0), "uint32")
        arg0_strides: tir.handle = tir.tvm_struct_get(arg0, 0, 3, dtype="handle")
        dev_id: tir.int32 = tir.tvm_struct_get(arg0, 0, 9, dtype="int32")
        b2: tir.Ptr[global tir.int32] = tir.tvm_struct_get(arg1, 0, 1, dtype="handle")
        tir.attr(b2, "storage_alignment", 128)
        arg1_shape: tir.handle = tir.tvm_struct_get(arg1, 0, 2, dtype="handle")
        arg1_strides: tir.handle = tir.tvm_struct_get(arg1, 0, 3, dtype="handle")
        b3: tir.Ptr[global tir.int32] = tir.tvm_struct_get(arg2, 0, 1, dtype="handle")
        tir.attr(b3, "storage_alignment", 128)
        arg2_shape: tir.handle = tir.tvm_struct_get(arg2, 0, 2, dtype="handle")
        arg2_strides: tir.handle = tir.tvm_struct_get(arg2, 0, 3, dtype="handle")
        b4: tir.Ptr[global tir.int32] = tir.tvm_struct_get(arg3, 0, 1, dtype="handle")
        tir.attr(b4, "storage_alignment", 128)
        arg3_shape: tir.handle = tir.tvm_struct_get(arg3, 0, 2, dtype="handle")
        arg3_strides: tir.handle = tir.tvm_struct_get(arg3, 0, 3, dtype="handle")
        assert arg0_code == 3 or arg0_code == 13 or arg0_code == 7 or arg0_code == 4, "main: Expect arg[0] to be pointer"
        assert arg1_code == 3 or arg1_code == 13 or arg1_code == 7 or arg1_code == 4, "main: Expect arg[1] to be pointer"
        assert arg2_code == 3 or arg2_code == 13 or arg2_code == 7 or arg2_code == 4, "main: Expect arg[2] to be pointer"
        assert arg3_code == 3 or arg3_code == 13 or arg3_code == 7 or arg3_code == 4, "main: Expect arg[3] to be pointer"
        assert 3 == tir.tvm_struct_get(arg0, 0, 4, dtype="int32"), "arg0.ndim is expected to equal 3"
        assert 3 == tir.tvm_struct_get(arg0, 0, 4, dtype="int32"), "arg0.ndim is expected to equal 3"
        assert tir.tvm_struct_get(arg0, 0, 5, dtype="uint8") == tir.uint8(0) and tir.tvm_struct_get(arg0, 0, 6, dtype="uint8") == tir.uint8(32) and tir.tvm_struct_get(arg0, 0, 7, dtype="uint16") == tir.uint16(1), "arg0.dtype is expected to be int32"
        assert N == tir.cast(tir.load("int64", arg0_shape, 1), "uint32"), "Argument arg0.shape[1] has an unsatisfied constraint: (N == uint32(arg0.shape[1]))"
        assert N == tir.cast(tir.load("int64", arg0_shape, 2), "uint32"), "Argument arg0.shape[2] has an unsatisfied constraint: (N == uint32(arg0.shape[2]))"
        if not(tir.isnullptr(arg0_strides, dtype="bool")):
            assert tir.uint32(1) == tir.cast(tir.load("int64", arg0_strides, 2), "uint32") and N == tir.cast(tir.load("int64", arg0_strides, 1), "uint32") and N * N == tir.cast(tir.load("int64", arg0_strides, 0), "uint32"), "arg0.strides: expected to be compact array"
            tir.evaluate(0)
        assert tir.uint64(0) == tir.tvm_struct_get(arg0, 0, 8, dtype="uint64"), "Argument arg0.byte_offset has an unsatisfied constraint: ((uint64)0 == tir.tvm_struct_get(arg0, 0, 8))"
        assert 1 == tir.tvm_struct_get(arg0, 0, 10, dtype="int32"), "Argument arg0.device_type has an unsatisfied constraint: (1 == tir.tvm_struct_get(arg0, 0, 10))"
        assert 2 == tir.tvm_struct_get(arg1, 0, 4, dtype="int32"), "arg1.ndim is expected to equal 2"
        assert 2 == tir.tvm_struct_get(arg1, 0, 4, dtype="int32"), "arg1.ndim is expected to equal 2"
        assert tir.tvm_struct_get(arg1, 0, 5, dtype="uint8") == tir.uint8(0) and tir.tvm_struct_get(arg1, 0, 6, dtype="uint8") == tir.uint8(32) and tir.tvm_struct_get(arg1, 0, 7, dtype="uint16") == tir.uint16(1), "arg1.dtype is expected to be int32"
        assert N == tir.cast(tir.load("int64", arg1_shape, 0), "uint32"), "Argument arg1.shape[0] has an unsatisfied constraint: (N == uint32(arg1.shape[0]))"
        assert N == tir.cast(tir.load("int64", arg1_shape, 1), "uint32"), "Argument arg1.shape[1] has an unsatisfied constraint: (N == uint32(arg1.shape[1]))"
        if not(tir.isnullptr(arg1_strides, dtype="bool")):
            assert tir.uint32(1) == tir.cast(tir.load("int64", arg1_strides, 1), "uint32") and N == tir.cast(tir.load("int64", arg1_strides, 0), "uint32"), "arg1.strides: expected to be compact array"
            tir.evaluate(0)
        assert tir.uint64(0) == tir.tvm_struct_get(arg1, 0, 8, dtype="uint64"), "Argument arg1.byte_offset has an unsatisfied constraint: ((uint64)0 == tir.tvm_struct_get(arg1, 0, 8))"
        assert 1 == tir.tvm_struct_get(arg1, 0, 10, dtype="int32"), "Argument arg1.device_type has an unsatisfied constraint: (1 == tir.tvm_struct_get(arg1, 0, 10))"
        assert dev_id == tir.tvm_struct_get(arg1, 0, 9, dtype="int32"), "Argument arg1.device_id has an unsatisfied constraint: (dev_id == tir.tvm_struct_get(arg1, 0, 9))"
        assert 1 == tir.tvm_struct_get(arg2, 0, 4, dtype="int32"), "arg2.ndim is expected to equal 1"
        assert 1 == tir.tvm_struct_get(arg2, 0, 4, dtype="int32"), "arg2.ndim is expected to equal 1"
        assert tir.tvm_struct_get(arg2, 0, 5, dtype="uint8") == tir.uint8(0) and tir.tvm_struct_get(arg2, 0, 6, dtype="uint8") == tir.uint8(32) and tir.tvm_struct_get(arg2, 0, 7, dtype="uint16") == tir.uint16(1), "arg2.dtype is expected to be int32"
        assert N == tir.cast(tir.load("int64", arg2_shape, 0), "uint32"), "Argument arg2.shape[0] has an unsatisfied constraint: (N == uint32(arg2.shape[0]))"
        if not(tir.isnullptr(arg2_strides, dtype="bool")):
            assert tir.uint32(1) == tir.cast(tir.load("int64", arg2_strides, 0), "uint32"), "arg2.strides: expected to be compact array"
            tir.evaluate(0)
        assert tir.uint64(0) == tir.tvm_struct_get(arg2, 0, 8, dtype="uint64"), "Argument arg2.byte_offset has an unsatisfied constraint: ((uint64)0 == tir.tvm_struct_get(arg2, 0, 8))"
        assert 1 == tir.tvm_struct_get(arg2, 0, 10, dtype="int32"), "Argument arg2.device_type has an unsatisfied constraint: (1 == tir.tvm_struct_get(arg2, 0, 10))"
        assert dev_id == tir.tvm_struct_get(arg2, 0, 9, dtype="int32"), "Argument arg2.device_id has an unsatisfied constraint: (dev_id == tir.tvm_struct_get(arg2, 0, 9))"
        assert 1 == tir.tvm_struct_get(arg3, 0, 4, dtype="int32"), "arg3.ndim is expected to equal 1"
        assert 1 == tir.tvm_struct_get(arg3, 0, 4, dtype="int32"), "arg3.ndim is expected to equal 1"
        assert tir.tvm_struct_get(arg3, 0, 5, dtype="uint8") == tir.uint8(0) and tir.tvm_struct_get(arg3, 0, 6, dtype="uint8") == tir.uint8(32) and tir.tvm_struct_get(arg3, 0, 7, dtype="uint16") == tir.uint16(1), "arg3.dtype is expected to be int32"
        assert N == tir.cast(tir.load("int64", arg3_shape, 0), "uint32"), "Argument arg3.shape[0] has an unsatisfied constraint: (N == uint32(arg3.shape[0]))"
        if not(tir.isnullptr(arg3_strides, dtype="bool")):
            assert tir.uint32(1) == tir.cast(tir.load("int64", arg3_strides, 0), "uint32"), "arg3.strides: expected to be compact array"
            tir.evaluate(0)
        assert tir.uint64(0) == tir.tvm_struct_get(arg3, 0, 8, dtype="uint64"), "Argument arg3.byte_offset has an unsatisfied constraint: ((uint64)0 == tir.tvm_struct_get(arg3, 0, 8))"
        assert 1 == tir.tvm_struct_get(arg3, 0, 10, dtype="int32"), "Argument arg3.device_type has an unsatisfied constraint: (1 == tir.tvm_struct_get(arg3, 0, 10))"
        assert dev_id == tir.tvm_struct_get(arg3, 0, 9, dtype="int32"), "Argument arg3.device_id has an unsatisfied constraint: (dev_id == tir.tvm_struct_get(arg3, 0, 9))"
        tir.attr(0, "compute_scope", "main_compute_")
        for v in tir.parallel(1, 4):
            while True:
                tir.evaluate(0)

Thanks for your reply! Does the IR before codegen look normal? I have no idea why it can trigger a crash in codegen :face_with_monocle:

TBH I didnt find any problem at first glance…Let me dig deeper