[DistIR] Convert IRModule to a distributed version and propagate sharding

I would like to convert a module generated via some frontend to a distributed version. The module has a main function with several calls to a global function and I would like to distribute each of them.

My current strategy is to add annotate_sharding calls via an expr mutator, then I would like to call apply PropageSharding. However I am getting some issues with PropagateSharding as the global function signature uses Tensor and not DTensor.

Here’s an example:

TVMError: Argument 0 type mismatch: expected R.Tensor((128, 128), dtype="float32"), given R.DTensor((128, 128), "float32", R.device_mesh((1,), R.Range(0, 1)), "R")
from tvm.script.parser import ir as I
from tvm.script.parser import relax as R

@I.ir_module
class Model:
    I.module_attrs({"device_num": 2})
    I.module_global_infos({"mesh": [R.device_mesh((1,), I.Range(0, 1)), R.device_mesh((1,), I.Range(1, 2))]})    
    @R.function(private=True)
    def mlp(
        x: R.Tensor((128, 128), "float32"),
        A: R.Tensor((128, 128), "float32"),
        B: R.Tensor((128, 128), "float32"),
    ) -> R.Tensor((128, 128), "float32"):
        with R.dataflow():
            lv0 = R.matmul(x, A)
            Y = R.nn.relu(lv0)
            Z = R.matmul(Y, B)
            R.output(Z)
        return Z

    @R.function
    def main(
        x: R.Tensor((128, 128), "float32"),
        A: R.Tensor((128, 128), "float32"),
        B: R.Tensor((128, 128), "float32"),
        C: R.Tensor((128, 128), "float32"),
        D: R.Tensor((128, 128), "float32"),
    ) -> R.Tensor((128, 128), "float32"):
        R.func_attr({"num_input": 1})
        cls = Model
        with R.dataflow():
            lv0 = cls.mlp(x, A, B)
            lv1 = cls.mlp(lv0, C, D)
            R.output(lv1)
        return lv1

from tvm.relax.distributed import Placement
from tvm import ir, relax
import itertools as it

@relax.expr_functor.mutator
class PP(relax.PyExprMutator):
    def __init__(self, mod: ir.IRModule):
        super().__init__()
        self.mod = mod
        assert "mesh" in mod.global_infos
        mesh = mod.global_infos["mesh"]
        self.dev_iter = it.cycle(mesh)
        
    def visit_call_(self, call):
        call = self.visit_expr_post_order(call)
        if isinstance(call.op, ir.GlobalVar) and call.op.name_hint == "mlp":
            value = call.args[0]
            sharded_input = relax.op.distributed.annotate_sharding(value, device_mesh=next(self.dev_iter), placement=Placement.from_text("R"))
            new_args = [sharded_input] + list(call.args[1:])
            return relax.Call(
                call.op,
                new_args,
                call.attrs,
                call.sinfo_args,
                call.span
            )

        return call

mod = Model.clone()
mod["main"] = PP(mod).visit_expr(mod["main"])
mod = relax.distributed.transform.PropagateSharding()(mod)
mod.show()

@DiTian @fPecc any thoughts on how to programmatically introduce annotate_sharding calls on an existing module with calls to private functions?

Hi @slai-nick . Sadly, I have not had time yet to work with the distributed IR.

Ok no worries.

I haven’t found much else on DistIR in tvm other than this colab which only showcases the use of DistIR in modules written in TVM script.

The current integration doesn’t seem complete, at least for the passes. I see in the sources some todos left. For example the LegalizeRedistribute pass (legalize_redistribute.cc) doesn’t support AlltoAll and AllGather communications and only supports 1D mesh.

The fact that I didn’t succeed in adding annotate_sharding calls programmatically yet could be the result of my novice tvm skills, or maybe the current integration doesn’t handle what I want to do.

My goal with DistIR is to perform analysis of distributed programs and I would like to see it fully featured in TVM.

I would like to dedicate some time to extend the integration and would greatly appreciate support/exchange from other interested people and/or regular contributors.

I think the PropagateSharding pass does not correctly handle the nested function calling case (in your code main calls into mlp).

Also you are right that the integration is not complete. You are welcome to extend the integration.