[DISCUSS] Inplace Update in Dataflow Block

Hi @slyubomirsky @tqchen , can we enable multiple outputs for call_tir_inplace?

We have a use case of fusing rotary embedding and flashattention in MLC-LLM, we suppose the programming interface would look like:

@T.prim_func
def fused_rotary_flashattention(k: T.Buffer(...), q: T.Buffer(...), v: T.Buffer(...), o: T.Buffer(...)):
    ...

updated_k, o = T.call_tir_inplace(fused_rotary_flashattention(k, q, v), inplace_shapes=(k_shape,), output_shapes=(o_shape,))

and the updates on k will be in place, and o will be a brand new output tensor.

To distinguish this with standard call_tir_inplace whose outputs are all in-place updated tensors, maybe we can rename this as call_tir_inplace_with_outputs.

2 Likes