[RFC] Relax Upstreaming

Hi all,

We sent out an RFC for upstreaming Relax (Relay Next): https://github.com/apache/tvm-rfcs/pull/89. We would love to hear your thoughts and feedback, feel free to review and leave comments on the RFC!

After this upstreaming lands, users can write a Relax program in TVMScript (with dynamic shapes and cross-layer interactions between Relax, TensorIR, and PackedFunc), compile the program to VM executable, and run it on Relax VM, demonstrated by the code snippet below:

import tvm
from tvm import relax
import tvm.script
from tvm.script import relax as R, tir as T

# Relax IRModule written in TVMScript
@tvm.script.ir_module
class MyIRModule:
    # This is a TIR PrimFunc which calls the TIR intrinsic T.exp
    @T.prim_func
    def tir_exp_func(x: T.handle, y: T.handle):
        X = T.match_buffer(x, (n,), "float32")
        Y = T.match_buffer(y, (n,), "float32")
        with T.grid(n) as i:
            Y[i] = T.exp(X[i])

    # This is a Relax function which contains a dataflow block
    # representing a computational graph, as well as a call to an
    # opaque packed function which performs an in-place update to the
    # data in variable gv0.
    @R.function
    def relax_func(x: R.Tensor[(n, k), "float32"], w: R.Tensor[_, "float32"]):
        # n, k above are symbolic variables to represent Tensor dimensions
        with R.dataflow():
            lv0 = R.match_shape(w, (k, m))
            lv1: R.Tensor[(n, m), "float32"] = R.dot(x, lv0)
            lv2: R.Tensor[(n * m,), "float32"] = R.flatten(lv1) 
            lv3: R.Shape = (n * m,)
            gv0 = R.call_tir(tir_exp_func, [lv2], lv3, dtype="float32")
            R.outputs(gv0)

        R.call_packed("custom_inplace_update", gv0)
        return gv0

# Print IRModule with syntax highlighting
MyIRModule.show()

# Build the Relax IRModule
target = tvm.target.Target("llvm")
exec = relax.vm.build(MyIRModule, target)

# Dump the VM executable instructions as text
print(exec.as_text())

# Run the function on Relax VM runtime
vm = relax.VirtualMachine(exec, tvm.cpu())
shape = (2, 3)
data = tvm.nd.array(np.random.rand(*shape).astype(np.float32))
res = vm["relax_func"](data)
1 Like

Excellent! TVM is more and more mature.

2 Likes