I attempted to tune a dynamic gemm workload using the script detailed below:
import tvm
from tvm.script import tir as T
import tvm.meta_schedule as ms
@T.prim_func
def gemm_dyn_shape(a: T.handle, b: T.handle, c: T.handle):
N = T.int32()
K = T.int32()
M = T.int32()
A = T.match_buffer(a, (N, K))
B = T.match_buffer(b, (K, M))
C = T.match_buffer(c, (N, M))
for i, j, k in T.grid(N, M, K):
with T.block("gemm"):
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
T.reads(A[vi, vk], B[vk, vj])
T.writes(C[vi, vj])
with T.init():
C[vi, vj] = T.float32(0)
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
target = tvm.target.Target("llvm -num-cores=10", host="llvm -num-cores=10")
ms.tune_tir(gemm_dyn_shape, target=target, work_dir="/tmp", max_trials_global=10)
Then, I encountered the following error:
......
9: tvm::meta_schedule::TaskRecord::TaskRecord(tvm::meta_schedule::TuneContext, double)
8: tvm::tir::EstimateTIRFlops(tvm::IRModule const&)
7: _ZZN3tvm3tir11StmtFunctorIFNS0_7TResultERKNS0_4StmtEEE10InitVTableEvENUlRKNS_7runtime9ObjectRefEPS7_
6: non-virtual thunk to tvm::tir::FlopEstimator::VisitStmt_(tvm::tir::BlockRealizeNode const*)
5: _ZZN3tvm3tir11StmtFunctorIFNS0_7TResultERKNS0_4StmtEEE10InitVTableEvENUlRKNS_7runtime9ObjectRefEPS7_
4: tvm::tir::FlopEstimator::VisitStmt_(tvm::tir::ForNode const*)
3: _ZZN3tvm3tir11StmtFunctorIFNS0_7TResultERKNS0_4StmtEEE10InitVTableEvENUlRKNS_7runtime9ObjectRefEPS7_
2: tvm::tir::FlopEstimator::VisitStmt_(tvm::tir::ForNode const*)
1: _ZZN3tvm3tir11StmtFunctorIFNS0_7TResultERKNS0_4StmtEEE10InitVTableEvENUlRKNS_7runtime9ObjectRefEPS7_
0: tvm::tir::FlopEstimator::VisitStmt_(tvm::tir::ForNode const*)
File "/mnt/disk5/wll/code/tvm_unity/src/tir/analysis/estimate_flops.cc", line 143
TVMError:
---------------------------------------------------------------
An error occurred during the execution of TVM.
For more information, please see: https://tvm.apache.org/docs/errors.html
---------------------------------------------------------------
Check failed: (int_imm) is false: TypeError: Expect the extent of a loop to be IntImm, but gets: tir.Var
It seems that the flops estimation need the static shape, To bypass this issue, I used a dummy value in flops estimation and removed the Postproc::DisallowDynamicLoop()
in the default postproc located in postproc.cc
. However, this resulted in a segmentation fault.
I am not sure whether tuning a dynamic shape workload currently is not supported in MetaSchedule or there are some errors within my code.