[Pytorch] [Quantization] Error during quantization

Hi, first let me thank you for TVM/VTA stack!

I want to deploy Pytorch models on VTA accelerator (Pytorch -> Relay -> Quantization -> VTA). I did not found any tutorial for this, so I’m using these two tutorials (that I can execute successfully) as inspiration :

https://tvm.apache.org/docs/tutorials/frontend/from_pytorch.html#sphx-glr-tutorials-frontend-from-pytorch-py

https://tvm.apache.org/docs/vta/tutorials/frontend/deploy_classification.html#sphx-glr-download-vta-tutorials-frontend-deploy-classification-py

I wrote this (simple) code based on these tutorials. I use it in simulation only for now.

import tvm
from tvm import rpc, autotvm, relay

import vta

import torch
import torchvision

# Load a pretrained PyTorch model

model_name = 'resnet18'
model = getattr(torchvision.models, model_name)(pretrained=True)
model = model.eval()

input_shape = [1, 3, 224, 224]
input_data = torch.randn(input_shape)
scripted_model = torch.jit.trace(model, input_data).eval()

# Import the graph to Relay

input_name = 'input0'
img_shape = (1, 3, 224, 224)
shape_list = [(input_name, img_shape)]

# The model is imported later to follow the VTA tutorial
#mod, params = relay.frontend.from_pytorch(scripted_model,
#                                          shape_list)

# Define the platform and model targets (VTA)

env = vta.get_env()
target = env.target

# Obtain an execution remote (simulation only for now)

remote = rpc.LocalSession()
ctx = remote.ext_dev(0)

# Build the inference graph runtime

with autotvm.tophub.context(target):

    mod, params = relay.frontend.from_pytorch(scripted_model,
                                              shape_list)

    with tvm.transform.PassContext(opt_level=3):

        with relay.quantize.qconfig(global_scale=8.0,
                                    skip_conv_layers=[0]):

            mod = relay.quantize.quantize(mod, params=params)

First, I have a lot of warning messages saying that my tensors are untyped (as in the tutorial) and will be converted as float. But my real problem is that I have an error while executing the quantization (last line of code).

Traceback (most recent call last):
  File "/home/julien/recherche/vta/pytorch/pytorch_relay_vta.py", line 65, in <module>
    mod = relay.quantize.quantize(mod, params=params)
  File "/home/julien/outils/tvm/python/tvm/relay/quantize/quantize.py", line 343, in quantize
    mod = prerequisite_optimize(mod, params)
  File "/home/julien/outils/tvm/python/tvm/relay/quantize/quantize.py", line 316, in prerequisite_optimize
    mod = optimize(mod)
  File "/home/julien/outils/tvm/python/tvm/ir/transform.py", line 130, in __call__
    return _ffi_transform_api.RunPass(self, mod)
  File "/home/julien/outils/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 225, in __call__
    raise get_last_ffi_error()
tvm._ffi.base.TVMError: Traceback (most recent call last):
  [bt] (8) /home/julien/outils/tvm/build/libtvm.so(tvm::relay::fold_scale_axis::ForwardFoldScaleAxis(tvm::RelayExpr const&)+0xc8) [0x7f566af52028]
  [bt] (7) /home/julien/outils/tvm/build/libtvm.so(tvm::relay::ExprVisitor::VisitExpr(tvm::RelayExpr const&)+0x7b) [0x7f566b0cd98b]
  [bt] (6) /home/julien/outils/tvm/build/libtvm.so(tvm::relay::ExprFunctor<void (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&)+0x7b) [0x7f566b08788b]
  [bt] (5) /home/julien/outils/tvm/build/libtvm.so(tvm::relay::fold_scale_axis::ForwardPrep::VisitExpr_(tvm::relay::FunctionNode const*)+0x21) [0x7f566af59d31]
  [bt] (4) /home/julien/outils/tvm/build/libtvm.so(tvm::relay::ExprVisitor::VisitExpr_(tvm::relay::FunctionNode const*)+0xeb) [0x7f566b0caaab]
  [bt] (3) /home/julien/outils/tvm/build/libtvm.so(tvm::relay::ExprVisitor::VisitExpr(tvm::RelayExpr const&)+0x7b) [0x7f566b0cd98b]
  [bt] (2) /home/julien/outils/tvm/build/libtvm.so(tvm::relay::ExprFunctor<void (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&)+0x7b) [0x7f566b08788b]
  [bt] (1) /home/julien/outils/tvm/build/libtvm.so(tvm::relay::fold_scale_axis::ForwardPrep::VisitExpr_(tvm::relay::LetNode const*)+0x3b) [0x7f566af5610b]
  [bt] (0) /home/julien/outils/tvm/build/libtvm.so(+0x121d532) [0x7f566af4c532]
  File "/home/julien/outils/tvm/src/relay/transforms/fold_scale_axis.cc", line 246
