[RFC][Unity][Relax]DistIR Compilation Pipeline


In this document, we propose a design to lower DistIR, a logical unified distirbuted IR, to ranked physical single-device Relax IRs that is ready to be consumed by the current TVM compiler infrastructure. After compiling the ranked Relax IR, the compiled executable is ready to be exeucted by a runtime, such as Relax VM and Disco, in a distributed way.

We also point out that some of the design choices made by DistIR and Disco may miss potential optimization oppourtunities and cause extra complexity when it comes to irregular sharding.


For cloud foundational model providers, being able to execute inference in a distributed way is important to guarantee the SLA(especially latency), as well as efficiency.

We at Kwai AI Platform Compiler Group are working on compiling foundational models with TVM-Unity and MLC-LLM, and we noticed that there’s still a gap between high-level DistIR and low-level single-device Relax IR. We also seen multiple questions and issues raised by community regarding this direction.

Hence, we aim to propose a design that connects DistIR and low-level infrastructures provided by TVM, which enables it to support distributed inference.

Related Work in Community


A Relax dialect that describes logical distributed sharding strategies. DistIR also provides infrastructure to propagate sharding information just like GSPMD. DistIR represents sharding strategy through DTensor and device mesh, which is expressive enough for DP, TP and PP. There exists abstraction redundancy in GSPMD where “partially tiled tensor” is just a special case of “tiled tensor”, DistIR unified them together hence eliminates this redundancy.


A Framework-Agnostic SPMD Runtime for Distributed Inference/Training. It is a single-controller system which consists of a controller and a group of workers, the controller broadcasts commands to all workers who then executes the commands. The single-controller architecture provides tremendous programmability because it enables programmers to write distributed code just like centralized code.


There’re two challenges in lowering DistIR to single-device Relax IRs.

C1: Irregular sharding

Both DistIR and Disco adopt SPMD paradigm. However, when it comes to irregular sharding, execution differs naturally on different devices, the complexity must be resolved at some level.

We argue that it is beneficial to represent this difference in IR explicitly instead of hidding it, which eventually shifts the complexity to some lower-level components, such as runtime, and cause extra complexities and miss potential optimization opportunities.

We give two examples to demonstrate it.

1. Tail Block

First, in the code snippet below, the ‘M’ axis of the ‘matmul’ is split to two devices. Notice that tensor A and C’s shapes differ at device 0 and 1. With a single program, or say IR, we are unable to represent this difference.

class ShardingAlongReductionAxis:

  	I.module_attrs({"device_num": 2})
    I.module_global_infos({"mesh": [R.device_mesh((2,), I.Range(0, 2))]})
    # Represents sharding stategy with DTensor and remove 'annotate_sharding' op.
    def foo(
      A: R.DTensor((127, 128), "float32", "mesh[0]", "S[1]"), 
      B: R.DTensor((128, 128), "float32", "mesh[0]", "S[0]"),
    ) -> R.DTensor((127, 128), "float32", "mesh[0]", "R"):
        C: R.DTensor((127, 128), "float32", "mesh[0]", "R") = R.matmul(A, B, out_dtype="void")
        return C

Applying padding is a possible way to resolve this differece, however, padding is asymmetric on devices, in this example, device 0 does not require padding but device 1 does. A new ‘asymmetric_pad’ op is needed, this is what we emphasized earlier as ‘shifting complexity to other lower level components’.

On the other hand, padding wastes extra computation and memory resources, hence miss potential optimization oppourtunities.

2. Sharding Spatial Axis of Slinding Window Operaions

Another example is sharding spatial axis of operations such as CONV and POOL. In this case, each device’s input tensor shard is overlapped, which requires extra asymmetric communications among each other. Again, we can certainly add new op to support this type of irregular communication, but this makes the list of ops longer and longer, and analysis and optimization upon it is more difficult.

image (1)

C2: Abstraction leak in DistIR

In DistIR, there’re two kinds of sharding, R, which stands for Replicate, and S, which stands for non-overlapped Sharding. However, neither of them can represent sharding along reduction axis.

For example, in the following code snippet, the ‘K’ axis of the ‘matmul’ is splited to two devices, currently DistIR use ‘R’ to represent the result matrix C, which gives the ‘matmul’ a implicit ‘AllReduce’ epilogue. This will cause two problems:

  • First, it forces the ‘allreduce’ to be exeucted right after the ‘matmul’, which potentially miss some optimization oppourtunities.

    For example, a ‘relu’ followed by the ‘matmul’ could not be fused because of the implicit ‘allreduce’.

  • Second, implicity make analysis and lowering more difficult.

    For the ‘relu’ example above, we can certainly develop optimization passes that identify the pattern and fuse it, but this introduces extra complexity and makes pass developer’s life harder.

