Does MetaSchedule support tuning dynamic shape workload now?

I attempted to tune a dynamic gemm workload using the script detailed below:

import tvm
from tvm.script import tir as T
import tvm.meta_schedule as ms

@T.prim_func
def gemm_dyn_shape(a: T.handle, b: T.handle, c: T.handle):
    N = T.int32()
    K = T.int32()
    M = T.int32()
    A = T.match_buffer(a, (N, K))
    B = T.match_buffer(b, (K, M))
    C = T.match_buffer(c, (N, M))
    for i, j, k in T.grid(N, M, K):
        with T.block("gemm"):
            vi, vj, vk = T.axis.remap("SSR", [i, j, k])
            T.reads(A[vi, vk], B[vk, vj])
            T.writes(C[vi, vj])
            with T.init():
                C[vi, vj] = T.float32(0)
            C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]

target = tvm.target.Target("llvm -num-cores=10", host="llvm -num-cores=10")
ms.tune_tir(gemm_dyn_shape, target=target, work_dir="/tmp", max_trials_global=10)

Then, I encountered the following error:

  ......
  9: tvm::meta_schedule::TaskRecord::TaskRecord(tvm::meta_schedule::TuneContext, double)
  8: tvm::tir::EstimateTIRFlops(tvm::IRModule const&)
  7: _ZZN3tvm3tir11StmtFunctorIFNS0_7TResultERKNS0_4StmtEEE10InitVTableEvENUlRKNS_7runtime9ObjectRefEPS7_
  6: non-virtual thunk to tvm::tir::FlopEstimator::VisitStmt_(tvm::tir::BlockRealizeNode const*)
  5: _ZZN3tvm3tir11StmtFunctorIFNS0_7TResultERKNS0_4StmtEEE10InitVTableEvENUlRKNS_7runtime9ObjectRefEPS7_
  4: tvm::tir::FlopEstimator::VisitStmt_(tvm::tir::ForNode const*)
  3: _ZZN3tvm3tir11StmtFunctorIFNS0_7TResultERKNS0_4StmtEEE10InitVTableEvENUlRKNS_7runtime9ObjectRefEPS7_
  2: tvm::tir::FlopEstimator::VisitStmt_(tvm::tir::ForNode const*)
  1: _ZZN3tvm3tir11StmtFunctorIFNS0_7TResultERKNS0_4StmtEEE10InitVTableEvENUlRKNS_7runtime9ObjectRefEPS7_
  0: tvm::tir::FlopEstimator::VisitStmt_(tvm::tir::ForNode const*)
  File "/mnt/disk5/wll/code/tvm_unity/src/tir/analysis/estimate_flops.cc", line 143
TVMError:
---------------------------------------------------------------
An error occurred during the execution of TVM.
For more information, please see: https://tvm.apache.org/docs/errors.html
---------------------------------------------------------------
  Check failed: (int_imm) is false: TypeError: Expect the extent of a loop to be IntImm, but gets: tir.Var

It seems that the flops estimation need the static shape, To bypass this issue, I used a dummy value in flops estimation and removed the Postproc::DisallowDynamicLoop() in the default postproc located in postproc.cc . However, this resulted in a segmentation fault.

I am not sure whether tuning a dynamic shape workload currently is not supported in MetaSchedule or there are some errors within my code.

MetaSchedule is designed for static shape workloads, it’s not designed for dynamic shapes. Supporting dynamic shape tuning is under discussion.

1 Like