Hi all, I think I find an interesting bug in TIR pass compact_buffer_region.
Code
Here is the code example to reproduce the bug, and the program will crash after printing the function body.
import tvm
from tvm import tir
buf = tir.decl_buffer((1,), name='buf')
var = buf.data
zero = tir.const(0)
buf_load = tir.expr.BufferLoad(buffer=buf, indices=[zero])
then_case = tir.Store(buffer_var=var, value=buf_load,index=zero)
for_stmt = tir.For(loop_var=var, min_val=0, extent=0, kind=1,body=then_case)
y = tir.IfThenElse(then_case=then_case,else_case=for_stmt,condition=tir.Cast('bool', zero))
f=tir.PrimFunc(body=y,params=[var])
mod = tvm.IRModule({'main':f})
mod = tir.transform.PlanAndUpdateBufferAllocationLocation()(mod)
mod = tir.transform.CompactBufferAllocation()(mod)
print(mod['main'].body)
Bug Analysis
This bug is located in the TIR pass compact_buffer_region. At the first, the pass initializes the class BufferCompactor with the map buffer_info. From its source code, we can find that each buffer_alloc_info in buffer_info is initialized without the value of new_buffer, that is, the value of new_buffer is 0x0.
The new_buffer is only assigned in function RewriteAllocBuffer:
However, the function RewriteAllocBuffer is only invoked by function VisitStmt_(const BlockNode* op).
What will happen if there are no Block statements but other statements, like BufferLoad, in the function? According to the source code, it will execute the function RewriteBufferAccess:
For function RewriteBufferAccess, it will re-assign buffer with info.new_buffer. As we know, the value of info.new_buffer is still 0x0(RewriteAllocBuffer is not executed).Therefore, there is a nullptr buffer in the program now. Everytime when we try to access the buffer(like printing this buffer by print(mod['main'].body)), the program will crash.
This pass only applies to TIR where block is expected to be present. PrimFunc from TE or created directly should be marked from_legacy_te_schedule so that the pass will skip it.