Ok I understand the problem. This is due to
When we reach the second dequantize below, we need to find its input scale and zp. I do this by DFS on input until we meet a quantized op that produces a new quantized tensor.
%47 : (Tensor, Tensor) = prim::TupleConstruct(%input.2, %Xq.1)
%7 : Tensor, %8 : Tensor = prim::TupleUnpack(%47)
%48 : Tensor = aten::dequantize(%7) # /home/masa/anaconda3/envs/torch-1.10/lib/python3.7/site-packages/torch/nn/quantized/modules/__init__.py:84:0
%49 : Tensor = aten::dequantize(%8)
The issue is that I assumed quantized tensor comes earlier in the input list, while we need to visit the second input of TupleConstruct. So we end up doing DFS on %input.2 which finds the wrong input qparam.
When I implemented that, I was not sure this would cause a problem in practice. Your example is the first instance that this strategy fails.
A better solution is to propagate qparam information forward, so that backward lookup won’t be necessary. For example, in
we can return (_expr.const(inputs[1]), _expr.const(inputs[2])) along with the original relay output. This is the most elegant solution (also close to what PyTorch does), but it requires many change in the existing converter functions to retrieve the relay input.
An easier, but ugly way would be to record output scale and zp in a global dictionary after each conversion. Then we can look up input scale and zp from anywhere.
Does this make sense? If possible, can you send a PR to fix this problem (in a whatever way). I can do it, but since I always find it difficult to get someone to review my PyTorch frontend change, it might take a while for the change to land. However, if you send a fix, I can review and merge right away.