[PyTorch] dyn.strided_slice loses shape information

Hello everyone I have a question about compiling Pytorch 1.9 retinanet_resnet50_fpn model, more specific while trying to compile this line (github.com/pytorch/vision/blob/v0.10.0/torchvision/models/detection/_utils.py#L205)

Traced jit graph:

aten::slice: Tensor slice(const Tensor& self, int64_t dim, int64_t start, int64_t end, int64_t step)

%6064 : int = prim::Constant[value=0](), scope: __module.model # /home/ubuntu/.local/lib/python3.6/site-packages/torchvision/models/detection/_utils.py:205:0
%6065 : int = prim::Constant[value=0](), scope: __module.model # /home/ubuntu/.local/lib/python3.6/site-packages/torchvision/models/detection/_utils.py:205:0
%6066 : int = prim::Constant[value=9223372036854775807](), scope: __module.model # /home/ubuntu/.local/lib/python3.6/site-packages/torchvision/models/detection/_utils.py:205:0
%6067 : int = prim::Constan[value=1](), scope: __module.model # /home/ubuntu/.local/lib/python3.6/site-packages/torchvision/models/detection/_utils.py:205:0
%6068 : Float(0, 4, strides=(4, 1], requires_grad=0, device=cpu) = aten::slice(%rel_codes.1, %6064, %6065, %6066, %6067), scope: __module.model # /home/ubuntu/.local/lib/python3.6/site-packages/torchvision/models/detection/_utils.py:205:0
%6069 : int = prim::Constant[value=1](), scope: __module.model # /home/ubuntu/.local/lib/python3.6/site-packages/torchvision/models/detection/_utils.py:205:0
%6070 : int = prim::Constant[value=0](), scope: __module.model # /home/ubuntu/.local/lib/python3.6/site-packages/torchvision/models/detection/_utils.py:205:0
%6071 : int = prim::Constant[value=9223372036854775807](), scope: __module.model # /home/ubuntu/.local/lib/python3.6/site-packages/torchvision/models/detection/_utils.py:205:0
%6072 : int = prim::Constant[value=4](), scope: __module.model # /home/ubuntu/.local/lib/python3.6/site-packages/torchvision/models/detection/_utils.py:205:0
%6073 : Float(0, 1, strides=[4, 4], requires_grad=0, device=cpu) = aten::slice(%6068, %6069, %6070, %6071, %6072), scope: __module.model # /home/ubuntu/.local/lib/python3.6/site-packages/torchvision/models/detection/_utils.py:205:0

My understanding here we get a N by 1 dx vector and later stacked together in _utils.py#L223

while the relay graph generates this

%1804 = adv_index(%1802) /* ty=Tensor[(?, 4), float32] /;
%1844 = where(%1839, %1835, %1838) /
ty=Tensor[(2), int32] /;
%1845 = cast(%1843, dtype=“int64”) /
ty=Tensor[(2), int64] /;
%1846 = dyn.strided_slice(%1804, %1844, %1845, meta[relay.Constant][88] /
ty=Tensor[(2), int32] /, begin=None, end=None, strides=None, axes=None) / ty=Tensor[(?, ?), float32] */;
Later the missing dimension causes an error while unbinding along static dimension in this line (github.com/pytorch/vision/blob/v0.10.0/torchvision/models/detection/transform.py#L287) \

The error message is this:
in unbind, ishapes: (?, ?)
Traceback (most recent call last):
File “retinanet_test.py”, line 110, in
retina_net_lab()
File “retinanet_test.py”, line 74, in retina_net_lab
mod, params = relay.frontend.from_pytorch(script_module, shape_list)
File “/home/ubuntu/neo-ai/tvm/python/tvm/relay/frontend/pytorch.py”, line 3363, in from_pytorch
ret = converter.convert_operators(_get_operator_nodes(graph.nodes()), outputs, ret_name)[0]
File “/home/ubuntu/neo-ai/tvm/python/tvm/relay/frontend/pytorch.py”, line 2785, in convert_operators
inputs, _get_input_types(op_node, outputs, default_dtype=self.default_dtype)
File “/home/ubuntu/neo-ai/tvm/python/tvm/relay/frontend/pytorch.py”, line 2142, in unbind
res_split = _op.split(data, selections, dim)
File “/home/ubuntu/neo-ai/tvm/python/tvm/relay/op/transform.py”, line 908, in split
ret_size = len(indices_or_sections) + 1
TypeError: object of type ‘Any’ has no len() \

I wonder whether the dyn.strided_slice behavior is expected and whether there is any workaround to enable this model? @masahi Thanks! :slightly_smiling_face:

Yes your observation is correct. We cannot support PT retinanet for two reasons:

  • Our dynamic strided slice doesn’t work great when input shape is partially static/dynamic. It makes output shape dynamic in all dimensions, even if slicing is only in a certain dimension (batch axis etc). Unfortunately this is a limitation of how runtime shapes are represented in Relay: Runtime shapes are fully dynamic in all dimensions.

  • Our split op doesn’t support dynamic sections. This is the error you got. Even if the first point cannot be overcome, in principle we can support retinanet if our split op supports dynamic sections (although performance would be suboptimal than the ideal case). This would be an easier solution in the short term.

Hi Masahi, I think dynamic split will be impossible, because the output tuple will have a unknown size. The size of a tuple must be static in Relay. However, it could be possible if the pytorch graph tells us the number of outputs and we know it to be static.

1 Like

Is this somewhat a design choice? I am asking as advanced indexing, strided slice, etc. is heavily used in detection models and dropping support on not a single model but a family of models seem pretty opinionated to me.

I wouldn’t say that. The decision to represent run time shapes as fully dynamic was a design choice, and the lack of support for partially static/dynamic runtime shape is the consequence/limitation of the design choice.

I’m fully aware of the limitation of dyn.strides_slice, so I did some work to preserve static shape as much as possible in partially dynamic slicing . See the following PRs for details:

In particular, axes argument added in 8165 can be used to preserve shape information in batch-only slicing case, for example. I didn’t update the PT frontend to make use of this, so I’m going to take a look at Retinanet support again to see if we can import this model now.

@jsheng-jian Ok I was able to remove all dyn.strided_slice, but shape information is also lost by dyn.reshape, like this:

  %1170 = stack(%1163, axis=2) /* ty=Tensor[(?, 2, 2), float32] */;
  %1548 = dyn.reshape(%1170, %1171, newshape=[]) /* ty=Tensor[(?, ?), float32] */;

%1548 above should have a shape (?, 4), but dyn.reshape makes everything dynamic. So we just want to reshape (?, 2, 2) into (?, 4), but we cannot do that and we don’t have a workaround for this problem.

In PT, this reshape corresponds to this line vision/boxes.py at be4ff9a37831c5bee9392bb403a1228131167972 · pytorch/vision · GitHub

Hi,

I think that I am having a similar behavior when trying to compile YOLOX-ONNX models in TVM.

%40 = dyn.strided_slice(%0, %8, %9, meta[relay.Constant][7] /* ty=Tensor[(4), int64] */, 
begin=None, end=None, strides=None, axes=None) /* ty=Tensor[(?, ?, ?, ?), float32] */;

Can you support YOLOX?

Thanks!