[DISCUSS][torch.fx] Support pytorch's new frontend torch.fx

Currently loading a pytorch module to tvm (relay) follows the path torch.nn.Module -> torch.jit.trace -> torchsript -> tvm.relay , which works work for most vision and nlp models. However, such conversion has its own limitations, one case is the limited tracing inside customized modules. For example,

def fn1(x, w):
    return F.linear(w, w)

x = torch.randn(10, 15)
w = torch.randn(20, 15)

# works 
ts = torch.jit.trace(fn1, (x, w))

class MyLinear(autograd.Function):
    @staticmethod
    def forward(ctx, x, w):
        ctx.save_for_backward(x, w)
        return x.mm(w.t())

    @staticmethod
    def backward(ctx, grad_output):
        input, weight = ctx.saved_tensors
        dydx = grad_output.mm(weight)
        dydw = grad_output.t().mm(input)
        return dydx, dydw

def fn2(x, w):
    return MyLinear.apply(x, w)

# does not work
ts = torch.jit.trace(fn2, (x, w))
print(ts.code)

In many DL tasks, like vision and nlp, defining forward and backward is commonly used in research and develpment. Missing support for customized OP prevents more people from using TVM.

Fortunately, pytorch recent has a new IR frontend torch.fx which can trace into the customized OP and the API has eventually become stable in recent release. I would suggest to add support for torch.fx and can help scratch a PR.

Feel free to comment and share your thoughts.

4 Likes

How does that work? Does it effectively “inline” the custom op?

If it is possible to share the op conversion table with Torchscript, we can certainly add fx support. Otherwise, we need to develop the PT frontend 2.0 from scratch, I don’t think that’s worth it.

I experimented with FX a bit in https://github.com/apache/tvm/pull/10091, my impression was that its “symbolic tracing” is currently very limited, so it only works on simple and clean models.

cc @comaniac @t-vi

Same opinion (and impression to torch.fx). It would be great to

  1. Provide sufficient data points to prove that torch.fx could cover current Relay PyTorch frontend use cases, and could support more use cases (e.g., custom ops).
  2. Discuss the performance impact. My general concern is would there be any performance regression from torch.fx IR. For example, if certain ops are missing or decomposed in torch.fx, it may affect the end to end performance. It would be great to compare the converted Relay IR from torchscript and torch.fx and deliver some insights.
  3. Illustrate the efforts (i.e., a completely separate frontend such as TF and TF2, or just a mode in the current PyTorch frontend).
2 Likes

How does that work? Does it effectively “inline” the custom op?

Yes, most OPs defined in ATen primitive can correctly traced. I attach an example below

x = torch.randn(10, 15)
w = torch.randn(20, 15)

class SomeOPs(autograd.Function):
    @staticmethod
    def forward(ctx, x, w):
        # More calculations can be added here
        w = (w + 10) * 2
        return x.mm(w.t())

def fn2(x, w):
    return SomeOPs.apply(x, w)

ts = torch.jit.trace(fn2, (x, w))
symbolic_traced : torch.fx.GraphModule = symbolic_trace(fn2)
print(symbolic_traced.graph)

'''
graph():
    %x : [#users=1] = placeholder[target=x]
    %w : [#users=1] = placeholder[target=w]
    %add : [#users=1] = call_function[target=operator.add](args = (%w, 10), kwargs = {})
    %mul : [#users=1] = call_function[target=operator.mul](args = (%add, 2), kwargs = {})
    %t : [#users=1] = call_method[target=t](args = (%mul,), kwargs = {})
    %mm : [#users=1] = call_method[target=mm](args = (%x, %t), kwargs = {})
    return mm
'''

In torchscript, such customized OP is simply naming as prim::PythonOp and the impl details are lossing. But with torch.fx, as long as the OP implementation are in python-level, the DAG can be properlly traced.

If it is possible to share the op conversion table with Torchscript, we can certainly add fx support. Otherwise, we need to develop the PT frontend 2.0 from scratch, I don’t think that’s worth it.

fx records primitive ATen operations and surely it can share op mapping table as torchscript. I don’t think we need to develop & maintian a separate table for fx backend.

I experimented with FX a bit in https://github.com/apache/tvm/pull/10091 , my impression was that its “symbolic tracing” is currently very limited, so it only works on simple and clean models.

One known limitation of symbolic tracing is about the if-else statement. But for “complex models”, I find that torch.fx statisfies most models I recently play with (e.g., ViT based models, quantized transformer). I agree there might be some corner issues we need to resolve, but for most cases (i.e., model shipped with torchvision / huggingface), this should work as expected.

