[RFC][Unity][Relax] Heterogeneous Execution for Relax


The goal of this design document is to add heterogeneous execution to Relax, enabling it to compile and optimize deep learning workloads across multiple types of devices, such as CPUs, GPUs, and specialized accelerators. By enabling heterogeneous execution, we can take advantage of the unique features and capabilities of each target device.

Proposed Design

The proposed design for heterogeneous support in Relax consists of the following components:

Target Abstraction

We will reuse the existing Target.

Device Abstraction

Virtual Device

VDevice, a subclass of GlobalInfo, will be introduced, it denotes the data storage representation during compilation and outlines how to compile and compute it.

class VDeviceNode : public GlobalInfoNode {
  Target target;
  int vdevice_id; // The virtual device id, this enables us to
                  // differentiate between distinct devices with
                  // same Target, such as multiple GPUs. It might
                  // be changed during runtime
  MemoryScope memory_scope;

class VDevice : public GlobalInfo {
  TVM_DEFINE_OBJECT_REF_METHODS(VDevice, GlobalInfo, VDeviceNode);
# the corresponding Python Binding
class VDevice(GlobalInfo):
    def __init__(self,
                 target: Union[str, dict, Target], 
                 vdevice_id: int = 0,
                 memory_scope: str = "global") -> None:

To help create VDevice in the IR, we will introduce a new syntactic sugar, R.VDevice. All virtual devices should be defined and added into the global_infos of IRModule using I.module_global_infos({"vdevice": vdevice_list}).

# python/tvm/script/parser/relax/entry.py
# R.VDevice
def VDevice(
    target: Union[str, dict, Target],
    vdevice_id: int = 0,
    memory_scope: str = "global"


A new member Optional<VDevice> vdevice will be added into TensorStructInfoNode. This indicates where the tensor is expected to be executed.

class TensorStructInfoNode : public StructInfoNode {
  Optional<VDevice> vdevice; // virtual device

To help users to annotate the expression with TensorStructInfo, we will extendR.Tensor to have an additional optional vdevice parameter, the value of vdevice could be either in the format "device_type:i", which represents the i-th vdevice of a specific device type, or "vdevice:j", which refers to the j-th element within the vdevice list defined in the global_infos. The i is optional, and default value 0 will be applied if it is not specified.

# R.Tensor
def Tensor(
    shape: Optional[Union[List[Union[PrimExpr, str]], ShapeExpr]] = None,
    dtype: Optional[str] = None,
    vdevice: Optional[str] = None, # the value could be "cuda", 
                                   # "cuda:0", or "vdevice:0"
    ndim: int = -1,

The following is an example of what the IR would look like. The "vdevice:1" refers to the second virtual device defined in global_infos. The"cuda:0" corresponds to the first virtual device with cuda device type, by looking up the vdevice using util.lookup_vdevice, the matched element should be the second one. An error will be raised if no matching virtual device is found.

from tvm.script.parser import ir as I
from tvm.script.parser import relax as R
vdevice = [
          R.VDevice("cuda", 0),
          R.VDevice("cuda -arch=sm_80", 0),
          R.VDevice("metal", 0, "global"),
class Module:
    I.module_global_infos({"vdevice": vdevice})
    def foo(
        a: R.Tensor((2, 3), "float32", "cuda"),
        b: R.Tensor((2, 3), "float32", "cuda:0"),
        c: R.Tensor((2, 3), "float32", "vdevice:1")):
        s1 = R.add(a, b)
        s = R.add(s1, c)
        return s

# use module pass `UpdateVDevice` to update a specific virtual device,
# append a new one, or reset them all

Device Collaboration

To optimize workloads that involve multiple devices, Relax will provide mechanisms for device collaboration, such as data copy between devices.

New operator hint_on_device

A special annotation op R.hint_on_device to hint the input expression should be executed on the specific device, this hint will be used by RealizeVDevice to propagate virtual device information across the entire IR. Error will be reported if virtual device conflict is found.

def hint_on_device(input: relax.Expr, device: Union[str, Device]) -> relax.Expr:
    input : relax.Expr
        The expression to be annotated.
    device : Union[Device, str]
        The device to annotate with.
    result : relax.Expr

a = R.hint_on_device(x, tvm.cpu())
b = R.hint_on_device(y, tvm.cuda()) # b is on different device from a
c = R.add(a, b) # Error occurs, conflict is found

New Relax operator to_vdevice(input, tgt_vdevice)

Copy the input expression to a specific target VDevice. This operator is considered as pure, no in-place operation happens, it is allowed to appear in DataFlow block. The operator pattern of to_vdevice is OpPatternKind::kOpaque, it should not be fused by FuseOps.

def foo(x: R.Tensor((2, 3), "float32"), y: R.Tensor((2, 3), "float32")):
    with R.dataflow():
        x1 = R.hint_on_device(x, tvm.cpu())
        y1 = R.hint_on_device(y, tvm.cuda())
        s1 = R.add(x1, x1)
        s2 = R.to_vdevice(s1, R.Tensor("cuda:0"))
        s = R.add(y1, s2)
    return s

VM builtin PackedFunc to_device

We will introduce a new PackedFunc, vm.builtin.to_device, to help copy data cross different devices, it has two arguments: {input, tgt_vdevice}, the first one is the input data, tgt_vdevice is the target vdevice to which copy the input data. TheVMCodeGen will be updated to emit vm.builtin.to_device wherever the relax.op.to_device is encountered, and the source and target vdevices are not the same.

Virtual Device Passes

The StructInfo rules about virtual devices will be enforced in Normalizer. Some additional cases, such as backward propagation, will be implemented in a separate pass called RealizeVDevice.

InferStructInfo in Normalize

The helper function InferStructInfo in block_builder will be extended to deduce the forward virtual device information. If the virtual device is determined, the information will be propagated forward throughout the Relax IR. In order to simplify the deduction, the virtual device is required to be either fully constrained or fully unconstrained. If it is not fully constrained, the default virtual device, which is the first element defined in vdevice of global_infos, will be applied. The last element of vdevice serves as the host virtual device if not explicitly specified. For instance, in the following example, the ‘after’ represents the anticipated program of the ‘before’ function after normalization.

def before(
    x: R.Tensor((2, 3), "float32", "cuda"),
    y: R.Tensor((2, 3), "float32", "cuda"),
) -> R.Tensor((2, 3), "float32"):
    a = R.add(x, y)
    b = R.multiply(a, x)
    return b

def after(
    x: R.Tensor((2, 3), "float32", "cuda"),
    y: R.Tensor((2, 3), "float32", "cuda")
) -> R.Tensor((2, 3), "float32", "cuda"):
    a: R.Tensor((2, 3), "float32", "cuda") = R.add(x, y)
    b: R.Tensor((2, 3), "float32", "cuda") = R.multiply(a, x)
    return b

RealizeVDevice pass

We are introducing a new pass called RealizeVDevice to help handle cases that InferStructInfo may have missed, such as backward propagation. The hint_on_device and to_vdevice will be used used to help propagate information about virtual devices across the entire IR. As shown in the following code, the return of function ModBefore["func1"] has vdevice annotated, the vdevice should be propagated to all the TensorStructInfos in this function, ModExpected["func1"] is the expected IR after applying RealizeVDevice. The virtual device information could be propagated across functions, for example in RXMod["caller"] function below, the output of calling RXMod["callee"] is supposed to be on cuda, this information will be propagated to the callee.

class ModBefore:
    def func1(
        x: R.Tensor((2, 3), "float32"), y: R.Tensor((2, 3), "float32")
    ) -> R.Tensor((2, 3), "float32", "cuda:0"):
        a = R.add(x, y)
        b = R.multiply(a, x)
        return b

    def func2(x: R.Tensor((2, 3), "float32"), y: R.Tensor((2, 3), "float32")):
        with R.dataflow():
            s1 = R.add(x, y)
            s2 = R.multiply(s1, s1)
            s3 = R.hint_on_device(s2, tvm.cpu())
            # s3's vdevice is annotated, the expresions preceding s3
            # are expected to have the same vdevice
        return s3

    def callee(x: R.Tensor((2, 3), "float32")):
        r = R.add(x, x)
        return r

    def caller(x: R.Tensor((2, 3), "float32")):
        s1 = gfunc(x)
        s2 = R.hint_on_device(s1, tvm.cuda())
        s3 = R.add(s2, s2)
        return s3

class ModExpected:
    def func1(
        x: R.Tensor((2, 3), "float32", "cuda:0"),
        y: R.Tensor((2, 3), "float32", "cuda:0")
    ) -> R.Tensor((2, 3), "float32", "cuda:0"):
        a: R.Tensor((2, 3), "float32", "cuda:0") = R.add(x, y)
        b: R.Tensor((2, 3), "float32", "cuda:0") = R.multiply(a, x)
        return b

UpdateVDevice pass

The virtual device in global_infos of IRModule can be updated using pass UpdateVDevice.

def UpdateVDevice(new_vdevice: Union[VDevice, List[VDevice], index: int = None
) -> tvm.ir.transform.Pass:
    # if new_vdevice is a list, the existing vdevice will be reset.
    # when the index is not specified, new_vdevice will be appended
    # to vdevice in global_infos. If index is provided, the specific
    # element will be updated. The affected TensorStructInfo in the IR
    # will be updated accordingly


Adding heterogeneous support to Relax will enable it to compile and optimize deep learning workloads across a wide range of devices, leveraging the unique capabilities of each target device.


Hi, happy to see this, I want to know how is it going now?

glad to hear that, its exciting, does any banch been push on? I have’t seen sth about the “Heterogeneous Exec for relax” on https://github.com/mlc-ai/relax.git or https://github.com/apache/tvm.git

Please refer the tracking issue

I like it, just with questions on some of the implications, and potential for future expansion.

  1. For the VDevice class, can we support dynamic vdevice_id?

    Currently, the VDeviceNode::vdevice_id is an integer, which only allows for static dispatch to a fixed number of GPUs. Can this parameter be a PrimExpr instead? This would allow for expression of dynamic dispatch (e.g. an implementation that loops over GPUs, launching a kernel on each). While later lowering passes could remove the dynamism, such as by unrolling the loop over GPUs, it can be convenient to write the initial functions generically.

  2. Is there a representation of the entire heterogeneous target?

    In the RFC, I only saw how the Target and VDevice structures would be used to represent delegation to a specific target/device, but I didn’t see a way to express all the devices available in a heterogeneous setup, which would be useful when determining which devices an optimizer may dispatch work onto.

  3. Can the VDevice structure be exposed for use in TIR as well?

    There are a few places where it would allow for simplification of the existing TIR. For example, the "device_type", "device_id" could be merged into a VDevice, or the BufferNode::scope could hold a virtual device.

  4. Does the IRModule::global_info need to contain the VDevice?

    Since the underlying C++ TensorStructInfo contains the VDevice directly, it looks like the primary purpose is to allow the string shorthand in R.Tensor. If that is the case, why does the IRModule::global_infos need to contain the virtual devices at all?

  5. Is there a difference between specifiying VDevice in a Tensor, or with hint_on_device?

    It looks like there are two ways to annotate the VDevice, either as part of the Tensor type annotation, or with a later R.hint_on_device operator. Is there a semantic difference between the two?

    # Using R.Tensor annotation
    x: R.Tensor(shape, dtype, vdevice) = R.my_op(...)
    # Using hint_on_device
    temp: R.Tensor(shape, dtype) = R.my_op(...)
    x = R.hint_on_device(temp, vdevice)
  6. Can R.hint_on_device be applied to function inputs?

    Since R.hint_on_device alters the interpretation of its inputs, R.hint_on_device(arg, "cuda") would state that the argument arg was already on "cuda" when provided by the caller.

  7. Does the normalization need depend on the global_info?

Having the normalization of an expression depend on non-local information seems likely to surprise users. (e.g. Mutating a tensor would result in normalization of operations using that tensor, which as written would then default to the vdevice in global_info instead of any previously-provided R.Tensor annotations.)

Can the normalization be based on the operator arguments instead? I’m picturing the following rules which would avoid this issue.

  1. If all arguments have the same `VDevice` annotation, then the
     result should have that `VDevice` annotation.

  2. If arguments have inconsistent `VDevice`, (e.g. `R.add(a,b)`,
     where `a` is on `tvm.cuda()` and `b` is on `tvm.cpu()`), then
     raise an error.

  3. If some arguments lack a `VDevice` annotation, then the result
     does not have a `VDevice` annotation.
  1. How does relax.op.to_device handle targets that share a TargetKind?

    Currently, the RFC states that a vm.builtin.to_device will be generated whenever the source and target vdevices are not the same. Because multiple targets may use the same underlying device (e.g. Both "cuda" and "nvptx" generate code that aruns on the kDLCUDA device type), this could result in unnecessary copies. Can we instead generate a to_device only when the vdevice->target->GetTargetDeviceType() or the vdevice->vdevice_id differs?

1 Like