[DISCUSS] Default Compilation Flow for Scan and Sort

This posts originates from original items here https://github.com/apache/tvm/issues/15851

It would be great for us to discuss our strategy for implementing operations like scan(cumsum) and sort. Both of these operators would require specialized parallel implementations per target. In the meantime, it is usefult to also support TIR callbacks.

I can see a few ways to do this:

  • W0: Keep these operators and not legalize them in legalization, enable a target specific dispatching so we can dispatch these operators to proper platform specific implementations when we have them.
  • W1: Enable a target dependent legalizations as well for cumsum, and that would allow us to optionally have tensor ir implementation of these related operators.

As of now, I think making sure our pipeline robustly support W0 out of box might be a good starting point. Sort and cumsum ae useful routines for parallel GPU sampling and so it is a good motivation to start formalizing support of them.

Could we not have a simple sequential version generated by legalization that wouldn’t be target specific (slower version), and then leave it to the targets to introduce optimized versions.

Sorry if this question seems a bit naive, but I’m not able to understand why we cannot legalize them like any other unoptimized generic op and let targets figure out how to replace it with an optimized version.

The main issue is that it is harder for us to pick things and optimize further in the case of scan and parallel sort as of now. So we might need to treat them specially.

We can indeed allow target specific legalization so for CPU like target, we legalize them to the normal TIR. The goal of the discussion is to enable us to get reasonably performant support of these ops across our GPU and other flows

1 Like

I like @sanirudh’s suggestion. Maybe it would be good for us to distinguish between “legalization” and “optimization”. For operators, I’d view “legalization” as a mandatory replacement with a valid representation of an operator at a lower level of abstraction, while “optimization” is a conditional replacement of an operator. That is, legalization steps must replace all instances of the operator and must be valid for all targets, while optimization steps are not required to replace anything and must only be valid for the replacements they perform.

With those definitions, I would say that the use of topi.cumsum is then an invalid legalization, because its validity depends on having a GPU target. I’d propose another option:

  • W2: Enable a target-dependent optimization that runs before legalization. The current replacement with topi.sum becomes a GPU-only optimization step. The legalization uses a target-agnostic representation.

This would allow the optimized GPU-specific routines to be used so we don’t lose runtime performance. Because the target-dependent optimization step occurs before legalization, the optimization can look for instances of R.cumsum, rather than looking for whatever lower-level operators the R.cumsum would be legalized into.

2 Likes

These are great points. I think we are generally going with W2 already (for some of the library dispatching steps).

The main thing to keep in mind is there can be cases where TensorIR legalization for certain ops is not possible (due to pack of shape information) or “hard” and not necessarily usedin practice(e.g. sort uses dispatching in most cases).

So from the reasoning pov, we are looking for “safety net” strategy to cover those cases, which might happen after legalization, or we can allow safetynet dispatch to happen before legalization, with the mindset that legalization will error out if some of the ops are preserved

1 Like

The main thing to keep in mind is there can be cases where legalization to TIR for certain ops is not possible

Potentially related: I recently implemented this change, which allows the legalization of Relax operators to include other Relax functions. The goal would be to allow legalization to occur in sequential steps, potentially interspersed with optimization steps. (e.g. Legalize R.nn.attention to expose internal R.reshape operators to the relax graph, optimize to see if they can be removed, then legalize further.)

So long as we can legalize to Relax operators at a lower level of abstraction, direct legalization to a single TensorIR implementation is no longer required.

1 Like

Legalization into smaller ops is also indeed relevant. I think as of now we are looking at the harder cases where the unit of work such as sort generally have platform dependencies, and we need to have solid solution for those.

I can see cumsum belongs to a different category, where we can have effective optimizations(e.g. thrust), but a naive version that only parallelize across spatial dimension could still work. Still in this case, having a good implementation can help.

1 Like

where the unit of work such as sort generally have platform dependencies,

Maybe this is me not quite understanding the problem. I’d view each operator as a definition of the output, and not a definition of the steps required to produce that output. How big the “unit of work” is, and therefore any platform dependency, would then exist within the current implementation of the operator’s legalization, not in the operator itself.

Would it be accurate to rephrase the problem as follows: “For some operators, a platform-independent implementation will always have poor performance. We want a solution that exposes platform-specific information to be used at an early stage of relax lowering.”

1 Like

It would be useful to ground on the case of sort as an example. Sorting on CPU and GPU are in nature quite different. For cases like ewise ops the mapping from a trivial loops to GPU is reasonably well defined. But for sort, in cpu we would like to call std::sort and in GPUs likely we would end up with the platform native sort if they are available.

I don’t think we need a new solution, since the pipeline can take target specific info already. So just would like to bring up the question about how we would like to deal with ops like sort, where should mapping of those ops occur(aka whether they happen before legalization, or they happen after legalization as a safetynet where legalize skips those ops).