class ShardingAlongReductionAxis:

  	I.module_attrs({"device_num": 2})
    I.module_global_infos({"mesh": [R.device_mesh((2,), I.Range(0, 2))]})
    # Represents sharding stategy with DTensor and remove 'annotate_sharding' op.
    def foo(
      A: R.DTensor((128, 128), "float32", "mesh[0]", "S[1]"), 
      B: R.DTensor((128, 128), "float32", "mesh[0]", "S[0]"),
    ) -> R.DTensor((128, 128), "float32", "mesh[0]", "R"):
        C: R.DTensor((128, 128), "float32", "mesh[0]", "R") = R.matmul(A, B, out_dtype="void")
        return C

Proposed Design

To solve the challenges mentioned before, we poposed our design as following.

Overall Architecture

Ranked Relax IR & VM Executable

The core concept in this design is Ranked Relax IRModule and Ranked Relax VM Executable, which are the distributed counterparts of the current single-device version. Ranked means we use independant IR/executable for each unique rank(or device) in the communication group(device mesh). Note that we only allocate IR/executable for unique ranks, if multiple ranks has exactly the same structure, they will share the same IR/executable. Hence, our design is basiclly an example of MPMD paradigm.

Compilatin phase

Users first annotate their desired sharding stratgy, we then reuse DistIR’s ShardingPropagation pass to propagate the sharding startegy to the entire graph. We then run a series of DistIR lowering passes, which will be explained in later chapter, to split the logical unified DistIR to ranked physical single-device Relax IR that is ready to be consumed by TVM compiler infrastructure. Note that if all ranks are exactly the same, the lowered Ranked Relax IRModule contains only a single common IR. Finally, each ranked IR is compiled with current TVM compiler infrastructure to produce the final Ranked Relax VM Executable.

Runtime phase

User first initialize their desired distributed context(e.g, through Disco’s Session object), this will assign a unique rank id for each device. Then, user can invoke inference function call and the VM needs to dispatch to the right function in Ranked Relax VM Executable according to device’s rank. Finally, necessary synchronization should be done(e.g, using Disco’s SyncDevice function) and the result is returned to user.

Introduce Sharding Type ‘I’

To resolve C2, we introduce a new sharding type ‘I’, which stands for intermediate tensor that is waiting to be reduces, to explicitly express reduction semantics.

‘I’ may contain mutiple subtypes to represent different reduction types, including ‘I_SUM’, ‘I_MEAN’, ‘I_MIN’, ‘I_MAX’. Note that the subtype could also be implemented as an op attribute. Examples and advantages of introducing this new type is elaborated in section Example 3

Compilation Pipeline

In this part, we introduce our main design, the compilation pipeline. The overall pipeline is shown in the image below. We will introduce each pass through an example IR.

Input Relax IR

We will use the following example to demonstarte how the pipeline works.

class LoweringDistIRExampleModule:  
    def foo(
        x: R.Tensor((127, 128), "float32"),
        A: R.Tensor((128, 128), "float32"),
        B: R.Tensor((128, 128), "float32"),
      	C: R.Tensor((127, 1), "float32"),
    ) -> R.Tensor((127, 128), "float32"):
        lv0 = R.matmul(x, A)
        Y = R.nn.gelu(lv0)
        Z = R.matmul(Y, B)
        O = R.nn.mul(Z, C)
        return O


This pass is for users to describes their sharding strategy through DistIR ‘annotate_sharding’ op.

class LoweringDistIRExampleModule:  
    ##### Distributed Environment Infos Begin #####
  	I.module_attrs({"device_num": 2})
    I.module_global_infos({"mesh": [R.device_mesh((2,), I.Range(0, 2))]})
    ##### Distributed Environment Infos End   #####
    def foo(
        x: R.Tensor((128, 128), "float32"),
        A: R.Tensor((128, 127), "float32"),
        B: R.Tensor((127, 128), "float32"),
      	C: R.Tensor((128, 1), "float32"),
    ) -> R.Tensor((128, 128), "float32"):
        ##### Sharding Strategy Begin #####
        A = R.dist.annotate_sharding(A, device_mesh="mesh[0]", placement="S[1]")
        ##### Sharding Strategy End   #####
        lv0 = R.matmul(x, A)
        Y = R.nn.gelu(lv0)
        Z = R.matmul(Y, B)
        O = R.multiply(Z, C)
        return O


