[BYOC][ONNX] Question about indices tensor of gather-scatter ops

Here is a segment of a Relay graph imported from an ONNX model:

  %73 = nn.batch_flatten(%72) /* ty=Tensor[(16, 1369), float32] */;
  %74 = argmax(%73, axis=[1], keepdims=True) /* ty=Tensor[(16, 1), int32] */;
  %75 = cast(%74, dtype="int64") /* ty=Tensor[(16, 1), int64] */;
  ...
  %86 = less(%75, 0 /* ty=int64 */) /* ty=Tensor[(16, 1), bool] */;
  %87 = add(%75, 1369 /* ty=int64 */) /* ty=Tensor[(16, 1), int64] */;
  %89 = where(%86, %87, %75) /* ty=Tensor[(16, 1), int64] */;
  %88 = squeeze(%85, axis=[1]) /* ty=Tensor[(16, 1369), float32] */;
  %95 = gather(%88, %89, axis=1) /* ty=Tensor[(16, 1), float32] */;
  ...

My accelerator runtime’s Gather-Scatter ops require i32 indices tensors, however, by default Relay uses i64 indices, is there a simple way to set all indices tensors in a Relay graph to i32 dtype?

Also, the ONNX frontend will insert less+add+where ops (It is done by the normalize_gather_indices function in onnx.py) on the indices when mapping ONNX op Gather and GatherElements, as shown in the example Relay graph. However, the indices tensors produced by ops like argmax/argmin/topk are guaranteed to have positive values. It would be better to add a rule to check whether the indices tensor is produced by those ops and skip the “normalization” in such scenario.

I don’t see an easy way for this, may be you can write a Relay pass that does the cast?

Good point, a PR welcome.

1 Like