How many Relax ops which support Dynamic shape now?

It seems only limited relax ops that support dynamic shape now? such as dynamic_strided_slice. How to use dynamic shape in relax? - Development / unity - Apache TVM Discuss

I’m focusing on the one of the import Relax feature:dynamic shape inferring.

I have searched the relevant tests, only got one test_e2e_op_dynamic: tvm/tests/python/relax/test_e2e_op_dynamic.py at unity · apache/tvm (github.com)

I tried to build a matmul IRModule and conform it by LegalizeOps(),but there is no “shape_func” likely PrimFunc in the optimized module, the follow is my testing code:

import tvm
from tvm import relax, topi
from tvm.ir.module import IRModule
from tvm.script import relax as R
from tvm.script import tir as T

@tvm.script.ir_module
class DynMatModule:
    @R.function
    def main(x: R.Tensor(("m", "n"), "float32"), w: R.Tensor((12, 24), "float32")) -> R.Tensor("float32", ndim=2):
        m = T.int64()
        n = T.int64()
        with R.dataflow():
            gv: R.Tensor("float32", ndim=2) = R.matmul(x, w)
            R.output(gv)

        return gv

mat_mod = LegalizeOps()(DynMatModule)
print(mat_mod.script())

and the Result:

# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R
@I.ir_module
class Module:
    @R.function
    def main(x: R.Tensor(("m", "n"), dtype="float32"), w: R.Tensor((12, 24), dtype="float32")) -> 
R.Tensor(dtype="float32", ndim=2):
        m = T.int64()
        n = T.int64()
        with R.dataflow():
            gv: R.Tensor(dtype="float32", ndim=2) = R.matmul(x, w, out_dtype="void")
            R.output(gv)
        return gv

It seems lack of regular to auto-generate “shape_func” I guess,If I understand correctly, Is there any manual or tutorial which can tell Users how to construct the user-defined “shape_func” ?

I’m not sure what’s your meaning of shape_func. But in this case, it’s not “legal” example. (to be specific, it’s supported by relax but not a common usage). Here, the input shape is m * n, while the weight is 12 * 24. That implicit indicates that n == 12 if it’s legal.

In common usage, please change the n to 12 directly

I fixed the clerical error, but the result are same.

x: R.Tensor((“m”, 12), “float32”)

# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R
@I.ir_module
class Module:
    @R.function
    def main(x: R.Tensor(("m", 12), dtype="float32"), w: R.Tensor((12, 24), dtype="float32")) -> 
R.Tensor(dtype="float32", ndim=2):
        m = T.int64()
        with R.dataflow():
            gv: R.Tensor(dtype="float32", ndim=2) = R.matmul(x, w, out_dtype="void")
            R.output(gv)
        return gv

the shape_func is to calculate the output shape in runtime, Let’s taking the DynamicStridedSlice as an example,If you run:

@tvm.script.ir_module
class DynamicStridedSlice:
@R.function
def main(x: R.Tensor(("m", "n", 10, 10), "float32"), begin: R.Tensor((4,),"int64"), end: R.Tensor((4,),"int64"), strides: R.Tensor((4,),"int64")) -> R.Tensor("float32", ndim=4):
    m = T.int64()
    n = T.int64()
    gv: R.Tensor("float32", ndim=4) = R.dynamic_strided_slice(x, begin, end, strides)
    return gv