This pass is provided by DistIR to propagete sharding strategy to the entrie graph.

class LoweringDistIRExampleModule:  
  	I.module_attrs({"device_num": 2})
    I.module_global_infos({"mesh": [R.device_mesh((2,), I.Range(0, 2))]})
    # Represents sharding stategy with DTensor and remove 'annotate_sharding' op.
    def foo(
      x: R.DTensor((128, 128), "float32", "mesh[0]", "R"), 
      A: R.DTensor((128, 127), "float32", "mesh[0]", "S[1]"), 
      B: R.DTensor((127, 128), "float32", "mesh[0]", "S[0]"),
      C: R.DTensor((128, 1), "float32", "mesh[0]", "R"),
    ) -> R.DTensor((128, 128), "float32", "mesh[0]", "R"):
        A_1: R.DTensor((128, 127), "float32", "mesh[0]", "S[1]") = A
        lv0: R.DTensor((128, 127), "float32", "mesh[0]", "S[1]") = R.matmul(x, A_1, out_dtype="void")
        Y: R.DTensor((128, 127), "float32", "mesh[0]", "S[1]") = R.nn.gelu(lv0)
        Z: R.DTensor((128, 128), "float32", "mesh[0]", "R") = R.matmul(Y, B, out_dtype="void")
        O: R.DTensor((128, 128), "float32", "mesh[0]", "R") = R.multiply(Z, C)
        return Z


In this pass, we use the introduced sharding type ‘I’ to explicitly express reduction semantics that is hided before. This laies down the foundation to employ more optimization in a simple and natrual way.

class LoweringDistIRExampleModule:  
  	I.module_attrs({"device_num": 2})
    I.module_global_infos({"mesh": [R.device_mesh((2,), I.Range(0, 2))]})
    def foo(
      x: R.DTensor((128, 128), "float32", "mesh[0]", "R"), 
      A: R.DTensor((128, 127), "float32", "mesh[0]", "S[1]"), 
      B: R.DTensor((127, 128), "float32", "mesh[0]", "S[0]"),
      C: R.DTensor((128, 1), "float32", "mesh[0]", "R"),
    ) -> R.DTensor((128, 128), "float32", "mesh[0]", "R"):
        A_1: R.DTensor((128, 127), "float32", "mesh[0]", "S[1]") = A
        lv0: R.DTensor((128, 127), "float32", "mesh[0]", "S[1]") = R.matmul(x, A_1, out_dtype="void")
        Y: R.DTensor((128, 127), "float32", "mesh[0]", "S[1]") = R.nn.gelu(lv0)
        ##### Matmul After Legalization Begin #####
        Z_1: R.DTensor((128, 128), "float32", "mesh[0]", "I_sum") = R.matmul(Y, B, out_dtype="void")
        Z: R.DTensor((128, 128), "float32", "mesh[0]", "R") = R.dist.redistribute(Z_1, device_mesh="mesh[0]", placement="R")
        ##### Matmul After Legalization End   #####
        O: R.DTensor((128, 128), "float32", "mesh[0]", "R") = R.multiply(Z, C)
        return O


This pass represents a series of global optimizations that should be applied on global unified DistIR. Indeed, it could be multiple passes that run on unified DistIR. Here, we change the order of ‘redistribute’ and ‘multiply’ as an example. After the switch, ‘multiply’ can be fused with ‘matmul’.

