Using TVM to create Bfloat16 datatype. Kernel dead

Hello. I’m trying to create a Bfloat16 datatype using TVM. I took as a reference the example exposed here. Instead of using myfloat32, I use bfloat16 in order to recreate this format. Te problem comes when I try this step, apparently the kernel dies when I execute compiled = ex.evaluate(program):

tvm.target.datatype.register_op(
    tvm.target.datatype.create_lower_func({16:"Custom16Add"}),
    "Add",
    "llvm",
    "bfloat",
)

tvm.target.datatype.register_op(
    tvm.target.datatype.create_lower_func({(16,32):"Custom16ToFloat"}),
    "Cast",
    "llvm",
    "bfloat",
    "float",
)
# Now, we can run our program without errors.
with tvm.transform.PassContext(config={"tir.disable_vectorize": True}):
    compiled = ex.evaluate(program)
    z_output_bfloat16 = compiled(x_input, y_input)   

print("z: {}".format(z_output_bfloat16))

print("x:\t\t{}".format(x_input))
print("y:\t\t{}".format(y_input))
print("z (float32):\t{}".format(z_output))
print("z (bfloat16):\t{}".format(z_output_bfloat16))

However, I tried the same steps but using Custom32 instead of Custom16 in the registration_op declarations and it appears that it works. I’ve also tried changing the Custom32 string to others and when I do this, the kernel always dies.

What is the possible solution to this? Is there any other way to cast float into bfloat16 and vice versa? Is there any way of manipulate the bits of the datatype?

Thank you