Model with torch.where function breaks TVM interpreter

Hello, I have the following pytorch code:

import torch
import tvm
from tvm import relay
from tvm.contrib import graph_runtime

class PytorchModel(torch.nn.Module):
    def __init__(self):
        super(PytorchModel, self).__init__()
        self.one = torch.ones(1, dtype=torch.int32)
        self.one = torch.as_strided(self.one, (10 * 10,), (0,))
        self.zero = torch.zeros(1, dtype=torch.int32)
        self.zero = torch.as_strided(self.zero, (10 * 10,), (0,))

    def forward(self, vector1, vector2):
        resized_vector1 = vector1.view(1, -1)
        resized_vector2 = vector2.view(-1, 1)
        count = torch.where((resized_vector1 ==
            resized_vector2).view(-1), self.one, self.zero).sum()
        return count


vector1 = torch.zeros([10], dtype=torch.int32)
vector2 = torch.zeros([10], dtype=torch.int32)
for i in range(10):
    vector1[i] = i
    vector2[i] = i

init_model = PytorchModel()

scripted_model = torch.jit.script(init_model)
matches = scripted_model(vector1, vector2)

input1_name = 'input0'
input2_name = 'input1'
shape_list = [(input1_name, (10,)),
              (input2_name, (10,))]
target = 'llvm'
ctx = tvm.cpu()
vector1_tvm = tvm.nd.array(vector1, ctx)
vector2_tvm = tvm.nd.array(vector2, ctx)

target = 'llvm'
ctx = tvm.cpu()

model, params = relay.frontend.from_pytorch(scripted_model, shape_list,
        default_dtype="int32")
with tvm.transform.PassContext(opt_level=3):
    executor = relay.create_executor("vm", mod=model, ctx=ctx, target=target)
    tvm_model = executor.evaluate()

matches = tvm_model(input0=vector1_tvm, input1=vector2_tvm)

My llibrary versions are:

torch==1.6.0
tvm==0.7.dev1

When I am running the code I am receiving the following error:

WARNING:root:Untyped Tensor found, assume it is int32
WARNING:root:Untyped Tensor found, assume it is int32
WARNING:root:Untyped Tensor found, assume it is int32
WARNING:root:Untyped Tensor found, assume it is int32
WARNING:root:Untyped Tensor found, assume it is int32
WARNING:root:Untyped Tensor found, assume it is int32
WARNING:root:Untyped Tensor found, assume it is int32
Traceback (most recent call last):
  File "reproduce_bug.py", line 51, in <module>
    matches = tvm_model(input0=vector1_tvm, input1=vector2_tvm)
  File "/home/dkoutsou/env/lib/python3.6/site-packages/tvm-0.7.dev1-py3.6-linux-x86_64.egg/tvm/relay/backend/vm.py", line 264, in _vm_wrapper
    args = self._convert_args(main, args, kwargs)
  File "/home/dkoutsou/env/lib/python3.6/site-packages/tvm-0.7.dev1-py3.6-linux-x86_64.egg/tvm/relay/backend/interpreter.py", line 120, in _convert_args
    cargs.append(kwargs[name])
KeyError: 'one'

Do you have any idea how I should debug this? In the beginning I though that the problem is the naming of the variables but renaming self.one to something else didn’t help. The same code works fine with vanilla TVM but with TVM VM it breaks.

Thanks very much for any help!

This is also the problem of using torch.jit.script. You need to supply inputs corresponding to self.one and self.zero. If you dump the relay model, you can see that it is expecting inputs one and zeros:

fn (%input0: Tensor[(10), int32], %input1: Tensor[(10), int32], %one: Tensor[(100), int32], %zero: Tensor[(100), int32]) -> int32 {
  %0 = reshape(%input0, newshape=[1, -1]) /* ty=Tensor[(1, 10), int32] */;
  %1 = reshape(%input1, newshape=[-1, 1]) /* ty=Tensor[(10, 1), int32] */;
  %2 = equal(%0, %1) /* ty=Tensor[(10, 10), bool] */;
  %3 = reshape(%2, newshape=[-1]) /* ty=Tensor[(100), bool] */;
  %4 = where(%3, %one, %zero) /* ty=Tensor[(100), int32] */;
  sum(%4) /* ty=int32 */
}

Unless you have a good reason to use torch.jit.script, please use torch.jit.trace instead, as I explained in the github issue you opened yesterday. I recommend learning the difference of scripting vs tracing on pytorch website.

1 Like

Thanks for your reply! Reading the documentation of TVM I understood that scripting is better (meaning more performant) for TVM VM and tracing for vanilla TVM. Have I understood this wrongly?

No, where did you read that document? Scripting makes sense if your model has input dependent control flow (if, for loop etc). And TVM VM can run models with control flow, while graph runtime (what you referred to as vanilla TVM) cannot.

So if you are using scripting for a good reason, you also want VM. But if your model doesn’t have any control flow, tracing is enough. VM can run all models vanilla TVM can, but it could have perf overhead compared to vanilla TVM.

Also, converting models coming from torch.jit.script is more challenging than traced models, because scripting, by definition, preserves all python constructs in your model definition. You want to use scripting to preserve python control flow, but it also preserves other python baggage like raising exception.

1 Like

Thank you very much! Seems I have to do some reading over scripting and tracing as you suggested :slight_smile: