Introduction
The goal of this design document is to add heterogeneous execution to Relax, enabling it to compile and optimize deep learning workloads across multiple types of devices, such as CPUs, GPUs, and specialized accelerators. By enabling heterogeneous execution, we can take advantage of the unique features and capabilities of each target device.
Proposed Design
The proposed design for heterogeneous support in Relax consists of the following components:
Target Abstraction
We will reuse the existing Target
.
Device Abstraction
Virtual Device
VDevice
, a subclass of GlobalInfo
, will be introduced, it denotes the data storage representation during compilation and outlines how to compile and compute it.
class VDeviceNode : public GlobalInfoNode {
public:
Target target;
int vdevice_id; // The virtual device id, this enables us to
// differentiate between distinct devices with
// same Target, such as multiple GPUs. It might
// be changed during runtime
MemoryScope memory_scope;
}
class VDevice : public GlobalInfo {
public:
TVM_DEFINE_OBJECT_REF_METHODS(VDevice, GlobalInfo, VDeviceNode);
};
# the corresponding Python Binding
class VDevice(GlobalInfo):
def __init__(self,
target: Union[str, dict, Target],
vdevice_id: int = 0,
memory_scope: str = "global") -> None:
To help create VDevice
in the IR, we will introduce a new syntactic sugar, R.VDevice
. All virtual devices should be defined and added into the global_infos
of IRModule
using I.module_global_infos({"vdevice": vdevice_list})
.
# python/tvm/script/parser/relax/entry.py
# R.VDevice
def VDevice(
target: Union[str, dict, Target],
vdevice_id: int = 0,
memory_scope: str = "global"
)
TensorStructInfo
A new member Optional<VDevice> vdevice
will be added into TensorStructInfoNode
. This indicates where the tensor is expected to be executed.
class TensorStructInfoNode : public StructInfoNode {
public:
Optional<VDevice> vdevice; // virtual device
To help users to annotate the expression with TensorStructInfo, we will extendR.Tensor
to have an additional optional vdevice
parameter, the value of vdevice
could be either in the format "device_type:i"
, which represents the i-th vdevice of a specific device type, or "vdevice:j"
, which refers to the j-th element within the vdevice
list defined in the global_infos
. The i
is optional, and default value 0 will be applied if it is not specified.
# R.Tensor
def Tensor(
shape: Optional[Union[List[Union[PrimExpr, str]], ShapeExpr]] = None,
dtype: Optional[str] = None,
vdevice: Optional[str] = None, # the value could be "cuda",
# "cuda:0", or "vdevice:0"
ndim: int = -1,
)
The following is an example of what the IR would look like. The "vdevice:1"
refers to the second virtual device defined in global_infos. The"cuda:0"
corresponds to the first virtual device with cuda
device type, by looking up the vdevice using util.lookup_vdevice
, the matched element should be the second one. An error will be raised if no matching virtual device is found.
from tvm.script.parser import ir as I
from tvm.script.parser import relax as R
vdevice = [
R.VDevice("llvm"),
R.VDevice("cuda", 0),
R.VDevice("cuda -arch=sm_80", 0),
R.VDevice("metal", 0, "global"),
]
@I.ir_module
class Module:
I.module_global_infos({"vdevice": vdevice})
@R.function
def foo(
a: R.Tensor((2, 3), "float32", "cuda"),
b: R.Tensor((2, 3), "float32", "cuda:0"),
c: R.Tensor((2, 3), "float32", "vdevice:1")):
s1 = R.add(a, b)
s = R.add(s1, c)
return s
# use module pass `UpdateVDevice` to update a specific virtual device,
# append a new one, or reset them all
Device Collaboration
To optimize workloads that involve multiple devices, Relax will provide mechanisms for device collaboration, such as data copy between devices.
New operator hint_on_device
A special annotation op R.hint_on_device
to hint the input
expression should be executed on the specific device, this hint will be used by RealizeVDevice
to propagate virtual device information across the entire IR. Error will be reported if virtual device conflict is found.
def hint_on_device(input: relax.Expr, device: Union[str, Device]) -> relax.Expr:
"""
Parameters
----------
input : relax.Expr
The expression to be annotated.
device : Union[Device, str]
The device to annotate with.
Returns
-------
result : relax.Expr
...
a = R.hint_on_device(x, tvm.cpu())
b = R.hint_on_device(y, tvm.cuda()) # b is on different device from a
c = R.add(a, b) # Error occurs, conflict is found
New Relax operator to_vdevice(input, tgt_vdevice)
Copy the input expression to a specific target VDevice
. This operator is considered as pure, no in-place operation happens, it is allowed to appear in DataFlow block. The operator pattern of to_vdevice
is OpPatternKind::kOpaque
, it should not be fused by FuseOps
.
@R.function
def foo(x: R.Tensor((2, 3), "float32"), y: R.Tensor((2, 3), "float32")):
with R.dataflow():
x1 = R.hint_on_device(x, tvm.cpu())
y1 = R.hint_on_device(y, tvm.cuda())
s1 = R.add(x1, x1)
s2 = R.to_vdevice(s1, R.Tensor("cuda:0"))
s = R.add(y1, s2)
R.output(s)
return s
VM builtin PackedFunc to_device
We will introduce a new PackedFunc, vm.builtin.to_device
, to help copy data cross different devices, it has two arguments: {input
, tgt_vdevice
}, the first one is the input data, tgt_vdevice
is the target vdevice to which copy the input data. TheVMCodeGen
will be updated to emit vm.builtin.to_device
wherever the relax.op.to_device
is encountered, and the source and target vdevices are not the same.
Virtual Device Passes
The StructInfo rules about virtual devices will be enforced in Normalizer. Some additional cases, such as backward propagation, will be implemented in a separate pass called RealizeVDevice
.
InferStructInfo
in Normalize
The helper function InferStructInfo
in block_builder will be extended to deduce the forward virtual device information. If the virtual device is determined, the information will be propagated forward throughout the Relax IR. In order to simplify the deduction, the virtual device is required to be either fully constrained or fully unconstrained. If it is not fully constrained, the default virtual device, which is the first element defined in vdevice of global_infos, will be applied. The last element of vdevice serves as the host virtual device if not explicitly specified. For instance, in the following example, the ‘after’ represents the anticipated program of the ‘before’ function after normalization.
@R.function
def before(
x: R.Tensor((2, 3), "float32", "cuda"),
y: R.Tensor((2, 3), "float32", "cuda"),
) -> R.Tensor((2, 3), "float32"):
a = R.add(x, y)
b = R.multiply(a, x)
return b
@R.function
def after(
x: R.Tensor((2, 3), "float32", "cuda"),
y: R.Tensor((2, 3), "float32", "cuda")
) -> R.Tensor((2, 3), "float32", "cuda"):
a: R.Tensor((2, 3), "float32", "cuda") = R.add(x, y)
b: R.Tensor((2, 3), "float32", "cuda") = R.multiply(a, x)
return b
RealizeVDevice
pass
We are introducing a new pass called RealizeVDevice
to help handle cases that InferStructInfo
may have missed, such as backward propagation. The hint_on_device
and to_vdevice
will be used used to help propagate information about virtual devices across the entire IR. As shown in the following code, the return of function ModBefore["func1"]
has vdevice annotated, the vdevice should be propagated to all the TensorStructInfos in this function, ModExpected["func1"]
is the expected IR after applying RealizeVDevice
. The virtual device information could be propagated across functions, for example in RXMod["caller"]
function below, the output of calling RXMod["callee"]
is supposed to be on cuda, this information will be propagated to the callee.
@I.ir_module
class ModBefore:
@R.function
def func1(
x: R.Tensor((2, 3), "float32"), y: R.Tensor((2, 3), "float32")
) -> R.Tensor((2, 3), "float32", "cuda:0"):
a = R.add(x, y)
b = R.multiply(a, x)
return b
@R.function
def func2(x: R.Tensor((2, 3), "float32"), y: R.Tensor((2, 3), "float32")):
with R.dataflow():
s1 = R.add(x, y)
s2 = R.multiply(s1, s1)
s3 = R.hint_on_device(s2, tvm.cpu())
# s3's vdevice is annotated, the expresions preceding s3
# are expected to have the same vdevice
R.output(s3)
return s3
@R.function
def callee(x: R.Tensor((2, 3), "float32")):
r = R.add(x, x)
return r
@R.function
def caller(x: R.Tensor((2, 3), "float32")):
s1 = gfunc(x)
s2 = R.hint_on_device(s1, tvm.cuda())
s3 = R.add(s2, s2)
return s3
@I.ir_module
class ModExpected:
@R.function
def func1(
x: R.Tensor((2, 3), "float32", "cuda:0"),
y: R.Tensor((2, 3), "float32", "cuda:0")
) -> R.Tensor((2, 3), "float32", "cuda:0"):
a: R.Tensor((2, 3), "float32", "cuda:0") = R.add(x, y)
b: R.Tensor((2, 3), "float32", "cuda:0") = R.multiply(a, x)
return b
...
UpdateVDevice
pass
The virtual device in global_infos of IRModule can be updated using pass UpdateVDevice
.
def UpdateVDevice(new_vdevice: Union[VDevice, List[VDevice], index: int = None
) -> tvm.ir.transform.Pass:
# if new_vdevice is a list, the existing vdevice will be reset.
# when the index is not specified, new_vdevice will be appended
# to vdevice in global_infos. If index is provided, the specific
# element will be updated. The affected TensorStructInfo in the IR
# will be updated accordingly
Conclusion
Adding heterogeneous support to Relax will enable it to compile and optimize deep learning workloads across a wide range of devices, leveraging the unique capabilities of each target device.