# fmt: on
slice_mod = LegalizeOps()(DynamicStridedSlice)
print(slice_mod.script()

you will get IRModule with “shape_func” likely PrimFunc:

@I.ir_module
class Module:
@T.prim_func
def dynamic_strided_slice(var_A: T.handle, B: T.Buffer((T.int64(4),), "int64"), C: T.Buffer((T.int64(4),), "int64"), D: T.Buffer((T.int64(4),), "int64"), var_T_strided_slice_dynamic: T.handle):
    T.func_attr({"tir.noalias": T.bool(True)})
    m, n = T.int64(), T.int64()
    A = T.match_buffer(var_A, (m, n, T.int64(10), T.int64(10)))
    s, s_1, s_2, s_3 = T.int64(), T.int64(), T.int64(), T.int64()
    T_strided_slice_dynamic = T.match_buffer(var_T_strided_slice_dynamic, (s, s_1, s_2, s_3))
    # with T.block("root"):
    for ax0, ax1, ax2, ax3 in T.grid(s, s_1, s_2, s_3):
        with T.block("T_strided_slice_dynamic"):
            v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
            T.reads(A[T.min(B[T.int64(0)], m - T.int64(1)) + v_ax0 * D[T.int64(0)], T.min(B[T.int64(1)], n - T.int64(1)) + v_ax1 * D[T.int64(1)], T.min(B[T.int64(2)], T.int64(9)) + v_ax2 * D[T.int64(2)], T.min(B[T.int64(3)], T.int64(9)) + v_ax3 * D[T.int64(3)]], B[T.int64(0):T.int64(4)], D[T.int64(0):T.int64(4)])
            T.writes(T_strided_slice_dynamic[v_ax0, v_ax1, v_ax2, v_ax3])
            T_strided_slice_dynamic[v_ax0, v_ax1, v_ax2, v_ax3] = A[T.min(B[T.int64(0)], m - T.int64(1)) + v_ax0 * D[T.int64(0)], T.min(B[T.int64(1)], n - T.int64(1)) + v_ax1 * D[T.int64(1)], T.min(B[T.int64(2)], T.int64(9)) + v_ax2 * D[T.int64(2)], T.min(B[T.int64(3)], T.int64(9)) + v_ax3 * D[T.int64(3)]]
@T.prim_func
def shape_func(var_A: T.handle, B: T.Buffer((T.int64(4),), "int64"), C: T.Buffer((T.int64(4),), "int64"), D: T.Buffer((T.int64(4),), "int64"), T_shape_func_strided_slice_dynamic: T.Buffer((T.int64(4),), "int64")):
    T.func_attr({"tir.noalias": T.bool(True)})
    m, n = T.int64(), T.int64()
    A = T.match_buffer(var_A, (m, n, T.int64(10), T.int64(10)))
    # with T.block("root"):
    for i in range(T.int64(4)):
        with T.block("T_shape_func_strided_slice_dynamic"):
            v_i = T.axis.spatial(T.int64(4), i)
            T.reads(D[v_i], B[v_i], C[v_i])
            T.writes(T_shape_func_strided_slice_dynamic[v_i])
            T_shape_func_strided_slice_dynamic[v_i] = T.Select(D[v_i] < T.int64(0), (T.min(T.max(T.Select(B[v_i] < T.int64(0), B[v_i] + T.Select(v_i == T.int64(3), T.int64(10), T.Select(v_i == T.int64(2), T.int64(10), T.Select(v_i == T.int64(1), n, T.Select(v_i == T.int64(0), m, T.int64(-1))))), B[v_i]), T.int64(-1)), T.Select(v_i == T.int64(3), T.int64(10), T.Select(v_i == T.int64(2), T.int64(10), T.Select(v_i == T.int64(1), n, T.Select(v_i == T.int64(0), m, T.int64(-1))))) - T.int64(1)) - T.min(T.max(T.Select(C[v_i] < T.int64(0), C[v_i] + T.Select(v_i == T.int64(3), T.int64(10), T.Select(v_i == T.int64(2), T.int64(10), T.Select(v_i == T.int64(1), n, T.Select(v_i == T.int64(0), m, T.int64(-1))))), C[v_i]), T.int64(-1)), T.Select(v_i == T.int64(3), T.int64(10), T.Select(v_i == T.int64(2), T.int64(10), T.Select(v_i == T.int64(1), n, T.Select(v_i == T.int64(0), m, T.int64(-1))))) - T.int64(1)) - D[v_i] - T.int64(1)) // (D[v_i] * T.int64(-1)), (T.min(T.max(T.Select(C[v_i] < T.int64(0), C[v_i] + T.Select(v_i == T.int64(3), T.int64(10), T.Select(v_i == T.int64(2), T.int64(10), T.Select(v_i == T.int64(1), n, T.Select(v_i == T.int64(0), m, T.int64(-1))))), C[v_i]), T.int64(0)), T.Select(v_i == T.int64(3), T.int64(10), T.Select(v_i == T.int64(2), T.int64(10), T.Select(v_i == T.int64(1), n, T.Select(v_i == T.int64(0), m, T.int64(-1)))))) + D[v_i] - T.min(T.max(T.Select(B[v_i] < T.int64(0), B[v_i] + T.Select(v_i == T.int64(3), T.int64(10), T.Select(v_i == T.int64(2), T.int64(10), T.Select(v_i == T.int64(1), n, T.Select(v_i == T.int64(0), m, T.int64(-1))))), B[v_i]), T.int64(0)), T.Select(v_i == T.int64(3), T.int64(10), T.Select(v_i == T.int64(2), T.int64(10), T.Select(v_i == T.int64(1), n, T.Select(v_i == T.int64(0), m, T.int64(-1)))))) - T.int64(1)) // D[v_i])
@R.function
def main(x: R.Tensor(("m", "n", 10, 10), dtype="float32"), begin: R.Tensor((4,), dtype="int64"), end: R.Tensor((4,), dtype="int64"), strides: R.Tensor((4,), dtype="int64")) -> R.Tensor(dtype="float32", ndim=4):
    m = T.int64()
    n = T.int64()
    cls = Module
    s = T.int64()
    s_1 = T.int64()
    s_2 = T.int64()
    s_3 = T.int64()
    gv = R.call_tir(cls.shape_func, (x, begin, end, strides), out_sinfo=R.Tensor((4,), dtype="int64"))
    gv1: R.Shape(ndim=4) = R.call_pure_packed("vm.builtin.tensor_to_shape", gv, sinfo_args=(R.Shape(ndim=4),))
    gv2: R.Shape([s, s_1, s_2, s_3]) = R.match_cast(gv1, R.Shape([s, s_1, s_2, s_3]))
    gv_1 = R.call_tir(cls.dynamic_strided_slice, (x, begin, end, strides), out_sinfo=R.Tensor((s, s_1, s_2, s_3), dtype="float32"))
    return gv_1

I’m not sure and just state my understand. Both matmul and dynamic_strided_slice support dynamic shape. Relax registers FInferStructInfo functions for ops.

  1. The shape infer of matmul should be more direct in your example, because the symbolic shape won’t rely on the runtime state. You can simply apply the FoldConstant pass and the output shape can be inferred by the function (m,12). Then the IR can be further lowered to TIR primfunc.

  2. The output shape of dynamic_strided_slice cannot be obtained during the compilation because begin, end, and strides are unknown. If you look into the registered legalize_ops, TVM inserts the shape_func explicitly to determine the output shape at runtime.

I hope this could help.