Put it in another way, I think the operators we are looking at roughly categorizes into three:

  • C0: Prefer TensorIR codegen, we prefer leverage TensorIR because they offer broader coverage, and opportunities like fusion
  • C1: Middle state, we have both libraries and TensorIR legalization, we would like to enable pick the best ones and switch between cases. Some of the cutlass dispatches ops belong to this category
  • C2: Prefer runtime library, as of now, we do not necessarily have good generic TensorIR implementation(yet), and likely we would offload to runtime library when they are available(e.g. sort)

C0 and C1 generally will have legalization. C2 usually do not. So the main goal of this thread is to gather thoughts on where we place C2 lowering(in form of target dependent library dispatch). Note that the situation can change as we have better codegen for C2, or say we have a radix sort implemented through TensorIR, they get moved to C0/C1

2 Likes

Would this mean that when we handle C2 lowering through library dispatch, all targets would have to implement their own versions to be able to support this op and would see an error if that’s not done.

Could we still implement a naive version for legalize that would be generic (perhaps print a warning about performance for these ops), but run an earlier pass that introduces the library dispatch if it exists so that legalization is not going to do anything for these. My thought here is to try and support these ops with newer targets without errors so that we don’t block work on other ops in a model that contains these ops (especially if the C2 ops are not high priority for performance). Users can think about introducing an optimized version later for a new target when the performance of that op becomes important enough for them.

@sanirudh great point, i think it is worthwhile to have some generic dispatch e.g. dispatch to TensorIR impl can be one way of handling C2. The main difference is that in such cases these TensorIR might still be specialized (aka for sort we will need to have gpu radix sort while on cpu maybe we we can do quick sort)

1 Like

So the main goal of this thread is to gather thoughts on where we place C2 lowering(in form of target dependent library dispatch).

From an API perspective, I think C0, C1, and C2 only differ by the replacement used for each operator. Since they each produce a relax::Expr as output, the only difference is in the functional a

  • C0: The relax operator is replaced by a call to a newly-generated PrimFunc.
  • C1: The relax operator may be replaced by a call to a newly-generated PrimFunc, or may be replaced by a call to R.ExternFunc.
  • C2: The relax operator is replaced by a call to a R.ExternFunc

C0 and C1 generally will have legalization. C2 usually do not.

Wouldn’t C2 cases still have a legalization step in which relax.op.some_operator(*args) gets replaced with R.ExternFunc("implementation_of_some_operator")(*args)? That is, every operator that isn’t a low-level builtin gets replaced at some point. C2 operators just have a 1:1 correspondence between the relax representation and the external function that implements it.

Users can think about introducing an optimized version later for a new target when the performance of that op becomes important enough for them.

I like this approach and this categorization quite a bit. The categorization also defines what can be done externally to the operator, If I were to summarize the key aspects of it:

  • Every replacement, whether “legalization” or “optimization”, replaces the call to the operator with a relax::Expr.
  • Every operator must have a legalization that is valid for every target.
  • Some operators may have replacements that are applied prior to legalization.

The main difference is that in such cases these TensorIR might still be specialized (aka for sort we will need to have gpu radix sort while on cpu maybe we we can do quick sort)

In the view of “legalization” and “optimization”, I’d consider this case as two possible optimizations (radix sort if gpu, quicksort if cpu) that are independently checked prior to legalization. That is, rather than having specialized forms produced during legalization, all specialized forms are produced prior to legalization. If that happens to cover all possible cases, great, but no there’s no requirement for them to do so.

1 Like

Wouldn’t C2 cases still have a legalization step in which relax.op.some_operator(*args) gets replaced with R.ExternFunc("implementation_of_some_operator")(*args)? That is, every operator that isn’t a low-level builtin gets replaced at some point. C2 operators just have a 1:1 correspondence between the relax representation and the external function that implements it.

I guess the term “legalization,” particularly in the context of Relax.transform.LegalizeOps, refers to the process of lowering Relax operations to TIR PrimFuncs. In C2, there is no PrimFuncs, the update from relax.op.some_operator(*args) to R.call_dps_packed("externfunc_name", (*args)) occurs independently of the legalization process.

As @Lunderberg mentioned above, FLegalize as it stands right now already supports legalizing to other relax operators with this change. So we don’t have to assume that the output of LegalizeOps is not always a primfunc and can be an R.call_dps_packed if I’m not mistaken.

I think the last statement in @Lunderberg previous reply sums up exactly what I also wanted to suggest.

rather than having specialized forms produced during legalization, all specialized forms are produced prior to legalization. If that happens to cover all possible cases, great, but no there’s no requirement for them to do so.

1 Like

It would be useful to ground on the case of sort as an example. Sorting on CPU and GPU are in nature quite different. For cases like ewise ops the mapping from a trivial loops to GPU is reasonably well defined. But for sort, in cpu we would like to call std::sort and in GPUs likely we would end up with the platform native sort if they are available follow link to know more : click here