Hey everyone,
I’d like to open up a conversation about the possibility of TVM supporting the FP8 datatype. CUTLASS, one of TVM’s third-party libraries, already includes FP8 support, and based on some tests, it appears to function properly on both host and device platforms just like built-in datatypes.
#include <iostream>
#include "cutlass/numeric_types.h"
using cutlass::float_e4m3_t;
__global__ void vector_add_kernel(float_e4m3_t *a, float_e4m3_t *b, float_e4m3_t *c)
{
int i = blockIdx.x * blockDim.x + threadIdx.x;
c[i] = a[i] + b[i];
}
// main
int main(){
for (int i = -8; i < 8; ++i)
{
float f = static_cast<float>(i);
cutlass::float_e4m3_t x = static_cast<cutlass::float_e4m3_t>(i);
cutlass::float_e4m3_t y = static_cast<cutlass::float_e4m3_t>(f);
cutlass::float_e4m3_t z = x + y;
std::cout
<< "i = " << i << ", f = " << f << ", x = " << x << ", y = " << y << ", z = " << z << std::endl;
std::cout << "sizeof float_e4m3_t is " << sizeof(cutlass::float_e4m3_t) << std::endl;
}
// define input a, b and output c
float_e4m3_t *a, *b, *c;
// define device input a, b and output c
float_e4m3_t *d_a, *d_b, *d_c;
// define size
int size = 10 * sizeof(float_e4m3_t);
// malloc host memory
a = (float_e4m3_t *)malloc(size);
b = (float_e4m3_t *)malloc(size);
c = (float_e4m3_t *)malloc(size);
// malloc device memory
cudaMalloc((void **)&d_a, size);
cudaMalloc((void **)&d_b, size);
cudaMalloc((void **)&d_c, size);
// initialize input data
for (int i = 0; i < 10; ++i)
{
a[i] = static_cast<float_e4m3_t>(i);
b[i] = static_cast<float_e4m3_t>(i);
}
// copy data from host to device
cudaMemcpy(d_a, a, size, cudaMemcpyHostToDevice);
cudaMemcpy(d_b, b, size, cudaMemcpyHostToDevice);
// define block and grid
dim3 block(10);
dim3 grid(1);
// launch kernel
vector_add_kernel<<<grid, block>>>(d_a, d_b, d_c);
// copy data from device to host
cudaMemcpy(c, d_c, size, cudaMemcpyDeviceToHost);
// print result
for (int i = 0; i < 10; ++i)
{
std::cout << "c[" << i << "] = " << c[i] << std::endl;
}
// free host memory
free(a);
free(b);
free(c);
// free device memory
cudaFree(d_a);
cudaFree(d_b);
cudaFree(d_c);
return 0;
}
Considering the compatibility and the growing interest in FP8, especially in large-languange model quantization (LLM PTQ), it seems like a worthwhile addition to the TVM ecosystem.
Do you think adding FP8 support to TVM would benefit the community? Are there any potential challenges or obstacles that could make its implementation difficult?
Looking forward to hearing another thoughts and opinions on this topic!