- auto B_data = reinterpret_cast<void*>(static_cast<char*>(B->data) + B->byte_offset);
- auto C_data = reinterpret_cast<void*>(static_cast<char*>(C->data) + C->byte_offset);
- CHECK_CUBLAS_ERROR(cublasGemmStridedBatchedEx(
- hdl, CUBLASBooleanToTranspose(transb), CUBLASBooleanToTranspose(transa),
- ColumnCount3D(B, transb), RowCount3D(A, transa), ColumnCount3D(A, transa), alpha_ptr, B_data,
- cuda_in_type, ColumnStride3D(B), B_size, A_data, cuda_in_type, ColumnStride3D(A), A_size,
- beta_ptr, C_data, cuda_out_type, ColumnStride3D(C), C_size, batch_size, cuda_out_type, algo));
- }
-
- // matrix multiplication for row major
- TVM_REGISTER_GLOBAL("tvm.contrib.cublas.matmul").set_body([](TVMArgs args, TVMRetValue* ret) {
- DLTensor* A = args[0];
- DLTensor* C = args[2];
-
- CuBlasThreadEntry* entry_ptr = CuBlasThreadEntry::ThreadLocal();
-
- CUBLASTryEnableTensorCore(entry_ptr->handle);
-
- if (TypeEqual(A->dtype, C->dtype)) {
- ICHECK(TypeMatch(A->dtype, kDLFloat, 16) || TypeMatch(A->dtype, kDLFloat, 32) ||
- TypeMatch(A->dtype, kDLFloat, 64));