AXIS_SEPARATOR in Relax Tensor

Hello All,

This is a follow up post from existing post on Axis separator support in Relax.

Our next goal is to enable axis separator in relax Tensor class so that the memory allocator can interpret 1D vs N Dimensional layouts and allocate memory appropriately.

While going through the code, I see that Relax.Tensor class invokes TensorStructInfo class which in turn creates an object of TensorStructInfoNode here

Could you please suggest if TensorStructInfoNode is the right class to add AXIS_SEPARATORS ? Or should we add this change at another place in TVM ?

Thanks you! :slight_smile:

cc: @psrivas2 @junrushao @sunggg @sanirudh @jverma @kparzysz

Hi @abhikran-quic,

Thanks for starting this topic. We added support for n-dimensional allocation in Relax with PR #15178. While lowering, in order to identify that the physical memory shape of a Tensor is N-D, we would need to lookup the AXIS_SEPARATORS, so yes it would be necessary to have that information somewhere.

I think TensorStructInfo seems like a good place for that. As long as we keep the default for AXIS_SEPARATORS to empty, the current flow would not be affected and we can explicitly read that and use it as needed.

It would still be good to get some suggestions from others who have worked on the struct info framework, so I’ll add a couple others into the discussion

@tqchen @slyubomirsky

1 Like

I think this is a case where we should think a bit more carefully here. atm TenorStructInfo aims to keep logical shape information and leaves the logical physical mapping info implicit in the pass

While axis separator can help reason about possible layout that would also adds to structure information deduction complexity since now all the tensors operator would need to consider how to propagate this information.

So it would be useful to first think more carefully how many layout candidates we have in mind, if there are not that many, having some implicit layout flattening pass (that implicitly takes the N-D and maps to related 1D would be useful) can be a good starting pt before introducing the general repr. Additionally, let us think about propagation rules in general.

2 Likes

Putting more to the discussion. Here is one possible way to embed the logical-physical info (namely 2D memory flattening hint)

@R.function
def main_canonicalized(x: R.Tensor((4, 5, 128)):
     # split on axis 1, this is also a default
     lv1 = R.memory_axis_separator(x, axis=1)
     ...
     # hint to split on axis 2 instead
     lvn =  R.memory_axis_separator(x, axis=2)

Then a pass FlattenLowAxisSepDimensions will come and rewrite all tensor data into the format of A[i0, i1, ... i_n], where by default(i0, i1, i2.. i_{n-1} are flattened to one dimension and i_n get flattened into another dimension, and index rewrite the related TIR functions as well.

The main rationale is that even in the presence of ND, the lower dimension is usually desirable to be more regular(aka multiple of page size or in many cases good to be explicitly the page size). So the number of possible mapping is not that many. Additionally, being directly in physical canonical format makes scheduling and optimizing the TensorIR much easier.

The tradeoff here is that if the reasoning of flatten is related to one pass, or a few, having the pass to reason some related info would be simpler.

Coming back to high level, there are tradeoffs in when to put things into StructInfo. The advantage of StructInfo is of course every operator would need to think about (default) deduction of the information and a lot of pass would need to reason about them. The possible downside is also every pass and operators need to reason about the related information. The shape is in TensorStructInfo because it is being used in many passes and we want to by default recovering the info in all cases, additionally, users do think more about shape information when constructing kernels like te. We make an explicit choice for now to make logical to physical mapping mostly be in default (aka for 1D mem case it is flattening) so that the reasoning of local performance is easier, while making layout propagation.

We can start with something like pass based approach first, learn about possible lessons, and then think about strategies, the related lessons would help us to think about further steps

2 Likes

I agree with @tqchen, I would be very hesitant to hoist low-level information like axis separators into Relax and especially to make it part of TensorStructInfo by default (as it would mean everything would have to reason about it).

What would be the main reasons to deal with axis separators in Relax as opposed to TIR? If there is really only one “killer application,” perhaps it might be worth dealing with it dynamically or finding some other kind of solution that would not require changing a basic part of the language.

I would be happy to chat about this live if it would be helpful. (E.g., we could make this a community meeting topic.)

2 Likes

Thanks a lot for the detailed explanation @tqchen and @slyubomirsky

I understand the concern about modifying the StructInfo. My only thought was to store the logical to physical memory mapping somewhere in Relax, so that when we do static memory planning, we don’t just assume 1d allocation, but compute the actual physical memory shape. StructInfo just seemed like a logical place to do this as that is used by the current memory planning/codegen pass to calculate the memory required for a tensor, but we can discuss other alternatives.

I’ll wait for @abhikran-quic to also reply to get his perspective, but yes I agree that it might be great to talk it out in a dev meeting to understand the perspectives and come up with some ideas.

2 Likes

for the static memory planning, i think it is not a big issue, since that can happens after we do FlattenLowAxisSepDimensions , so memory planner will be able to capture the n-D needs

2 Likes

Thank you so much @tqchen @slyubomirsky @sanirudh for your detailed inputs on this!

I have some thoughts and questions listed below. I also discussed them with @sanirudh. It would be great if we can address them in the next unity meeting. I’ll add this topic to the agenda.

  1. Axis separator will be set by AlterOpImpl pass and hence, it can embed R.memory_axis_separator for each tensor.

  2. IIUC, FlattenLowAxisSepDimensions should be invoked to before all TIR passes. Now, if a TIR pass depends on logical shape of a tensor, then how should we handle such a scenario ? Should the logical shape be stored/saved in an attribute so that TIR passes can run smoothly ?

Thanks.

1 Like

FlattenLowAxisSepDimensions do not need to be invoked before all TIR passes I think, only ones related to allocation.

Most TIR perf related pass, like scheduling actually benefit from the canonical form – since the access pattern is more clear in such case. Having things in logical form do not help as much. We are analyzing index access patterns(via affine map) so I think shape dep is not a concern, but if there is an example of TIR pass that benefit from original shape, we can check and discuss.

Note that likely for most cases we will do packing, which means the shape have more dimensions and the new information contains more

2 Likes

Thank you @tqchen! I will try this approach and update here if I have further questions.

Hi @tqchen ,

While working on the solution, I see two more problems that I’d like to discuss here:

  1. No Control over flattening of buffers: If a buffer is flattened via FlattenLowAxisSepDimensions , then it shouldn’t be flattened via TIR passes like FlattenBuffer or FlattenStorage. Is it possible to avoid flattening in TIR passes ?
  2. Propagating axis_separator information across passes: If a pass executed between AlterOpImpl(which introduces R.memory_axis_separator for each tensor) and FlattenLowAxisSepDimensions introduces new relax operators, the new operators need to propagate axis_separators information across the graph. This would mean changing multiple passes or ensuring minimum passes execute between AlterOpImpl and FlattenLowAxisSepDimensions.

The reason we want to use axis_separator information across graph is for memory allocation of N-d vs 1-d buffer.

One possible solution to accomplish the goal: if axis_separator is added to the arguments of replacement functions(used in AlterOpImpl), an analysis pass can be invoked before memory planning to identify the axis_separators in the graph and collate the information in a data structure which can be used by memory planning pass to allocate memory.

Please share your thoughts/comments on this.

Thank you!