Conflict free shared memory permutation in tensorir

To get the best gemm/conv performance with gpu tensorcore, we need to do memory permutation to avoid shared memory bank conflict in most case, while current sota library like cutlass takes the following permutation strategy to avoid bank conflict of shared memory store and load:

as described here, which is very carefully designed, I take few days to implement a cuda version of this permutation and achieve same performance with cutlass.

In new tvm and tensor ir, I found that there is a new schedule named transform_layout can cover this permutation stage, but I failed in re-implemetation because it seems like tensorir script currently do not support some operators like << and ^.

The schedule code is:

    block_read = sch.cache_read(block, idx, "shared")

    def shared_conflict_free_permutation(i, j):
        def shared_8x64_to_free_4x128_layout(i, j):
            elements_per_thread = 16
            element_id = (i << 6) + j
            lane_id = element_id // elements_per_thread
            element_id_in_thread = element_id % elements_per_thread
            shared_c = lane_id % 8
            shared_s = lane_id // 8
            shared_row = (shared_c & 1) | ((shared_c >> 1) & 2)
            shared_col = ((shared_c << 1) & 4) | shared_s ^ shared_row
            return (shared_row, shared_col * elements_per_thread + element_id_in_thread)
        return (i // 8, j // 64, *shared_8x64_to_free_4x128_layout(i % 8, j % 64))

    sch.transform_layout(block_read, ("write", 0), shared_conflict_free_permutation)

and I got the following error message:

Traceback (most recent call last):
  File "/workspace/v-leiwang3/tvm_gpu_gemm/schedule/tensorize_mma/mma_i32_conflict_free.py", line 163, in <module>
    block_shared_A = fetch_to_shared(block_b, 0)
  File "/workspace/v-leiwang3/tvm_gpu_gemm/schedule/tensorize_mma/mma_i32_conflict_free.py", line 150, in fetch_to_shared
    sch.transform_layout(block_read, ("write", 0), shared_conflict_free_permutation)
  File "/workspace/v-leiwang3/tvm/python/tvm/tir/schedule/_type_checker.py", line 338, in wrap
    return func(*args, **kwargs)
  File "/workspace/v-leiwang3/tvm/python/tvm/tir/schedule/schedule.py", line 2764, in transform_layout
    self, block, buffer_index, buffer_index_type_enum, index_map, pad_value
  File "/workspace/v-leiwang3/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 237, in __call__
    raise get_last_ffi_error()
tvm._ffi.base.TVMError: Traceback (most recent call last):
  77: TVMFuncCall
        at /workspace/v-leiwang3/tvm/src/runtime/c_runtime_api.cc:477
  76: tvm::runtime::PackedFuncObj::CallPacked(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*) const
        at /workspace/v-leiwang3/tvm/include/tvm/runtime/packed_func.h:1217
  75: Call
        at /workspace/v-leiwang3/tvm/include/tvm/runtime/packed_func.h:1213
  74: operator()
        at /workspace/v-leiwang3/tvm/include/tvm/runtime/packed_func.h:1731
  73: unpack_call<void, 6, tvm::tir::<lambda(tvm::tir::Schedule, const tvm::tir::BlockRV&, int, int, const tvm::tir::IndexMap&, const tvm::runtime::Optional<tvm::tir::IndexMap>&)> >
        at /workspace/v-leiwang3/tvm/include/tvm/runtime/packed_func.h:1671
  72: run<>
        at /workspace/v-leiwang3/tvm/include/tvm/runtime/packed_func.h:1631
  71: run<tvm::runtime::TVMMovableArgValueWithContext_>
        at /workspace/v-leiwang3/tvm/include/tvm/runtime/packed_func.h:1631
  70: run<tvm::runtime::TVMMovableArgValueWithContext_, tvm::runtime::TVMMovableArgValueWithContext_>
        at /workspace/v-leiwang3/tvm/include/tvm/runtime/packed_func.h:1631
  69: run<tvm::runtime::TVMMovableArgValueWithContext_, tvm::runtime::TVMMovableArgValueWithContext_, tvm::runtime::TVMMovableArgValueWithContext_>
        at /workspace/v-leiwang3/tvm/include/tvm/runtime/packed_func.h:1631
  68: run<tvm::runtime::TVMMovableArgValueWithContext_, tvm::runtime::TVMMovableArgValueWithContext_, tvm::runtime::TVMMovableArgValueWithContext_, tvm::runtime::TVMMovableArgValueWithContext_>
        at /workspace/v-leiwang3/tvm/include/tvm/runtime/packed_func.h:1631
  67: run<tvm::runtime::TVMMovableArgValueWithContext_, tvm::runtime::TVMMovableArgValueWithContext_, tvm::runtime::TVMMovableArgValueWithContext_, tvm::runtime::TVMMovableArgValueWithContext_, tvm::runtime::TVMMovableArgValueWithContext_>
        at /workspace/v-leiwang3/tvm/include/tvm/runtime/packed_func.h:1631
  66: run<tvm::runtime::TVMMovableArgValueWithContext_, tvm::runtime::TVMMovableArgValueWithContext_, tvm::runtime::TVMMovableArgValueWithContext_, tvm::runtime::TVMMovableArgValueWithContext_, tvm::runtime::TVMMovableArgValueWithContext_, tvm::runtime::TVMMovableArgValueWithContext_>
        at /workspace/v-leiwang3/tvm/include/tvm/runtime/packed_func.h:1659
  65: operator()
        at /workspace/v-leiwang3/tvm/src/tir/schedule/schedule.cc:257
  64: tvm::tir::TracedScheduleNode::TransformLayout(tvm::tir::BlockRV const&, int, tvm::tir::BufferIndexType, tvm::tir::IndexMap const&, tvm::runtime::Optional<tvm::tir::IndexMap> const&)
        at /workspace/v-leiwang3/tvm/src/tir/schedule/traced_schedule.cc:522
  63: tvm::tir::ConcreteScheduleNode::TransformLayout(tvm::tir::BlockRV const&, int, tvm::tir::BufferIndexType, tvm::tir::IndexMap const&, tvm::runtime::Optional<tvm::tir::IndexMap> const&)
        at /workspace/v-leiwang3/tvm/src/tir/schedule/concrete_schedule.cc:793
  62: tvm::tir::TransformLayout(tvm::tir::ScheduleState, tvm::tir::StmtSRef const&, int, tvm::tir::BufferIndexType, tvm::tir::IndexMap const&, tvm::runtime::Optional<tvm::tir::IndexMap> const&)
        at /workspace/v-leiwang3/tvm/src/tir/schedule/primitive/layout_transformation.cc:997
  61: operator()
        at /workspace/v-leiwang3/tvm/src/tir/schedule/primitive/layout_transformation.cc:996
  60: tvm::tir::IndexMap::NonSurjectiveInverse(tvm::runtime::Array<tvm::Range, void>) const
        at /workspace/v-leiwang3/tvm/src/tir/ir/index_map.cc:137
  59: tvm::tir::IndexMapInverseImpl(tvm::tir::IndexMap const&, tvm::runtime::Array<tvm::Range, void> const&, tvm::arith::IterMapLevel)
        at /workspace/v-leiwang3/tvm/src/tir/ir/index_map.cc:94
  58: tvm::arith::DetectIterMap(tvm::runtime::Array<tvm::PrimExpr, void> const&, tvm::runtime::Map<tvm::tir::Var, tvm::Range, void, void> const&, tvm::PrimExpr const&, tvm::arith::IterMapLevel, tvm::arith::Analyzer*, bool)
        at /workspace/v-leiwang3/tvm/src/arith/iter_affine_map.cc:1152
  57: tvm::arith::IterMapRewriter::RewriteAndUpdatePadding(tvm::PrimExpr const&)
        at /workspace/v-leiwang3/tvm/src/arith/iter_affine_map.cc:214
  56: tvm::arith::IterMapRewriter::Rewrite(tvm::PrimExpr const&)
        at /workspace/v-leiwang3/tvm/src/arith/iter_affine_map.cc:209
  55: tvm::arith::IterMapRewriter::DirectMutate(tvm::PrimExpr const&)
  54: tvm::tir::ExprFunctor<tvm::PrimExpr (tvm::PrimExpr const&)>::VisitExpr(tvm::PrimExpr const&)
        at /workspace/v-leiwang3/tvm/include/tvm/tir/expr_functor.h:114
  53: tvm::NodeFunctor<tvm::PrimExpr (tvm::runtime::ObjectRef const&, tvm::tir::ExprFunctor<tvm::PrimExpr (tvm::PrimExpr const&)>*)>::operator()(tvm::runtime::ObjectRef const&, tvm::tir::ExprFunctor<tvm::PrimExpr (tvm::PrimExpr const&)>*) const
        at /workspace/v-leiwang3/tvm/include/tvm/node/functor.h:97
  52: tvm::tir::ExprFunctor<tvm::PrimExpr (tvm::PrimExpr const&)>::InitVTable()::{lambda(tvm::runtime::ObjectRef const&, tvm::tir::ExprFunctor<tvm::PrimExpr (tvm::PrimExpr const&)>*)#7}::_FUN(tvm::runtime::ObjectRef const&, tvm::tir::ExprFunctor<tvm::PrimExpr (tvm::PrimExpr const&)>*)
        at /workspace/v-leiwang3/tvm/include/tvm/tir/expr_functor.h:170
  51: tvm::tir::ExprFunctor<tvm::PrimExpr (tvm::PrimExpr const&)>::InitVTable()::{lambda(tvm::runtime::ObjectRef const&, tvm::tir::ExprFunctor<tvm::PrimExpr (tvm::PrimExpr const&)>*)#7}::operator()(tvm::runtime::ObjectRef const&, tvm::tir::ExprFunctor<tvm::PrimExpr (tvm::PrimExpr const&)>*) const
        at /workspace/v-leiwang3/tvm/include/tvm/tir/expr_functor.h:170
  50: tvm::tir::ExprMutator::VisitExpr_(tvm::tir::CallNode const*)
        at /workspace/v-leiwang3/tvm/src/tir/ir/expr_functor.cc:165
  49: Map<tvm::tir::ExprMutator::VisitExpr_(const tvm::tir::CallNode*)::<lambda(const tvm::PrimExpr&)> >
        at /workspace/v-leiwang3/tvm/include/tvm/runtime/container/array.h:622
  48: MapHelper<tvm::tir::ExprMutator::VisitExpr_(const tvm::tir::CallNode*)::<lambda(const tvm::PrimExpr&)> >
        at /workspace/v-leiwang3/tvm/include/tvm/runtime/container/array.h:780
  47: operator()
        at /workspace/v-leiwang3/tvm/src/tir/ir/expr_functor.cc:164
  46: tvm::arith::IterMapRewriter::VisitExpr(tvm::PrimExpr const&)
  45: tvm::tir::ExprFunctor<tvm::PrimExpr (tvm::PrimExpr const&)>::VisitExpr(tvm::PrimExpr const&)
        at /workspace/v-leiwang3/tvm/include/tvm/tir/expr_functor.h:114
  44: tvm::NodeFunctor<tvm::PrimExpr (tvm::runtime::ObjectRef const&, tvm::tir::ExprFunctor<tvm::PrimExpr (tvm::PrimExpr const&)>*)>::operator()(tvm::runtime::ObjectRef const&, tvm::tir::ExprFunctor<tvm::PrimExpr (tvm::PrimExpr const&)>*) const
        at /workspace/v-leiwang3/tvm/include/tvm/node/functor.h:97
  43: tvm::tir::ExprFunctor<tvm::PrimExpr (tvm::PrimExpr const&)>::InitVTable()::{lambda(tvm::runtime::ObjectRef const&, tvm::tir::ExprFunctor<tvm::PrimExpr (tvm::PrimExpr const&)>*)#7}::_FUN(tvm::runtime::ObjectRef const&, tvm::tir::ExprFunctor<tvm::PrimExpr (tvm::PrimExpr const&)>*)
        at /workspace/v-leiwang3/tvm/include/tvm/tir/expr_functor.h:170
  42: tvm::tir::ExprFunctor<tvm::PrimExpr (tvm::PrimExpr const&)>::InitVTable()::{lambda(tvm::runtime::ObjectRef const&, tvm::tir::ExprFunctor<tvm::PrimExpr (tvm::PrimExpr const&)>*)#7}::operator()(tvm::runtime::ObjectRef const&, tvm::tir::ExprFunctor<tvm::PrimExpr (tvm::PrimExpr const&)>*) const
        at /workspace/v-leiwang3/tvm/include/tvm/tir/expr_functor.h:170
  41: tvm::tir::ExprMutator::VisitExpr_(tvm::tir::CallNode const*)
        at /workspace/v-leiwang3/tvm/src/tir/ir/expr_functor.cc:165
  40: Map<tvm::tir::ExprMutator::VisitExpr_(const tvm::tir::CallNode*)::<lambda(const tvm::PrimExpr&)> >
        at /workspace/v-leiwang3/tvm/include/tvm/runtime/container/array.h:622
  39: MapHelper<tvm::tir::ExprMutator::VisitExpr_(const tvm::tir::CallNode*)::<lambda(const tvm::PrimExpr&)> >
        at /workspace/v-leiwang3/tvm/include/tvm/runtime/container/array.h:780
  38: operator()
        at /workspace/v-leiwang3/tvm/src/tir/ir/expr_functor.cc:164
  37: tvm::arith::IterMapRewriter::VisitExpr(tvm::PrimExpr const&)
  36: tvm::tir::ExprFunctor<tvm::PrimExpr (tvm::PrimExpr const&)>::VisitExpr(tvm::PrimExpr const&)
        at /workspace/v-leiwang3/tvm/include/tvm/tir/expr_functor.h:114
  35: tvm::NodeFunctor<tvm::PrimExpr (tvm::runtime::ObjectRef const&, tvm::tir::ExprFunctor<tvm::PrimExpr (tvm::PrimExpr const&)>*)>::operator()(tvm::runtime::ObjectRef const&, tvm::tir::ExprFunctor<tvm::PrimExpr (tvm::PrimExpr const&)>*) const
        at /workspace/v-leiwang3/tvm/include/tvm/node/functor.h:97
  34: tvm::tir::ExprFunctor<tvm::PrimExpr (tvm::PrimExpr const&)>::InitVTable()::{lambda(tvm::runtime::ObjectRef const&, tvm::tir::ExprFunctor<tvm::PrimExpr (tvm::PrimExpr const&)>*)#14}::_FUN(tvm::runtime::ObjectRef const&, tvm::tir::ExprFunctor<tvm::PrimExpr (tvm::PrimExpr const&)>*)
        at /workspace/v-leiwang3/tvm/include/tvm/tir/expr_functor.h:177
  33: tvm::tir::ExprFunctor<tvm::PrimExpr (tvm::PrimExpr const&)>::InitVTable()::{lambda(tvm::runtime::ObjectRef const&, tvm::tir::ExprFunctor<tvm::PrimExpr (tvm::PrimExpr const&)>*)#14}::operator()(tvm::runtime::ObjectRef const&, tvm::tir::ExprFunctor<tvm::PrimExpr (tvm::PrimExpr const&)>*) const
        at /workspace/v-leiwang3/tvm/include/tvm/tir/expr_functor.h:177
  32: tvm::arith::IterMapRewriter::VisitExpr_(tvm::tir::FloorModNode const*)
        at /workspace/v-leiwang3/tvm/src/arith/iter_affine_map.cc:1652
  31: tvm::arith::IterMapRewriter::DirectMutate(tvm::PrimExpr const&)
  30: tvm::tir::ExprFunctor<tvm::PrimExpr (tvm::PrimExpr const&)>::VisitExpr(tvm::PrimExpr const&)
        at /workspace/v-leiwang3/tvm/include/tvm/tir/expr_functor.h:114
  29: tvm::NodeFunctor<tvm::PrimExpr (tvm::runtime::ObjectRef const&, tvm::tir::ExprFunctor<tvm::PrimExpr (tvm::PrimExpr const&)>*)>::operator()(tvm::runtime::ObjectRef const&, tvm::tir::ExprFunctor<tvm::PrimExpr (tvm::PrimExpr const&)>*) const
        at /workspace/v-leiwang3/tvm/include/tvm/node/functor.h:97
  28: tvm::tir::ExprFunctor<tvm::PrimExpr (tvm::PrimExpr const&)>::InitVTable()::{lambda(tvm::runtime::ObjectRef const&, tvm::tir::ExprFunctor<tvm::PrimExpr (tvm::PrimExpr const&)>*)#13}::_FUN(tvm::runtime::ObjectRef const&, tvm::tir::ExprFunctor<tvm::PrimExpr (tvm::PrimExpr const&)>*)
        at /workspace/v-leiwang3/tvm/include/tvm/tir/expr_functor.h:176
  27: tvm::tir::ExprFunctor<tvm::PrimExpr (tvm::PrimExpr const&)>::InitVTable()::{lambda(tvm::runtime::ObjectRef const&, tvm::tir::ExprFunctor<tvm::PrimExpr (tvm::PrimExpr const&)>*)#13}::operator()(tvm::runtime::ObjectRef const&, tvm::tir::ExprFunctor<tvm::PrimExpr (tvm::PrimExpr const&)>*) const
        at /workspace/v-leiwang3/tvm/include/tvm/tir/expr_functor.h:176
  26: tvm::arith::IterMapRewriter::VisitExpr_(tvm::tir::FloorDivNode const*)
        at /workspace/v-leiwang3/tvm/src/arith/iter_affine_map.cc:1595
  25: tvm::arith::IterMapRewriter::SplitFloorDivConst(tvm::arith::IterSplitExpr, tvm::PrimExpr, tvm::PrimExpr)
        at /workspace/v-leiwang3/tvm/src/arith/iter_affine_map.cc:1536
  24: tvm::arith::IterMapRewriter::PadDividendToDivisor(tvm::arith::IterSplitExpr, tvm::PrimExpr, tvm::PrimExpr)
        at /workspace/v-leiwang3/tvm/src/arith/iter_affine_map.cc:1374
  23: tvm::arith::Analyzer::Simplify(tvm::PrimExpr const&, int)
        at /workspace/v-leiwang3/tvm/src/arith/analyzer.cc:137
  22: tvm::arith::RewriteSimplifier::operator()(tvm::PrimExpr const&)
        at /workspace/v-leiwang3/tvm/src/arith/rewrite_simplify.cc:1991
  21: tvm::tir::ExprFunctor<tvm::PrimExpr (tvm::PrimExpr const&)>::operator()(tvm::PrimExpr const&)
        at /workspace/v-leiwang3/tvm/include/tvm/tir/expr_functor.h:104
  20: tvm::tir::StmtExprMutator::VisitExpr(tvm::PrimExpr const&)
        at /workspace/v-leiwang3/tvm/include/tvm/tir/stmt_functor.h:324
  19: tvm::tir::ExprFunctor<tvm::PrimExpr (tvm::PrimExpr const&)>::VisitExpr(tvm::PrimExpr const&)
        at /workspace/v-leiwang3/tvm/include/tvm/tir/expr_functor.h:114
  18: tvm::NodeFunctor<tvm::PrimExpr (tvm::runtime::ObjectRef const&, tvm::tir::ExprFunctor<tvm::PrimExpr (tvm::PrimExpr const&)>*)>::operator()(tvm::runtime::ObjectRef const&, tvm::tir::ExprFunctor<tvm::PrimExpr (tvm::PrimExpr const&)>*) const
        at /workspace/v-leiwang3/tvm/include/tvm/node/functor.h:97
  17: tvm::tir::ExprFunctor<tvm::PrimExpr (tvm::PrimExpr const&)>::InitVTable()::{lambda(tvm::runtime::ObjectRef const&, tvm::tir::ExprFunctor<tvm::PrimExpr (tvm::PrimExpr const&)>*)#14}::_FUN(tvm::runtime::ObjectRef const&, tvm::tir::ExprFunctor<tvm::PrimExpr (tvm::PrimExpr const&)>*)
        at /workspace/v-leiwang3/tvm/include/tvm/tir/expr_functor.h:177
  16: tvm::tir::ExprFunctor<tvm::PrimExpr (tvm::PrimExpr const&)>::InitVTable()::{lambda(tvm::runtime::ObjectRef const&, tvm::tir::ExprFunctor<tvm::PrimExpr (tvm::PrimExpr const&)>*)#14}::operator()(tvm::runtime::ObjectRef const&, tvm::tir::ExprFunctor<tvm::PrimExpr (tvm::PrimExpr const&)>*) const
        at /workspace/v-leiwang3/tvm/include/tvm/tir/expr_functor.h:177
  15: tvm::arith::RewriteSimplifier::Impl::VisitExpr_(tvm::tir::FloorModNode const*)
        at /workspace/v-leiwang3/tvm/src/arith/rewrite_simplify.cc:959
  14: tvm::tir::ExprMutator::VisitExpr_(tvm::tir::FloorModNode const*)
        at /workspace/v-leiwang3/tvm/src/tir/ir/expr_functor.cc:198
  13: tvm::tir::StmtExprMutator::VisitExpr(tvm::PrimExpr const&)
        at /workspace/v-leiwang3/tvm/include/tvm/tir/stmt_functor.h:324
  12: tvm::tir::ExprFunctor<tvm::PrimExpr (tvm::PrimExpr const&)>::VisitExpr(tvm::PrimExpr const&)
        at /workspace/v-leiwang3/tvm/include/tvm/tir/expr_functor.h:114
  11: tvm::NodeFunctor<tvm::PrimExpr (tvm::runtime::ObjectRef const&, tvm::tir::ExprFunctor<tvm::PrimExpr (tvm::PrimExpr const&)>*)>::operator()(tvm::runtime::ObjectRef const&, tvm::tir::ExprFunctor<tvm::PrimExpr (tvm::PrimExpr const&)>*) const
        at /workspace/v-leiwang3/tvm/include/tvm/node/functor.h:97
  10: tvm::tir::ExprFunctor<tvm::PrimExpr (tvm::PrimExpr const&)>::InitVTable()::{lambda(tvm::runtime::ObjectRef const&, tvm::tir::ExprFunctor<tvm::PrimExpr (tvm::PrimExpr const&)>*)#7}::_FUN(tvm::runtime::ObjectRef const&, tvm::tir::ExprFunctor<tvm::PrimExpr (tvm::PrimExpr const&)>*)
        at /workspace/v-leiwang3/tvm/include/tvm/tir/expr_functor.h:170
  9: tvm::tir::ExprFunctor<tvm::PrimExpr (tvm::PrimExpr const&)>::InitVTable()::{lambda(tvm::runtime::ObjectRef const&, tvm::tir::ExprFunctor<tvm::PrimExpr (tvm::PrimExpr const&)>*)#7}::operator()(tvm::runtime::ObjectRef const&, tvm::tir::ExprFunctor<tvm::PrimExpr (tvm::PrimExpr const&)>*) const
        at /workspace/v-leiwang3/tvm/include/tvm/tir/expr_functor.h:170
  8: tvm::arith::RewriteSimplifier::Impl::VisitExpr_(tvm::tir::CallNode const*)
        at /workspace/v-leiwang3/tvm/src/arith/rewrite_simplify.cc:1874
  7: tvm::arith::IRMutatorWithAnalyzer::VisitExpr_(tvm::tir::CallNode const*)
        at /workspace/v-leiwang3/tvm/src/arith/ir_mutator_with_analyzer.cc:157
  6: tvm::tir::ExprMutator::VisitExpr_(tvm::tir::CallNode const*)
        at /workspace/v-leiwang3/tvm/src/tir/ir/expr_functor.cc:165
  5: Map<tvm::tir::ExprMutator::VisitExpr_(const tvm::tir::CallNode*)::<lambda(const tvm::PrimExpr&)> >
        at /workspace/v-leiwang3/tvm/include/tvm/runtime/container/array.h:622
  4: MapHelper<tvm::tir::ExprMutator::VisitExpr_(const tvm::tir::CallNode*)::<lambda(const tvm::PrimExpr&)> >
        at /workspace/v-leiwang3/tvm/include/tvm/runtime/container/array.h:780
  3: operator()
        at /workspace/v-leiwang3/tvm/src/tir/ir/expr_functor.cc:164
  2: tvm::tir::StmtExprMutator::VisitExpr(tvm::PrimExpr const&)
        at /workspace/v-leiwang3/tvm/include/tvm/tir/stmt_functor.h:324
  1: tvm::tir::ExprFunctor<tvm::PrimExpr (tvm::PrimExpr const&)>::VisitExpr(tvm::PrimExpr const&)
        at /workspace/v-leiwang3/tvm/include/tvm/tir/expr_functor.h:114
  0: tvm::NodeFunctor<tvm::PrimExpr (tvm::runtime::ObjectRef const&, tvm::tir::ExprFunctor<tvm::PrimExpr (tvm::PrimExpr const&)>*)>::operator()(tvm::runtime::ObjectRef const&, tvm::tir::ExprFunctor<tvm::PrimExpr (tvm::PrimExpr const&)>*) const
        at /workspace/v-leiwang3/tvm/include/tvm/node/functor.h:95
  File "/workspace/v-leiwang3/tvm/include/tvm/node/functor.h", line 95
TVMError: 
---------------------------------------------------------------
An error occurred during the execution of TVM.
For more information, please see: https://tvm.apache.org/docs/errors.html
---------------------------------------------------------------
  Check failed: (can_dispatch(n)) is false: NodeFunctor calls un-registered function on type arith.IterSplitExpr

Is there any plan to support such a problem or it’s easy to extend?

1 Like

We have some ongoing effort lead by @Hzfengsy and @spectrometerHBH.

1 Like

Hi, could you kindly provide some script to reproduce it? The stack trace seems to be a bug in iteration analysis.

The code to reproduce can be accessed at

I agreed there are some bugs in iteration analysis, which I have got before I found the limits of the expresion, thanks a lot.

Seems this is impacted by recent commits to TransformLayout. There are some checks here https://github.com/apache/tvm/blob/main/src/tir/schedule/primitive/layout_transformation.cc#L1093-L1096 that rely on NonSurjectiveInverse analysis, which doesn’t cover cases. Operators used here (xor) are not supported by affine analysis but is allowed in TransformLayout. The previous behavior is to allow new buffer being padded, but the padding region will not be accessed by the transformed program. Maybe we can disable such check if pad_value is not provided. @Lunderberg may have suggestions

1 Like

Interesting, while ago I was also trying to realize cutlass permuted layout in TVM. My branch from then was https://github.com/masahi/tvm/tree/permuted-shmem and the test case https://github.com/masahi/tvm/blob/abda50710ea64eddb82ce7c9151d108b99efcf0a/tests/python/unittest/test_mma_permuted.py.

I remember I got ~35 TFLOPs on RTX 3070, but it was worse than my simpler baseline schedule with naive shmem layout (~40 TFLOPs).

1 Like

@vinx13 @masahi Thanks alot, and quite interesting, I think I need to rollback tvm and do test again. And I implemented cutlass dp4a permutation with tvm which can even be faster than cutlass about 3~4% in nn layout, so I believe we can do the same thing with tensorcore.

Hi @vinx13 , could you kindly help me to check this layout permutation? it didn’t pass with the latest code.

def B_global_16x32_to_shared_load_16x32_layout(i, j):
    thread_id = i * 2 + j // 16
    row = (i // 8) * 8 + thread_id % 8
    col = (j % 16) + 16 * ((thread_id // 8) % 2)

    return row, col

error:

TVMError: Check failed: (padded_iter_map->errors.empty()) is false: Could not parse mapping as sum of iterators.  Error: Mapped indices are not independent.

I checked the code it seems like it was caused by incomplete mapping, but the mapping form src to dst is complete, I will keep trying to resolve it, and it would be nice to have some advice :grinning_face_with_smiling_eyes:

code to reproduce

import tvm
import numpy as np
import tvm.testing
from tvm.script import tir as T

'''
    evaluate tir transform layout
'''

M = 16
N = 32

@tvm.script.ir_module
class MyModule:
    @T.prim_func
    def main(a: T.handle, b: T.handle):
        T.func_attr({"global_symbol": "main", "tir.noalias": True})
        A = T.match_buffer(a, [M, N], dtype="int8")
        B = T.match_buffer(b, [N, N], dtype="int8")

        for i, j in T.grid(M, N):
            with T.block("B"):
                vi, vj = T.axis.remap("SS", [i, j])
                B[vi, vj] = A[vi, vj]

ir_module = MyModule
print(ir_module)
sch = tvm.tir.Schedule(ir_module, debug_mask="all")

def B_global_16x32_to_shared_load_16x32_layout(i, j):
    thread_id = i * 2 + j // 16
    row = (i // 8) * 8 + thread_id % 8
    col = (j % 16) + 16 * ((thread_id // 8) % 2)

    return row, col

block_b = sch.get_block("B")
sch.transform_layout(block_b, ("read", 0),
                     B_global_16x32_to_shared_load_16x32_layout)
i, j = sch.get_loops(block_b)
sch.bind(i, "threadIdx.x")

# build and run
ctx = tvm.cuda(0)
cuda_mod = tvm.build(sch.mod, target="cuda")

Try commenting out https://github.com/apache/tvm/blob/main/src/tir/schedule/primitive/layout_transformation.cc#L1120-L1133

If we know xor doesn’t introduce padding, we can skip this check

1 Like

I commented out related check code lines, and it works fine, thanks a lot.