Pre-RFC: Multiple device support in Relay VM

Currently the Relay VM only supports a single device:

It would be useful to support multiple devices, e.g. for heterogeneous splitting of networks for data center workloads. I’ve been messing with this change on a private branch but I don’t have anything presentable yet. It is (technically) possible to represent this at the Relay IR level; the attributes for device copy & storage allocations have slots for (static) device IDs. However, the device ID information is currently thrown away during compilation for the VM. I’d like to change that.

The work to support heterogeneous execution has laid some groundwork here:

However, more invasive changes are needed. In particular, the VM bytecode format will need to be modified to include device IDs on AllocStorage and DeviceCopy, and that data will need to be plumbed through various compilation passes.

Key questions:

  • What should the API for annotating modules with device information look like? It would be nice to support both homogeneous splitting (i.e. across a batch dimension) and heterogeneous splitting (anything else.)
  • Should device selection be static or dynamic? Static is simpler to implement, dynamic would be more flexible and could e.g. test the number of devices available and adapt based on that. The analysis passes determining device associations currently assume static device assignments.
  • How should we deal with constants? A simple implementation would change the relation between constants and devices from one-to-one and one-to-many. Alternatively, constants could all be logically associated with the CPU, and could be dynamically loaded to particular devices as needed.
  • How should this API be tested? I don’t believe the CI machines have multiple GPUs. One solution would be to implement a new device type, virtual cpu, which is pretty much the same as regular CPU but allows multiple contexts to be instantiated, and forbids using tensors associated with one context with another.

Hi, there’s some overlap with It tries to at least ensure device_id, as an uninterpreted int, is plumbed through from annotation to device_copy, parameter metadata, etc.

Your key questions, however, are very good ones and well outside the scope of the RFC:

  • Given a few device annotations we heuristically default devices for the rest of the program. But that’s a whole optimization problem in itself.
  • Right, it would be amazing to be able to shard a tensor across devices on the N dimension.
  • We consider the ‘devices’ we plan with to be ‘virtual devices’, but there’s currently no way to control the mapping from virtual to actual. We may want to choose the actual at runtime based on load, capabilities, etc.
  • For constants I think we could rewrite:
       @global_const = ...constant implicitly on device A...
       .... device_copy(@global_const, A, B) ...
       .... device_copy(@global_const, A, C) ...

by partially evaluating the device_copy and hoisting the result into new constants on B and C. However we currently don’t have a way to represent globally bound constants.

1 Like