TypeError: convert "48.0" with `const` first

Hello, I’m trying to convert a fairly complex model (StyleGAN2-ADA generator) to TVM to quantize and/or auto-tune it.

I’ve already patched a couple unsupported operations to make it compatible (aten::randn and aten::flip), but now I’m running into a cryptic error:

Traceback (most recent call last):
  File "/home/hans/.conda/envs/hans/lib/python3.8/runpy.py", line 193, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/home/hans/.conda/envs/hans/lib/python3.8/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/home/hans/code/maua-stylegan2/nvsg2a/quantization/tvm.py", line 40, in <module>
    mod, params = relay.frontend.from_pytorch(generator, [("input", data_shape)], {"aten::randn": randn})
  File "/home/hans/code/tvm/python/tvm/relay/frontend/pytorch.py", line 3288, in from_pytorch
    ret = converter.convert_operators(_get_operator_nodes(graph.nodes()), outputs, ret_name)[0]
  File "/home/hans/code/tvm/python/tvm/relay/frontend/pytorch.py", line 2709, in convert_operators
    relay_out = relay_op(
  File "/home/hans/code/tvm/python/tvm/relay/frontend/pytorch.py", line 557, in reciprocal
    return _expr.const(1.0, dtype=input_types[0]) / data
  File "/home/hans/code/tvm/python/tvm/relay/expr.py", line 145, in __truediv__
    return self.__div__(other)
  File "/home/hans/code/tvm/python/tvm/relay/expr.py", line 135, in __div__
    raise TypeError('convert "%s" with `const` first' % str(other))
TypeError: convert "48.0" with `const` first

This is pretty deep in the network (~12th block) which is weird because all the blocks before it seem to be ok.

Module code
def forward(self,
    x: Tensor,
    x0: Tensor) -> Tensor:
  _0 = self.bias
  _1 = self.resample_filter
  _2 = self.weight
  _3 = self.noise_strength
  _4 = (self.affine).forward(x, )
  _5 = ops.prim.NumToTensor(torch.size(x0, 0))
  _6 = torch.randn([int(_5), 1, 256, 256], dtype=6, layout=None, device=torch.device("cuda:0"), pin_memory=False)
  noise = torch.mul(_6, _3)
  batch_size = ops.prim.NumToTensor(torch.size(x0, 0))
  _7 = int(batch_size)
  _8 = int(batch_size)
  in_channels = ops.prim.NumToTensor(torch.size(_2, 1))
  _9 = int(in_channels)
  kh = ops.prim.NumToTensor(torch.size(_2, 2))
  _10 = int(kh)
  kw = ops.prim.NumToTensor(torch.size(_2, 3))
  _11 = int(kw)
  _12 = torch.mul(torch.reciprocal(CONSTANTS.c0), CONSTANTS.c1)
  _13 = torch.div(_12, torch.norm(_2, inf, [1, 2, 3], True))
  weight = torch.mul(_2, _13)
  styles = torch.div(_4, torch.norm(_4, inf, [1], True))
  w = torch.unsqueeze(weight, 0)
  _14 = torch.reshape(styles, [_8, 1, -1, 1, 1])
  w0 = torch.mul(w, _14)
  _15 = torch.sum(torch.mul(w0, w0), [2, 3, 4], False, dtype=None)
  dcoefs = torch.rsqrt(torch.add(_15, CONSTANTS.c2, alpha=1))
  _16 = torch.reshape(dcoefs, [_7, -1, 1, 1, 1])
  w1 = torch.mul(w0, _16)
  _17 = ops.prim.NumToTensor(torch.size(x0, 2))
  _18 = int(_17)
  _19 = ops.prim.NumToTensor(torch.size(x0, 3))
  input = torch.reshape(x0, [1, -1, _18, int(_19)])
  w2 = torch.reshape(w1, [-1, _9, _10, _11])
  w3 = torch.to(w2, 5, False, False, None)
  w4 = torch.transpose(w3, 0, 1)
  x1 = torch._convolution(input, w4, None, [2, 2], [0, 0], [1, 1], True, [0, 0], 1, False, False, True, True)
  batch_size0 = ops.prim.NumToTensor(torch.size(x1, 0))
  _20 = int(batch_size0)
  _21 = int(batch_size0)
  num_channels = ops.prim.NumToTensor(torch.size(x1, 1))
  _22 = int(num_channels)
  _23 = int(num_channels)
  _24 = int(num_channels)
  in_height = ops.prim.NumToTensor(torch.size(x1, 2))
  _25 = int(in_height)
  in_width = ops.prim.NumToTensor(torch.size(x1, 3))
  input0 = torch.reshape(x1, [_21, _24, _25, 1, int(in_width), 1])
  x2 = torch.constant_pad_nd(input0, [0, 0, 0, 0, 0, 0], 0)
  _26 = int(torch.mul(in_height, CONSTANTS.c1))
  _27 = int(torch.mul(in_width, CONSTANTS.c1))
  input1 = torch.reshape(x2, [_20, _23, _26, _27])
  x3 = torch.constant_pad_nd(input1, [1, 1, 1, 1], 0)
  _28 = ops.prim.NumToTensor(torch.size(x3, 2))
  _29 = int(torch.sub(_28, CONSTANTS.c3, alpha=1))
  _30 = ops.prim.NumToTensor(torch.size(x3, 3))
  _31 = int(torch.sub(_30, CONSTANTS.c3, alpha=1))
  _32 = torch.slice(x3, 0, 0, 9223372036854775807, 1)
  _33 = torch.slice(_32, 1, 0, 9223372036854775807, 1)
  input2 = torch.slice(torch.slice(_33, 2, 0, _29, 1), 3, 0, _31, 1)
  f = torch.mul(_1, CONSTANTS.c4)
  x4 = torch.to(f, 5, False, False, None)
  _34 = torch.slice(x4, 0, 0, 9223372036854775807, 1)
  f0 = torch.slice(_34, 1, 0, 9223372036854775807, 1)
  _35 = torch.unsqueeze(torch.unsqueeze(f0, 0), 1)
  weight0 = torch.repeat(_35, [_22, 1, 1, 1])
  x5 = torch._convolution(input2, weight0, None, [1, 1], [0, 0], [1, 1], False, [0, 0], 128, False, False, True, True)
  _36 = torch.slice(x5, 0, 0, 9223372036854775807, 1)
  _37 = torch.slice(_36, 1, 0, 9223372036854775807, 1)
  _38 = torch.slice(_37, 2, 0, 9223372036854775807, 1)
  x6 = torch.slice(_38, 3, 0, 9223372036854775807, 1)
  _39 = ops.prim.NumToTensor(torch.size(x6, 2))
  _40 = int(_39)
  _41 = ops.prim.NumToTensor(torch.size(x6, 3))
  x7 = torch.reshape(x6, [1, -1, _40, int(_41)])
  x8 = torch.add_(x7, noise, alpha=1)
  b = torch.to(_0, 5, False, False, None)
  input3 = torch.add(x8, torch.reshape(b, [1, -1, 1, 1]), alpha=1)
  x9 = torch.leaky_relu(input3, 0.20000000000000001)
  x10 = torch.mul(x9, CONSTANTS.c5)
  return torch.clamp(x10, -256., 256.)
Module graph

The offending op seems to be %63

graph(%self.53 : __torch__.torch_utils.persistence.___torch_mangle_37.SynthesisLayer,
      %x.185 : Float(1:9216, 512:1, requires_grad=0, device=cuda:0),
      %x.186 : Half(1:4194304, 256:16384, 128:128, 128:1, requires_grad=0, device=cuda:0)):
  %3 : Tensor = prim::GetAttr[name="bias"](%self.53)
  %4 : Tensor = prim::GetAttr[name="resample_filter"](%self.53)
  %5 : Tensor = prim::GetAttr[name="weight"](%self.53)
  %6 : Tensor = prim::GetAttr[name="noise_strength"](%self.53)
  %7 : __torch__.torch_utils.persistence.___torch_mangle_36.FullyConnectedLayer = prim::GetAttr[name="affine"](%self.53)
  %514 : Tensor = prim::CallMethod[name="forward"](%7, %x.185)
  %9 : int = prim::Constant[value=0](), scope: __module.synthesis/__module.synthesis.b256/__module.synthesis.b256.conv0 # <string>:333:0
  %10 : int = aten::size(%x.186, %9), scope: __module.synthesis/__module.synthesis.b256/__module.synthesis.b256.conv0 # <string>:333:0
  %11 : Long(device=cpu) = prim::NumToTensor(%10), scope: __module.synthesis/__module.synthesis.b256/__module.synthesis.b256.conv0
  %12 : int = aten::Int(%11), scope: __module.synthesis/__module.synthesis.b256/__module.synthesis.b256.conv0
  %22 : int = prim::Constant[value=1](), scope: __module.synthesis/__module.synthesis.b256/__module.synthesis.b256.conv0 # <string>:333:0
  %23 : int = prim::Constant[value=256](), scope: __module.synthesis/__module.synthesis.b256/__module.synthesis.b256.conv0 # <string>:333:0
  %24 : int = prim::Constant[value=256](), scope: __module.synthesis/__module.synthesis.b256/__module.synthesis.b256.conv0 # <string>:333:0
  %25 : int[] = prim::ListConstruct(%12, %22, %23, %24), scope: __module.synthesis/__module.synthesis.b256/__module.synthesis.b256.conv0
  %26 : int = prim::Constant[value=6](), scope: __module.synthesis/__module.synthesis.b256/__module.synthesis.b256.conv0 # <string>:333:0
  %27 : None = prim::Constant(), scope: __module.synthesis/__module.synthesis.b256/__module.synthesis.b256.conv0
  %28 : Device = prim::Constant[value="cuda:0"](), scope: __module.synthesis/__module.synthesis.b256/__module.synthesis.b256.conv0 # <string>:333:0
  %29 : bool = prim::Constant[value=0](), scope: __module.synthesis/__module.synthesis.b256/__module.synthesis.b256.conv0 # <string>:333:0
  %30 : Float(1:65536, 1:65536, 256:256, 256:1, requires_grad=0, device=cuda:0) = aten::randn(%25, %26, %27, %28, %29), scope: __module.synthesis/__module.synthesis.b256/__module.synthesis.b256.conv0 # <string>:333:0
  %noise.12 : Float(1:65536, 1:65536, 256:256, 256:1, requires_grad=0, device=cuda:0) = aten::mul(%30, %6), scope: __module.synthesis/__module.synthesis.b256/__module.synthesis.b256.conv0 # <string>:333:0
  %32 : int = prim::Constant[value=0](), scope: __module.synthesis/__module.synthesis.b256/__module.synthesis.b256.conv0 # <string>:43:0
  %33 : int = aten::size(%x.186, %32), scope: __module.synthesis/__module.synthesis.b256/__module.synthesis.b256.conv0 # <string>:43:0
  %batch_size.28 : Long(device=cpu) = prim::NumToTensor(%33), scope: __module.synthesis/__module.synthesis.b256/__module.synthesis.b256.conv0
  %35 : int = aten::Int(%batch_size.28), scope: __module.synthesis/__module.synthesis.b256/__module.synthesis.b256.conv0
  %36 : int = aten::Int(%batch_size.28), scope: __module.synthesis/__module.synthesis.b256/__module.synthesis.b256.conv0
  %49 : int = prim::Constant[value=1](), scope: __module.synthesis/__module.synthesis.b256/__module.synthesis.b256.conv0 # <string>:44:0
  %50 : int = aten::size(%5, %49), scope: __module.synthesis/__module.synthesis.b256/__module.synthesis.b256.conv0 # <string>:44:0
  %in_channels.18 : Long(device=cpu) = prim::NumToTensor(%50), scope: __module.synthesis/__module.synthesis.b256/__module.synthesis.b256.conv0
  %52 : int = aten::Int(%in_channels.18), scope: __module.synthesis/__module.synthesis.b256/__module.synthesis.b256.conv0
  %53 : int = prim::Constant[value=2](), scope: __module.synthesis/__module.synthesis.b256/__module.synthesis.b256.conv0 # <string>:44:0
  %54 : int = aten::size(%5, %53), scope: __module.synthesis/__module.synthesis.b256/__module.synthesis.b256.conv0 # <string>:44:0
  %kh.18 : Long(device=cpu) = prim::NumToTensor(%54), scope: __module.synthesis/__module.synthesis.b256/__module.synthesis.b256.conv0
  %56 : int = aten::Int(%kh.18), scope: __module.synthesis/__module.synthesis.b256/__module.synthesis.b256.conv0
  %57 : int = prim::Constant[value=3](), scope: __module.synthesis/__module.synthesis.b256/__module.synthesis.b256.conv0 # <string>:44:0
  %58 : int = aten::size(%5, %57), scope: __module.synthesis/__module.synthesis.b256/__module.synthesis.b256.conv0 # <string>:44:0
  %kw.18 : Long(device=cpu) = prim::NumToTensor(%58), scope: __module.synthesis/__module.synthesis.b256/__module.synthesis.b256.conv0
  %60 : int = aten::Int(%kw.18), scope: __module.synthesis/__module.synthesis.b256/__module.synthesis.b256.conv0
  %63 : Double(requires_grad=0, device=cpu) = prim::Constant[value={48}](), scope: __module.synthesis/__module.synthesis.b256/__module.synthesis.b256.conv0 # /home/hans/.conda/envs/hans/lib/python3.8/site-packages/torch/tensor.py:519:0
  %64 : Double(requires_grad=0, device=cpu) = aten::reciprocal(%63), scope: __module.synthesis/__module.synthesis.b256/__module.synthesis.b256.conv0 # /home/hans/.conda/envs/hans/lib/python3.8/site-packages/torch/tensor.py:519:0
  %65 : Long(requires_grad=0, device=cpu) = prim::Constant[value={1}](), scope: __module.synthesis/__module.synthesis.b256/__module.synthesis.b256.conv0 # /home/hans/.conda/envs/hans/lib/python3.8/site-packages/torch/tensor.py:519:0
  %66 : Double(requires_grad=0, device=cpu) = aten::mul(%64, %65), scope: __module.synthesis/__module.synthesis.b256/__module.synthesis.b256.conv0 # /home/hans/.conda/envs/hans/lib/python3.8/site-packages/torch/tensor.py:519:0
  %67 : float = prim::Constant[value=inf](), scope: __module.synthesis/__module.synthesis.b256/__module.synthesis.b256.conv0 # /home/hans/.conda/envs/hans/lib/python3.8/site-packages/torch/functional.py:1337:0
  %68 : int = prim::Constant[value=1](), scope: __module.synthesis/__module.synthesis.b256/__module.synthesis.b256.conv0 # /home/hans/.conda/envs/hans/lib/python3.8/site-packages/torch/functional.py:1337:0
  %69 : int = prim::Constant[value=2](), scope: __module.synthesis/__module.synthesis.b256/__module.synthesis.b256.conv0 # /home/hans/.conda/envs/hans/lib/python3.8/site-packages/torch/functional.py:1337:0
  %70 : int = prim::Constant[value=3](), scope: __module.synthesis/__module.synthesis.b256/__module.synthesis.b256.conv0 # /home/hans/.conda/envs/hans/lib/python3.8/site-packages/torch/functional.py:1337:0
  %71 : int[] = prim::ListConstruct(%68, %69, %70), scope: __module.synthesis/__module.synthesis.b256/__module.synthesis.b256.conv0
  %72 : bool = prim::Constant[value=1](), scope: __module.synthesis/__module.synthesis.b256/__module.synthesis.b256.conv0 # /home/hans/.conda/envs/hans/lib/python3.8/site-packages/torch/functional.py:1337:0
  %73 : Float(128:1, 1:1, 1:1, 1:1, requires_grad=0, device=cuda:0) = aten::norm(%5, %67, %71, %72), scope: __module.synthesis/__module.synthesis.b256/__module.synthesis.b256.conv0 # /home/hans/.conda/envs/hans/lib/python3.8/site-packages/torch/functional.py:1337:0
  %74 : Float(128:1, 1:1, 1:1, 1:1, requires_grad=0, device=cuda:0) = aten::div(%66, %73), scope: __module.synthesis/__module.synthesis.b256/__module.synthesis.b256.conv0 # <string>:52:0
  %weight.29 : Float(128:2304, 256:9, 3:3, 3:1, requires_grad=0, device=cuda:0) = aten::mul(%5, %74), scope: __module.synthesis/__module.synthesis.b256/__module.synthesis.b256.conv0 # <string>:51:0
  %76 : float = prim::Constant[value=inf](), scope: __module.synthesis/__module.synthesis.b256/__module.synthesis.b256.conv0 # /home/hans/.conda/envs/hans/lib/python3.8/site-packages/torch/functional.py:1337:0
  %77 : int = prim::Constant[value=1](), scope: __module.synthesis/__module.synthesis.b256/__module.synthesis.b256.conv0 # /home/hans/.conda/envs/hans/lib/python3.8/site-packages/torch/functional.py:1337:0
  %78 : int[] = prim::ListConstruct(%77), scope: __module.synthesis/__module.synthesis.b256/__module.synthesis.b256.conv0
  %79 : bool = prim::Constant[value=1](), scope: __module.synthesis/__module.synthesis.b256/__module.synthesis.b256.conv0 # /home/hans/.conda/envs/hans/lib/python3.8/site-packages/torch/functional.py:1337:0
  %80 : Float(1:1, 1:1, requires_grad=0, device=cuda:0) = aten::norm(%514, %76, %78, %79), scope: __module.synthesis/__module.synthesis.b256/__module.synthesis.b256.conv0 # /home/hans/.conda/envs/hans/lib/python3.8/site-packages/torch/functional.py:1337:0
  %styles.18 : Float(1:256, 256:1, requires_grad=0, device=cuda:0) = aten::div(%514, %80), scope: __module.synthesis/__module.synthesis.b256/__module.synthesis.b256.conv0 # <string>:54:0
  %82 : int = prim::Constant[value=0](), scope: __module.synthesis/__module.synthesis.b256/__module.synthesis.b256.conv0 # <string>:60:0
  %w.111 : Float(1:294912, 128:2304, 256:9, 3:3, 3:1, requires_grad=0, device=cuda:0) = aten::unsqueeze(%weight.29, %82), scope: __module.synthesis/__module.synthesis.b256/__module.synthesis.b256.conv0 # <string>:60:0
  %84 : int = prim::Constant[value=1](), scope: __module.synthesis/__module.synthesis.b256/__module.synthesis.b256.conv0 # <string>:61:0
  %85 : int = prim::Constant[value=-1](), scope: __module.synthesis/__module.synthesis.b256/__module.synthesis.b256.conv0 # <string>:61:0
  %86 : int = prim::Constant[value=1](), scope: __module.synthesis/__module.synthesis.b256/__module.synthesis.b256.conv0 # <string>:61:0
  %87 : int = prim::Constant[value=1](), scope: __module.synthesis/__module.synthesis.b256/__module.synthesis.b256.conv0 # <string>:61:0
  %88 : int[] = prim::ListConstruct(%36, %84, %85, %86, %87), scope: __module.synthesis/__module.synthesis.b256/__module.synthesis.b256.conv0
  %89 : Float(1:256, 1:256, 256:1, 1:1, 1:1, requires_grad=0, device=cuda:0) = aten::reshape(%styles.18, %88), scope: __module.synthesis/__module.synthesis.b256/__module.synthesis.b256.conv0 # <string>:61:0
  %w.112 : Float(1:294912, 128:2304, 256:9, 3:3, 3:1, requires_grad=0, device=cuda:0) = aten::mul(%w.111, %89), scope: __module.synthesis/__module.synthesis.b256/__module.synthesis.b256.conv0 # <string>:61:0
  %91 : Float(1:294912, 128:2304, 256:9, 3:3, 3:1, requires_grad=0, device=cuda:0) = aten::mul(%w.112, %w.112), scope: __module.synthesis/__module.synthesis.b256/__module.synthesis.b256.conv0 # <string>:63:0
  %92 : int = prim::Constant[value=2](), scope: __module.synthesis/__module.synthesis.b256/__module.synthesis.b256.conv0 # <string>:63:0
  %93 : int = prim::Constant[value=3](), scope: __module.synthesis/__module.synthesis.b256/__module.synthesis.b256.conv0 # <string>:63:0
  %94 : int = prim::Constant[value=4](), scope: __module.synthesis/__module.synthesis.b256/__module.synthesis.b256.conv0 # <string>:63:0
  %95 : int[] = prim::ListConstruct(%92, %93, %94), scope: __module.synthesis/__module.synthesis.b256/__module.synthesis.b256.conv0
  %96 : bool = prim::Constant[value=0](), scope: __module.synthesis/__module.synthesis.b256/__module.synthesis.b256.conv0 # <string>:63:0
  %97 : None = prim::Constant(), scope: __module.synthesis/__module.synthesis.b256/__module.synthesis.b256.conv0
  %98 : Float(1:128, 128:1, requires_grad=0, device=cuda:0) = aten::sum(%91, %95, %96, %97), scope: __module.synthesis/__module.synthesis.b256/__module.synthesis.b256.conv0 # <string>:63:0
  %99 : Double(requires_grad=0, device=cpu) = prim::Constant[value={1e-08}](), scope: __module.synthesis/__module.synthesis.b256/__module.synthesis.b256.conv0 # <string>:63:0
  %100 : int = prim::Constant[value=1](), scope: __module.synthesis/__module.synthesis.b256/__module.synthesis.b256.conv0 # <string>:63:0
  %101 : Float(1:128, 128:1, requires_grad=0, device=cuda:0) = aten::add(%98, %99, %100), scope: __module.synthesis/__module.synthesis.b256/__module.synthesis.b256.conv0 # <string>:63:0
  %dcoefs.12 : Float(1:128, 128:1, requires_grad=0, device=cuda:0) = aten::rsqrt(%101), scope: __module.synthesis/__module.synthesis.b256/__module.synthesis.b256.conv0 # <string>:63:0
  %103 : int = prim::Constant[value=-1](), scope: __module.synthesis/__module.synthesis.b256/__module.synthesis.b256.conv0 # <string>:65:0
  %104 : int = prim::Constant[value=1](), scope: __module.synthesis/__module.synthesis.b256/__module.synthesis.b256.conv0 # <string>:65:0
  %105 : int = prim::Constant[value=1](), scope: __module.synthesis/__module.synthesis.b256/__module.synthesis.b256.conv0 # <string>:65:0
  %106 : int = prim::Constant[value=1](), scope: __module.synthesis/__module.synthesis.b256/__module.synthesis.b256.conv0 # <string>:65:0
  %107 : int[] = prim::ListConstruct(%35, %103, %104, %105, %106), scope: __module.synthesis/__module.synthesis.b256/__module.synthesis.b256.conv0
  %108 : Float(1:128, 128:1, 1:1, 1:1, 1:1, requires_grad=0, device=cuda:0) = aten::reshape(%dcoefs.12, %107), scope: __module.synthesis/__module.synthesis.b256/__module.synthesis.b256.conv0 # <string>:65:0
  %w.113 : Float(1:294912, 128:2304, 256:9, 3:3, 3:1, requires_grad=0, device=cuda:0) = aten::mul(%w.112, %108), scope: __module.synthesis/__module.synthesis.b256/__module.synthesis.b256.conv0 # <string>:65:0
...
truncated for char limit

My code can be found here: stylegan2-ada-pytorch/tvm.py at quant · JCBrouwer/stylegan2-ada-pytorch · GitHub

Given TVM is in the PYTHONPATH, I think it should be possible to run with (although this is using a slightly different pretrained network, this one gives TypeError: convert "67.88225099390856" with const first):

git clone https://github.com/JCBrouwer/stylegan2-ada-pytorch.git
cd stylegan2-ada-pytorch
pip install click requests tqdm pyspng ninja imageio-ffmpeg==0.4.3
wget https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/afhqwild.pkl
python -m quantization.tvm

Anyone know what this could be related to or how I can debug it?

Update: I’ve managed to track down the error and fix it.

The Generator module uses fp16 in the last (largest) few layers of the model. To prevent fp16 overflow it normalizes the weights and modulation inputs to the model. This is why the error was only occurring deep in the network.

batch_size = x.shape[0]
out_channels, in_channels, kh, kw = weight.shape

if x.dtype == torch.float16 and demodulate:
    weight = weight * (
        1 / np.sqrt(in_channels * kh * kw) / weight.norm(float("inf"), dim=[1, 2, 3], keepdim=True)
    )
    styles = styles / styles.norm(float("inf"), dim=1, keepdim=True)

The np.sqrt(in_channels * kh * kw) was equal to 48.0 (as in the error above).

Switching this to np.sqrt(int(in_channels) * int(kh) * int(kw)) avoids the error.

TBH this seems like a bug to me, although I’ll admit I’m not experienced enough with TVM to know if this is expected behavior.

Coincidentally the L_infinity norms in this snippet of code are also what prevents the model compiling to ONNX, so I think it’s probably better to just force disable fp16 and avoid these issues altogether.

Yes this is a bug. I think the error is saying that data at tvm/pytorch.py at 38e0bbed7d4156242067aced19b9dbb3d49c2ef7 · apache/tvm · GitHub is a torch or numpy expression that TVM doesn’t understand, but we are trying to use this in the relay expression in the next line.

We can probably fix this by wrapping data with _expr.const(...) if data is not already an instance of relay Expr. We also need to extract the numpy array from data if data is a torch tensor.

1 Like