class LoweringDistIRExampleModule:  
  	I.module_attrs({"device_num": 2})
    I.module_global_infos({"mesh": [R.device_mesh((2,), I.Range(0, 2))]})
    def foo(
      x: R.DTensor((128, 128), "float32", "mesh[0]", "R"), 
      A: R.DTensor((128, 127), "float32", "mesh[0]", "S[1]"), 
      B: R.DTensor((127, 128), "float32", "mesh[0]", "S[0]"),
      C: R.DTensor((128, 1), "float32", "mesh[0]", "R"),
    ) -> R.DTensor((128, 128), "float32", "mesh[0]", "R"):
        A_1: R.DTensor((128, 127), "float32", "mesh[0]", "S[1]") = A
        lv0: R.DTensor((128, 127), "float32", "mesh[0]", "S[1]") = R.matmul(x, A_1, out_dtype="void")
        Y: R.DTensor((128, 127), "float32", "mesh[0]", "S[1]") = R.nn.gelu(lv0)
        Z_1: R.DTensor((128, 128), "float32", "mesh[0]", "I_sum") = R.matmul(Y, B, out_dtype="void")
        ##### Switch multiply and redistribute Begin #####
        Z_2: R.DTensor((128, 128), "float32", "mesh[0]", "R") = R.multiply(Z_1, C)
        O: R.DTensor((128, 128), "float32", "mesh[0]", "R") = R.dist.redistribute(Z_2, device_mesh="mesh[0]", placement="R")
        ##### Switch multiply and redistribute End   #####
        return O


This pass lower the global unified DistIR to ranked single-device Relax IR. This is done by 3 steps:

  • First, copy the original IR for each rank
  • Second, convert DTensor to Tensor and lower ‘redistribute’ op to actual communication op such as ‘allreduce’. This could be done by visit function signature first, convert its DTensor to Tensor, then propagate the converted Tensor through entire graph.
  • Finally, run an elimination pass to eliminate redundant ranked IRs, leaving only one IR for each unique rank.
class LoweringDistIRExampleModule:  
  	I.module_attrs({"device_num": 2})
    I.module_global_infos({"mesh": [R.device_mesh((2,), I.Range(0, 2))]})
    # Ranked Relax IR for rank 0
    def foo_0(
      x: R.Tensor((128, 128), "float32"), 
      A: R.Tensor((128, 64), "float32"), 
      B: R.Tensor((64, 128), "float32"),
      C: R.Tensor((128, 1), "float32"),
    ) -> R.Tensor((128, 128), "float32"):
        lv0: R.Tensor((128, 64), "float32") = R.matmul(x, A, out_dtype="void")
        Y: R.Tensor((128, 64), "float32") = R.nn.gelu(lv0)
        Z: R.Tensor((128, 128), "float32") = R.matmul(Y, B_1, out_dtype="void")
        Z_1: R.Tensor((128, 128), "float32") = R.multiply(Z, C)
        O: R.Tensor((128, 128), "float32") = R.call_packed("vm.dist.allreduce", Z_1, Z_1)
        return O
    # Rank Relax IR for rank 1
    def foo_1(
      x: R.Tensor((128, 128), "float32"), 
      A: R.Tensor((128, 63), "float32"), 
      B: R.Tensor((63, 128), "float32"),
      C: R.Tensor((128, 1), "float32"),
    ) -> R.Tensor((128, 128), "float32"):
      	##### Padding only happen at rank 1 #####
      	A_1: R.Tensor((128, 64), "float32") = R.pad(A, (128, 64), value=0)
        B_1: R.Tensor((64, 128), "float32") = R.pad(B, (64, 128), value=0)
        ##### Padding only happen at rank 1 #####
        lv0: R.Tensor((128, 64), "float32") = R.matmul(x, A_1, out_dtype="void")
        Y: R.Tensor((128, 64), "float32") = R.nn.gelu(lv0)
        Z: R.Tensor((128, 128), "float32") = R.matmul(Y, B_1, out_dtype="void")
        Z_1: R.Tensor((128, 128), "float32") = R.multiply(Z, C)
        O: R.Tensor((128, 128), "float32") = R.call_packed("vm.dist.allreduce", Z_1, Z_1)
        return O


This pass represents a series of local optimizations that should be applied on ranked single-device Relax IR. Indeed, it could be multiple passes. Here, we fuse ‘multiply’ and ‘matmul’ as an example.

