I would like to convert a module generated via some frontend to a distributed version.
The module has a main
function with several calls to a global function and I would like to distribute each of them.
My current strategy is to add annotate_sharding
calls via an expr mutator, then I would like to call apply PropageSharding
. However I am getting some issues with PropagateSharding
as the global function signature uses Tensor
and not DTensor
.
Here’s an example:
TVMError: Argument 0 type mismatch: expected R.Tensor((128, 128), dtype="float32"), given R.DTensor((128, 128), "float32", R.device_mesh((1,), R.Range(0, 1)), "R")
from tvm.script.parser import ir as I
from tvm.script.parser import relax as R
@I.ir_module
class Model:
I.module_attrs({"device_num": 2})
I.module_global_infos({"mesh": [R.device_mesh((1,), I.Range(0, 1)), R.device_mesh((1,), I.Range(1, 2))]})
@R.function(private=True)
def mlp(
x: R.Tensor((128, 128), "float32"),
A: R.Tensor((128, 128), "float32"),
B: R.Tensor((128, 128), "float32"),
) -> R.Tensor((128, 128), "float32"):
with R.dataflow():
lv0 = R.matmul(x, A)
Y = R.nn.relu(lv0)
Z = R.matmul(Y, B)
R.output(Z)
return Z
@R.function
def main(
x: R.Tensor((128, 128), "float32"),
A: R.Tensor((128, 128), "float32"),
B: R.Tensor((128, 128), "float32"),
C: R.Tensor((128, 128), "float32"),
D: R.Tensor((128, 128), "float32"),
) -> R.Tensor((128, 128), "float32"):
R.func_attr({"num_input": 1})
cls = Model
with R.dataflow():
lv0 = cls.mlp(x, A, B)
lv1 = cls.mlp(lv0, C, D)
R.output(lv1)
return lv1
from tvm.relax.distributed import Placement
from tvm import ir, relax
import itertools as it
@relax.expr_functor.mutator
class PP(relax.PyExprMutator):
def __init__(self, mod: ir.IRModule):
super().__init__()
self.mod = mod
assert "mesh" in mod.global_infos
mesh = mod.global_infos["mesh"]
self.dev_iter = it.cycle(mesh)
def visit_call_(self, call):
call = self.visit_expr_post_order(call)
if isinstance(call.op, ir.GlobalVar) and call.op.name_hint == "mlp":
value = call.args[0]
sharded_input = relax.op.distributed.annotate_sharding(value, device_mesh=next(self.dev_iter), placement=Placement.from_text("R"))
new_args = [sharded_input] + list(call.args[1:])
return relax.Call(
call.op,
new_args,
call.attrs,
call.sinfo_args,
call.span
)
return call
mod = Model.clone()
mod["main"] = PP(mod).visit_expr(mod["main"])
mod = relax.distributed.transform.PropagateSharding()(mod)
mod.show()