CUTLASS Support

Hello,

I am trying to use CUTLASS within the TVM, but I am not able to do so.

To begin with, I am using the example from 4. Matrix Multiplication — Dive into Deep Learning Compiler 0.1 documentation for matrix multiplication.

Instead of using

    A, B, C = d2ltvm.matmul(n, n, n) 

I am using the lambda function for writing my kernel.

    k = te.reduce_axis((0, d3), name='k') 
    output_shape = (batch_sz, d1, d2)  
    algorithm = lambda l, i, j: te.sum(A[l, i, k] * B[l, k, j], axis=k)
    C = te.compute(output_shape, algorithm, name='C')
    s = te.create_schedule(C.op)

Rest of the code is same, and when I compared the performance of resulting kernel to the PyTorch implementation of batched multiplication, the above code is 3x slower.

In order to achieve better performance, I came across the thread at [RFC][BYOC]NVIDIA CUTLASS Integration. My idea is to use CUTLASS based version for better performance. To get my hands wet with CUTLASS based example, user masahi had pointed out to their CUTLASS example at github, but whenever I try to import tune_cutlass_kernels or build_cutlass_kernels_vm used in the examples, it gives an error, package is not found.. I am not sure where to get these packages.

Moreover, I am not sure the usage of the above function. Looking at the examples posted, I believe that those functions are important to build and tune CUTLASS kernels.

If anyone can help me with an example on how to use CUTLASS based kernels for batched multiplication using TVM, that would be great.

P.S. To install TVM, I followed the guide for install from source. I have also enabled the CUTLASS support while building.

Here is the link to the github example. tvm-cutlass-eval/cudnn.py - GitHub

The 3rdpart example can be outdated. Please see

I tried this example, but it gives an error.

Currently, using CUTLASS requires building TVM from source

In my cmake build, I have set USE_CUTLASS to ON.

# Enable using CUTLASS as a BYOC backend
# Need to have USE_CUDA=ON
set(USE_CUTLASS ON)

set(USE_CUTLASS ON) should be OK and it works well on my machine.

Please make sure you’ve changed the config.cmake under the build folder, rather than the that in the cmake folder. (reference)

Yes, this is the exact thing that I am following, and I retried everything yet same error. I reinstalled CUDA and all the drivers. I can query nvcc version and nvidia-smi. I have also installed CUDNN and CuBlas after this step. After that I built the CUTLASS library as well and then followed the steps for building TVM from source.

I created build directory and copied over config.cmake and made the change. Installed the package to python and I can import tvm. However, the CUTLASS library in TVM still doesn’t work.

That error message just means you don’t have tvm/3rdparty/cutlass directory. Please make sure you have it (it should be the case if you have cloned TVM with --recursive). You don’t need to build CUTLASS, it is a header-only lib.

In the downloaded folder I do see cutlass folder in the 3rdparty folder. However, I guess when I Install TVM python bindings by setup.py, it doesn’t copy over cutlass folder.

Because it is looking for the cutlass folder in

/home/username/miniconda3/lib/python3.9/site-packages/tvm-0.10.dev484+g2e83e03b2-py3.9-linux-x86_64.egg/tvm/contrib/cutlass/../../../../3rdparty/cutlass.

while the downloaded folder is at

/home/username/tvm which contains 3rdparty/cutlass and build folder.

I see, when I wrote that code that looks for the cutlass folder, I was assuming that TVM is “installed” by manually modifying PYTHONPATH (“method 1” in Install from Source — tvm 0.10.dev0 documentation). I think most developers doesn’t use setup.py to install TVM.

If you modify tvm_root at https://github.com/apache/tvm/blob/bdcfa01eae3ffe8c6d39aa26d0d1e5b311d47efb/python/tvm/contrib/cutlass/build.py#L39 according to your environment, it should work.

1 Like

Yes, I used method I and I was able to compile using CUTLASS but then it gives segmentation error when I try to run the code posted here. tvm/test_cutlass.py

Here is the detailed output.

========================================================================================================================================================= test session starts ==========================================================================================================================================================
platform linux -- Python 3.8.10, pytest-7.1.3, pluggy-1.0.0
rootdir: /home/username/example
collected 12 items                                                                                                                                                                                                                                                                                                                     

test_cutlass.py FFFFFFatal Python error: Segmentation fault