class LoweringDistIRExampleModule:  
  	I.module_attrs({"device_num": 2})
    I.module_global_infos({"mesh": [R.device_mesh((2,), I.Range(0, 2))]})
    # Ranked Relax IR for rank 0
    def foo_0(
      x: R.Tensor((128, 128), "float32"), 
      A: R.Tensor((128, 64), "float32"), 
      B: R.Tensor((64, 128), "float32"),
      C: R.Tensor((128, 1), "float32"),
    ) -> R.Tensor((128, 128), "float32"):
        lv0: R.Tensor((128, 64), "float32") = R.matmul(x, A, out_dtype="void")
        Y: R.Tensor((128, 64), "float32") = R.nn.gelu(lv0)
        Z: R.Tensor((128, 128), "float32") = R.fused_matmul_multiply(Y, B_1, C, out_dtype="void")
        O: R.Tensor((128, 128), "float32") = R.call_packed("vm.dist.allreduce", Z, Z)
        return O
    # Rank Relax IR for rank 1
    def foo_1(
      x: R.Tensor((128, 128), "float32"), 
      A: R.Tensor((128, 63), "float32"), 
      B: R.Tensor((63, 128), "float32"),
      C: R.Tensor((128, 1), "float32"),
    ) -> R.Tensor((128, 128), "float32"):
      	##### Padding only happen at rank 1 #####
      	A_1: R.Tensor((128, 64), "float32") = R.pad(A, (128, 64), value=0)
        B_1: R.Tensor((64, 128), "float32") = R.pad(B, (64, 128), value=0)
        ##### Padding only happen at rank 1 #####
        lv0: R.Tensor((128, 64), "float32") = R.matmul(x, A_1, out_dtype="void")
        Y: R.Tensor((128, 64), "float32") = R.nn.gelu(lv0)
        Z: R.Tensor((128, 128), "float32") = R.fused_matmul_multiply(Y, B_1, C, out_dtype="void")
        O: R.Tensor((128, 128), "float32") = R.call_packed("vm.dist.allreduce", Z, Z)
        return O

Advantages of the design

In this part, we demonstrate the advantages of the design through example outputs.

Example 1 —— Tail Block

With the following code snippet, we can see that tail block is resolve natrually because we are able to use different IR for each rank. There’re other advantages:

  • Enable more ways to tackle irregularities. To resolve tail block, with MPMD we can either padding it or compile a standalone executable(but most kernel executables can be shared) for that particular rank. We can also fallback to SPMD since it is just a special case of MPMD.
  • No need to invent new infrastructure. With MPMD, we resolve irregularities at the expense of redundancy, this enables us to reuse all infrastructure TVM provides now, including operator inventory, optimization passes, lowering passes and etc. Since each ranked IR is just a single-device IR which TVM
  • Enables more optimization oppourtunities. In the example below, tensor B on rank 1 has a shape of (128, 2), applying padding on it wastes extra memory. With MPMD, we can compile kernels specifically for its shape and save those memory.
class TailBlockExample:  
  	I.module_attrs({"device_num": 2})
    I.module_global_infos({"mesh": [R.device_mesh((2,), I.Range(0, 2))]})
    # Ranked Relax IR for rank 0
    def foo_0(
      x: R.Tensor((128, 128), "float32"), 
      A: R.Tensor((128, 64), "float32"), 
      B: R.Tensor((64, 128), "float32"),
      C: R.Tensor((128, 1), "float32"),
    ) -> R.Tensor((128, 128), "float32"):
        lv0: R.Tensor((128, 64), "float32") = R.matmul(x, A, out_dtype="void")
        Y: R.Tensor((128, 64), "float32") = R.nn.gelu(lv0)
        Z: R.Tensor((128, 128), "float32") = R.matmul(Y, B_1, out_dtype="void")
        Z_1: R.Tensor((128, 128), "float32") = R.call_packed("vm.dist.allreduce", Z, Z)
        O: R.Tensor((128, 128), "float32") = R.multiply(Z_1, C)
        return O
    # Rankde Relax IR for rank 1
    def foo_1(
      x: R.Tensor((128, 128), "float32"), 
      A: R.Tensor((128, 2), "float32"), 
      B: R.Tensor((2, 128), "float32"),
      C: R.Tensor((128, 1), "float32"),
    ) -> R.Tensor((128, 128), "float32"):
      	##### Padding only happen at rank 1 #####
      	A_1: R.Tensor((128, 64), "float32") = R.pad(A, (128, 64), value=0)
        B_1: R.Tensor((64, 128), "float32") = R.pad(B, (64, 128), value=0)
        ##### Padding only happen at rank 1 #####
        lv0: R.Tensor((128, 64), "float32") = R.matmul(x, A_1, out_dtype="void")
        Y: R.Tensor((128, 64), "float32") = R.nn.gelu(lv0)
        Z: R.Tensor((128, 128), "float32") = R.matmul(Y, B_1, out_dtype="void")
        Z_1: R.Tensor((128, 128), "float32") = R.call_packed("vm.dist.allreduce", Z, Z)
        O: R.Tensor((128, 128), "float32") = R.multiply(Z_1, C)
        return O

