Hello,
I noticed that if I have a pytorch network with a output like: (out_1, (out_2, out_3)) – LSTM being the classical example, after applying relay.frontend.from_pytorch
, the inner tuple is unpacked and the final return type is just a single tuple.
Why is that?
Additionally, I created a simple test to verify for tuples with more depth, and the same behavior is observed:
def test_inner_tuple():
class torch_net(torch.nn.Module):
def __init__(self):
super().__init__()
self.dense_1 = torch.nn.Linear(10, 20)
self.dense_2 = torch.nn.Linear(10, 30)
self.dense_3 = torch.nn.Linear(10, 40)
self.dense_4 = torch.nn.Linear(10, 50)
def forward(self, inputs):
out_1 = self.dense_1(inputs)
out_2 = self.dense_2(inputs)
out_3 = self.dense_3(inputs)
out_4 = self.dense_4(inputs)
return ((out_1, out_2), out_3), out_4
input_shape = (32, 10)
pytorch_net = torch_net()
scripted_model = torch.jit.trace(pytorch_net.eval(),
torch.randn(input_shape))
mod, params = relay.frontend.from_pytorch(scripted_model,
[('input', input_shape)])
mod = relay.transform.InferType()(mod)
exp_output = relay.TupleType([relay.TupleType([relay.TupleType([relay.TensorType([32, 20]),
relay.TensorType([32, 30])]),
relay.TensorType([32, 40])]),
relay.TensorType([32, 50])])
# don't know how to check in another manner
assert str(mod["main"].ret_type) == str(exp_output)
# mod["main"].ret_type = (Tensor[(32, 20), float32], Tensor[(32, 30), float32], Tensor[(32, 40), float32], Tensor[(32, 50), float32])
Am I missing some basics or is this behavior only for convenience?