the shape_func is to calculate the output shape in runtime, Let’s taking the DynamicStridedSlice as an example,If you run:
@tvm.script.ir_module
class DynamicStridedSlice:
@R.function
def main(x: R.Tensor(("m", "n", 10, 10), "float32"), begin: R.Tensor((4,),"int64"), end: R.Tensor((4,),"int64"), strides: R.Tensor((4,),"int64")) -> R.Tensor("float32", ndim=4):
m = T.int64()
n = T.int64()
gv: R.Tensor("float32", ndim=4) = R.dynamic_strided_slice(x, begin, end, strides)
return gv
# fmt: on
slice_mod = LegalizeOps()(DynamicStridedSlice)
print(slice_mod.script()
you will get IRModule with “shape_func” likely PrimFunc:
@I.ir_module
class Module:
@T.prim_func
def dynamic_strided_slice(var_A: T.handle, B: T.Buffer((T.int64(4),), "int64"), C: T.Buffer((T.int64(4),), "int64"), D: T.Buffer((T.int64(4),), "int64"), var_T_strided_slice_dynamic: T.handle):
T.func_attr({"tir.noalias": T.bool(True)})
m, n = T.int64(), T.int64()
A = T.match_buffer(var_A, (m, n, T.int64(10), T.int64(10)))
s, s_1, s_2, s_3 = T.int64(), T.int64(), T.int64(), T.int64()
T_strided_slice_dynamic = T.match_buffer(var_T_strided_slice_dynamic, (s, s_1, s_2, s_3))
# with T.block("root"):
for ax0, ax1, ax2, ax3 in T.grid(s, s_1, s_2, s_3):
with T.block("T_strided_slice_dynamic"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(A[T.min(B[T.int64(0)], m - T.int64(1)) + v_ax0 * D[T.int64(0)], T.min(B[T.int64(1)], n - T.int64(1)) + v_ax1 * D[T.int64(1)], T.min(B[T.int64(2)], T.int64(9)) + v_ax2 * D[T.int64(2)], T.min(B[T.int64(3)], T.int64(9)) + v_ax3 * D[T.int64(3)]], B[T.int64(0):T.int64(4)], D[T.int64(0):T.int64(4)])
T.writes(T_strided_slice_dynamic[v_ax0, v_ax1, v_ax2, v_ax3])
T_strided_slice_dynamic[v_ax0, v_ax1, v_ax2, v_ax3] = A[T.min(B[T.int64(0)], m - T.int64(1)) + v_ax0 * D[T.int64(0)], T.min(B[T.int64(1)], n - T.int64(1)) + v_ax1 * D[T.int64(1)], T.min(B[T.int64(2)], T.int64(9)) + v_ax2 * D[T.int64(2)], T.min(B[T.int64(3)], T.int64(9)) + v_ax3 * D[T.int64(3)]]
@T.prim_func
def shape_func(var_A: T.handle, B: T.Buffer((T.int64(4),), "int64"), C: T.Buffer((T.int64(4),), "int64"), D: T.Buffer((T.int64(4),), "int64"), T_shape_func_strided_slice_dynamic: T.Buffer((T.int64(4),), "int64")):
T.func_attr({"tir.noalias": T.bool(True)})
m, n = T.int64(), T.int64()
A = T.match_buffer(var_A, (m, n, T.int64(10), T.int64(10)))
# with T.block("root"):
for i in range(T.int64(4)):
with T.block("T_shape_func_strided_slice_dynamic"):
v_i = T.axis.spatial(T.int64(4), i)
T.reads(D[v_i], B[v_i], C[v_i])
T.writes(T_shape_func_strided_slice_dynamic[v_i])
T_shape_func_strided_slice_dynamic[v_i] = T.Select(D[v_i] < T.int64(0), (T.min(T.max(T.Select(B[v_i] < T.int64(0), B[v_i] + T.Select(v_i == T.int64(3), T.int64(10), T.Select(v_i == T.int64(2), T.int64(10), T.Select(v_i == T.int64(1), n, T.Select(v_i == T.int64(0), m, T.int64(-1))))), B[v_i]), T.int64(-1)), T.Select(v_i == T.int64(3), T.int64(10), T.Select(v_i == T.int64(2), T.int64(10), T.Select(v_i == T.int64(1), n, T.Select(v_i == T.int64(0), m, T.int64(-1))))) - T.int64(1)) - T.min(T.max(T.Select(C[v_i] < T.int64(0), C[v_i] + T.Select(v_i == T.int64(3), T.int64(10), T.Select(v_i == T.int64(2), T.int64(10), T.Select(v_i == T.int64(1), n, T.Select(v_i == T.int64(0), m, T.int64(-1))))), C[v_i]), T.int64(-1)), T.Select(v_i == T.int64(3), T.int64(10), T.Select(v_i == T.int64(2), T.int64(10), T.Select(v_i == T.int64(1), n, T.Select(v_i == T.int64(0), m, T.int64(-1))))) - T.int64(1)) - D[v_i] - T.int64(1)) // (D[v_i] * T.int64(-1)), (T.min(T.max(T.Select(C[v_i] < T.int64(0), C[v_i] + T.Select(v_i == T.int64(3), T.int64(10), T.Select(v_i == T.int64(2), T.int64(10), T.Select(v_i == T.int64(1), n, T.Select(v_i == T.int64(0), m, T.int64(-1))))), C[v_i]), T.int64(0)), T.Select(v_i == T.int64(3), T.int64(10), T.Select(v_i == T.int64(2), T.int64(10), T.Select(v_i == T.int64(1), n, T.Select(v_i == T.int64(0), m, T.int64(-1)))))) + D[v_i] - T.min(T.max(T.Select(B[v_i] < T.int64(0), B[v_i] + T.Select(v_i == T.int64(3), T.int64(10), T.Select(v_i == T.int64(2), T.int64(10), T.Select(v_i == T.int64(1), n, T.Select(v_i == T.int64(0), m, T.int64(-1))))), B[v_i]), T.int64(0)), T.Select(v_i == T.int64(3), T.int64(10), T.Select(v_i == T.int64(2), T.int64(10), T.Select(v_i == T.int64(1), n, T.Select(v_i == T.int64(0), m, T.int64(-1)))))) - T.int64(1)) // D[v_i])
@R.function
def main(x: R.Tensor(("m", "n", 10, 10), dtype="float32"), begin: R.Tensor((4,), dtype="int64"), end: R.Tensor((4,), dtype="int64"), strides: R.Tensor((4,), dtype="int64")) -> R.Tensor(dtype="float32", ndim=4):
m = T.int64()
n = T.int64()
cls = Module
s = T.int64()
s_1 = T.int64()
s_2 = T.int64()
s_3 = T.int64()
gv = R.call_tir(cls.shape_func, (x, begin, end, strides), out_sinfo=R.Tensor((4,), dtype="int64"))
gv1: R.Shape(ndim=4) = R.call_pure_packed("vm.builtin.tensor_to_shape", gv, sinfo_args=(R.Shape(ndim=4),))
gv2: R.Shape([s, s_1, s_2, s_3]) = R.match_cast(gv1, R.Shape([s, s_1, s_2, s_3]))
gv_1 = R.call_tir(cls.dynamic_strided_slice, (x, begin, end, strides), out_sinfo=R.Tensor((s, s_1, s_2, s_3), dtype="float32"))
return gv_1