TVM Support for FP8? A Discussion

Hey everyone,

I’d like to open up a conversation about the possibility of TVM supporting the FP8 datatype. CUTLASS, one of TVM’s third-party libraries, already includes FP8 support, and based on some tests, it appears to function properly on both host and device platforms just like built-in datatypes.

#include <iostream>
#include "cutlass/numeric_types.h"

using cutlass::float_e4m3_t;

__global__ void vector_add_kernel(float_e4m3_t *a, float_e4m3_t *b, float_e4m3_t *c)
{
    int i = blockIdx.x * blockDim.x + threadIdx.x;
    c[i] = a[i] + b[i];
}
// main
int main(){
    for (int i = -8; i < 8; ++i)
    {
        float f = static_cast<float>(i);

        cutlass::float_e4m3_t x = static_cast<cutlass::float_e4m3_t>(i);
        cutlass::float_e4m3_t y = static_cast<cutlass::float_e4m3_t>(f);
        cutlass::float_e4m3_t z = x + y;
                std::cout
            << "i = " << i << ", f = " << f << ", x = " << x << ", y = " << y << ", z = " << z << std::endl;
        std::cout << "sizeof float_e4m3_t is " << sizeof(cutlass::float_e4m3_t) << std::endl;
    }
    // define input a, b and output c
    float_e4m3_t *a, *b, *c;
    // define device input a, b and output c
    float_e4m3_t *d_a, *d_b, *d_c;
    // define size
    int size = 10 * sizeof(float_e4m3_t);
    // malloc host memory
    a = (float_e4m3_t *)malloc(size);
    b = (float_e4m3_t *)malloc(size);
    c = (float_e4m3_t *)malloc(size);
    // malloc device memory
    cudaMalloc((void **)&d_a, size);
    cudaMalloc((void **)&d_b, size);
    cudaMalloc((void **)&d_c, size);
    // initialize input data
    for (int i = 0; i < 10; ++i)
    {
        a[i] = static_cast<float_e4m3_t>(i);
        b[i] = static_cast<float_e4m3_t>(i);
    }
    // copy data from host to device
    cudaMemcpy(d_a, a, size, cudaMemcpyHostToDevice);
    cudaMemcpy(d_b, b, size, cudaMemcpyHostToDevice);
    // define block and grid
    dim3 block(10);
    dim3 grid(1);
    // launch kernel
    vector_add_kernel<<<grid, block>>>(d_a, d_b, d_c);
    // copy data from device to host
    cudaMemcpy(c, d_c, size, cudaMemcpyDeviceToHost);
    // print result
    for (int i = 0; i < 10; ++i)
    {
        std::cout << "c[" << i << "] = " << c[i] << std::endl;
    }
    // free host memory
    free(a);
    free(b);
    free(c);
    // free device memory
    cudaFree(d_a);
    cudaFree(d_b);
    cudaFree(d_c);
    
    return 0;
}

Considering the compatibility and the growing interest in FP8, especially in large-languange model quantization (LLM PTQ), it seems like a worthwhile addition to the TVM ecosystem.

Do you think adding FP8 support to TVM would benefit the community? Are there any potential challenges or obstacles that could make its implementation difficult?

Looking forward to hearing another thoughts and opinions on this topic!

2 Likes

Of course FP8 is going to help a lot, but the real problem at the moment is that we have no access to the hardware that’s fp8 capable yet :frowning:

There seems to be a misunderstanding here. Although there isn’t any hardware available that can run FP8 Tensor Cores, we can still implement FP8 computations on CUDA cores or even cpus using certain data structures. I believe this approach is still meaningful and worth considering.

1 Like

Sure thing! I’m actually super interested in fp8 as well as int4

I suppose H100 has fp8 support? According to PTX doc there is wgmma.mma_async intrinsic:

FP8 floating point type

wgmma.mma_async.sync.aligned.shape.dtype.atype.btype  d, a-desc, b-desc, scale-d, imm-scale-a, imme-scale-b;

wgmma.mma_async.sync.aligned.shape.dtype.atype.btype  d, a, b-desc, scale-d, imm-scale-a, imme-scale-b;

.shape   = {.m64n8k32, .m64n16k32, .m64n24k32, .m64n32k32,
            .m64n40k32, .m64n48k32, .m64n56k32, .m64n64k32,
            .m64n72k32, .m64n80k32, .m64n88k32, .m64n96k32,
            .m64n104k32, .m64n112k32, .m64n120k32, .m64n128k32,
            .m64n136k32, .m64n144k32, .m64n152k32, .m64n160k32,
            .m64n168k32, .m648176k32, .m64n184k32, .m64n192k32,
            .m64n200k32, .m64n208k32, .m64n216k32, .m64n224k32,
            .m64n232k32, .m64n240k32, .m64n248k32, .m64n256k32};
