[BYOC] 'InferSpecialValues' pass?

I am trying to offload the relay.reshape Op to my BYOC backend, however, my target runtime’s Reshape Op does not support special values like -1 in its newshape parameter, so I need to write a custom legalization pass to map these special values into all-positive values.

I found the InferNewShape function in src/relay/op/tensor/transform.cc is exactly what I need, however it seems that the function is only used internally and cannot be accessed in Python. It would be really convenient if TVM can offer a pass that can infer the special values in the newshape and axes parameters of many transform and reduce ops.

InferNewShape is not a pass but just a utility function, so it doesn’t make sense to register this function to Python. It seems to me that your problem can be resolved by writing a simple Relay pass in Python like

class LegalizeReshape(tvm.relay.ExprMutator):
    def visit_call(self, call):
        if call.op.name == "reshape":
            # Write InferNewShape here.
            return relay.reshape(...)
        return super().visit_call(call)
1 Like