TVMError: FoldScaleAxis only accept dataflow-form

Process finished with exit code 1

If I understood well, the Pytorch -> Relay API is very recent. It was previously necessary to use NNVM and ONNX right ? So am I the first one to have this issue ?

Thank you all in advance :slight_smile:

I’ve also hit this issue. See [Relay][Pass]Do we want to allow LetNode in FoldScaleAxis pass?

For now you can try changing FATAL to WARN below.

cc @kevinthesun

Thanks a lot @masahi!

I just replaced FATAL by WARNING (WARN rises an error) and it works. Also do not forget to recompile TVM for the change to take effect.

Now I hit another issue.

Traceback (most recent call last):
  File "/home/julien/recherche/vta/pytorch/pytorch_relay_vta.py", line 65, in <module>
    mod = relay.quantize.quantize(mod, params=params)
  File "/home/julien/outils/tvm/python/tvm/relay/quantize/quantize.py", line 360, in quantize
    mod = quantize_seq(mod)
  File "/home/julien/outils/tvm/python/tvm/ir/transform.py", line 130, in __call__
    return _ffi_transform_api.RunPass(self, mod)
  File "/home/julien/outils/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 225, in __call__
    raise get_last_ffi_error()
tvm._ffi.base.TVMError: Traceback (most recent call last):
  [bt] (8) /home/julien/outils/tvm/build/libtvm.so(tvm::relay::ExprVisitor::VisitExpr_(tvm::relay::MatchNode const*)+0x80) [0x7fe4d685fa60]
  [bt] (7) /home/julien/outils/tvm/build/libtvm.so(tvm::relay::ExprVisitor::VisitExpr(tvm::RelayExpr const&)+0x7b) [0x7fe4d6862adb]
  [bt] (6) /home/julien/outils/tvm/build/libtvm.so(tvm::relay::ExprFunctor<void (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&)+0x7b) [0x7fe4d681c9db]
  [bt] (5) /home/julien/outils/tvm/build/libtvm.so(tvm::relay::ExprVisitor::VisitExpr_(tvm::relay::CallNode const*)+0x2e) [0x7fe4d685f75e]
  [bt] (4) /home/julien/outils/tvm/build/libtvm.so(tvm::relay::ExprVisitor::VisitExpr(tvm::RelayExpr const&)+0x7b) [0x7fe4d6862adb]
  [bt] (3) /home/julien/outils/tvm/build/libtvm.so(tvm::relay::ExprFunctor<void (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&)+0x7b) [0x7fe4d681c9db]
  [bt] (2) /home/julien/outils/tvm/build/libtvm.so(tvm::relay::TypeVarEVisitor::VisitExpr_(tvm::ConstructorNode const*)+0x4a) [0x7fe4d66711fa]
  [bt] (1) /home/julien/outils/tvm/build/libtvm.so(tvm::IRModuleNode::LookupTypeDef(tvm::GlobalTypeVar const&) const+0x112) [0x7fe4d5fd70b2]
  [bt] (0) /home/julien/outils/tvm/build/libtvm.so(+0xb0f5d2) [0x7fe4d5fd35d2]
  File "/home/julien/outils/tvm/src/ir/module.cc", line 299
TVMError: Check failed: it != type_definitions.end(): There is no definition of tensor_int64_t

Process finished with exit code 1

The error is not the same every time: There is no definition of tensor_int64_t There is no definition of tensor_int32_t There is no definition of tensor_int8_t There is no definition of List

Do this issue can be linked to the previous one ? If not I will start another thread.

This is not related, but I can see this and the previous error are introduced by the same feature added to the PyTorch frontend. It is related to dynamic model support and what’s called Prelude module. tensor_int64_t etc are defined in Prelude, but it is complaining that it cannot find their definition.

Can you show me your script? I want to know when this error is raised.

Thanks again @masahi for your help :slight_smile:

It is still the same script. I put it here again for your convenience:

The error is raised when I execute the last line:

mod = relay.quantize.quantize(mod, params=params)

Which commit are you on? I tried your script first on older commit 2ec7caa07, and it worked. Then tried with the latest commit and I got the same error.

The result of from_pytorch is different. I get more free variable warnings now. I’ll take a look.

ok I bisected, if you undo this commit it should work

cc @seanlatias

Thanks a lot @masahi!

I reseted to commit afaa9e492 (just before https://github.com/apache/incubator-tvm/pull/5768) and it solves the second error. But not the first one.

I also confirm that using the latest commit, I have a lot of free variable warnings while calling from_pytorch. But if you undo the changes of https://github.com/apache/incubator-tvm/pull/5768 it will work.

Yes, you need the two fixes on top of the latest master for your script to work. I found that free variable warnings are coming from a different change I made to the PyTorch frontend recently. It shouldn’t result in anything bad, so the warnings can be ignored.