Example 2 —— Sharding CONV Op

With the following code snippet, we can see that sharding along spatial axis of sliding window operators is also resolve natrually. Some other advatnges including:

  • No need to invent complex communication operators. With this design, complex communication logics can be expressed by simple ops, e.g. send and recv.
  • Enables more general and composable optimizations. Another advantage of using simple communication ops to build more complex communication logics is that optimization patterns and passes need only targets those basic building blocks, enables more general, composable optimizations.
class ShardingConvExample:  
  	I.module_attrs({"device_num": 2})
    I.module_global_infos({"mesh": [R.device_mesh((2,), I.Range(0, 2))]})
    def foo_0(
      a: R.Tensor((1, 3, 5, 3), "float32"), # N, H, W, Cin
      w: R.Tensor((3, 3, 3, 10), "float32") # Kh, Kw, Cin, Cout
    ) -> R.Tensor((1, 2, 5, 10), "float32"):
      # Recive from rank 1
      from_rank1 = R.alloc_tensor((1, 1, 5, 3), "float32")
      R.call_packed("vm.dist.recv", from_rank1, sendRank=1)
      # Send to rank 1
      to_rank1 = R.slice(a, begin=[0, 2, 0, 0], end=[1, 3, 5, 3])
      R.call_packed("vm.dist.send", to_rank1, recvRank=1)
      # Concate local activation
      a_0 = R.concate(a, from_rank1, axis=1)
      # Execute conv2d
      r = R.conv2d(a0, w)
      return r
    def foo_1(
      a: R.Tensor((1, 2, 5, 3), "float32"), # N, H, W, Cin
      w: R.Tensor((3, 3, 3, 10), "float32") # Kh, Kw, Cin, Cout
    ) -> R.Tensor((1, 1, 5, 10), "float32"):
      # Send to rank 0
      to_rank0 = R.slice(a, begin=[0, 0, 0, 0], end=[1, 1, 5, 3])
      R.call_packed("vm.dist.send", to_rank0, recvRank=0)
      # Recive from rank 0
      from_rank0 = R.alloc_tensor((1, 1, 5, 3), "float32")
      R.call_packed("vm.dist.recv", from_rank0, sendRank=0)
      # Concate local activation
      a_1 = R.concate(a, from_rank0, axis=1)
      # Execute conv2d
      r = R.conv2d(a_1, w)
      return r

Example 3 —— Introducing ‘I’ in Sharding Type

Antoher design choice we made is to add a new sharding type ‘I’ which stands for intermediate. From the example below, we can see that by explicitly express intermediate tensor that is waiting to be reduced, we can enables more fusion oppourtunities.

# Without Sharding Type "I"
class LoweringDistIRExampleModule:  
  	I.module_attrs({"device_num": 2})
    I.module_global_infos({"mesh": [R.device_mesh((2,), I.Range(0, 2))]})
    def foo(
      A: R.DTensor((128, 8), "float32", "mesh[0]", "S[1]"), 
      B: R.DTensor((128, 128), "float32", "mesh[0]", "S[0]"),
      C: R.DTensor((128, 1), "float32", "mesh[0]", "R"),
    ) -> R.DTensor((128, 128), "float32", "mesh[0]", "R"):
        # We cannot fuse 'multiply' into matmul because of the implicit reduce.
        # If we want to force the fusion, new patterns or rules that identifies the implicit reduce must be developed.
        D_1: R.DTensor((128, 128), "float32", "mesh[0]", "R") = R.matmul(A, B, out_dtype="void")
        O: R.DTensor((128, 128), "float32", "mesh[0]", "R") = R.multiply(D, C)
        return O

# With Sharding Type "I"
class LoweringDistIRExampleModule:  
  	I.module_attrs({"device_num": 2})
    I.module_global_infos({"mesh": [R.device_mesh((2,), I.Range(0, 2))]})
    def foo(
      A: R.DTensor((128, 8), "float32", "mesh[0]", "S[1]"), 
      B: R.DTensor((128, 128), "float32", "mesh[0]", "S[0]"),
      C: R.DTensor((128, 1), "float32", "mesh[0]", "R"),
    ) -> R.DTensor((128, 128), "float32", "mesh[0]", "R"):
        D_1: R.DTensor((128, 128), "float32", "mesh[0]", "I_SUM") = R.matmul(A, B, out_dtype="void")
        # The order of 'redistribute' and 'multiply' can be switched and then 'multiply' can be fused into 'matmul'
        D: R.DTensor((128, 128), "float32", "mesh[0]", "R") = R.dist.redistribute(D_1, device_mesh="mesh[0]", placement="I_SUM")
        O: R.DTensor((128, 128), "float32", "mesh[0]", "R") = R.multiply(D, C)
        return O

