AXIS_SEPARATOR in Relax Tensor

Hi @tqchen ,

While working on the solution, I see two more problems that I’d like to discuss here:

  1. No Control over flattening of buffers: If a buffer is flattened via FlattenLowAxisSepDimensions , then it shouldn’t be flattened via TIR passes like FlattenBuffer or FlattenStorage. Is it possible to avoid flattening in TIR passes ?
  2. Propagating axis_separator information across passes: If a pass executed between AlterOpImpl(which introduces R.memory_axis_separator for each tensor) and FlattenLowAxisSepDimensions introduces new relax operators, the new operators need to propagate axis_separators information across the graph. This would mean changing multiple passes or ensuring minimum passes execute between AlterOpImpl and FlattenLowAxisSepDimensions.

The reason we want to use axis_separator information across graph is for memory allocation of N-d vs 1-d buffer.

One possible solution to accomplish the goal: if axis_separator is added to the arguments of replacement functions(used in AlterOpImpl), an analysis pass can be invoked before memory planning to identify the axis_separators in the graph and collate the information in a data structure which can be used by memory planning pass to allocate memory.

Please share your thoughts/comments on this.

Thank you!