.atype  = {.e4m3, .e5m2};
.btype  = {.e4m3, .e5m2};
.dtype  = {.f16, .f32};

Yes fp8 is important for TVM, we can start with supporting native e4m3 and e5m2 (or more general eAmB) data type.

Yes, my statement had some issues. The H100 indeed has FP8 Tensor Cores, but they are difficult to access personally. What I wanted to express is that supporting FP8 should not be solely for the sake of Hopper Tensor Cores. We should be able to expect the emergence of some FP8 LLM models, and they shouldn’t be restricted to running only on Hopper. For instance, we can use tir.Cast to cast FP8 to FP16 or FP32, which allows for efficient implementation on previous GPU architectures while reducing the size of checkpoints. I think this idea is quite interesting.

Definitely something that worth doing. We can start with software implementations like bfloat16 support.

Given there are two variants of fp8, we might need to be able to have two type codes here.

One quick pragmatic way to get started to reuse the type code of bfloat16. Here is one actionable plan that I think we can get related support very quickly

One strawman proposal is as follows (for two most frequently used floating points):

  • kFloat, bits=8 corresponds to the foward variant e4m3. Aka we use float8 for e4m3.
  • kBfloat, bits=8 corresponds to the backward variant e5m2. Aka we use bfloat8 for e5m2, this also aligns with the fact that bfloat comes with more exponents.

Of course we can also introduce more type code if needed, but the two most common types likely cover the common need.

To get the first version of implementation, we can update bf16_legalize.cc, which should come with everything we need.

  • Update BF16ComputeLegalize to LowBitsFPComputeLegalize
  • Update BF16StorageLegalize to LowBitsFPStorageLegalize

Note that while we legalize bfloat16 to fp32 as compute. FP8 can be legalized to fp16 (or fp32 if fp16 is not available depending on a system config or target).

Love to know what folks think about this

3 Likes

This is exactly what I was looking for :partying_face:. I think there are also some considerations on the GPU side and I have some naive comments, such as at which level should data casting occur? when not using the LDMATRIX, we can perform data casting at the register level, which can reduce read and write requests for both global and shared memory. If we need to use MMA + LDMATRIX instructions to do tensorize, it seems that casting can only be done at the shared memory.

or If we want to deploy model quantized by GPTQ, we even need to consider how to design the compression and de-compression int3/4 weights in higher-bit data types…

1 Like

Great point, i agree casting should likely be done at shared memory most of the case, and we can start with this assumption.

Atm we indeed legalize the internal storage of bfloat16 to float32 so to enable more use of registers, we can change the behavior with more fine grained control. For example, we can always write mixed precision program that starts with fp8 and cast to fp16 explicitly for fine-grained control.

We should either create e5m3 and e4m3 as separate type codes, or revamp the float representation altogether (e.g. bfloat16 = float_e8m7, _Float16 = float_e5e10, etc.). Since type code is an exposed part of DataType, we should not reuse the code of bfloat for 8-bit floats.

RTX 40 series appear to have FP8 support according to https://images.nvidia.com/aem-dam/Solutions/geforce/ada/nvidia-ada-gpu-architecture.pdf

3 Likes

Indeed normally a new type code would be desirable.

In this case bfloat is only defined for bits=16 and is not defined for bits=8, so that type code can be reused. If there isa single IEEE standard float8, likely we can just reuse the kFloat flag here.

Agree that perhaps more composable encoding of fp could be interesting, although that would need a bit more thinkings. Mainly on the DLPack compact part.

Yes, bfloat only makes sense for 16 bits. This is why the type code kBFloat should not be used for anything else (see POLA principle). If we really want some way of representing 8-bit floats, and it can’t wait, we should tack them on to the kFloat, and use some unreasonable non-power-of-2 values for the bits field to distinguish between e5m2 and e4m3.

Edit: To elaborate:

  1. “kBFloat” could be assumed to imply “bits=16”, so it’s not unreasonable to assume that there may be code out there (in some downstream repo) that checks the code ignoring the bits. We don’t want to break it if we don’t have to.
  2. “kFloat” with strange bits would be easy to track down (and refactor later), assuming that we don’t want to introduce new type codes at the moment.

Get it, don’t have strong feelings on this, i think we can try to have some extra type code in this case to be careful

Good news, FP8 capabilities on Ada are starting to get exposed. See

3 Likes
2 Likes