Model parallism of inference for large model like GPT2 with TVM

Can somebody shed some lights on that whether it’s technical feasible to do model parallism of inference for large model on TVM?

From my personal perspective, followings field should be investigated:

  1. Whether the computation graph (torch.jit.trace) can capture communication operator like allgather?
  2. Need TVM support allgather op for specific device. (need TVM integrate with underlying communication library like MPI, nccl)
  3. Need launcher provied by original framework (like pytorch) so that it can provide communication context to each communicator (process). Anything else?
1 Like

@tqchen @jroesch Would you like to add comments for this topic? Big model is AI trend. If TVM community didn’t have plan to support inference of model parallelsim, can you comment upon the technique feasibility?

this topic is very interesting, currently we have a pending RFC/PR([RFC] Compute graph pipeline with new subgraph executor) related Model parallelism , it not designed for model papalism but it do some of the work what model parallelism ask like horizon split model, pipeline execution, reduce memory requirement, cross device memory movement, with tvm RPC help, the device/target also can be distributed.

I think it should can help for large model deploy, after just by pass communication operator.

@hjiang Thanks for your suggestion. Yes, we really consider this method: split computation graph and offload these sub computation graph to different device. The drawback of this method is: It’s not scalable and some large model like GPT2 has the mechanism of model parallelism for inference inherently.

@yezhouhai Could you sharing some information about how framework like PyTorch handles parallel inference? More specifically, who is responsible for specifying part of model to a device?

PyTorch use DDP component to handle distributed training/inference. It provides communitive primitives or DDP optimizer. It’s user’s responsiblity to split model. We finally successfully enabled model paralllism of inference with TVM.
Steps:

  1. hook PyTorch DDP primitives because allreduce or allgather in PyTorch are API. They are not operator can’t be captured by jit trace. Instead, I replace them with dummy operators (allreduce, allgather) in PyTorch aten (very few lines. 8 lines).
  2. Then torch.jit.trace/script can capture allreduce/allgather operator in model.
  3. Add allreduce/allgather op support in tvm. It means you need to integrate communication library into tvm. It really take many engineering work for this step.

For Step 1,2, maybe there’s eaiser way to add allreduce/allgather operator in relay graph.

3 Likes

By the way, local node information like rank, world_size are passed through environment variable.

@yezhouhai,could you tell me which communication library you integrate into tvm, thanks!

any updates from TVM on model parallelism? :slightly_smiling_face: