I’m trying to decompose torch.masked_select[ torch.masked_select — PyTorch 2.6 documentation] in TVM using tvm.relax. My approach involves the following steps:
- Flatten both the input tensor and mask tensor.
- Create a new tensor where values are
0
where the corresponding mask value is0
, and otherwise keep the original tensor values. - Use relax.op.nonzero() to find the indices of nonzero elements.
- Flatten the indices tensor so it can be used with R.gather_elements()
Code:
from tvm.script import relax as R
from tvm.relax import op
from tvm.script import ir as I
from tvm.script import tir as T
@I.ir_module
class Module:
@R.function
def main(
tensor: R.Tensor((4, 4), dtype="float32"),
mask: R.Tensor((4, 4), dtype="int32")
):
# Flatten input tensors
tensor_flat = R.flatten(tensor)
mask_flat = R.flatten(mask)
# Apply mask
lvl0 = R.where(R.not_equal(mask_flat, R.zeros_like(mask_flat)), tensor_flat, R.zeros_like(tensor_flat))
# Find indices of nonzero elements
indices = op.nonzero(lvl0)
# Flatten indices tensor
indices_flat = R.reshape(indices, (-1, )) # Error occurs here
res = R.gather_elements(lvl0, indices_flat, 0)
return res
mod = Module
mod.show()
from tvm import relax
mod = relax.transform.LegalizeOps()(mod)
import tvm
import torch
exec = tvm.compile(mod, "llvm")
vm = tvm.relax.VirtualMachine(exec, tvm.cpu())
input_tensor = torch.rand(4, 4)
mask_tensor = input_tensor.ge(0.5).int()
input_tvm = tvm.nd.array(input_tensor, tvm.cpu())
mask_tvm = tvm.nd.array(mask_tensor, tvm.cpu())
output = vm['main'](input_tvm, mask_tvm)`
Errors Encountered:
-
Using R.reshape(indices, (-1, )):
TVMError: Check failed: (data_sinfo->shape.defined()) is false: Reshape expects the input tensor to have known shape when there is some dimension length to infer.
-
Using R.flatten(indices)
VMError: CodeGenVM cannot handle this intrinsic now: Op(relax.flatten)
Is there a better approach to implementing torch.masked_select
in TVM?
Since I’m new to TVM any guidance on debugging or restructuring my approach would be greatly appreciated. Thanks!