Great discussions, sorry for being late on this. I think one way to address this issue would be to have some form of Continuous TIR call inlining pass after the lowering, let us call it FuseAndUnpackContinuousTIRCalls.
The following code shows one possible Before/After of this transformation. The pass will
- Recognize a continuous region of low level TIR calls that only involves types it recognize, e.g.
- Storage
- Tensor
- Calls into functions that we know won’t retain the memory of Tensor
- Lift that continuous region into another TIR function
fused_bulk_call
- Create unpacked variant of the tir functions involved and redirect calls as unpacked calls.
@I.module
class Before:
@T.prim_func
def add(
A: T.Buffer((10, 20), "float32"),
B: T.Buffer((10, 20), "float32")
):
...
@T.prim_func
def relu(
A: T.Buffer((10, 20), "float32"),
B: T.Buffer((10, 20), "float32")
):
...
@R.function
def main(X: R.Tensor((10, 20), "float32")):
# code after memory planning
s0 = R.memory.alloc_storage((200), "float32")
s1 = R.memory.alloc_storage(((200), "float32")
lv0 = R.memory.alloc_tensor(s0, R.Tensor((10, 20), "float32")
lv1 = R.memory.alloc_tensor(s0, R.Tensor((10, 20), "float32")
R.call_tir_lowered(add, [X, lv0])
R.call_tir_lowered(relu, [lv0, lv1])
return lv1
@I.module
class After:
@T.prim_func
def add_unpacked(
Adata: T.handle,
Bdata: T.handle
):
A = T.Buffer(Adata, (10, 20), "float32")
B = T.Buffer(Bdata, (10, 20), "float32")
...
@T.prim_func
def relu_unpacked(
Adata: T.handle,
Bdata: T.handle
):
A = T.Buffer(Adata, (10, 20), "float32")
B = T.Buffer(Bdata, (10, 20), "float32")
...
@T.prim_func
def fused_bulk_call(
A: T.Buffer((10, 20), "float32"),
s0: T.Buffer((100,), "float32"),
B: T.Buffer((10, 20), "float32")
):
# or have another intrinsic to directly call unpacked
T.call_extern(add_unpacked, A.data, s0.data)
T.call_extern(relu_unpacked, s0.data, B.data)
@R.function
def main(X: R.Tensor((10, 20), "float32")):
# code after memory planning
s0 = R.memory.alloc_storage((200), "float32")
s1 = R.memory.aloc_storage(((200), "float32")
lv0 = R.memory.alloc_tensor(s0, R.Tensor((200), "float32")
lv1 = R.memory.alloc_tensor(s1, R.Tensor((10, 20), "float32")
R.call_tir_lowered(fused_bulk_call, [X, lv0, lv1])
return lv1
The main convention we need to resolve here are two needs:
- N0: the need to remain flexible and enable a composite set of generic types
- N1: the need to be efficient, and knowing that low-level TIR codegen generally do not manage memory and reference counting.
The solution aims to split the program by recognizing that some of the high-level code in N0 are necessary but most part of the code have simpler types, so we can afford to perform allocation outside in relax VM, and then call into a fused function that simply operates on the already allocated memory.
If done correctly, this could bring the benefit of efficiency while still enable generality for code that needs them