Hi @tqchen ,
While working on the solution, I see two more problems that I’d like to discuss here:
- No Control over flattening of buffers: If a buffer is flattened via
FlattenLowAxisSepDimensions, then it shouldn’t be flattened via TIR passes likeFlattenBufferorFlattenStorage. Is it possible to avoid flattening in TIR passes ? - Propagating axis_separator information across passes: If a pass executed between
AlterOpImpl(which introducesR.memory_axis_separatorfor each tensor) andFlattenLowAxisSepDimensionsintroduces new relax operators, the new operators need to propagate axis_separators information across the graph. This would mean changing multiple passes or ensuring minimum passes execute betweenAlterOpImplandFlattenLowAxisSepDimensions.
The reason we want to use axis_separator information across graph is for memory allocation of N-d vs 1-d buffer.
One possible solution to accomplish the goal: if axis_separator is added to the arguments of replacement functions(used in AlterOpImpl), an analysis pass can be invoked before memory planning to identify the axis_separators in the graph and collate the information in a data structure which can be used by memory planning pass to allocate memory.
Please share your thoughts/comments on this.
Thank you!