Build issues with tvm 0.7dev

I try to use auto tensorcore codegen feature in tvm 0.7dev as following tutorial. https://docs.tvm.ai/tutorials/optimize/opt_matmul_auto_tensorcore.html

commit:a2429c1fa61cf54d1890e887572c8fa93c467d7a I built tvm on docker image: nvidia/cuda:10.2-cudnn7-devel-ubuntu18.04

-- The C compiler identification is GNU 7.4.0
-- The CXX compiler identification is GNU 7.4.0
-- Check for working C compiler: /usr/bin/cc
-- Check for working C compiler: /usr/bin/cc -- works
-- Detecting C compiler ABI info
-- Detecting C compiler ABI info - done
-- Detecting C compile features
-- Detecting C compile features - done
-- Check for working CXX compiler: /usr/bin/c++
-- Check for working CXX compiler: /usr/bin/c++ -- works
-- Detecting CXX compiler ABI info
-- Detecting CXX compiler ABI info - done
-- Detecting CXX compile features
-- Detecting CXX compile features - done
-- Build with RPC support...
-- Build with Graph runtime support...
-- Build VTA runtime with target: sim
-- Looking for pthread.h
-- Looking for pthread.h - found
-- Looking for pthread_create
-- Looking for pthread_create - not found
-- Looking for pthread_create in pthreads
-- Looking for pthread_create in pthreads - not found
-- Looking for pthread_create in pthread
-- Looking for pthread_create in pthread - found
-- Found Threads: TRUE
-- Found CUDA_TOOLKIT_ROOT_DIR=/usr/local/cuda
-- Found CUDA_CUDA_LIBRARY=/usr/local/cuda/targets/x86_64-linux/lib/stubs/libcuda.so
-- Found CUDA_CUDART_LIBRARY=/usr/local/cuda/lib64/libcudart.so
-- Found CUDA_NVRTC_LIBRARY=/usr/local/cuda/lib64/libnvrtc.so
-- Found CUDA_CUDNN_LIBRARY=/usr/lib/x86_64-linux-gnu/libcudnn.so
-- Found CUDA_CUBLAS_LIBRARY=/usr/lib/x86_64-linux-gnu/libcublas.so
-- Found CUDA_CUBLASLT_LIBRARY=CUDA_CUBLASLT_LIBRARY-NOTFOUND
-- Build with CUDA support
-- Build with cuDNN support
-- Use llvm-config=llvm-config-8
-- /usr/lib/llvm-8/include
-- Found LLVM_INCLUDE_DIRS=/usr/lib/llvm-8/include
-- Found LLVM_DEFINITIONS= -D_GNU_SOURCE -D__STDC_CONSTANT_MACROS -D__STDC_FORMAT_MACROS -D__STDC_LIMIT_MACROS
-- Found TVM_LLVM_VERSION=80
-- Build with LLVM
-- Set TVM_LLVM_VERSION=80
-- Use BLAS library /usr/lib/x86_64-linux-gnu/libopenblas.so
-- Build with contrib.sort
-- Build with contrib.hybriddump
-- Performing Test SUPPORT_CXX11
-- Performing Test SUPPORT_CXX11 - Success
-- Build with c++11
-- Build with thread support...
-- Check if compiler accepts -pthread
-- Check if compiler accepts -pthread - yes
-- Configuring done
-- Generating done
-- Build files have been written to: /usr/tvm/build
Scanning dependencies of target tvm_runtime
[  1%] Building CXX object CMakeFiles/tvm_runtime.dir/src/runtime/builtin_fp16.cc.o
[  1%] Building CXX object CMakeFiles/tvm_runtime.dir/src/runtime/cpu_device_api.cc.o
[  1%] Building CXX object CMakeFiles/tvm_runtime.dir/src/runtime/c_runtime_api.cc.o
[  1%] Building CXX object CMakeFiles/tvm_runtime.dir/src/runtime/container.cc.o
[  2%] Building CXX object CMakeFiles/tvm_runtime.dir/src/runtime/module.cc.o
[  2%] Building CXX object CMakeFiles/tvm_runtime.dir/src/runtime/ndarray.cc.o
[  2%] Building CXX object CMakeFiles/tvm_runtime.dir/src/runtime/file_util.cc.o
[  2%] Building CXX object CMakeFiles/tvm_runtime.dir/src/runtime/dso_library.cc.o
[  2%] Building CXX object CMakeFiles/tvm_runtime.dir/src/runtime/library_module.cc.o
[  3%] Building CXX object CMakeFiles/tvm_runtime.dir/src/runtime/object.cc.o
[  3%] Building CXX object CMakeFiles/tvm_runtime.dir/src/runtime/registry.cc.o
[  3%] Building CXX object CMakeFiles/tvm_runtime.dir/src/runtime/system_library.cc.o
[  4%] Building CXX object CMakeFiles/tvm_runtime.dir/src/runtime/thread_pool.cc.o
[  4%] Building CXX object CMakeFiles/tvm_runtime.dir/src/runtime/threading_backend.cc.o
[  4%] Building CXX object CMakeFiles/tvm_runtime.dir/src/runtime/workspace_pool.cc.o
[  4%] Building CXX object CMakeFiles/tvm_runtime.dir/src/runtime/vm/executable.cc.o
[  5%] Building CXX object CMakeFiles/tvm_runtime.dir/src/runtime/vm/memory_manager.cc.o
[  5%] Building CXX object CMakeFiles/tvm_runtime.dir/src/runtime/vm/vm.cc.o
[  5%] Building CXX object CMakeFiles/tvm_runtime.dir/3rdparty/bfloat16/bfloat16.cc.o
[  6%] Building CXX object CMakeFiles/tvm_runtime.dir/src/runtime/rpc/rpc_device_api.cc.o
[  6%] Building CXX object CMakeFiles/tvm_runtime.dir/src/runtime/rpc/rpc_event_impl.cc.o
[  6%] Building CXX object CMakeFiles/tvm_runtime.dir/src/runtime/rpc/rpc_module.cc.o
Scanning dependencies of target tvm
[  7%] Building CXX object CMakeFiles/tvm.dir/src/node/container.cc.o
[  7%] Building CXX object CMakeFiles/tvm.dir/src/node/reflection.cc.o
[  7%] Building CXX object CMakeFiles/tvm.dir/src/node/repr_printer.cc.o
[  8%] Building CXX object CMakeFiles/tvm.dir/src/node/serialization.cc.o
In file included from /usr/tvm/src/node/reflection.cc:28:0:
/usr/tvm/include/tvm/ir/attrs.h: In lambda function:
/usr/tvm/include/tvm/ir/attrs.h:768:21: error: 'strcmp' is not a member of 'std'
           if (!std::strcmp(key, args.values[i].v_str)) {
                     ^~~~~~
[  8%] Building CXX object CMakeFiles/tvm.dir/src/ir/adt.cc.o
make[2]: *** [CMakeFiles/tvm.dir/src/node/reflection.cc.o] Error 1
make[2]: *** Waiting for unfinished jobs....
CMakeFiles/tvm.dir/build.make:86: recipe for target 'CMakeFiles/tvm.dir/src/node/reflection.cc.o' failed
[  8%] Building CXX object CMakeFiles/tvm_runtime.dir/src/runtime/rpc/rpc_server_env.cc.o
[ 10%] Building CXX object CMakeFiles/tvm_runtime.dir/src/runtime/rpc/rpc_session.cc.o
[ 10%] Building CXX object CMakeFiles/tvm_runtime.dir/src/runtime/rpc/rpc_socket_impl.cc.o
[ 10%] Building CXX object CMakeFiles/tvm_runtime.dir/src/runtime/graph/graph_runtime.cc.o
In file included from /usr/tvm/include/tvm/ir/type_relation.h:30:0,
                 from /usr/tvm/include/tvm/relay/type.h:29,
                 from /usr/tvm/src/ir/adt.cc:24:
/usr/tvm/include/tvm/ir/attrs.h: In lambda function:
/usr/tvm/include/tvm/ir/attrs.h:768:21: error: 'strcmp' is not a member of 'std'
           if (!std::strcmp(key, args.values[i].v_str)) {
                     ^~~~~~
CMakeFiles/tvm.dir/build.make:158: recipe for target 'CMakeFiles/tvm.dir/src/ir/adt.cc.o' failed
make[2]: *** [CMakeFiles/tvm.dir/src/ir/adt.cc.o] Error 1
[ 11%] Building CXX object CMakeFiles/tvm_runtime.dir/src/runtime/cuda/cuda_device_api.cc.o
[ 11%] Building CXX object CMakeFiles/tvm_runtime.dir/src/runtime/cuda/cuda_module.cc.o
[ 11%] Building CXX object CMakeFiles/tvm_runtime.dir/src/runtime/contrib/cudnn/conv_forward.cc.o
[ 11%] Building CXX object CMakeFiles/tvm_runtime.dir/src/runtime/contrib/cudnn/cudnn_utils.cc.o
[ 12%] Building CXX object CMakeFiles/tvm_runtime.dir/src/runtime/contrib/cblas/cblas.cc.o
[ 12%] Building CXX object CMakeFiles/tvm_runtime.dir/src/runtime/contrib/sort/sort.cc.o
CMakeFiles/Makefile2:557: recipe for target 'CMakeFiles/tvm.dir/all' failed
make[1]: *** [CMakeFiles/tvm.dir/all] Error 2
make[1]: *** Waiting for unfinished jobs....
[ 12%] Linking CXX shared library libtvm_runtime.so
[ 12%] Built target tvm_runtime
Makefile:129: recipe for target 'all' failed

Any idea how to resolve it?

This is my DockerFile

FROM nvidia/cuda:10.2-cudnn7-devel-ubuntu18.04

RUN apt-get update --fix-missing

COPY tvm/docker/install/ubuntu_install_core.sh /install/ubuntu_install_core.sh
RUN bash /install/ubuntu_install_core.sh

# Python: basic dependencies
RUN apt-get update && apt-get install -y python3-dev python3-pip
RUN pip3 install numpy nose-timer cython decorator scipy

# LLVM
#RUN echo deb http://apt.llvm.org/xenial/ llvm-toolchain-xenial-6.0 main \
#     >> /etc/apt/sources.list.d/llvm.list && \
#     wget -O - http://apt.llvm.org/llvm-snapshot.gpg.key|sudo apt-key add - && \
#     apt-get update && apt-get install -y --force-yes llvm-6.0
COPY tvm/docker/install/ubuntu_install_llvm.sh /install/ubuntu_install_llvm.sh
RUN bash /install/ubuntu_install_llvm.sh

# Jupyter notebook.
RUN pip3 install matplotlib Image Pillow jupyter[notebook]

# Deep learning frameworks
RUN pip3 install mxnet tensorflow keras gluoncv

# Build TVM
COPY . /usr/
COPY tvm/docker/install/install_tvm_gpu.sh /install/install_tvm_gpu.sh
RUN bash /install/install_tvm_gpu.sh

# Environment variables
ENV PYTHONPATH=/usr/tvm/python:/usr/tvm/topi/python:/usr/tvm/nnvm/python/:/usr/tvm/vta/python:${PYTHONPATH}
ENV PATH=/usr/local/nvidia/bin:${PATH}
ENV PATH=/usr/local/cuda/bin:${PATH}
ENV LD_LIBRARY_PATH=/usr/local/cuda/lib64:/usr/local/nvidia/lib64:${LD_LIBRARY_PATH}

add #include <cstring> in attrs.h.

Thanks for your quick reply!

I have solved the tvm compile problem.

However, when i try to run the example, i met the following problem.

Device: Nvidia Tesla T4 (compute_version:7.5)

I run the example with default parameters (only change dtype from float16 to int8).

xgboost search log:

ConfigSpace (len=288, space_map=
   0 bx: OtherOption([2, 4, 8]) len=3
   1 by: OtherOption([8, 16, 32, 64]) len=4
   2 step_k: OtherOption([1, 2, 4, 8, 16, 32]) len=6
   3 v: OtherOption([4, 8, 16, 32]) len=4
)
Get devices for measurement successfully!
No: 1   GFLOPS: 828.26/828.26   result: MeasureResult(costs=(2.0256e-05,), error_no=0, all_cost=0.78989791       87011719, timestamp=1582783952.6947682) [('bx', 2), ('by', 64), ('step_k', 32), ('v', 8)],None,141
No: 2   GFLOPS: 1107.90/1107.90 result: MeasureResult(costs=(1.5143200000000001e-05,), error_no=0, all_cos       t=0.8722705841064453, timestamp=1582783953.1710722)     [('bx', 2), ('by', 16), ('step_k', 16), ('v', 8)],       None,123
No: 3   GFLOPS: 372.70/1107.90  result: MeasureResult(costs=(4.5015399999999996e-05,), error_no=0, all_cos       t=0.7959163188934326, timestamp=1582783953.60059)       [('bx', 4), ('by', 32), ('step_k', 1), ('v', 8)],N       one,79
No: 4   GFLOPS: 0.00/1107.90    result: MeasureResult(costs=(TVMError('Traceback (most recent call last):\       n  File "/usr/tvm/src/target/source/codegen_cuda.cc", line 225\nTVMError: Cannot convert type int8x32 to C       UDA type\n[bt] (0) /usr/tvm/build/libtvm.so(dmlc::StackTrace[abi:cxx11](unsigned long)+0x1f5) [0x7f7ca1f2f       6a5]\n[bt] (1) /usr/tvm/build/libtvm.so(dmlc::LogMessageFatal::~LogMessageFatal()+0x3e) [0x7f7ca1f303ae]\n       [bt] (2) /usr/tvm/build/libtvm.so(tvm::codegen::CodeGenCUDA::PrintType(tvm::runtime::DataType, std::ostrea       m&)+0xd2) [0x7f7ca22f7912]\n[bt] (3) /usr/tvm/build/libtvm.so(tvm::codegen::CodeGenC::GetBufferRef[abi:cxx       11](tvm::runtime::DataType, tvm::tir::VarNode const*, tvm::PrimExpr)+0x232) [0x7f7ca22e99e2]\n[bt] (4) /us       r/tvm/build/libtvm.so(tvm::codegen::CodeGenC::GetVecLoad[abi:cxx11](tvm::runtime::DataType, tvm::tir::VarN       ode const*, tvm::PrimExpr)+0x3c) [0x7f7ca22df73c]\n[bt] (5) /usr/tvm/build/libtvm.so(tvm::codegen::CodeGen       C::VisitExpr_(tvm::tir::LoadNode const*, std::ostream&)+0x1a8) [0x7f7ca22e39e8]\n[bt] (6) /usr/tvm/build/l       ibtvm.so(tvm::codegen::CodeGenC::PrintExpr(tvm::PrimExpr const&, std::ostream&)+0xaf) [0x7f7ca22e0f9f]\n[b       t] (7) /usr/tvm/build/libtvm.so(tvm::codegen::CodeGenC::PrintExpr[abi:cxx11](tvm::PrimExpr const&)+0x47) [       0x7f7ca22f0e27]\n[bt] (8) /usr/tvm/build/libtvm.so(tvm::codegen::CodeGenC::VisitStmt_(tvm::tir::StoreNode        const*)+0x3ad) [0x7f7ca22e441d]\n[bt] (9) /usr/tvm/build/libtvm.so(tvm::NodeFunctor<void (tvm::runtime::Ob       jectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>*)>::operator()(tvm::runtime::ObjectRef        const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>*) const+0x56) [0x7f7ca1ff4d36]\n',),), error_       no=2, all_cost=0.028113603591918945, timestamp=1582783952.2496214)      [('bx', 8), ('by', 16), ('step_k',        16), ('v', 32)],None,269
...
XGB stopped. Best iteration: [87] tr-a-recall@64:0.81142        tr-map:0.33333
XGB train: 8.96 obs: 280        error: 40       n_cache: 288
SA iter: 50     last_update: 20 max-0: -inf     max-1: 1.28     temp: 0.90      elapsed: 0.75
SA iter: 70     last_update: 20 elapsed: 1.04
SA Maximums: [(1.2786218, 175), (1.148524, 85)]
...
Finish loading 288 records

Best Config:

[('bx', 2), ('by', 32), ('step_k', 32), ('v', 8)],None,138
Finish loading 288 records
Cannot find config for target=cuda, workload=None. A fallback configuration is used, which may bring great        performance regression.
produce compute {
  // attr [iter_var(blockIdx.y, , blockIdx.y)] thread_extent = 4
  // attr [compute.local] storage_scope = "local"
  allocate compute.local[int32 * 8]
  // attr [A.shared] storage_scope = "shared"
  allocate A.shared[int8 * 128]
  // attr [B.shared] storage_scope = "shared"
  allocate B.shared[int8 * 256]
  // attr [A.shared.local] storage_scope = "local"
  allocate A.shared.local[int8 * 16]
  // attr [B.shared.local] storage_scope = "local"
  allocate B.shared.local[int8 * 128]
  // attr [iter_var(blockIdx.x, , blockIdx.x)] thread_extent = 32
  // attr [iter_var(threadIdx.z, , threadIdx.z)] thread_extent = 1
  // attr [iter_var(threadIdx.y, , threadIdx.y)] thread_extent = 8
  // attr [iter_var(threadIdx.x, , threadIdx.x)] thread_extent = 2
  produce compute.local {
    for (j.c.init, 0, 8) {
      compute.local[j.c.init] = 0
    }
    // attr [iter_var(k.outer, )] pragma_tensor_core = 1
    for (k.outer, 0, 32) {
      produce A.shared {
        for (ax0.ax1.outer.fused.outer, 0, 2) {
          // attr [iter_var(threadIdx.y, , threadIdx.y)] thread_extent = 8
          // attr [iter_var(threadIdx.z, , threadIdx.z)] thread_extent = 1
          // attr [iter_var(threadIdx.x, , threadIdx.x)] thread_extent = 2
          A.shared[ramp((((ax0.ax1.outer.fused.outer*64) + (threadIdx.y*8)) + (threadIdx.x*4)), 1, 4)] = A       [ramp(((((((blockIdx.y*4096) + (ax0.ax1.outer.fused.outer*2048)) + (floordiv(threadIdx.y, 2)*512)) + (k.ou       ter*16)) + (floormod(threadIdx.y, 2)*8)) + (threadIdx.x*4)), 1, 4)]
        }
      }
      produce B.shared {
        for (ax0.ax1.outer.fused.outer, 0, 4) {
          // attr [iter_var(threadIdx.y, , threadIdx.y)] thread_extent = 8
          // attr [iter_var(threadIdx.z, , threadIdx.z)] thread_extent = 1
          // attr [iter_var(threadIdx.x, , threadIdx.x)] thread_extent = 2
          B.shared[ramp((((ax0.ax1.outer.fused.outer*64) + (threadIdx.y*8)) + (threadIdx.x*4)), 1, 4)] = B       [ramp(((((((k.outer*8192) + (ax0.ax1.outer.fused.outer*2048)) + (floordiv(threadIdx.y, 2)*512)) + (blockId       x.x*16)) + (floormod(threadIdx.y, 2)*8)) + (threadIdx.x*4)), 1, 4)]
        }
      }
      produce A.shared.local {
        for (ax1, 0, 16) {
          A.shared.local[ax1] = A.shared[((threadIdx.y*16) + ax1)]
        }
      }
      produce B.shared.local {
        for (ax0, 0, 16) {
          for (ax1, 0, 8) {
            B.shared.local[((ax0*8) + ax1)] = B.shared[(((ax0*16) + (threadIdx.x*8)) + ax1)]
          }
        }
      }
      for (k.inner.inner, 0, 16) {
        for (j.c, 0, 8) {
          compute.local[j.c] = (compute.local[j.c] + (int32(A.shared.local[k.inner.inner])*int32(B.shared.       local[((k.inner.inner*8) + j.c)])))
        }
      }
    }
  }
  for (j.inner.inner.inner, 0, 8) {
    compute[(((((blockIdx.y*4096) + (threadIdx.y*512)) + (blockIdx.x*16)) + (threadIdx.x*8)) + j.inner.inn       er.inner)] = compute.local[j.inner.inner.inner]
  }
}

cuda source:

#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 610)
#include <sm_61_intrinsics.h>
#endif
extern "C" __global__ void default_function_kernel0( signed char* __restrict__ A,  signed char* __restrict       __ B,  int* __restrict__ compute) {
   int compute_local[8];
  __shared__ signed char A_shared[128];
  __shared__ signed char B_shared[256];
   signed char A_shared_local[16];
   signed char B_shared_local[128];
  for (int j_c_init = 0; j_c_init < 8; ++j_c_init) {
    compute_local[(j_c_init)] = 0;
  }
  for (int k_outer = 0; k_outer < 32; ++k_outer) {
    __syncthreads();
    for (int ax0_ax1_outer_fused_outer = 0; ax0_ax1_outer_fused_outer < 2; ++ax0_ax1_outer_fused_outer) {
      (( int*)(A_shared + ((((ax0_ax1_outer_fused_outer * 64) + (((int)threadIdx.y) * 8)) + (((int)threadI       dx.x) * 4)))))[0] = (( int*)(A + (((((((((int)blockIdx.y) * 4096) + (ax0_ax1_outer_fused_outer * 2048)) +        ((((int)threadIdx.y) >> 1) * 512)) + (k_outer * 16)) + ((((int)threadIdx.y) & 1) * 8)) + (((int)threadIdx.       x) * 4)))))[0];
    }
    for (int ax0_ax1_outer_fused_outer1 = 0; ax0_ax1_outer_fused_outer1 < 4; ++ax0_ax1_outer_fused_outer1)        {
      (( int*)(B_shared + ((((ax0_ax1_outer_fused_outer1 * 64) + (((int)threadIdx.y) * 8)) + (((int)thread       Idx.x) * 4)))))[0] = (( int*)(B + (((((((k_outer * 8192) + (ax0_ax1_outer_fused_outer1 * 2048)) + ((((int)       threadIdx.y) >> 1) * 512)) + (((int)blockIdx.x) * 16)) + ((((int)threadIdx.y) & 1) * 8)) + (((int)threadId       x.x) * 4)))))[0];
    }
    __syncthreads();
    for (int ax1 = 0; ax1 < 16; ++ax1) {
      A_shared_local[(ax1)] = A_shared[(((((int)threadIdx.y) * 16) + ax1))];
    }
    for (int ax0 = 0; ax0 < 16; ++ax0) {
      for (int ax11 = 0; ax11 < 8; ++ax11) {
        B_shared_local[(((ax0 * 8) + ax11))] = B_shared[((((ax0 * 16) + (((int)threadIdx.x) * 8)) + ax11))       ];
      }
    }
    for (int k_inner_inner = 0; k_inner_inner < 16; ++k_inner_inner) {
      for (int j_c = 0; j_c < 8; ++j_c) {
        compute_local[(j_c)] = (compute_local[(j_c)] + (((int)A_shared_local[(k_inner_inner)]) * ((int)B_s       hared_local[(((k_inner_inner * 8) + j_c))])));
      }
    }
  }
  for (int j_inner_inner_inner = 0; j_inner_inner_inner < 8; ++j_inner_inner_inner) {
    compute[((((((((int)blockIdx.y) * 4096) + (((int)threadIdx.y) * 512)) + (((int)blockIdx.x) * 16)) + ((       (int)threadIdx.x) * 8)) + j_inner_inner_inner))] = compute_local[(j_inner_inner_inner)];
  }
}

It seems that normal cuda source code (not TensorCore source code, wmma API )was generated.

If I remember correctly, we only support FP16 TensorCore. However, I don’t be familiar with this part. Maybe @Hzfengsy know the detail.

According to the error log “TVMError: Cannot convert type int8x32 to CUDA type”, I changed the input matrix shape from “M, N, L = 512, 32, 512” to “M, N, L = 32, 32, 32”.

The error log disappeared, but still normal cuda source code generated.

xgboost search log:

    ConfigSpace (len=288, space_map=
   0 bx: OtherOption([2, 4, 8]) len=3
   1 by: OtherOption([8, 16, 32, 64]) len=4
   2 step_k: OtherOption([1, 2, 4, 8, 16, 32]) len=6
   3 v: OtherOption([4, 8, 16, 32]) len=4
)
Get devices for measurement successfully!
No: 1   GFLOPS: 8.67/8.67       result: MeasureResult(costs=(7.5596e-06,), error_no=0, all_cost=0.8503806591033936, timestamp=1582785647.0259442)        [('bx', 4), ('by', 32), ('step_k', 4), ('v', 32)],None,247
No: 2   GFLOPS: 8.23/8.67       result: MeasureResult(costs=(7.959e-06,), error_no=0, all_cost=0.8734729290008545, timestamp=1582785647.4638467) [('bx', 2), ('by', 32), ('step_k', 32), ('v', 8)],None,138
No: 3   GFLOPS: 7.92/8.67       result: MeasureResult(costs=(8.279399999999999e-06,), error_no=0, all_cost=0.8526017665863037, timestamp=1582785647.8984694)     [('bx', 2), ('by', 64), ('step_k', 16), ('v', 16)],None,201
No: 4   GFLOPS: 9.19/9.19       result: MeasureResult(costs=(7.1314e-06,), error_no=0, all_cost=0.8400857448577881, timestamp=1582785648.328787) [('bx', 4), ('by', 32), ('step_k', 16), ('v', 4)],None,55
No: 5   GFLOPS: 8.60/9.19       result: MeasureResult(costs=(7.617e-06,), error_no=0, all_cost=0.8763003349304199, timestamp=1582785648.77706)   [('bx', 4), ('by', 16), ('step_k', 4), ('v', 32)],None,244

Best config:

[('bx', 2), ('by', 32), ('step_k', 4), ('v', 4)],None,30
Finish loading 288 records
Cannot find config for target=cuda, workload=None. A fallback configuration is used, which may bring great performance regression.
produce compute {
  // attr [iter_var(blockIdx.y, , blockIdx.y)] thread_extent = 4
  // attr [compute.local] storage_scope = "local"
  allocate compute.local[int32 * 8]
  // attr [A.shared] storage_scope = "shared"
  allocate A.shared[int8 * 128]
  // attr [B.shared] storage_scope = "shared"
  allocate B.shared[int8 * 256]
  // attr [A.shared.local] storage_scope = "local"
  allocate A.shared.local[int8 * 16]
  // attr [B.shared.local] storage_scope = "local"
  allocate B.shared.local[int8 * 128]
  // attr [iter_var(blockIdx.x, , blockIdx.x)] thread_extent = 2
  // attr [iter_var(threadIdx.z, , threadIdx.z)] thread_extent = 1
  // attr [iter_var(threadIdx.y, , threadIdx.y)] thread_extent = 8
  // attr [iter_var(threadIdx.x, , threadIdx.x)] thread_extent = 2
  produce compute.local {
    for (j.c.init, 0, 8) {
      compute.local[j.c.init] = 0
    }
    // attr [iter_var(k.outer, )] pragma_tensor_core = 1
    for (k.outer, 0, 2) {
      produce A.shared {
        for (ax0.ax1.outer.fused.outer, 0, 2) {
          // attr [iter_var(threadIdx.y, , threadIdx.y)] thread_extent = 8
          // attr [iter_var(threadIdx.z, , threadIdx.z)] thread_extent = 1
          // attr [iter_var(threadIdx.x, , threadIdx.x)] thread_extent = 2
          A.shared[ramp((((ax0.ax1.outer.fused.outer*64) + (threadIdx.y*8)) + (threadIdx.x*4)), 1, 4)] = A[ramp(((((((blockIdx.y*256) + (ax0.ax1.outer.fused.outer*128)) + (floordiv(threadIdx.y, 2)*32)) + (k.outer*16)) + (floormod(threadIdx.y, 2)*8)) + (threadIdx.x*4)), 1, 4)]
        }
      }
      produce B.shared {
        for (ax0.ax1.outer.fused.outer, 0, 4) {
          // attr [iter_var(threadIdx.y, , threadIdx.y)] thread_extent = 8
          // attr [iter_var(threadIdx.z, , threadIdx.z)] thread_extent = 1
          // attr [iter_var(threadIdx.x, , threadIdx.x)] thread_extent = 2
          B.shared[ramp((((ax0.ax1.outer.fused.outer*64) + (threadIdx.y*8)) + (threadIdx.x*4)), 1, 4)] = B[ramp(((((((k.outer*512) + (ax0.ax1.outer.fused.outer*128)) + (floordiv(threadIdx.y, 2)*32)) + (blockIdx.x*16)) + (floormod(threadIdx.y, 2)*8)) + (threadIdx.x*4)), 1, 4)]
        }
      }
      produce A.shared.local {
        for (ax1, 0, 16) {
          A.shared.local[ax1] = A.shared[((threadIdx.y*16) + ax1)]
        }
      }
      produce B.shared.local {
        for (ax0, 0, 16) {
          for (ax1, 0, 8) {
            B.shared.local[((ax0*8) + ax1)] = B.shared[(((ax0*16) + (threadIdx.x*8)) + ax1)]
          }
        }
      }
      for (k.inner.inner, 0, 16) {
        for (j.c, 0, 8) {
          compute.local[j.c] = (compute.local[j.c] + (int32(A.shared.local[k.inner.inner])*int32(B.shared.local[((k.inner.inner*8) + j.c)])))
        }
      }
    }
  }
  for (j.inner.inner.inner, 0, 8) {
    compute[(((((blockIdx.y*256) + (threadIdx.y*32)) + (blockIdx.x*16)) + (threadIdx.x*8)) + j.inner.inner.inner)] = compute.local[j.inner.inner.inner]
  }
}

cuda source:

#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 610)
#include <sm_61_intrinsics.h>
#endif
extern "C" __global__ void default_function_kernel0( signed char* __restrict__ A,  signed char* __restrict__ B,  int* __restrict__ compute) {
   int compute_local[8];
  __shared__ signed char A_shared[128];
  __shared__ signed char B_shared[256];
   signed char A_shared_local[16];
   signed char B_shared_local[128];
  for (int j_c_init = 0; j_c_init < 8; ++j_c_init) {
    compute_local[(j_c_init)] = 0;
  }
  for (int k_outer = 0; k_outer < 2; ++k_outer) {
    __syncthreads();
    for (int ax0_ax1_outer_fused_outer = 0; ax0_ax1_outer_fused_outer < 2; ++ax0_ax1_outer_fused_outer) {
      (( int*)(A_shared + ((((ax0_ax1_outer_fused_outer * 64) + (((int)threadIdx.y) * 8)) + (((int)threadIdx.x) * 4)))))[0] = (( int*)(A + (((((((((int)blockIdx.y) * 256) + (ax0_ax1_outer_fused_outer * 128)) + ((((int)threadIdx.y) >> 1) * 32)) + (k_outer * 16)) + ((((int)threadIdx.y) & 1) * 8)) + (((int)threadIdx.x) * 4)))))[0];
    }
    for (int ax0_ax1_outer_fused_outer1 = 0; ax0_ax1_outer_fused_outer1 < 4; ++ax0_ax1_outer_fused_outer1) {
      (( int*)(B_shared + ((((ax0_ax1_outer_fused_outer1 * 64) + (((int)threadIdx.y) * 8)) + (((int)threadIdx.x) * 4)))))[0] = (( int*)(B + (((((((k_outer * 512) + (ax0_ax1_outer_fused_outer1 * 128)) + ((((int)threadIdx.y) >> 1) * 32)) + (((int)blockIdx.x) * 16)) + ((((int)threadIdx.y) & 1) * 8)) + (((int)threadIdx.x) * 4)))))[0];
    }
    __syncthreads();
    for (int ax1 = 0; ax1 < 16; ++ax1) {
      A_shared_local[(ax1)] = A_shared[(((((int)threadIdx.y) * 16) + ax1))];
    }
    for (int ax0 = 0; ax0 < 16; ++ax0) {
      for (int ax11 = 0; ax11 < 8; ++ax11) {
        B_shared_local[(((ax0 * 8) + ax11))] = B_shared[((((ax0 * 16) + (((int)threadIdx.x) * 8)) + ax11))];
      }
    }
    for (int k_inner_inner = 0; k_inner_inner < 16; ++k_inner_inner) {
      for (int j_c = 0; j_c < 8; ++j_c) {
        compute_local[(j_c)] = (compute_local[(j_c)] + (((int)A_shared_local[(k_inner_inner)]) * ((int)B_shared_local[(((k_inner_inner * 8) + j_c))])));
      }
    }
  }
  for (int j_inner_inner_inner = 0; j_inner_inner_inner < 8; ++j_inner_inner_inner) {
    compute[((((((((int)blockIdx.y) * 256) + (((int)threadIdx.y) * 32)) + (((int)blockIdx.x) * 16)) + (((int)threadIdx.x) * 8)) + j_inner_inner_inner))] = compute_local[(j_inner_inner_inner)];
  }
}

After the PR (https://github.com/apache/incubator-tvm/pull/4546) merged, we should support all types (Float16, Int8, even Int4 and Int1) on TensorCore. However, this problem is caused by the pass Auto TensorCore CodeGen, written by @MinminSun and @jcf94. Maybe they can offer helps.

Hi @songy, have you also tried fp16? Is it OK? This info may help us to analyze the problem.

Yes,i also tried fp16 and it failed too.

xgboost search log:

ConfigSpace (len=288, space_map=
   0 bx: OtherOption([2, 4, 8]) len=3
   1 by: OtherOption([8, 16, 32, 64]) len=4
   2 step_k: OtherOption([1, 2, 4, 8, 16, 32]) len=6
   3 v: OtherOption([4, 8, 16, 32]) len=4
)
Get devices for measurement successfully!
No: 1   GFLOPS: 9.49/9.49       result: MeasureResult(costs=(6.909e-06,), error_no=0, all_cost=0.8302960395812988, timestamp=1582792874.7657251) [('bx', 2), ('by', 16), ('step_k', 16), ('v', 4)],None,51
No: 2   GFLOPS: 1.26/9.49       result: MeasureResult(costs=(5.21092e-05,), error_no=0, all_cost=0.8451621532440186, timestamp=1582792875.139594)        [('bx', 2), ('by', 8), ('step_k', 32), ('v', 4)],None,60
No: 3   GFLOPS: 8.16/9.49       result: MeasureResult(costs=(8.0358e-06,), error_no=0, all_cost=0.8374130725860596, timestamp=1582792875.5084713)        [('bx', 4), ('by', 16), ('step_k', 32), ('v', 8)],None,136
No: 4   GFLOPS: 3.28/9.49       result: MeasureResult(costs=(2.00002e-05,), error_no=0, all_cost=0.8476271629333496, timestamp=1582792875.8783343)       [('bx', 8), ('by', 64), ('step_k', 32), ('v', 32)],None,287
No: 5   GFLOPS: 11.91/11.91     result: MeasureResult(costs=(5.5013999999999995e-06,), error_no=0, all_cost=0.8303091526031494, timestamp=1582792876.246076)     [('bx', 8), ('by', 16), ('step_k', 4), ('v', 4)],None,29
No: 6   GFLOPS: 8.19/11.91      result: MeasureResult(costs=(7.998199999999999e-06,), error_no=0, all_cost=0.8223295211791992, timestamp=1582792876.610781)      [('bx', 8), ('by', 32), ('step_k', 4), ('v', 16)],None,176
No: 7   GFLOPS: 5.85/11.91      result: MeasureResult(costs=(1.12028e-05,), error_no=0, all_cost=0.8253054618835449, timestamp=1582792876.9790044)       [('bx', 8), ('by', 64), ('step_k', 2), ('v', 32)],None,239
No: 8   GFLOPS: 9.88/11.91      result: MeasureResult(costs=(6.6352000000000004e-06,), error_no=0, all_cost=0.8190181255340576, timestamp=1582792877.3432121)    [('bx', 8), ('by', 64), ('step_k', 4), ('v', 4)],None,35
No: 9   GFLOPS: 12.63/12.63     result: MeasureResult(costs=(5.1884e-06,), error_no=0, all_cost=0.817601203918457, timestamp=1582792877.7091498) [('bx', 2), ('by', 16), ('step_k', 2), ('v', 8)],None,87
No: 10  GFLOPS: 0.00/12.63      result: MeasureResult(costs=(TVMError('Traceback (most recent call last):\n  File "/usr/tvm/src/target/source/codegen_cuda.cc", line 225\nTVMError: Cannot convert type float16x16 to CUDA type\n[bt] (0) /usr/tvm/build/libtvm.so(dmlc::StackTrace[abi:cxx11](unsigned long)+0x1f5) [0x7f01eb51b6a5]\n[bt] (1) /usr/tvm/build/libtvm.so(dmlc::LogMessageFatal::~LogMessageFatal()+0x3e) [0x7f01eb51c3ae]\n[bt] (2) /usr/tvm/build/libtvm.so(tvm::codegen::CodeGenCUDA::PrintType(tvm::runtime::DataType, std::ostream&)+0xd2) [0x7f01eb8e3912]\n[bt] (3) /usr/tvm/build/libtvm.so(tvm::codegen::CodeGenC::GetBufferRef[abi:cxx11](tvm::runtime::DataType, tvm::tir::VarNode const*, tvm::PrimExpr)+0x232) [0x7f01eb8d59e2]\n[bt] (4) /usr/tvm/build/libtvm.so(tvm::codegen::CodeGenC::GetVecLoad[abi:cxx11](tvm::runtime::DataType, tvm::tir::VarNode const*, tvm::PrimExpr)+0x3c) [0x7f01eb8cb73c]\n[bt] (5) /usr/tvm/build/libtvm.so(tvm::codegen::CodeGenC::VisitExpr_(tvm::tir::LoadNode const*, std::ostream&)+0x1a8) [0x7f01eb8cf9e8]\n[bt] (6) /usr/tvm/build/libtvm.so(tvm::codegen::CodeGenC::PrintExpr(tvm::PrimExpr const&, std::ostream&)+0xaf) [0x7f01eb8ccf9f]\n[bt] (7) /usr/tvm/build/libtvm.so(tvm::codegen::CodeGenC::PrintExpr[abi:cxx11](tvm::PrimExpr const&)+0x47) [0x7f01eb8dce27]\n[bt] (8) /usr/tvm/build/libtvm.so(tvm::codegen::CodeGenC::VisitStmt_(tvm::tir::StoreNode const*)+0x3ad) [0x7f01eb8d041d]\n[bt] (9) /usr/tvm/build/libtvm.so(tvm::NodeFunctor<void (tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>*)>::operator()(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>*) const+0x56) [0x7f01eb5e0d36]\n',),), error_no=2, all_cost=0.027663707733154297, timestamp=1582792874.311063)   [('bx', 2), ('by', 16), ('step_k', 8), ('v', 16)],None,183

Best config:

[('bx', 2), ('by', 16), ('step_k', 2), ('v', 4)],None,15
Finish loading 288 records
Cannot find config for target=cuda, workload=None. A fallback configuration is used, which may bring great performance regression.
produce compute {
  // attr [iter_var(blockIdx.y, , blockIdx.y)] thread_extent = 4
  // attr [compute.local] storage_scope = "local"
  allocate compute.local[float32 * 8]
  // attr [A.shared] storage_scope = "shared"
  allocate A.shared[float16 * 192]
  // attr [B.shared] storage_scope = "shared"
  allocate B.shared[float16 * 256]
  // attr [A.shared.local] storage_scope = "local"
  allocate A.shared.local[float16 * 16]
  // attr [B.shared.local] storage_scope = "local"
  allocate B.shared.local[float16 * 128]
  // attr [iter_var(blockIdx.x, , blockIdx.x)] thread_extent = 2
  // attr [iter_var(threadIdx.z, , threadIdx.z)] thread_extent = 1
  // attr [iter_var(threadIdx.y, , threadIdx.y)] thread_extent = 8
  // attr [iter_var(threadIdx.x, , threadIdx.x)] thread_extent = 2
  produce compute.local {
    for (j.c.init, 0, 8) {
      compute.local[j.c.init] = 0f
    }
    // attr [iter_var(k.outer, )] pragma_tensor_core = 1
    for (k.outer, 0, 2) {
      produce A.shared {
        for (ax0.ax1.outer.fused.outer, 0, 2) {
          // attr [iter_var(threadIdx.y, , threadIdx.y)] thread_extent = 8
          // attr [iter_var(threadIdx.z, , threadIdx.z)] thread_extent = 1
          // attr [iter_var(threadIdx.x, , threadIdx.x)] thread_extent = 2
          A.shared[ramp(((((ax0.ax1.outer.fused.outer*96) + (floordiv(threadIdx.y, 2)*24)) + (floormod(threadIdx.y, 2)*8)) + (threadIdx.x*4)), 1, 4)] = A[ramp(((((((blockIdx.y*256) + (ax0.ax1.outer.fused.outer*128)) + (floordiv(threadIdx.y, 2)*32)) + (k.outer*16)) + (floormod(threadIdx.y, 2)*8)) + (threadIdx.x*4)), 1, 4)]
        }
      }
      produce B.shared {
        for (ax0.ax1.outer.fused.outer, 0, 4) {
          // attr [iter_var(threadIdx.y, , threadIdx.y)] thread_extent = 8
          // attr [iter_var(threadIdx.z, , threadIdx.z)] thread_extent = 1
          // attr [iter_var(threadIdx.x, , threadIdx.x)] thread_extent = 2
          B.shared[ramp((((ax0.ax1.outer.fused.outer*64) + (threadIdx.y*8)) + (threadIdx.x*4)), 1, 4)] = B[ramp(((((((k.outer*512) + (ax0.ax1.outer.fused.outer*128)) + (floordiv(threadIdx.y, 2)*32)) + (blockIdx.x*16)) + (floormod(threadIdx.y, 2)*8)) + (threadIdx.x*4)), 1, 4)]
        }
      }
      produce A.shared.local {
        for (ax1, 0, 16) {
          A.shared.local[ax1] = A.shared[((threadIdx.y*24) + ax1)]
        }
      }
      produce B.shared.local {
        for (ax0, 0, 16) {
          for (ax1, 0, 8) {
            B.shared.local[((ax0*8) + ax1)] = B.shared[(((ax0*16) + (threadIdx.x*8)) + ax1)]
          }
        }
      }
      for (k.inner.inner, 0, 16) {
        for (j.c, 0, 8) {
          compute.local[j.c] = (compute.local[j.c] + (float32(A.shared.local[k.inner.inner])*float32(B.shared.local[((k.inner.inner*8) + j.c)])))
        }
      }
    }
  }
  for (j.inner.inner.inner, 0, 8) {
    compute[(((((blockIdx.y*256) + (threadIdx.y*32)) + (blockIdx.x*16)) + (threadIdx.x*8)) + j.inner.inner.inner)] = compute.local[j.inner.inner.inner]
  }
}

cuda source:

#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)
#include <cuda_fp16.h>
__device__ half max(half a, half b)
{
  return __hgt(__half(a), __half(b)) ? a : b;
}
__device__ half min(half a, half b)
{
  return __hlt(__half(a), __half(b)) ? a : b;
}
#else

typedef unsigned short uint16_t;
typedef unsigned char uint8_t;
typedef signed char int8_t;
typedef int int32_t;
typedef unsigned long long uint64_t;
typedef unsigned int uint32_t;

#define TVM_FORCE_INLINE inline __attribute__((always_inline))
#define TVM_XINLINE TVM_FORCE_INLINE __device__ __host__
#define TVM_ALIGNED(x) __attribute__ ((aligned(x)))
#define TVM_HALF_OPERATOR(RTYPE, OP)                              \
  TVM_XINLINE RTYPE operator OP (half a, half b) {                \
    return RTYPE(float(a) OP float(b));                           \
  }                                                               \
  template<typename T>                                            \
  TVM_XINLINE RTYPE operator OP (half a, T b) {                   \
    return RTYPE(float(a) OP float(b));                           \
  }                                                               \
  template<typename T>                                            \
  TVM_XINLINE RTYPE operator OP (T a, half b) {                   \
    return RTYPE(float(a) OP float(b));                           \
  }

#define TVM_HALF_ASSIGNOP(AOP, OP)                                \
  template<typename T>                                            \
  TVM_XINLINE half operator AOP (const T& a) {                    \
    return *this = half(float(*this) OP float(a));                \
  }                                                               \
  template<typename T>                                            \
  TVM_XINLINE half operator AOP (const volatile T& a) volatile {  \
    return *this = half(float(*this) OP float(a));                \
  }

class TVM_ALIGNED(2) half {
 public:
  uint16_t half_;

  static TVM_XINLINE half Binary(uint16_t value) {
    half res;
    res.half_ = value;
    return res;
  }

  TVM_XINLINE half() {}

  TVM_XINLINE half(const float& value) { constructor(value); }
  TVM_XINLINE explicit half(const double& value) { constructor(value); }
  TVM_XINLINE explicit half(const int8_t& value) { constructor(value); }
  TVM_XINLINE explicit half(const uint8_t& value) { constructor(value); }
  TVM_XINLINE explicit half(const int32_t& value) { constructor(value); }
  TVM_XINLINE explicit half(const uint32_t& value) { constructor(value); }
  TVM_XINLINE explicit half(const long long& value) { constructor(value); }
  TVM_XINLINE explicit half(const uint64_t& value) { constructor(value); }

  TVM_XINLINE operator float() const {                          \
    return float(half2float(half_));                            \
  }                                                             \
  TVM_XINLINE operator float() const volatile {                 \
    return float(half2float(half_));                            \
  }


  TVM_HALF_ASSIGNOP(+=, +)
  TVM_HALF_ASSIGNOP(-=, -)
  TVM_HALF_ASSIGNOP(*=, *)
  TVM_HALF_ASSIGNOP(/=, /)

  TVM_XINLINE half operator+() {
    return *this;
  }

  TVM_XINLINE half operator-() {
    return half(-float(*this));
  }

  TVM_XINLINE half operator=(const half& a) {
    half_ = a.half_;
    return a;
  }

  template<typename T>
  TVM_XINLINE half operator=(const T& a) {
    return *this = half(a);
  }

  TVM_XINLINE half operator=(const half& a) volatile {
    half_ = a.half_;
    return a;
  }

  template<typename T>
  TVM_XINLINE half operator=(const T& a) volatile {
    return *this = half(a);
  }

 private:
  union Bits {
    float f;
    int32_t si;
    uint32_t ui;
  };

  static int const fp16FractionBits = 10;
  static int const fp32FractionBits = 23;
  static int32_t const fp32FractionMask = ~(~0u << fp32FractionBits);   // == 0x7fffff
  static int32_t const fp32HiddenBit = 1 << fp32FractionBits;   // == 0x800000
  static int const shift = fp32FractionBits - fp16FractionBits;   // == 13
  static int const shiftSign = 16;
  static int32_t const expAdjust = 127 - 15;   // exp32-127 = exp16-15, so exp16 = exp32 - (127-15)

  static int32_t const infN = 0x7F800000;   // flt32 infinity
  static int32_t const maxN = 0x477FFFFF;   // max flt32 that's a flt16 normal after >> by shift
  static int32_t const minN = 0x38800000;   // min flt16 normal as a flt32
  static int32_t const maxZ = 0x33000000;   // max fp32 number that's still rounded to zero in fp16
  static int32_t const signN = 0x80000000;  // flt32 sign bit

  static int32_t const infC = infN >> shift;
  static int32_t const nanN = (infC + 1) << shift;   // minimum flt16 nan as a flt32
  static int32_t const maxC = maxN >> shift;
  static int32_t const minC = minN >> shift;
  static int32_t const signC = signN >> shiftSign;  // flt16 sign bit

  static int32_t const mulN = 0x52000000;  // (1 << 23) / minN
  static int32_t const mulC = 0x33800000;  // minN / (1 << (23 - shift))

  static int32_t const subC = 0x003FF;  // max flt32 subnormal down shifted
  static int32_t const norC = 0x00400;  // min flt32 normal down shifted

  static int32_t const maxD = infC - maxC - 1;
  static int32_t const minD = minC - subC - 1;

  TVM_XINLINE uint16_t float2half(const float& value) const {
    Bits v;
    v.f = value;
    uint32_t sign = v.si & signN;    // grab sign bit
    v.si ^= sign;                    // clear sign bit from v
    sign >>= shiftSign;              // logical shift sign to fp16 position

    if (v.si <= maxZ) {
      // Handle eventual zeros here to ensure
      // vshift will not exceed 32 below.
      v.ui = 0;
    } else if (v.si < minN) {
      // Handle denorms
      uint32_t exp32 = v.ui >> fp32FractionBits;
      int32_t exp16 = exp32 - expAdjust;
      // If exp16 == 0 (just into the denorm range), then significant should be shifted right 1.
      // Smaller (so negative) exp16 values should result in greater right shifts.
      uint32_t vshift = 1 - exp16;
      uint32_t significand = fp32HiddenBit | (v.ui & fp32FractionMask);
      v.ui = significand >> vshift;
      v.ui += (v.ui & 0x3fff) != 0x1000 || (significand & 0x7ff) ? 0x1000 : 0;
    } else if (v.si <= maxN) {
      // Handle norms
      v.ui += (v.ui & 0x3fff) != 0x1000 ? 0x1000 : 0;
      v.ui -= expAdjust << fp32FractionBits;
    } else if (v.si <= infN) {
      v.si = infN;
    } else if (v.si < nanN) {
      v.si = nanN;
    }

    v.ui >>= shift;
    return sign | (v.ui & 0x7fff);
  }

  // Same as above routine, except for addition of volatile keyword
  TVM_XINLINE uint16_t float2half(
    const volatile float& value) const volatile {
    Bits v;
    v.f = value;
    uint32_t sign = v.si & signN;    // grab sign bit
    v.si ^= sign;                    // clear sign bit from v
    sign >>= shiftSign;              // logical shift sign to fp16 position

    if (v.si <= maxZ) {
      // Handle eventual zeros here to ensure
      // vshift will not exceed 32 below.
      v.ui = 0;
    } else if (v.si < minN) {
      // Handle denorms
      uint32_t exp32 = v.ui >> fp32FractionBits;
      int32_t exp16 = exp32 - expAdjust;
      // If exp16 == 0 (just into the denorm range), then significant should be shifted right 1.
      // Smaller (so negative) exp16 values should result in greater right shifts.
      uint32_t vshift = 1 - exp16;
      uint32_t significand = fp32HiddenBit | (v.ui & fp32FractionMask);
      v.ui = significand >> vshift;
      v.ui += (v.ui & 0x3fff) != 0x1000 || (significand & 0x7ff) ? 0x1000 : 0;
    } else if (v.si <= maxN) {
      // Handle norms
      v.ui += (v.ui & 0x3fff) != 0x1000 ? 0x1000 : 0;
      v.ui -= expAdjust << fp32FractionBits;
    } else if (v.si <= infN) {
      v.si = infN;
    } else if (v.si < nanN) {
      v.si = nanN;
    }

    v.ui >>= shift;
    return sign | (v.ui & 0x7fff);
  }

  TVM_XINLINE float half2float(const uint16_t& value) const {
    Bits v;
    v.ui = value;
    int32_t sign = v.si & signC;
    v.si ^= sign;
    sign <<= shiftSign;
    v.si ^= ((v.si + minD) ^ v.si) & -(v.si > subC);
    v.si ^= ((v.si + maxD) ^ v.si) & -(v.si > maxC);
    Bits s;
    s.si = mulC;
    s.f *= v.si;
    int32_t mask = -(norC > v.si);
    v.si <<= shift;
    v.si ^= (s.si ^ v.si) & mask;
    v.si |= sign;
    return v.f;
  }

  TVM_XINLINE float half2float(
    const volatile uint16_t& value) const volatile {
    Bits v;
    v.ui = value;
    int32_t sign = v.si & signC;
    v.si ^= sign;
    sign <<= shiftSign;
    v.si ^= ((v.si + minD) ^ v.si) & -(v.si > subC);
    v.si ^= ((v.si + maxD) ^ v.si) & -(v.si > maxC);
    Bits s;
    s.si = mulC;
    s.f *= v.si;
    int32_t mask = -(norC > v.si);
    v.si <<= shift;
    v.si ^= (s.si ^ v.si) & mask;
    v.si |= sign;
    return v.f;
  }

  template<typename T>
  TVM_XINLINE void constructor(const T& value) {
    half_ = float2half(float(value));
  }
};

TVM_HALF_OPERATOR(half, +)
TVM_HALF_OPERATOR(half, -)
TVM_HALF_OPERATOR(half, *)
TVM_HALF_OPERATOR(half, /)
TVM_HALF_OPERATOR(bool, >)
TVM_HALF_OPERATOR(bool, <)
TVM_HALF_OPERATOR(bool, >=)
TVM_HALF_OPERATOR(bool, <=)

TVM_XINLINE half __float2half_rn(const float a) {
  return half(a);
}
#endif


// Pack two half values.
static inline __device__ __host__ unsigned
__pack_half2(const half x, const half y) {
  unsigned v0 = *((unsigned short *)&x);
  unsigned v1 = *((unsigned short *)&y);
  return (v0 << 16) | v1;
}
extern "C" __global__ void default_function_kernel0( half* __restrict__ A,  half* __restrict__ B,  float* __restrict__ compute) {
   float compute_local[8];
  __shared__ half A_shared[192];
  __shared__ half B_shared[256];
   half A_shared_local[16];
   half B_shared_local[128];
  for (int j_c_init = 0; j_c_init < 8; ++j_c_init) {
    compute_local[(j_c_init)] = 0.000000e+00f;
  }
  for (int k_outer = 0; k_outer < 2; ++k_outer) {
    __syncthreads();
    for (int ax0_ax1_outer_fused_outer = 0; ax0_ax1_outer_fused_outer < 2; ++ax0_ax1_outer_fused_outer) {
      (( uint2*)(A_shared + (((((ax0_ax1_outer_fused_outer * 96) + ((((int)threadIdx.y) >> 1) * 24)) + ((((int)threadIdx.y) & 1) * 8)) + (((int)threadIdx.x) * 4)))))[0] = (( uint2*)(A + (((((((((int)blockIdx.y) * 256) + (ax0_ax1_outer_fused_outer * 128)) + ((((int)threadIdx.y) >> 1) * 32)) + (k_outer * 16)) + ((((int)threadIdx.y) & 1) * 8)) + (((int)threadIdx.x) * 4)))))[0];
    }
    for (int ax0_ax1_outer_fused_outer1 = 0; ax0_ax1_outer_fused_outer1 < 4; ++ax0_ax1_outer_fused_outer1) {
      (( uint2*)(B_shared + ((((ax0_ax1_outer_fused_outer1 * 64) + (((int)threadIdx.y) * 8)) + (((int)threadIdx.x) * 4)))))[0] = (( uint2*)(B + (((((((k_outer * 512) + (ax0_ax1_outer_fused_outer1 * 128)) + ((((int)threadIdx.y) >> 1) * 32)) + (((int)blockIdx.x) * 16)) + ((((int)threadIdx.y) & 1) * 8)) + (((int)threadIdx.x) * 4)))))[0];
    }
    __syncthreads();
    for (int ax1 = 0; ax1 < 16; ++ax1) {
      A_shared_local[(ax1)] = A_shared[(((((int)threadIdx.y) * 24) + ax1))];
    }
    for (int ax0 = 0; ax0 < 16; ++ax0) {
      for (int ax11 = 0; ax11 < 8; ++ax11) {
        B_shared_local[(((ax0 * 8) + ax11))] = B_shared[((((ax0 * 16) + (((int)threadIdx.x) * 8)) + ax11))];
      }
    }
    for (int k_inner_inner = 0; k_inner_inner < 16; ++k_inner_inner) {
      for (int j_c = 0; j_c < 8; ++j_c) {
        compute_local[(j_c)] = (compute_local[(j_c)] + (((float)A_shared_local[(k_inner_inner)]) * ((float)B_shared_local[(((k_inner_inner * 8) + j_c))])));
      }
    }
  }
  for (int j_inner_inner_inner = 0; j_inner_inner_inner < 8; ++j_inner_inner_inner) {
    compute[((((((((int)blockIdx.y) * 256) + (((int)threadIdx.y) * 32)) + (((int)blockIdx.x) * 16)) + (((int)threadIdx.x) * 8)) + j_inner_inner_inner))] = compute_local[(j_inner_inner_inner)];
  }
}

This might be the cause. And this looks odd, because there’s valid record, like the one bellow. I don’t know why fallback happened.

Thanks!

Could the following TVM Error info be the cause?

result: MeasureResult(costs=(TVMError('Traceback "/usr/tvm/src/target/source/codegen_cuda.cc", line 225\nTVMError: Cannot convertbt] (0) /usr/tvm/build/libtvm.so(dmlc::StackTrace[abi:cxx11](unsigned long)+0x1f5r/tvm/build/libtvm.so(dmlc::LogMessageFatal::~LogMessageFatal()+0x3e) [0x7f01eb51ibtvm.so(tvm::codegen::CodeGenCUDA::PrintType(tvm::runtime::DataType, std::ostrea (3) /usr/tvm/build/libtvm.so(tvm::codegen::CodeGenC::GetBufferRef[abi:cxx11](tvmarNode const*, tvm::PrimExpr)+0x232) [0x7f01eb8d59e2]\n[bt] (4) /usr/tvm/build/liGetVecLoad[abi:cxx11](tvm::runtime::DataType, tvm::tir::VarNode const*, tvm::Primt] (5) /usr/tvm/build/libtvm.so(tvm::codegen::CodeGenC::VisitExpr_(tvm::tir::Load8) [0x7f01eb8cf9e8]\n[bt] (6) /usr/tvm/build/libtvm.so(tvm::codegen::CodeGenC::Prd::ostream&)+0xaf) [0x7f01eb8ccf9f]\n[bt] (7) /usr/tvm/build/libtvm.so(tvm::codeg1](tvm::PrimExpr const&)+0x47) [0x7f01eb8dce27]\n[bt] (8) /usr/tvm/build/libtvm.stmt_(tvm::tir::StoreNode const*)+0x3ad) [0x7f01eb8d041d]\n[bt] (9) /usr/tvm/build (tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt cons::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>*) const+0ror_no=2, all_cost=0.027663707733154297, timestamp=1582792874.311063) [('bx', 2v', 16)],None,183

Or any debug method can i try?

@Orion34C may help to debug this issue.

Hi guys, I have got the similar error like @songy says. The codegen does not generate wmma code.

TVM version: 0.8.dev0

CUDA version: 10.0

GPU: nvidia T4

https://tvm.apache.org/docs/tutorials/optimize/opt_matmul_auto_tensorcore.html

No: 280	GFLOPS: 0.00/1089.87	result: MeasureResult(costs=(TVMError('Traceback (most recent call last):\n  [bt] (8) /liuhe/tvm/build/libtvm.so(tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)+0x6c) [0x7f50d2c449bc]\n  [bt] (7) /liuhe/tvm/build/libtvm.so(tvm::codegen::CodeGenC::VisitStmt_(tvm::tir::StoreNode const*)+0x859) [0x7f50d32a18b9]\n  [bt] (6) /liuhe/tvm/build/libtvm.so(tvm::codegen::CodeGenC::PrintExpr(tvm::PrimExpr const&)+0x15a) [0x7f50d32ac8ea]\n  [bt] (5) /liuhe/tvm/build/libtvm.so(tvm::codegen::CodeGenC::PrintExpr(tvm::PrimExpr const&, std::ostream&)+0x9b) [0x7f50d329d1eb]\n  [bt] (4) /liuhe/tvm/build/libtvm.so(tvm::codegen::CodeGenC::VisitExpr_(tvm::tir::LoadNode const*, std::ostream&)+0x291) [0x7f50d32a0981]\n  [bt] (3) /liuhe/tvm/build/libtvm.so(tvm::codegen::CodeGenC::GetVecLoad(tvm::runtime::DataType, tvm::tir::VarNode const*, tvm::PrimExpr)+0x2f) [0x7f50d3298fcf]\n  [bt] (2) /liuhe/tvm/build/libtvm.so(tvm::codegen::CodeGenC::GetBufferRef(tvm::runtime::DataType, tvm::tir::VarNode const*, tvm::PrimExpr)+0x1ae) [0x7f50d32a4fce]\n  [bt] (1) /liuhe/tvm/build/libtvm.so(tvm::codegen::CodeGenCUDA::PrintType(tvm::runtime::DataType, std::ostream&)+0x105) [0x7f50d32b7855]\n  [bt] (0) /liuhe/tvm/build/libtvm.so(+0x119bfb0) [0x7f50d32b5fb0]\n  File "/liuhe/tvm/src/target/source/codegen_cuda.cc", line 267\nTVMError: Cannot convert type float16x16 to CUDA type',),), error_no=2, all_cost=0.06485867500305176, timestamp=1614065870.018728)	[('bx', 8), ('by', 8), ('step_k', 8), ('v', 16)],None,182

config result

Best config:
[('bx', 2), ('by', 32), ('step_k', 16), ('v', 8)],None,126
Finish loading 288 records
primfn(A_1: handle, B_1: handle, compute_1: handle) -> ()
  attr = {"global_symbol": "main", "tir.noalias": True}
  buffers = {compute: Buffer(compute_2: Pointer(float32), float32, [32, 512], []),
            A: Buffer(A_2: Pointer(float16), float16, [32, 512], []),
            B: Buffer(B_2: Pointer(float16), float16, [512, 512], [])}
  buffer_map = {A_1: A, B_1: B, compute_1: compute} {
  attr [IterVar(blockIdx.y: int32, (nullptr), "ThreadIndex", "blockIdx.y")] "thread_extent" = 1;
  attr [compute.local: Pointer(float32)] "storage_scope" = "local";
  allocate(compute.local, float32, [8]);
  attr [A.shared: Pointer(float16)] "storage_scope" = "shared";
  allocate(A.shared, float16, [8448]);
  attr [B.shared: Pointer(float16)] "storage_scope" = "shared";
  allocate(B.shared, float16, [4096]);
  attr [A.shared.local: Pointer(float16)] "storage_scope" = "local";
  allocate(A.shared.local, float16, [16]);
  attr [B.shared.local: Pointer(float16)] "storage_scope" = "local";
  allocate(B.shared.local, float16, [128]);
  attr [IterVar(blockIdx.x: int32, (nullptr), "ThreadIndex", "blockIdx.x")] "thread_extent" = 32;
  attr [IterVar(threadIdx.z: int32, (nullptr), "ThreadIndex", "threadIdx.z")] "thread_extent" = 1;
  attr [IterVar(threadIdx.y: int32, (nullptr), "ThreadIndex", "threadIdx.y")] "thread_extent" = 32;
  attr [IterVar(threadIdx.x: int32, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 2 {
    for (j.c.init: int32, 0, 8) {
      compute.local[j.c.init] = 0f32
    }
    attr [IterVar(k.outer: int32, (nullptr), "CommReduce", "")] "pragma_tensor_core" = 1;
    for (k.outer, 0, 2) {
      for (ax0.ax1.outer.fused.outer: int32, 0, 16) {
        attr [IterVar(threadIdx.y_1: int32, (nullptr), "ThreadIndex", "threadIdx.y")] "thread_extent" = 32;
        attr [IterVar(threadIdx.z_1: int32, (nullptr), "ThreadIndex", "threadIdx.z")] "thread_extent" = 1;
        attr [IterVar(threadIdx.x_1: int32, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 2;
        A.shared[ramp(((((ax0.ax1.outer.fused.outer*528) + (floordiv(threadIdx.y_1, 16)*264)) + (floormod(threadIdx.y_1, 16)*16)) + (threadIdx.x_1*8)), 1, 8)] = (float16x8*)A_2[ramp((((((ax0.ax1.outer.fused.outer*1024) + (floordiv(threadIdx.y_1, 16)*512)) + (k.outer*256)) + (floormod(threadIdx.y_1, 16)*16)) + (threadIdx.x_1*8)), 1, 8)]
      }
      for (ax0.ax1.outer.fused.outer_1: int32, 0, 8) {
        attr [IterVar(threadIdx.y_2: int32, (nullptr), "ThreadIndex", "threadIdx.y")] "thread_extent" = 32;
        attr [IterVar(threadIdx.z_2: int32, (nullptr), "ThreadIndex", "threadIdx.z")] "thread_extent" = 1;
        attr [IterVar(threadIdx.x_2: int32, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 2;
        B.shared[ramp((((ax0.ax1.outer.fused.outer_1*512) + (threadIdx.y_2*16)) + (threadIdx.x_2*8)), 1, 8)] = (float16x8*)B_2[ramp((((((k.outer*131072) + (ax0.ax1.outer.fused.outer_1*16384)) + (threadIdx.y_2*512)) + (blockIdx.x*16)) + (threadIdx.x_2*8)), 1, 8)]
      }
      for (k.inner.outer: int32, 0, 16) {
        for (ax1: int32, 0, 16) {
          A.shared.local[ax1] = (float16*)A.shared[(((threadIdx.y*264) + (k.inner.outer*16)) + ax1)]
        }
        for (ax0: int32, 0, 16) {
          for (ax1_1: int32, 0, 8) {
            B.shared.local[((ax0*8) + ax1_1)] = (float16*)B.shared[((((k.inner.outer*256) + (ax0*16)) + (threadIdx.x*8)) + ax1_1)]
          }
        }
        for (k.inner.inner: int32, 0, 16) {
          for (j.c: int32, 0, 8) {
            compute.local[j.c] = ((float32*)compute.local[j.c] + (cast(float32, (float16*)A.shared.local[k.inner.inner])*cast(float32, (float16*)B.shared.local[((k.inner.inner*8) + j.c)])))
          }
        }
      }
    }
    for (j.inner.inner.inner: int32, 0, 8) {
      compute_2[((((threadIdx.y*512) + (blockIdx.x*16)) + (threadIdx.x*8)) + j.inner.inner.inner)] = (float32*)compute.local[j.inner.inner.inner]
    }
  }
}

cuda code

extern "C" __global__ void default_function_kernel0(half* __restrict__ A, half* __restrict__ B, float* __restrict__ compute) {
  float compute_local[8];
  __shared__ half A_shared[8448];
  __shared__ half B_shared[4096];
  half A_shared_local[16];
  half B_shared_local[128];
  for (int j_c_init = 0; j_c_init < 8; ++j_c_init) {
    compute_local[(j_c_init)] = 0.000000e+00f;
  }
  for (int k_outer = 0; k_outer < 2; ++k_outer) {
    __syncthreads();
    for (int ax0_ax1_outer_fused_outer = 0; ax0_ax1_outer_fused_outer < 16; ++ax0_ax1_outer_fused_outer) {
      ((uint4*)(A_shared + (((((ax0_ax1_outer_fused_outer * 528) + ((((int)threadIdx.y) >> 4) * 264)) + ((((int)threadIdx.y) & 15) * 16)) + (((int)threadIdx.x) * 8)))))[0] = ((uint4*)(A + ((((((ax0_ax1_outer_fused_outer * 1024) + ((((int)threadIdx.y) >> 4) * 512)) + (k_outer * 256)) + ((((int)threadIdx.y) & 15) * 16)) + (((int)threadIdx.x) * 8)))))[0];
    }
    for (int ax0_ax1_outer_fused_outer1 = 0; ax0_ax1_outer_fused_outer1 < 8; ++ax0_ax1_outer_fused_outer1) {
      ((uint4*)(B_shared + ((((ax0_ax1_outer_fused_outer1 * 512) + (((int)threadIdx.y) * 16)) + (((int)threadIdx.x) * 8)))))[0] = ((uint4*)(B + ((((((k_outer * 131072) + (ax0_ax1_outer_fused_outer1 * 16384)) + (((int)threadIdx.y) * 512)) + (((int)blockIdx.x) * 16)) + (((int)threadIdx.x) * 8)))))[0];
    }
    __syncthreads();
    for (int k_inner_outer = 0; k_inner_outer < 16; ++k_inner_outer) {
      for (int ax1 = 0; ax1 < 16; ++ax1) {
        A_shared_local[(ax1)] = A_shared[((((((int)threadIdx.y) * 264) + (k_inner_outer * 16)) + ax1))];
      }
      for (int ax0 = 0; ax0 < 16; ++ax0) {
        for (int ax11 = 0; ax11 < 8; ++ax11) {
          B_shared_local[(((ax0 * 8) + ax11))] = B_shared[(((((k_inner_outer * 256) + (ax0 * 16)) + (((int)threadIdx.x) * 8)) + ax11))];
        }
      }
      for (int k_inner_inner = 0; k_inner_inner < 16; ++k_inner_inner) {
        for (int j_c = 0; j_c < 8; ++j_c) {
          compute_local[(j_c)] = (compute_local[(j_c)] + (((float)A_shared_local[(k_inner_inner)]) * ((float)B_shared_local[(((k_inner_inner * 8) + j_c))])));
        }
      }
    }
  }
  for (int j_inner_inner_inner = 0; j_inner_inner_inner < 8; ++j_inner_inner_inner) {
    compute[(((((((int)threadIdx.y) * 512) + (((int)blockIdx.x) * 16)) + (((int)threadIdx.x) * 8)) + j_inner_inner_inner))] = compute_local[(j_inner_inner_inner)];
  }
}