For example, When we try to implement a reduction schedule plan which required both inter thread reduction and shared memory cache.
import tvm
from tvm.script import ir as I
from tvm.script import tir as T
@T.prim_func
def main(A: T.Buffer((128, 150528), "float32"), B: T.Buffer((128, 150528), "float32"), C: T.Buffer((128, 128), "float32")):
T.func_attr({"tir.noalias": T.bool(True), "global_symbol": "main"})
# with T.block("root"):
C_local = T.alloc_buffer((128, 128), scope="local")
A_shared = T.alloc_buffer((128, 150528), scope="shared")
B_shared = T.alloc_buffer((128, 150528), scope="shared")
for ax0_0_ax1_0_fused in T.thread_binding(256, thread="blockIdx.x"):
for ax0_1_ax1_1_fused in T.thread_binding(64, thread="threadIdx.y"):
for ax2_1_1_fused in T.thread_binding(2, thread="threadIdx.x"):
for ax2_0 in range(392):
for ax0_ax1_fused_0 in range(6):
for ax0_ax1_fused_1 in T.thread_binding(64, thread="threadIdx.y"):
for ax0_ax1_fused_2 in T.thread_binding(2, thread="threadIdx.x"):
for ax0_ax1_fused_3 in T.serial(4):
with T.block("A_shared"):
v0 = T.axis.spatial(128, ax0_0_ax1_0_fused // 16 * 8 + (ax0_ax1_fused_0 * 512 + ax0_ax1_fused_1 * 8 + ax0_ax1_fused_2 * 4 + ax0_ax1_fused_3) // 384)
v1 = T.axis.spatial(150528, ax2_0 * 384 + (ax0_ax1_fused_0 * 512 + ax0_ax1_fused_1 * 8 + ax0_ax1_fused_2 * 4 + ax0_ax1_fused_3) % 384)
T.reads(A[v0, v1])
T.writes(A_shared[v0, v1])
A_shared[v0, v1] = A[v0, v1]
for ax0_ax1_fused_0 in range(6):
for ax0_ax1_fused_1 in T.thread_binding(64, thread="threadIdx.y"):
for ax0_ax1_fused_2 in T.thread_binding(2, thread="threadIdx.x"):
for ax0_ax1_fused_3 in T.serial(4):
with T.block("B_shared"):
v0 = T.axis.spatial(128, ax0_0_ax1_0_fused % 16 * 8 + (ax0_ax1_fused_0 * 512 + ax0_ax1_fused_1 * 8 + ax0_ax1_fused_2 * 4 + ax0_ax1_fused_3) // 384)
v1 = T.axis.spatial(150528, ax2_0 * 384 + (ax0_ax1_fused_0 * 512 + ax0_ax1_fused_1 * 8 + ax0_ax1_fused_2 * 4 + ax0_ax1_fused_3) % 384)
T.reads(B[v0, v1])
T.writes(B_shared[v0, v1])
B_shared[v0, v1] = B[v0, v1]
for ax2_1_0 in range(192):
with T.block("B"):
v0 = T.axis.spatial(128, ax0_0_ax1_0_fused // 16 * 8 + ax0_1_ax1_1_fused // 8)
v1 = T.axis.spatial(128, ax0_0_ax1_0_fused % 16 * 8 + ax0_1_ax1_1_fused % 8)
v2 = T.axis.reduce(150528, ax2_0 * 384 + ax2_1_0 * 2 + ax2_1_1_fused)
T.reads(A_shared[v0, v2], B_shared[v1, v2])
T.writes(C_local[v0, v1])
with T.init():
C_local[v0, v1] = T.float32(0)
C_local[v0, v1] = C_local[v0, v1] + A_shared[v0, v2] * B_shared[v1, v2]
with T.block("C_local"):
v0 = T.axis.spatial(128, ax0_0_ax1_0_fused // 16 * 8 + ax0_1_ax1_1_fused // 8)
v1 = T.axis.spatial(128, ax0_0_ax1_0_fused % 16 * 8 + ax0_1_ax1_1_fused % 8)
T.reads(C_local[v0, v1])
T.writes(C[v0, v1])
C[v0, v1] = C_local[v0, v1]
mod = tvm.IRModule.from_expr(main)
sch = tvm.tir.Schedule(mod, debug_mask="all")
rt_mod = tvm.build(sch.mod, target="cuda")
it will fail, the error message:
Traceback (most recent call last):
File "/home/t-leiwang/mlc_workspace/unity/python/tvm/dlight/test_tir.py", line 56, in <module>
dense_relu_0_rt_mod = tvm.build(sch.mod, target="cuda")
File "/home/t-leiwang/ladder_workspace/LadderTVM/python/tvm/driver/build_module.py", line 281, in build
rt_mod_host = _driver_ffi.tir_to_runtime(annotated_mods, target_host)
File "/home/t-leiwang/ladder_workspace/LadderTVM/python/tvm/_ffi/_ctypes/packed_func.py", line 237, in __call__
raise get_last_ffi_error()
tvm._ffi.base.TVMError: Traceback (most recent call last):
12: TVMFuncCall
11: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::runtime::Module (tvm::runtime::Map<tvm::Target, tvm::IRModule, void, void> const&, tvm::Target)>::AssignTypedLambda<tvm::{lambda(tvm::runtime::Map<tvm::Target, tvm::IRModule, void, void> const&, tvm::Target)#6}>(tvm::{lambda(tvm::runtime::Map<tvm::Target, tvm::IRModule, void, void> const&, tvm::Target)#6}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, tvm::runtime::TVMRetValue)
10: tvm::TIRToRuntime(tvm::runtime::Map<tvm::Target, tvm::IRModule, void, void> const&, tvm::Target const&)
9: tvm::SplitMixedModule(tvm::IRModule, tvm::Target const&, tvm::Target const&)
8: tvm::ApplyPasses(tvm::IRModule, tvm::transform::Sequential)
7: tvm::transform::Pass::operator()(tvm::IRModule) const
6: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
5: tvm::transform::SequentialNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
4: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
3: tvm::transform::ModulePassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
2: _ZN3tvm7runtime13PackedFuncObj9ExtractorINS0_16PackedFuncSubObjIZNS0_15TypedPackedFuncIFNS_8IRModuleES5_NS_9transform11PassContextEEE17AssignTypedLambdaIZNS_3tir9transform13MakePackedAPIEvEUlS5_S7_E_EEvT_EUlRKNS0_7TVMArgsEPNS0_11TVMRetValueEE_EEE4CallEPKS1_SF_SJ_
1: tvm::tir::transform::MakePackedAPI()::{lambda(tvm::IRModule, tvm::transform::PassContext)#1}::operator()(tvm::IRModule, tvm::transform::PassContext) const [clone .isra.0]
0: tvm::tir::MakePackedAPI(tvm::tir::PrimFunc&&)
File "/transforms/make_packed_api.cc", line 295
TVMError: Not all Vars are passed in api_args: 'ax0_ax1_fused_1' 'ax0_ax1_fused_2' 'ax0_ax1_fused_1' 'ax0_ax1_fused_2' is not bound to any variables
However, the same schedule functions correctlty under tensor expression