Current thread 0x00007f1e4d5b0740 (most recent call first):
  File "/home/username/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 227 in __call__
  File "/home/username/tvm/python/tvm/runtime/module.py", line 271 in save
  File "/home/username/tvm/python/tvm/runtime/module.py", line 499 in export_library
  File "/home/username/tvm/python/tvm/relay/backend/executor_factory.py", line 203 in export_library
  File "/home/username/tvm/python/tvm/contrib/test_cutlass/build.py", line 599 in finalize_modules
  File "/home/username/example/test_cutlass.py", line 285 in profile_and_build
  File "/home/username/example/test_cutlass.py", line 406 in verify_batch_matmul
  File "/home/username/example/test_cutlass.py", line 510 in test_batch_matmul
  File "/home/username/.local/lib/python3.8/site-packages/_pytest/python.py", line 192 in pytest_pyfunc_call
  File "/home/username/.local/lib/python3.8/site-packages/pluggy/_callers.py", line 39 in _multicall
  File "/home/username/.local/lib/python3.8/site-packages/pluggy/_manager.py", line 80 in _hookexec
  File "/home/username/.local/lib/python3.8/site-packages/pluggy/_hooks.py", line 265 in __call__
  File "/home/username/.local/lib/python3.8/site-packages/_pytest/python.py", line 1761 in runtest
  File "/home/username/.local/lib/python3.8/site-packages/_pytest/runner.py", line 166 in pytest_runtest_call
  File "/home/username/.local/lib/python3.8/site-packages/pluggy/_callers.py", line 39 in _multicall
  File "/home/username/.local/lib/python3.8/site-packages/pluggy/_manager.py", line 80 in _hookexec
  File "/home/username/.local/lib/python3.8/site-packages/pluggy/_hooks.py", line 265 in __call__
  File "/home/username/.local/lib/python3.8/site-packages/_pytest/runner.py", line 259 in <lambda>
  File "/home/username/.local/lib/python3.8/site-packages/_pytest/runner.py", line 338 in from_call
  File "/home/username/.local/lib/python3.8/site-packages/_pytest/runner.py", line 258 in call_runtest_hook
  File "/home/username/.local/lib/python3.8/site-packages/_pytest/runner.py", line 219 in call_and_report
  File "/home/username/.local/lib/python3.8/site-packages/_pytest/runner.py", line 130 in runtestprotocol
  File "/home/username/.local/lib/python3.8/site-packages/_pytest/runner.py", line 111 in pytest_runtest_protocol
  File "/home/username/.local/lib/python3.8/site-packages/pluggy/_callers.py", line 39 in _multicall
  File "/home/username/.local/lib/python3.8/site-packages/pluggy/_manager.py", line 80 in _hookexec
  File "/home/username/.local/lib/python3.8/site-packages/pluggy/_hooks.py", line 265 in __call__
  File "/home/username/.local/lib/python3.8/site-packages/_pytest/main.py", line 347 in pytest_runtestloop
  File "/home/username/.local/lib/python3.8/site-packages/pluggy/_callers.py", line 39 in _multicall
  File "/home/username/.local/lib/python3.8/site-packages/pluggy/_manager.py", line 80 in _hookexec
  File "/home/username/.local/lib/python3.8/site-packages/pluggy/_hooks.py", line 265 in __call__
  File "/home/username/.local/lib/python3.8/site-packages/_pytest/main.py", line 322 in _main
  File "/home/username/.local/lib/python3.8/site-packages/_pytest/main.py", line 268 in wrap_session
  File "/home/username/.local/lib/python3.8/site-packages/_pytest/main.py", line 315 in pytest_cmdline_main
  File "/home/username/.local/lib/python3.8/site-packages/pluggy/_callers.py", line 39 in _multicall
  File "/home/username/.local/lib/python3.8/site-packages/pluggy/_manager.py", line 80 in _hookexec
  File "/home/username/.local/lib/python3.8/site-packages/pluggy/_hooks.py", line 265 in __call__
  File "/home/username/.local/lib/python3.8/site-packages/_pytest/config/__init__.py", line 164 in main
  File "/home/username/tvm/python/tvm/testing/utils.py", line 1734 in main
  File "test_cutlass.py", line 964 in <module>
Segmentation fault (core dumped)

