I got an error while using init_ccl()
in di.ThreadedSession. Here is my code for MLP sharding. init_ccl()
do not throw error when the devices
which be passed into the init_ccl()
is an empty List
. Please help and thanks for helping .
import os
from tvm.script.parser import ir as I, relax as R, tir as T
from tvm import dlight
from tvm import relax
import tvm
from tvm.target.target import Target, cuda
import tvm.exec.disco_worker
import tvm.runtime.disco as di
from tvm._ffi.registry import list_global_func_names, get_global_func, register_func, register_object
from tvm.runtime.container import ShapeTuple
from colorama import Fore, Back, Style
import numpy as np
@tvm.script.ir_module
class MLP: # pylint: disable=too-few-public-methods
@R.function
def main(
x: R.Tensor((128, 128), "float32"),
W1: R.Tensor((128, 128), "float32"),
W2: R.Tensor((128, 128), "float32"),
) -> R.Tensor((128, 128), "float32"):
R.func_attr({"global_symbol": "main"})
with R.dataflow():
lv0: R.Tensor((128, 128), "float32") = R.matmul(x, W1)
lv1: R.Tensor((128, 128), "float32") = R.nn.gelu(lv0)
lv2: R.Tensor((128, 128), "float32") = R.matmul(lv1, W2)
R.output(lv2)
return lv2
@tvm.script.ir_module
class ShardedMLP: # pylint: disable=too-few-public-methods
@R.function
def main(
x: R.Tensor((128, 128), "float32"),
W1: R.Tensor((128, 64), "float32"), # shard along axis 1
W2: R.Tensor((64, 128), "float32"), # shard along axis 0
) -> R.Tensor((128, 128), "float32"):
R.func_attr({"global_symbol": "main"})
with R.dataflow():
broadcast_x: R.Tensor((128, 128), "float32") = R.ccl.broadcast_from_worker0(x)
lv0: R.Tensor((128, 64), "float32") = R.matmul(broadcast_x, W1)
lv1: R.Tensor((128, 64), "float32") = R.nn.gelu(lv0)
lv2: R.Tensor((128, 128), "float32") = R.matmul(lv1, W2)
lv3: R.Tensor((128, 128), "float32") = R.ccl.allreduce(lv2, "sum")
R.output(lv3)
return lv3
def create_device_target(ccl):
if ccl == "nccl":
dev = tvm.cuda(0)
else:
dev = tvm.rocm(0)
target = tvm.target.Target.from_device(dev)
return (dev, target)
def relax_build(mod, target):
with target:
mod = relax.get_pipeline("zero")(mod) # pylint: disable=no-value-for-parameter
mod = dlight.ApplyDefaultSchedule( # pylint: disable=not-callable
dlight.gpu.Matmul(),
dlight.gpu.GEMV(),
dlight.gpu.Reduction(),
dlight.gpu.GeneralReduction(),
dlight.gpu.Fallback(),
)(mod)
return relax.build(mod, target=target)
target = Target(target=cuda(options=["-arch=sm_89"]), host="llvm")
# mod = MLP
mod = ShardedMLP
mod = relax.distributed.transform.PropagateSharding()(mod)
mod = relax.distributed.transform.LowerDistIR()(mod)
relax_build(mod, target).export_library("data/mod.so")
shape = (128, 128)
devices = [x for x in range(1)]
sess = di.ThreadedSession(num_workers=len(devices))
# sess = di.ProcessSession(num_workers=len(devices), entrypoint="nccl_process") # entrypoint="python's module name"
sess.init_ccl("nccl", *devices) # ERROR
# exec = sess.load_vm_module("data/mod.so", device=tvm.cuda(0))
exec = sess.load_vm_module("data/mod.so")
print(Fore.GREEN + "[S] Load mod.so" + Fore.RESET)
# np.random.seed(0)
# x_np = np.random.randn(128, 128).astype("float32")
# A_np = np.random.randn(128, 128).astype("float32")
# B_np = np.random.randn(128, 128).astype("float32")
# ---------- Use this if not using sharding ----------
# x = sess.empty((128, 128), "float32", device=tvm.cuda(0))
# A = sess.empty((128, 128), "float32", device=tvm.cuda(0))
# B = sess.empty((128, 128), "float32", device=tvm.cuda(0))
# x = sess.empty((128, 128), "float32")
# A = sess.empty((128, 128), "float32")
# B = sess.empty((128, 128), "float32")
# ---------- Use this if using sharding ----------
x = sess.empty((128, 128), "float32")
A = sess.empty((128, 64), "float32")
B = sess.empty((64, 128), "float32")
# x = sess.empty((128, 128), "float32", device=tvm.cuda(0))
# A = sess.empty((128, 64), "float32", device=tvm.cuda(0))
# B = sess.empty((64, 128), "float32", device=tvm.cuda(0))
# x.debug_copy_from(0, x_np)
# A.debug_copy_from(0, A_np)
# B.debug_copy_from(0, B_np)
while True:
res_dist = exec["main"](x, A, B)
print(Fore.GREEN + "[S] Execute funtion" + Fore.RESET)
# res_dev = tvm.nd.empty((128, 128), "float32", device=tvm.cpu(0))
# sess.copy_from_worker_0(res_dev, res_dist)
# sess.sync_worker_0()
# res_dev = res_dev.numpy()
# print(Fore.BLUE + "[P] Print numpy result" + Fore.RESET)
# print(res_dev)
The error message occurred on sess.init_ccl("nccl", *devices)
is below.
terminate called after throwing an instance of 'tvm::runtime::InternalError'
what(): [09:36:35] /home/myhome/workspace/tvm-unity/src/runtime/disco/nccl/nccl.cc:196: ncclErrror: internal error - please report this issue to the NCCL developers
Stack trace:
0: tvm::runtime::nccl::InitCCLPerWorker(tvm::runtime::ShapeTuple, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)
at /home/myhome/workspace/tvm-unity/src/runtime/disco/nccl/nccl.cc:178
1: tvm::runtime::DiscoWorker::Impl::CallPacked(tvm::runtime::DiscoWorker*, long, tvm::runtime::PackedFunc, tvm::runtime::TVMArgs const&)
at /home/myhome/workspace/tvm-unity/src/runtime/disco/disco_worker.cc:193
2: tvm::runtime::DiscoWorker::Impl::MainLoop(tvm::runtime::DiscoWorker*)
at /home/myhome/workspace/tvm-unity/src/runtime/disco/disco_worker.cc:81
3: 0x00007f61f0edc252
4: start_thread
at ./nptl/pthread_create.c:442
5: 0x00007f628792684f
at ../sysdeps/unix/sysv/linux/x86_64/clone3.S:81
6: 0xffffffffffffffff
Platform
- CPU: Intel x86_64, 48 cores
- GPU:
NVIDIA GeForce RTX 4090
× 2 - CUDA Toolkits: Information from
nvidia-smi
:- NVIDIA-SMI 535.146.02
- Driver Version: 535.146.02
- CUDA Version: 12.2
- CUDNN: version 12.2
- NCCL (build from source):version 2.19.4
TVM Unity is built from here.