Use relax builtin_with_ctx operator in `rewrite_call`

I’m trying to replace a copy_tensor operation with DMA builtin function to perform async copy of data.

The expectation is to replace one instance of copy_tensor with a pair of dma_copy and dma_wait calls (performed using R.call_builtin_with_ctx)

Original IR:

a: R.Tensor([32, 64], dtype="int32") = R.vm.alloc_tensor(
      vtcm_obj, offset=0, shape=R.shape([32, 64]), dtype="int32"
  )
__: R.Tuple = R.vm.copy_tensor(x, a)
b: R.Tensor([32, 64], dtype="int32") = R.vm.alloc_tensor(
    vtcm_obj, offset=8192, shape=R.shape([32, 64]), dtype="int32"
)

Expected IR:

a: R.Tensor([32, 64], dtype="int32") = R.vm.alloc_tensor(
    vtcm_obj, offset=0, shape=R.shape([32, 64]), dtype="int32"
)
__: R.Tuple = R.call_builtin_with_ctx(
                "vm.builtin.hexagon.dma_copy",
                [x, a, 0, True],
                sinfo_args=[],
)
__: R.Tuple = R.call_builtin_with_ctx(
                "vm.builtin.hexagon.dma_wait",
                [0, 1],
                sinfo_args=[],
)
b: R.Tensor([32, 64], dtype="int32") = R.vm.alloc_tensor(
    vtcm_obj, offset=8192, shape=R.shape([32, 64]), dtype="int32"
)

To replace copy_tensor function, I am using relax rewrite_call function. However since the builtin function does not return anything(it’s a void call), it’s becoming difficult to replace a copy_tensor operation with dma_copy and dma_wait operators(since both are void operators) and cannot be assiciated with each other as call_nodes. Hence, they cannot be returned back from the rewriter function passed to rewrite_call.

Can you please suggest me of a way to handle such a scenario ? Thanks in advance! :slight_smile:

CC: @masahi @tqchen @Lunderberg

Maybe what you need is rewrite_bindings. This is more difficult to use but it doesn’t need a CallNode.

we discussed a bit about complexity of the rewrite_bindings, and actually want to encourage rewrite_call when possible (since it have a lower complexity atm). Perhaps what we can also consider enhance rewrite_call to allow builder to be used to emit previous bindings

Thanks @masahi and @tqchen for your inputs!

My relax function is lowered already and is not a dataflow block. So rewrite_bindings is not helping here.

Could you please share more insights into this ? IIUC, you are recommending to pass block builder as an argument to rewrite_call(which should be passed eventually to the rewriter function passed by user) and then use it to extract bindings which can be altered in the rewrite function.

Here’s how I’ve solved this problem.

Since the relax function that I was using, wasn’t a dataflow block, I wrote a mutator class in python containing a visitor for binding_block : visit_binding_block_

This function takes a relax binding_block and allows modification to it. Mentioned below is a high level implementation of visit_binding_block_

 def visit_binding_block_(self, binding_block_i):
       binding_list = []  # Empty list to hold new bindings
       # Iterate through bindings
       for binding in binding_block_i.bindings:
             if isinstance(
                 binding.value, tvm.relax.expr.Call
             ) and binding.value.op == tvm.ir.Op.get("relax.vm.copy_tensor"):
                 # If copy_tensor op is found, replace it with DMA bindings
                 binding_list.append(dma_copy_op)
                 binding_list.append(dma_wait_op)
             else:
                 # Append the existing binding to the list
                 binding_list.append(binding) 

          # Return the new binding block
          return tvm.relax.BindingBlock(binding_list)