[DISCUSS] Adding a PyTorch Frontend

Currently, the path to support PyTorch in TVM isn’t ideal. Previous issue ended w/ Pytorch/TVM being a suitable solution for the time being but it doesn’t follow TVM’s other frontend implementations in providing a python function to pass in a model and get the corresponding Relay module and converted parameters. Rather, one must build the project, which has its own pointer to a TVM repo (currently Facebook’s own fork) as a sub-directory which is also built in the project’s build script, and either use to_relay (to grab the relay module) or call an enable function in the python script where they define their model (uses TVM to optimize and update the PT graph). The other option which is probably more often used is PyTorch->ONNX->Relay.

I propose we add a simple frontend to enable PyTorch for TVM in line w/ the other frameworks. I have currently implemented a simplified version which accepts traced models and has all image classification models in TorchVision working except for shufflenet (so resnet, squeezenet, vgg, mobilenet, densenet, inception, alexnet, googlenet, and mnasnet) (PR in Amazon’s fork). This implementation is pretty similar to the other frontends and allows us to easily map new operators for more coverage down the line. It supports PyTorch 1.3 and below and also seems to support the nightly version of 1.4 as of a week ago (haven’t tested recently)

Later on, we can also add support for control flow (through scripting) which will lead to a wider range of model coverage. Some issues I can see are making sure the parser stays up to date w/ PyTorch (but I think all of the frontends have the same issue). Let me know your thoughts and if there’s any interest in bringing this to upstream, I opened the PR in Amazon’s fork as we will probably want to have this in regardless.

9 Likes

For discussions of technical choices such as this. it would be great to have a comprehensive discussion about the technical path, pros/cons of both sides of the choices and feasibility study about(e.g. control flow)

We would also like to make the discussion as inclusive as possible, and getting inputs from broader community members. Finally, we certainly want to reach consensus and reduce the amount of duplicated effort, if possible.

Yes, definitely agree with all of the above. Perhaps we can use this thread to discuss all of these? I think some discussion that can be had off the bat is with respect to future control flow support and how to consume models (ie: eager mode, traces, scripts). I haven’t looked too much into control flow (but my initial thought is that it seems slightly straightforward in that we have to implement the control operators and the rest can probably borrow from the other frontends which have control flow).

1 Like

Since we also had a similar discussion in #1984 when we added caffe2 frontend, and most of the necessary bits have been implemented already, I’d say why not. I can certainly see a good use case for this - being able to leverage state of the art pytorch tech without waiting for onnx standardization.

Facebook’s work is to use TVM from PyTorch, I think we are talking about different use case here. It would be nice if we could extract ATen to Relay converter that can be used by both projects.

5 Likes

Before we focus on the action items, would be great to see more discussions in terms of technical directions.

In particular how can be deal with additional control flow(TorchScript) related things that are not yet covered. Having a tracing based frontend may not be as interesting as onnx already covered some of that. It would be more interesting to get something like TorchScript to relay translation.

1 Like

@masahi It would be nice to share the operator mappings used in both projects. From briefly looking on their project, some initial hurdles would be due to python (this) vs C++ implementation (theirs).

@t-vi Thanks for the response and providing your perspective as primarily a PyTorch user. Are you saying that ONNX (and this) is playing catch-up to PyTorch/TVM? Or that this would be playing catch-up to ONNX?

For the first case, I think both the ONNX frontend and this frontend is slightly more mature than PyTorch/TVM in terms of operator support and model coverage. I’m basing this off the README on their repo but it looks like only a few ops are supported in the current version. For the latter (ONNX having more support than PyTorch in TVM), this might be true.

I definitely agree that the bail-out mechanism they provide in their repo and ultimately how it’s used (just call torch_tvm.enable() in the same Python script where they define their model and it’ll automatically optimize the graph!) is very attractive to PyTorch users. I forsee adding something similar to this frontend, and other frontends, in TVM is pretty challenging and not too feasible.

I would still argue that this provides a separate use-case which is why I think its good to have in TVM. I know I (and I’ve had difficulties doing so) don’t want to build a separate project outside of TVM just for a specific frontend. Serializing a PT model to another format is also not as intuitive as just having a parser.

@tqchen

Fair enough, not sure if I’ll have time to explore these before the end of the year :frowning: Exploring the scripting primitive rather than tracing in TorchScript is probably the way to go in the future.

@t-vi you might be interested in the recent proposal by @jonso, who shares similar concerns with yours.

Thanks, @masahi. It seems that a lot of people would like the “fallback” support, as “operator not supported” is one of the main blockers for running models. Even though it would be ideal to add all ops in TVM, playing some level of catch-up is unavoidable as the frameworks we import from are constantly being updated.