PS: sry for updating late, was busy with conference experiements.

Provide sufficient data points to prove that torch.fx could cover current Relay PyTorch frontend use cases, and could support more use cases (e.g., custom ops).

When model does not contain customized OPs (e.g., ResNet familifies shipped with torchvision), torch.fx could definitely cover torchscript use cases since they are delivered in torch 1.8.2.

Any performance regression from torch.fx IR.

Good question, and this depends on the user case.

  1. If a model contains customized OPs, existing pytorch model[python] -> torchscript -> tvm relay will fail in the last step as such OP is tagging as prim::PythonOP which is supported in the current map. https://github.com/apache/tvm/blob/main/python/tvm/relay/frontend/pytorch.py#L2968 ,let alone comparing the performance.
  2. If a model does not contain customized OPs, torch.fx will honestly record every torch.nn.functional operations just as torchscript. We can let the two frontend import shares the same torch op → relay op map to enforce the conversion are identical.

Let me scratch an implementation and show some comparisons between converted relay IR from torchscript and torch.fx.

Illustrate the efforts (i.e., a completely separate frontend such as TF and TF2, or just a mode in the current PyTorch frontend).

Fortunately, torch.fx is more closer to the latter one, a new IR mode in current pytorch and shares the same low-level Aten primitives. So the nightmare between TF and TF2 will not happen again in this case.

What happens if we run fx trace on a custom op, and then run jit.trace? If that flow removes the custom op or PythonOp after jit trace, we don’t have to develop FX frontend.

By complex models, I mean something like SSD or MaskRCNN in torchvision. FX never work on these models having data dependent control flow, by design. That’s a well-known limitation, but from my experience, FX tracing also breaks as soon as it encounters if/else/for that are not truly data dependent (for example, a for loop over image list, which is irrelevant if we always use a single image input, or an if condition checking an input shape, irrelevant if the input shape is always fixed). The latter pain point was what turned me away from FX. In contrast, JIT trace doesn’t have issue with those non-data dependent if / for and it fits perfectly with typical TVM use cases (fixed input shape etc).

Transformer models are simple in terms of model conversion (no control flow or dynamic shape). IMO, being able to import dynamic models like MaskRCNN is what distinguishes our PT frontend from other frontends, so that’s why I’m not really excited about FX in general.

What happens if we run fx trace on a custom op, and then run jit.trace ?

Can you show a simple case? I am not clear how we can use two different tracing method together.

If that flow removes the custom op or PythonOp after jit trace, we don’t have to develop FX frontend.

Agree, if PythonOP does not appear in actual dataflow, then we do not need to bother. But in many codebases, customizing forward / backward in python is very common to implement specific algorithms (not everyone codes in cuda) and those PythonOps are indeed used in the data flow.

It’s possible to run torch.jit.trace on the output of FX tracing. We do that in https://github.com/apache/tvm/blob/95aac9224eb5ef30ab5bb67471c1b6ddfd6e1d6e/tests/python/frontend/pytorch/test_fx_quant.py#L37-L40. You can try that after symbolic_traced : torch.fx.GraphModule = symbolic_trace(fn2) in your example.

1 Like

Get your point. I agree that current FX tracing is not good at if/else/for statement (most python-level symbolic tracing methods have similar issues). But given the recent efforts pytorch community have made on torch.fx: Continual developments since 1.8 and extended fx with functorch, these issues should be improved gradually.

Not sure which one will be the main choices for pytorch IR in the future. I suggest to support fx in advance so users can choose which one to adapt by themselves (fx for customized OP models, torchscript for detection models).

If FX trace → JIT trace flow works, I think users can run FX trace to “remove” custom ops, and then run JIT trace on the model without the custom op to make it ready for TVM import. That way I don’t see a need to support FX models in TVM at all.

Hi, thank you for the interesting discussion.

fx records primitive ATen operations and surely it can share op mapping table as torchscript. I don’t think we need to develop & maintian a separate table for fx backend.

@Lyken17, I’m interested in extracting ATen op records from the fx graph. Would you mind sharing your thoughts on how we can do this? I’m new to TorchFx and these information seems to be abstracted away at the FX graph based on my short experience.

jit.trace after fx.symbolic_trace that @masahi suggested could be one way to achieve this, but I’m wondering if there is a way to do this purely within Fx.

Sure, let me post an example after the metting.

1 Like