Feature Request: Support for Scaled Dot Product in TVM CUDA Backend Tensor Cores for FP8 MatMul

Dear TVM Community,

I hope this message finds you well. I am writing to request support for scaled dot product functionality in the TVM CUDA backend Tensor Cores for FP8 matrix multiplication (matmul). This feature is particularly relevant for leveraging the capabilities introduced in NVIDIA architectures starting from SM89, as detailed in the CUDA Parallel Thread Execution documentation.

    Examples of mma with block scale

.reg .b32 %Ra<4>, %Rb<4>;
.reg .f32 %Rc<4>, %Rd<4>;
.reg .b32 scaleAData, scaleBData;
mma.sync.aligned.m16n8k64.row.col kind::mxf4.block_scale.f32.e2m1.e2m1.f32.ue8m0
  {%Rd0, %Rd1, %Rd2, %Rd3},
  {%Ra0, %Ra1, %Ra2, %Ra3},
  {%Rb0, %Rb1},
  {%Rc0, %Rc1, %Rc2, %Rc3},
  scaleAData, {2, 1}, scaleBData, {2, 3};

.reg .b32 %Ra<4>, %Rb<4>;
.reg .f32 %Rc<4>, %Rd<4>;
.reg .b32 scaleAData, scaleBData;
.reg .u16 bidA, bidB, tidA, tidB;
mma.sync.aligned.m16n8k64.row.col.kind::mxf4nvf4.block_scale.scale_vec::4X.f32.e2m1.e2m1.f32.ue4m3
  {%Rd0, %Rd1, %Rd2, %Rd3},
  {%Ra0, %Ra1, %Ra2, %Ra3},
  {%Rb0, %Rb1},
  {%Rc0, %Rc1, %Rc2, %Rc3},
  scaleAData, {bidA, tidA}, scaleBData, {bidB, tidB};

.reg .b32 %Ra<4>, %Rb<4>;
.reg .f32 %Rc<4>, %Rd<4>;
.reg .b32 scaleAData, scaleBData;
mma.sync.aligned.m16n8k32.row.col.kind::mxf8f6f4.block_scale.scale_vec::1X.f32.e3m2.e2m1.f32.ue8m0
  {%Rd0, %Rd1, %Rd2, %Rd3},
  {%Ra0, %Ra1, %Ra2, %Ra3},
  {%Rb0, %Rb1},
  {%Rc0, %Rc1, %Rc2, %Rc3},
  scaleAData, {0, 1}, scaleBData, {0, 1};

.reg .b32 %Ra<4>, %Rb<4>;
.reg .f32 %Rc<4>, %Rd<4>;
.reg .b32 scaleAData, scaleBData;
mma.sync.aligned.m16n8k32.row.col.kind::mxf8f6f4.block_scale.scale_vec::1X.f32.e4m3.e5m2.f32.ue8m0
  {%Rd0, %Rd1, %Rd2, %Rd3},
  {%Ra0, %Ra1, %Ra2,  %Ra3},
  {%Rb0, %Rb1},
  {%Rc0, %Rc1, %Rc2, %Rc3},
  scaleAData, {0, 1}, scaleBData, {0, 0};

Background

TVM already supports FP8 matmul, which is a great step forward in optimizing low-precision computations. However, the current implementation does not appear to support the scaled dot product feature that NVIDIA introduced in SM89. This feature, as demonstrated in the MMA (Matrix Multiply-Accumulate) instructions, allows for more efficient and flexible matrix multiplications by incorporating scaling factors.

For reference, here is an example of how this is implemented in Triton using the dot_scaled function:
Triton dot_scaled Documentation

The Triton API provides a convenient way to perform scaled matrix multiplications, as shown in the examples below:

import triton.language as tl

@triton.jit
def scaled_matmul(a, b, scale_a, scale_b):
    # Triton implementation of scaled matmul
    ...

This functionality is also demonstrated in the Python example you can find here:
Scaled MM Example

In PyTorch, a similar effect is achieved using torch._scaled_mm, which applies scaling factors to the inputs before performing the matrix multiplication. This allows for precise control over the numerical behavior of the computation, especially in low-precision formats like FP8.

Feature Request

T.ptx_mma(
                        mma_prefix,
                        "row",
                        "col",
                        a_dtype_abbrv,
                        b_dtype_abbrv,
                        out_dtype_abbrv,
                        A.data,
                        A.elem_offset + tx * lift(local_size),
                        B.data,
                        B.elem_offset + tx * lift(local_size),
                        C.data,
                        C.elem_offset + tx * lift(local_size_out),
                        False,
                        dtype=out_dtype,
                    )

enable scaleA and scaleB when A /B is fp8/fp4

Why This Matters

This feature would enable:
Improved Performance: By leveraging the scaling factors, we can achieve better utilization of Tensor Cores and reduce numerical inaccuracies in low-precision computations.
Better Compatibility: Aligning TVM with NVIDIA’s latest hardware capabilities will make it a more attractive choice for developers working on cutting-edge AI models.
Enhanced Flexibility: Users would have more control over the numerical behavior of their FP8 computations, enabling fine-tuning for specific use cases.

Call for Help

If anyone is interested in working on this feature or has insights into how it could be implemented, I would love to hear from you. Please feel free to share your thoughts, suggestions, or even a rough plan for how this could be approached.

Thank you for your time and consideration. I look forward to your responses!

Best regards,
Zhiyuan

It seems like scaleData only works on Blackwell. FP8 starts working on sm89 but this arch doesn’t have a scaleData buffer.

Yes! on sm_102a, I misread it before.