I would encourage more comments on the RFC post so we can discuss the implementation options.

@tqchen @alexwong Translating TorchScript to Relay sounds very interesting. TorchScript also seems to be the direction PyTorch devs are going forward. Quoting from a comment by one of the PyTorch devs:

TorchScript is intended as a replacement for PyTorch -> ONNX -> Caffe2 conversion. We think the experience is overall better, as we can precisely preserve the semantics of your model code and you don’t have to work with two separate frameworks. That said, we will continue to support a PyTorch → ONNX conversion step, as it is a standard that we want PyTorch as a framework to interoperate with well. ONNX covers CoreML, ONNX.js etc. which are pretty useful (from what we heard).

For me, if I have to go through ONNX, I have to wait for three pieces in place:

  • ONNX standardization of operators that I care about
  • PyTorch to ONNX exporter
  • ONNX to Relay frontend

If instead we consider a direct PyTorch to TVM route, we only need to implement Relay frontend. This alone is a deal breaker for me. The particular use case I have is quantization - I want to leverage PyTorch’s quantization capability, but I don’t know how long I have to wait before the three pieces above become ready for new quantization methods that PyTorch community would develop.

2 Likes

@t-vi

I don’t think PyTorch -> TVM and PyTorch -> ONNX is the fair comparison here (rather PyTorch -> ONNX -> TVM). This is removing a step from the existing PyTorch -> ONNX -> TVM path. I should also clarify that its definitely not my intention to dismiss the PyTorch/TVM project, I just think its a separate enough use case so its worth having a native PyTorch parser in TVM.

As for how this will fit in with @jonso’s fallback feature, let me read through the proposal and get a better idea :slight_smile:

1 Like

I think @masahi’s comment above illustrates the pain-point. MLIR, ONNX, Relay are all some sort of IR and it would definitely be nicer to go directly from PT rather than through ONNX. For future operator support for PT, we just have to update this parser rather than hope ONNX covers it and then update the ONNX parser.

I’m also not an ONNX person but the ONNX parser in TVM doesn’t support control flow and from looking at some docs and code (very quickly so correct me if I’m wrong!) https://pytorch.org/docs/stable/onnx.html, https://github.com/pytorch/pytorch/blob/66f2bba8527d941c7d73d3cae9e9576c601587a6/torch/jit/init.py#L249, the ONNX parser uses a trace to export a PyTorch graph which means it doesn’t preserve control flow. This gives more reason to move to a native parser for PT as we can implement some support in the future rather than wait for PT -> ONNX path to be updated.

1 Like

@alexwong Since traced models and TorchScript-ed models share the same PyTorch IR (correct me if I am wrong @t-vi), we can use your WIP implementation as is, and add control flow ops on top to support loading TorchScript-ed models, is that right?

TracedModule is a subclass of ScriptModule, so if you have conversion from TracedModule working, you are already converting a subset of ScriptModule.

1 Like

Yes, that is how I envision it. Both tracing and scripting are ways to go from Eager-mode PyTorch to TorchScript and while the current implementation is grabbing a trace; in the future we can switch to scripting which among other things has control flow ops which we can then support.

We can also consider loading serialized *.pt ScriptModule (similar to ONNX workflow, where we assume serialization is already done by users). That way we don’t have to care if the model comes from script or trace.

@masahi That is something I thought about and certainly an option.

@t-vi I would PM this but can’t for some reason (maybe because basic trust level -> basic trust level PM is not allowed). I noticed you withdrew your comments which is too bad. I thought it was a good discussion to have and would be helpful for those joining the convo to read through.

Thanks everyone for good discussion. I want to say that the community welcomes different opinions, even conflicting views. This is exactly what apache way is about, resolving technical discussions in a diplomatic way and get things that is best for the users.

Specifically, we would like to separate how we do it(technical pros and cons) from the the decision(what is the final decision)

Most of the technical discussions can be proceeded with as many opinions as possible and people can agree on the result. Then the decision is something that can be suggested by summarizing the points from two sides

4 Likes

This is good work.

I agree this part is not ideal . I think the custom TVM repo thing is totally unnecessary and I’ll check whether we can just go back to use the upstream TVM repo. As for conversion code, ideally I think we should make the JIT IR -> Relay conversion code a reusable library preferably in C++. Then it’s easier to bind to Python. But it’s not the status for pytorch/tvm repo now.

