Issue with R.reshape() and R.flatten() in Decomposing torch.masked_select

I’m trying to decompose torch.masked_select[ torch.masked_select — PyTorch 2.6 documentation] in TVM using tvm.relax. My approach involves the following steps:

  1. Flatten both the input tensor and mask tensor.
  2. Create a new tensor where values are 0 where the corresponding mask value is 0, and otherwise keep the original tensor values.
  3. Use relax.op.nonzero() to find the indices of nonzero elements.
  4. Flatten the indices tensor so it can be used with R.gather_elements()

Code:

from tvm.script import relax as R
from tvm.relax import op
from tvm.script import ir as I
from tvm.script import tir as T

@I.ir_module
class Module:
    @R.function
    def main(
        tensor: R.Tensor((4, 4), dtype="float32"),
        mask: R.Tensor((4, 4), dtype="int32")
    ):

    # Flatten input tensors
    tensor_flat = R.flatten(tensor)
    mask_flat = R.flatten(mask)

    # Apply mask
    lvl0 = R.where(R.not_equal(mask_flat, R.zeros_like(mask_flat)), tensor_flat, R.zeros_like(tensor_flat))

    # Find indices of nonzero elements
    indices = op.nonzero(lvl0)

    # Flatten indices tensor
    indices_flat = R.reshape(indices, (-1, ))  # Error occurs here

    res = R.gather_elements(lvl0, indices_flat, 0)

    return res

mod = Module
mod.show()

from tvm import relax
mod = relax.transform.LegalizeOps()(mod)

import tvm
import torch

exec = tvm.compile(mod, "llvm")
vm = tvm.relax.VirtualMachine(exec, tvm.cpu())

input_tensor = torch.rand(4, 4)
mask_tensor = input_tensor.ge(0.5).int()

input_tvm = tvm.nd.array(input_tensor, tvm.cpu())
mask_tvm = tvm.nd.array(mask_tensor, tvm.cpu())

output = vm['main'](input_tvm, mask_tvm)`

Errors Encountered:

  1. Using R.reshape(indices, (-1, )):

     TVMError: Check failed: (data_sinfo->shape.defined()) is false: 
     Reshape expects the input tensor to have known shape when there is some dimension length to infer.
    
  2. Using R.flatten(indices)

    VMError: CodeGenVM cannot handle this intrinsic now: Op(relax.flatten)

Is there a better approach to implementing torch.masked_select in TVM?

Since I’m new to TVM any guidance on debugging or restructuring my approach would be greatly appreciated. Thanks!

Relay supported masked_select in the PyTorch frontend. You can use the logic here to achieve something similar in the Relax frontend. Curious as to why it is not already supported…

1 Like

Thank you very much for the kind reply and information @AQSingh.

I’ll try to implement same logic in relax and fx_translator.