@masahi Do you have any idea? Why the segmentation fault?

Sorry I’ve never seen segfault from that test. It looks like the segfault is happening when we try to export the compiled binary as a shared lib.

I was able to get it to run, it seems that it doesn’t result in segmentation fault, but I guess CUTLASS support is limited to certain Cuda compatibilities. I have Tesla 100 GPU that has sm60 arch. What would be the procedure to get this code to run?

Oh we only support sm 75 and higher (those with tensor core). So you cannot run on Tesla P100, unfortunately.

@masahi, So would GPU with Cuda capability of 86 should be able to run it? Or should I make any changes?

Yes I was testing on 3070 during my development. It should work out of the box.

So, I tried this on Tesla T4 which has compute capability of 7.5 and yet the code fails for me.

I am running the code at tvm/test_cutlass.py.

The error is as follows.

def @main(%x: Tensor[(8, 96, 64), float16] /* ty=Tensor[(8, 96, 64), float16] */, %y: Tensor[(8, 6
4, 64), float16] /* ty=Tensor[(8
-bash: syntax error near unexpected token `('
azureuser@pipeline1:~/examples$ cat cutlass.log 
INFO:root:before partitioning:
def @main(%x: Tensor[(8, 96, 64), float16], %y: Tensor[(8, 64, 64), float16]) {
  nn.batch_matmul(%x, %y, out_dtype="float16", transpose_b=True)
}
INFO:root:after partitioning:
def @main(%x: Tensor[(8, 96, 64), float16] /* ty=Tensor[(8, 96, 64), float16] */, %y: Tensor[(8, 64, 64), float16] /* ty=Tensor[(8
, 64, 64), float16] */) -> Tensor[(8, 96, 64), float16] {
  @tvmgen_default_cutlass_main_0(%x, %y) /* ty=Tensor[(8, 96, 64), float16] */
}
def @tvmgen_default_cutlass_main_0(%cutlass_0_i0: Tensor[(8, 96, 64), float16] /* ty=Tensor[(8, 96, 64), float16] */, %cutlass_0_i
1: Tensor[(8, 64, 64), float16] /* ty=Tensor[(8, 64, 64), float16] */, Inline=1, Compiler="cutlass", global_symbol="tvmgen_default
_cutlass_main_0", Primitive=1) -> Tensor[(8, 96, 64), float16] {
  %0 = fn (%FunctionVar_0_0: Tensor[(8, 96, 64), float16] /* ty=Tensor[(8, 96, 64), float16] */, %FunctionVar_0_1: Tensor[(8, 64, 
64), float16] /* ty=Tensor[(8, 64, 64), float16] */, PartitionedFromPattern="nn.batch_matmul_", Composite="cutlass.batch_matmul") 
-> Tensor[(8, 96, 64), float16] {
    nn.batch_matmul(%FunctionVar_0_0, %FunctionVar_0_1, out_dtype="float16", transpose_b=True) /* ty=Tensor[(8, 96, 64), float16] 
*/
  } /* ty=fn (Tensor[(8, 96, 64), float16], Tensor[(8, 64, 64), float16]) -> Tensor[(8, 96, 64), float16] */;
  %0(%cutlass_0_i0, %cutlass_0_i1) /* ty=Tensor[(8, 96, 64), float16] */
}
INFO:cutlass:Tuning for CUTLASS
INFO:cutlass:invoking evaluation ['./tmp/cutlass_tensorop_h1688gemm_256x128_32x2_tn_align8', '96', '64', '64']
INFO:cutlass:invoking evaluation ['./tmp/cutlass_tensorop_h1688gemm_128x256_32x2_tn_align8', '96', '64', '64']
INFO:cutlass:invoking evaluation ['./tmp/cutlass_tensorop_h1688gemm_128x128_32x2_tn_align8', '96', '64', '64']
INFO:cutlass:invoking evaluation ['./tmp/cutlass_tensorop_h1688gemm_64x128_32x2_tn_align8', '96', '64', '64']
INFO:cutlass:invoking evaluation ['./tmp/cutlass_tensorop_h1688gemm_128x64_32x2_tn_align8', '96', '64', '64']
INFO:cutlass:invoking evaluation ['./tmp/cutlass_tensorop_h1688gemm_64x64_32x2_tn_align8', '96', '64', '64']
INFO:cutlass:invoking evaluation ['./tmp/cutlass_tensorop_h1688gemm_64x128_64x2_tn_align8', '96', '64', '64']
INFO:cutlass:Picked the first kernel found cutlass_tensorop_h1688gemm_256x128_32x2_tn_align8
INFO:cutlass:Creating CSource module for CUTLASS
INFO:cutlass:Compiling generated CUTLASS code
INFO:cutlass:Loading compiled CUTLASS code
INFO:cutlass:Done with CUTLASS compilation
WARNING:autotvm:One or more operators have not been tuned. Please tune your model for better performance. Use DEBUG logging level 
to see more details.
INFO:te_compiler:Using batch_matmul_tensorcore.cuda for nn.batch_matmul based on highest priority (20)
INFO:root:before partitioning:
def @main(%x: Tensor[(8, 96, 64), float16], %y: Tensor[(8, 64, 64), float16]) {
  nn.batch_matmul(%x, %y, out_dtype="float16", transpose_b=True)
}
INFO:root:after partitioning:
def @main(%x: Tensor[(8, 96, 64), float16] /* ty=Tensor[(8, 96, 64), float16] */, %y: Tensor[(8, 64, 64), float16] /* ty=Tensor[(8
, 64, 64), float16] */) -> Tensor[(8, 96, 64), float16] {
  @tvmgen_default_cutlass_main_0(%x, %y) /* ty=Tensor[(8, 96, 64), float16] */
}
def @tvmgen_default_cutlass_main_0(%cutlass_0_i0: Tensor[(8, 96, 64), float16] /* ty=Tensor[(8, 96, 64), float16] */, %cutlass_0_i
1: Tensor[(8, 64, 64), float16] /* ty=Tensor[(8, 64, 64), float16] */, Inline=1, Compiler="cutlass", global_symbol="tvmgen_default
_cutlass_main_0", Primitive=1) -> Tensor[(8, 96, 64), float16] {
  %0 = fn (%FunctionVar_0_0: Tensor[(8, 96, 64), float16] /* ty=Tensor[(8, 96, 64), float16] */, %FunctionVar_0_1: Tensor[(8, 64, 
64), float16] /* ty=Tensor[(8, 64, 64), float16] */, PartitionedFromPattern="nn.batch_matmul_", Composite="cutlass.batch_matmul") 
-> Tensor[(8, 96, 64), float16] {
    nn.batch_matmul(%FunctionVar_0_0, %FunctionVar_0_1, out_dtype="float16", transpose_b=True) /* ty=Tensor[(8, 96, 64), float16] 
*/
  } /* ty=fn (Tensor[(8, 96, 64), float16], Tensor[(8, 64, 64), float16]) -> Tensor[(8, 96, 64), float16] */;
  %0(%cutlass_0_i0, %cutlass_0_i1) /* ty=Tensor[(8, 96, 64), float16] */
}
INFO:cutlass:Tuning for CUTLASS
INFO:cutlass:invoking evaluation ['./tmp/cutlass_tensorop_h1688gemm_256x128_32x2_tn_align8', '96', '64', '64']
INFO:cutlass:invoking evaluation ['./tmp/cutlass_tensorop_h1688gemm_128x256_32x2_tn_align8', '96', '64', '64']
INFO:cutlass:invoking evaluation ['./tmp/cutlass_tensorop_h1688gemm_128x128_32x2_tn_align8', '96', '64', '64']
INFO:cutlass:invoking evaluation ['./tmp/cutlass_tensorop_h1688gemm_64x128_32x2_tn_align8', '96', '64', '64']
INFO:cutlass:invoking evaluation ['./tmp/cutlass_tensorop_h1688gemm_128x64_32x2_tn_align8', '96', '64', '64']
INFO:cutlass:invoking evaluation ['./tmp/cutlass_tensorop_h1688gemm_64x64_32x2_tn_align8', '96', '64', '64']
INFO:cutlass:invoking evaluation ['./tmp/cutlass_tensorop_h1688gemm_64x128_64x2_tn_align8', '96', '64', '64']
INFO:cutlass:Picked the first kernel found cutlass_tensorop_h1688gemm_256x128_32x2_tn_align8
INFO:cutlass:Creating CSource module for CUTLASS
INFO:cutlass:Compiling generated CUTLASS code
INFO:cutlass:Loading compiled CUTLASS code
INFO:cutlass:Done with CUTLASS compilation
INFO:te_compiler:Using batch_matmul_tensorcore.cuda for nn.batch_matmul based on highest priority (20)
INFO:cutlass:Tuning for CUTLASS
INFO:cutlass:Picked the default kernel cutlass_tensorop_h1688gemm_128x64_32x2_tn_align1
INFO:cutlass:Creating CSource module for CUTLASS
INFO:cutlass:Compiling generated CUTLASS code
INFO:cutlass:Loading compiled CUTLASS code
INFO:cutlass:Done with CUTLASS compilation
Traceback (most recent call last):
  File "test_cutlass.py", line 963, in <module>
    test_batch_matmul()
  File "test_cutlass.py", line 518, in test_batch_matmul
    verify_batch_matmul(
  File "test_cutlass.py", line 397, in verify_batch_matmul
    rt_mod, dev, num_partition = profile_and_build_vm(mod, {}, sm)
  File "test_cutlass.py", line 318, in profile_and_build_vm
    vm_exec = relay.vm.compile(mod, target=[cuda, cutlass], params=params)
  File "/home/azureuser/tvm/python/tvm/relay/backend/vm.py", line 67, in compile
    compiler.lower(mod, target, target_host)
  File "/home/azureuser/tvm/python/tvm/relay/backend/vm.py", line 126, in lower
    self._lower(mod, raw_targets)
  File "/home/azureuser/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 237, in __call__
    raise get_last_ffi_error()
ValueError: Traceback (most recent call last):
  54: TVMFuncCall
  53: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::relay::vm::VMCompiler::GetFunction(std::__cxx11::
  basic_string<char, std::char_traits<char>, std::allocator<char> > const&, tvm::runtime::ObjectPtr<tvm::runtime::Object> const&)::{
lambda(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, t
vm::runtime::TVMRetValue*)
  52: tvm::relay::vm::VMCompiler::Lower(tvm::IRModule, tvm::runtime::Array<tvm::Target, void> const&)
  51: tvm::relay::vm::VMCompiler::LowerImpl(tvm::IRModule)
  50: tvm::relay::vm::VMCompiler::OptimizeModuleImpl(tvm::IRModule)
  49: tvm::transform::Pass::operator()(tvm::IRModule) const
  48: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  47: tvm::transform::SequentialNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  46: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  45: tvm::transform::SequentialNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  44: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  43: tvm::transform::ModulePassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  42: _ZN3tvm7runtime13PackedFuncObj9ExtractorINS0_16PackedFuncSubObjIZNS0_15TypedPackedFuncIFNS_8IRModuleES5_NS_9transform11PassC
ontextEEE17AssignTypedLambdaIZNS_5relay3tec7LowerTEENS0_6StringENS_17CompilationConfigESt8functionIFvNS_8BaseFuncEEEEUlS5_S7_E_EEv
T_EUlRKNS0_7TVMArgsEPNS0_11TVMRetValueEE_EEE4CallEPKS1_SL_SP_
  41: tvm::relay::tec::LowerTE(tvm::IRModule const&, tvm::runtime::String const&, std::function<void (tvm::BaseFunc)>, tvm::Compil
ationConfig)
  40: tvm::transform::Pass::operator()(tvm::IRModule) const
  39: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  38: tvm::relay::transform::FunctionPassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  37: _ZN3tvm7runtime13PackedFuncObj9ExtractorINS0_16PackedFuncSubObjIZNS0_15TypedPackedFuncIFNS_5relay8FunctionES6_NS_8IRModuleEN
S_9transform11PassContextEEE17AssignTypedLambdaIZNS5_3tec15LowerTensorExprENSD_10TECompilerESt8functionIFvNS_8BaseFuncEEENS_17Comp
ilationConfigEEUlS6_S7_S9_E_EEvT_EUlRKNS0_7TVMArgsEPNS0_11TVMRetValueEE_EEE4CallEPKS1_SM_SQ_
  36: tvm::relay::ExprMutator::VisitExpr(tvm::RelayExpr const&)
  35: tvm::relay::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&)
  34: _ZZN3tvm5relay11ExprFunctorIFNS_9RelayExprERKS2_EE10InitVTableEvENUlR
  33: tvm::relay::transform::DeviceAwareExprMutator::VisitExpr_(tvm::relay::FunctionNode const*)
  32: tvm::relay::tec::LowerTensorExprMutator::DeviceAwareVisitExpr_(tvm::relay::FunctionNode const*)
  31: _ZN3tvm5relay9transform22DeviceAwareExprMutator21DeviceAwar
  30: tvm::relay::ExprMutator::VisitExpr_(tvm::relay::FunctionNode const*)
  29: tvm::relay::ExprMutator::VisitExpr(tvm::RelayExpr const&)
  28: tvm::relay::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&)
  27: _ZZN3tvm5relay11ExprFunctorIFNS_9RelayExprERKS2_EE10InitVTableEvENUlR
  26: tvm::relay::transform::DeviceAwareExprMutator::VisitExpr_(tvm::relay::LetNode const*)
  25: tvm::relay::tec::LowerTensorExprMutator::PreVisitLetBinding_(tvm::relay::Var const&, tvm::RelayExpr const&)
  24: tvm::relay::ExprMutator::VisitExpr(tvm::RelayExpr const&)
  23: tvm::relay::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&)
  22: _ZZN3tvm5relay11ExprFunctorIFNS_9RelayExprERKS2_EE10InitVTableEvENUlR
  21: tvm::relay::transform::DeviceAwareExprMutator::VisitExpr_(tvm::relay::CallNode const*)
  20: tvm::relay::ExprMutator::VisitExpr(tvm::RelayExpr const&)
  19: tvm::relay::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&)
  18: _ZZN3tvm5relay11ExprFunctorIFNS_9RelayExprERKS2_EE10InitVTableEvENUlR
  17: tvm::relay::transform::DeviceAwareExprMutator::VisitExpr_(tvm::relay::CallNode const*)
  16: tvm::relay::tec::LowerTensorExprMutator::DeviceAwareVisitExpr_(tvm::relay::CallNode const*)
  15: tvm::relay::tec::LowerTensorExprMutator::MakeLoweredCall(tvm::BaseFunc const&, tvm::GlobalVar const&, tvm::runtime::Array<tv
m::RelayExpr, void>, tvm::Span, tvm::Target const&, tvm::runtime::Map<tvm::GlobalVar, tvm::BaseFunc, void, void> const&)
  14: tvm::relay::tec::TECompilerImpl::LowerShapeFunc(tvm::relay::tec::CCacheKey const&)
  13: tvm::relay::tec::TECompilerImpl::LowerShapeFuncInternal(tvm::relay::tec::CCacheKey const&)
  12: tvm::relay::tec::ShapeFuncFor(tvm::relay::Function const&, tvm::Target const&, tvm::GlobalVarSupply)
  11: tvm::relay::tec::MakeShapeFunc::Create(tvm::relay::Function const&, tvm::Target const&, tvm::GlobalVarSupply)
  10: tvm::relay::tec::MakeShapeFunc::VisitExpr(tvm::RelayExpr const&)
  9: tvm::relay::backend::MemoizedExprTranslator<tvm::runtime::Array<tvm::te::Tensor, void> >::VisitExpr(tvm::RelayExpr const&)
  8: tvm::relay::ExprFunctor<tvm::runtime::Array<tvm::te::Tensor, void> (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&)
  7: _ZZN3tvm5relay11ExprFunctorIFNS_7runtime5ArrayINS_2te6TensorEvEERKNS_
  6: tvm::relay::tec::MakeShapeFunc::VisitExpr_(tvm::relay::CallNode const*)
  6: tvm::relay::tec::MakeShapeFunc::VisitExpr_(tvm::relay::CallNode const*)
  5: tvm::relay::tec::MakeShapeFunc::VisitExpr(tvm::RelayExpr const&)
  4: tvm::relay::backend::MemoizedExprTranslator<tvm::runtime::Array<tvm::te::Tensor, void> >::VisitExpr(tvm::RelayExpr const&)
  3: tvm::relay::ExprFunctor<tvm::runtime::Array<tvm::te::Tensor, void> (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&)
  2: _ZZN3tvm5relay11ExprFunctorIFNS_7runtime5ArrayINS_2te6TensorEvEERKNS_
  1: tvm::relay::tec::MakeShapeFunc::VisitExpr_(tvm::relay::CallNode const*)
  0: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<TVMFuncCreateFromCFunc::{lambda(tvm::runtime::TVMArgs, 
tvm::runtime::TVMRetValue*)#2}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
  File "/home/azureuser/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 81, in cfun
    rv = local_pyfunc(*pyargs)
  File "/home/azureuser/tvm/python/tvm/relay/op/nn/_nn.py", line 1459, in batch_matmul_shape_func
    _batch_matmul_shape_func(
  File "/home/azureuser/miniconda3/envs/TVM/lib/python3.8/site-packages/decorator.py", line 232, in fun
    return caller(func, *(extras + args), **kw)
  File "/home/azureuser/tvm/python/tvm/te/hybrid/__init__.py", line 60, in wrapped_func
    return source_to_op(src, args, func.__globals__, closure_vars)
  File "/home/azureuser/tvm/python/tvm/te/hybrid/parser.py", line 644, in source_to_op
    parser = parse_python(src, args, symbols, closure_vars)
  File "/home/azureuser/tvm/python/tvm/te/hybrid/parser.py", line 614, in parse_python
    parser.parsed_body = parser.visit(root)
  File "/home/azureuser/miniconda3/envs/TVM/lib/python3.8/ast.py", line 371, in visit
    return visitor(node)
  File "/home/azureuser/tvm/python/tvm/te/hybrid/parser.py", line 229, in visit_Module
    return self.visit(node.body[0])
  File "/home/azureuser/miniconda3/envs/TVM/lib/python3.8/ast.py", line 371, in visit
    return visitor(node)
  File "/home/azureuser/tvm/python/tvm/te/hybrid/parser.py", line 242, in visit_FunctionDef
    res = visit_list_to_block(self.visit, node.body)
  File "/home/azureuser/tvm/python/tvm/te/hybrid/parser.py", line 58, in visit_list_to_block
    lst = [visit(stmt) for stmt in lst if not utils.is_docstring(stmt)]
  File "/home/azureuser/tvm/python/tvm/te/hybrid/parser.py", line 58, in <listcomp>
    lst = [visit(stmt) for stmt in lst if not utils.is_docstring(stmt)]
  File "/home/azureuser/miniconda3/envs/TVM/lib/python3.8/ast.py", line 371, in visit
    return visitor(node)
  File "/home/azureuser/tvm/python/tvm/te/hybrid/parser.py", line 303, in visit_Assign
    rhs = self.visit(node.value)
  File "/home/azureuser/miniconda3/envs/TVM/lib/python3.8/ast.py", line 371, in visit
    return visitor(node)
  File "/home/azureuser/tvm/python/tvm/te/hybrid/parser.py", line 470, in visit_Call
    args = [self.visit(i) for i in node.args]
  File "/home/azureuser/tvm/python/tvm/te/hybrid/parser.py", line 470, in <listcomp>
    args = [self.visit(i) for i in node.args]
  File "/home/azureuser/miniconda3/envs/TVM/lib/python3.8/ast.py", line 371, in visit
    return visitor(node)
  File "/home/azureuser/tvm/python/tvm/te/hybrid/parser.py", line 574, in visit_Tuple
    return tuple(self.visit(i) for i in node.elts)
  File "/home/azureuser/tvm/python/tvm/te/hybrid/parser.py", line 574, in <genexpr>
    return tuple(self.visit(i) for i in node.elts)
  File "/home/azureuser/miniconda3/envs/TVM/lib/python3.8/ast.py", line 371, in visit
    return visitor(node)
  File "/home/azureuser/tvm/python/tvm/te/hybrid/parser.py", line 387, in visit_Subscript
    _internal_assert(
  File "/home/azureuser/tvm/python/tvm/te/hybrid/utils.py", line 43, in _internal_assert
    raise ValueError(err)
ValueError: All indices are supposed to be constants

This is not related to cutlass. Maybe your are using Python 3.9 or later? You can probably fix this error by updating your TVM to include the commit https://github.com/apache/tvm/pull/12769 or using Python 3.8

I am using python 3.8.13 to run it.

I was able to run the same code in python 3.9.13 Since the commit is already merged into main branch, I already have the above update and it seems that it is not backward compatible with python 3.8