I can see what this proposal is coming from. There are different use scenarios and for us we would like to have PyTorch framework to drive the outer layer of inference run and use TVM as sort of acceleration framework (https://sampl.cs.washington.edu/tvmconf/slides/2019/E05-Hao-Lu-Ansha-Yu.pdf). But I can totally see a use case where we just take a pt model, convert it and run in in relay VM without any other runtime dependency.

I think the WIP PR is good. And since we really just have TorchScript where tracing and scripting are just different ways of converting the python program to TorchScript, starting from the simpler traced model is a good plan.

3 Likes

@alexwong @tqchen

I’ve added support for translating TorchScript If and Loop nodes to Relay in my modified version of PyTorch parser based on @alexwong’s PR. I uploaded test cases along with a standalone pytorch frontend that can be run with TVM upstream master branch.

For example, given a PyTorch module with conditional and loop,

class LoopWithIf(torch.nn.Module):
    def forward(self, inp):
        a = inp
        for i in range(inp.size(0)):
            b = a * 2
            b = a + b
            if b.sum() > 0.0:
                a += b
            else:
                a -= b
        return a

The TorchScript compiler generates the following IR. Note the prim::If and prim::Loop nodes.

graph(%self : __torch__.LoopWithIf,
      %inp.1 : Tensor):
  %2 : None = prim::Constant()
  %3 : int = prim::Constant[value=1]()
  %4 : bool = prim::Constant[value=1]() # dynamic_test.py:64:8
  %5 : int = prim::Constant[value=0]() # dynamic_test.py:64:32
  %6 : int = prim::Constant[value=2]() # dynamic_test.py:65:20
  %7 : float = prim::Constant[value=0]() # dynamic_test.py:67:25
  %8 : int = aten::size(%inp.1, %5) # dynamic_test.py:64:23
  %a : Tensor = prim::Loop(%8, %4, %inp.1) # dynamic_test.py:64:8
    block0(%i : int, %a.15 : Tensor):
      %b.1 : Tensor = aten::mul(%a.15, %6) # dynamic_test.py:65:16
      %b.3 : Tensor = aten::add(%a.15, %b.1, %3) # dynamic_test.py:66:16
      %14 : Tensor = aten::sum(%b.3, %2) # dynamic_test.py:67:15
      %15 : Tensor = aten::gt(%14, %7) # dynamic_test.py:67:15
      %16 : bool = aten::Bool(%15) # dynamic_test.py:67:15
      %a.14 : Tensor = prim::If(%16) # dynamic_test.py:67:12
        block0():
          %a.4 : Tensor = aten::add_(%a.15, %b.3, %3) # dynamic_test.py:68:16
          -> (%a.4)
        block1():
          %a.7 : Tensor = aten::sub_(%a.15, %b.3, %3) # dynamic_test.py:70:16
          -> (%a.7)
      -> (%4, %a.14)
  return (%a)

My parser can translate above IR to Relay equivalent below.

v0.0.4
def @main(%X: Tensor[(10, 20), float32]) -> Tensor[(10, 20), float32] {
  %9 = (
    let %while_loop: fn (int32, Tensor[(10, 20), float32]) -> (int32, Tensor[(10, 20), float32]) = fn (%i: int32, %a.15: Tensor[(10, 20), float32]) -> (int32, Tensor[(10, 20), float32]) {
      %0 = greater_equal(%i, 1 /* ty=int32 */) /* ty=bool */;
      %1 = less_equal(%i, 10 /* ty=int32 */) /* ty=bool */;
      %2 = logical_and(%0, %1) /* ty=bool */;
      if (%2) {
        %3 = add(%i, 1 /* ty=int32 */) /* ty=int32 */;
        %4 = multiply(%a.15, 2f /* ty=float32 */) /* ty=Tensor[(10, 20), float32] */;
        %5 = add(%a.15, %4) /* ty=Tensor[(10, 20), float32] */;
        %6 = sum(%5) /* ty=float32 */;
        %7 = greater(%6, 0f /* ty=float32 */) /* ty=bool */;
        %8 = if (%7) {
          add(%a.15, %5) /* ty=Tensor[(10, 20), float32] */
        } else {
          subtract(%a.15, %5) /* ty=Tensor[(10, 20), float32] */
        };
        %while_loop(%3, %8) /* ty=(int32, Tensor[(10, 20), float32]) */
      } else {
        (%i, %a.15)
      }
    };
    %while_loop
  );
  %10 = %9(1 /* ty=int32 */, %X) /* ty=(int32, Tensor[(10, 20), float32]) */;
  %10.1
}

The conditional maps to Relay If and loop is translated to conditional and tail recursion via relay.loop.while_loop(...). The translation is straightforward and it only takes about 50 lines in my parser.

5 Likes

Really nice work @masahi!

1 Like