There’re other advantages of the design:

  • Enables local optimizations. Split the IR for each unique ranks enables us to optimize each ranked IR separately and locally.
  • More layered organization of optimizations and passes. By introducing Ranked Relax IRModule, the optimization and analysis can be naturally separated to two categories, a global one that should be applied on unified DistIR and a local one that should be applied on ranked IR.

Disadvantages of the design

Some disadvantages of this design including:

  • Local optimizations on communications are subject to deadlock. Since we split the IR to multiple ranked one, deadlock may happend when we apply local optimizations, such as changing the order of communication ops. This could be solved by introduce a helper ‘barrier’ op, only ops inside the same barrier region can be switched.
  • Indroduce extra compilation overhead. For large-scale distributed application, we may need to compile a large number of ranked IR, this brings extra compilation overhead. Also, the size of the executable may inflate.
  • Extra runtime complexity. Runtime now need to support MPMD paradigm, requires further engineering effort.

Future Possibilities

Some future directions including:

  • Explore ways to combine distributed inference with other LLM optimizations, such as KVCache, paged memory management and continuous batching.
  • Exolore possible compiler-level optimizations including optimizations of communication ops and optimization between communication ops and compute ops.
  • Explore ways to support asynchronous inference. This could be helpful for pipeline paralliesm to fully utilize device resources and improve throughput.

Our Plan

Since all works in the filed of distributed inference/training are at very early stage, including our design which has not been started to implement, we expect to discuss with community in a broader and deeper way to avoid any possible repeat work and conflicts in designs.


Kwai AI Platform Compiler Group.


Great work! Thanks for finding out the limitations of current design of DistIR. Irregular sharding is indeed an important problem and I’m happy that you give a representation for MPMD strategy.

The second challenge is open to discussion. DistIR is designed to hide allreduce in matmul operator on purpose. For high-level Relax operator, we don’t expect user to apply fine-grained optimizations on it. The only goal of high-level operator is simplicity. If user wants to express advanced optimizations in DistIR, we expect them to use TensorIR. Below is a TensorIR of matmul sharded in reduction dimension which contains allreduce.

    def matmul2(A: T.Buffer((T.int64(128), T.int64(64)), "float32"), B: T.Buffer((T.int64(64), T.int64(128), "float32"), matmul_1: T.Buffer((T.int64(128), T.int64(128)), "float32")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for i0, i1, k in T.grid(T.int64(128), T.int64(128), T.int64(64)):
            with T.block("matmul"):
                v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
                T.reads(A[v_i0, v_k], B[v_k, v_i1])
                T.writes(matmul_1[v_i0, v_i1])
                with T.init():
                    matmul_1[v_i0, v_i1] = T.float32(0)
                matmul_1[v_i0, v_i1] = matmul_1[v_i0, v_i1] + A[v_i0, v_k] * B[v_k, v_i1]
	with T.block("all_reduce"):
		T.evaluate(T.allreduce, src_buffer=matmul_1.data, dst_buffer=matmul_1.data, size=128*128, group=(mesh="mesh[0]", dims=[0]))

Through this representation in TensorIR, we can enable fine-grained scheduling of communication and computation.

Furthermore, operator fusion may be implemented in another way. The precondition of fusing matmul and operator X is that the order of sum and X can be swapped, which only stands for a limited set of operators. For the multiply case,it can be fused to a later operator.

Anyway, I’m open to adding “I” in sharding type if you have strong opinion on this because it does not require much effort. Please let me know how you think about this. Thanks.

1 Like

Thanks for replying! I think the logic behind DistIR to hide communication is solid.

At first I was concerning that hiding communication may miss potential optimizations, but with your example, it is clear that the optimization can be done at TIR level. Also as you stated, this helps users to focus on modeling at RelaxIR level.

Hence, I don’t think we need to add an ‘I’ type for now.

Hi @DiTian ,

I am also working on the Distributed Inference topic, and wanted to ask you about the status of DistIR. Do we have some working example on the repository? I only found the tests, which may be enough to start, but are always very limited.
