I’m trying to replace a copy_tensor
operation with DMA builtin function to perform async copy of data.
The expectation is to replace one instance of copy_tensor
with a pair of dma_copy
and dma_wait
calls (performed using R.call_builtin_with_ctx
)
Original IR:
a: R.Tensor([32, 64], dtype="int32") = R.vm.alloc_tensor( vtcm_obj, offset=0, shape=R.shape([32, 64]), dtype="int32" ) __: R.Tuple = R.vm.copy_tensor(x, a) b: R.Tensor([32, 64], dtype="int32") = R.vm.alloc_tensor( vtcm_obj, offset=8192, shape=R.shape([32, 64]), dtype="int32" )
Expected IR:
a: R.Tensor([32, 64], dtype="int32") = R.vm.alloc_tensor( vtcm_obj, offset=0, shape=R.shape([32, 64]), dtype="int32" ) __: R.Tuple = R.call_builtin_with_ctx( "vm.builtin.hexagon.dma_copy", [x, a, 0, True], sinfo_args=[], ) __: R.Tuple = R.call_builtin_with_ctx( "vm.builtin.hexagon.dma_wait", [0, 1], sinfo_args=[], ) b: R.Tensor([32, 64], dtype="int32") = R.vm.alloc_tensor( vtcm_obj, offset=8192, shape=R.shape([32, 64]), dtype="int32" )
To replace copy_tensor
function, I am using relax
rewrite_call
function. However since the builtin function does not return anything(it’s a void call), it’s becoming difficult to replace a copy_tensor
operation with dma_copy
and dma_wait
operators(since both are void operators) and cannot be assiciated with each other as call_nodes. Hence, they cannot be returned back from the rewriter
function passed to rewrite_call
.
Can you please suggest me of a way to handle such a scenario ? Thanks in advance!