Why relay.frontend.from_pytorch unpacks inner tuples

Hello,

I noticed that if I have a pytorch network with a output like: (out_1, (out_2, out_3)) – LSTM being the classical example, after applying relay.frontend.from_pytorch, the inner tuple is unpacked and the final return type is just a single tuple. Why is that?

Additionally, I created a simple test to verify for tuples with more depth, and the same behavior is observed:

def test_inner_tuple():
    class torch_net(torch.nn.Module):
        def __init__(self):
            super().__init__()
            self.dense_1 = torch.nn.Linear(10, 20)
            self.dense_2 = torch.nn.Linear(10, 30)
            self.dense_3 = torch.nn.Linear(10, 40)
            self.dense_4 = torch.nn.Linear(10, 50)

        def forward(self, inputs):
            out_1 = self.dense_1(inputs)
            out_2 = self.dense_2(inputs)
            out_3 = self.dense_3(inputs)
            out_4 = self.dense_4(inputs)
            return ((out_1, out_2), out_3), out_4

    input_shape = (32, 10)
    pytorch_net = torch_net()
    scripted_model = torch.jit.trace(pytorch_net.eval(),
                                     torch.randn(input_shape))
    mod, params = relay.frontend.from_pytorch(scripted_model,
                                              [('input', input_shape)])
    mod = relay.transform.InferType()(mod)

    exp_output = relay.TupleType([relay.TupleType([relay.TupleType([relay.TensorType([32, 20]),
                                                                    relay.TensorType([32, 30])]),
                                                   relay.TensorType([32, 40])]),
                                  relay.TensorType([32, 50])])
    # don't know how to check in another manner
    assert str(mod["main"].ret_type) == str(exp_output)
    # mod["main"].ret_type = (Tensor[(32, 20), float32], Tensor[(32, 30), float32], Tensor[(32, 40), float32], Tensor[(32, 50), float32])

Am I missing some basics or is this behavior only for convenience?

Apparently this is due to _jit_pass_lower_all_tuples PyTorch jit pass we run in preprocessing:

See [Bug] [Frontend][Pytorch] Relay IR is inconsistent with that of the original model - #4 by masahi for why we do this.

But running this pass is not supposed to modify the user API, so this is a bug. We have a check to decide if we can run this pass safely. Right now we only check inputs, but your use case shows that we also need to check output types. Would you be interested in contributing a fix?

Also note that nested tuples are not supported by our graph runtime (what you’d be using if you use relay.build), so they are flattened during memory planning. You can use our VM runtime, if you want to preserve nested tuples at runtime.

Hello masahi and thank you for your in-depth response. I would like to contribute a fix, but unfortunately I will not be able to do it until September, as I have to finish my Dissertation Thesis. If this will remain open until then, I will start investigating / implementing a fix.

If this is blocking your progress, you can simply set enable_lower_all_tuples = False to unblock.