Bring your own data type to tvm

Hi,

First of all, I am completely new in TVM and thanks in advance for your help.

I‌ have an onnx model that I want to run with my own data type, and I found out here that it’s possible by tvm.

So I got the idea from bring_your_own_datatypes.py and developed the following code:

import numpy as np
import tvm
from tvm.contrib import graph_runtime
import os
import onnx
from tvm import relay
import ctypes
from tvm.relay.frontend.change_datatype import ChangeDatatype


#register the custom type with TVM
ctypes.CDLL('./float_st.so', ctypes.RTLD_GLOBAL)
tvm.target.datatype.register("float_st", 150)

ctx = tvm.cpu()

# load example onnx model 
onnx_model = onnx.load('./moel.onnx')

# convert to relay, needs the onnx model and input layer name and shape
module, params = relay.frontend.from_onnx( onnx_model, {"input_1": (100,128,3)} )


ex = tvm.relay.create_executor("graph", mod=module)


def convert_ndarray(dst_dtype, array):
    """Converts an NDArray into the specified datatype"""
    x = relay.var("x", shape=array.shape, dtype=str(array.dtype))
    cast = relay.Function([x], x.astype(dst_dtype))
    with tvm.transform.PassContext(config={"tir.disable_vectorize": True}):
        a = relay.create_executor("graph").evaluate(cast)(array)
        return a


tvm.target.datatype.register_op(
    tvm.target.datatype.create_lower_func(
        {
            (32, 128): "FloatToFloatst", 
        }
    ),
    "Cast",
    "llvm",
    "float",
    "float_st",
)

tvm.target.datatype.register_op(
    tvm.target.datatype.create_lower_func({(128, 32): "FloatstToFloat"}),
    "Cast",
    "llvm",
    "float_st",
    "float",
)



src_dtype = "float32"
dst_dtype = "custom[float_st]128"


module = relay.transform.InferType()(module)

module = tvm.relay.transform.SimplifyInference()(module)

# Currently, custom datatypes only work if you run simplify_inference beforehand
module = tvm.relay.transform.SimplifyInference()(module)

# Run type inference before changing datatype
module = tvm.relay.transform.InferType()(module)

# Change datatype from float to float_st and re-infer types
cdtype = ChangeDatatype(src_dtype, dst_dtype)
expr = cdtype.visit(module["main"])
module = tvm.relay.transform.InferType()(module)

# We need to convert our input:
data_shape = [100,128,3]
input = np.random.uniform(size=data_shape).astype('float32')
input_st = convert_ndarray(dst_dtype, input)

# We also convert the parameters:
params = {k: convert_ndarray(dst_dtype, v) for k, v in params.items()}


#register all the needed functions:
tvm.target.datatype.register_op(
    tvm.target.datatype.create_lower_func({128: "FloatToFloatst"}),
    "FloatImm",
    "llvm",
    "float_st",
)

tvm.target.datatype.register_op(
    tvm.target.datatype.lower_ite, "Call", "llvm", "float_st", intrinsic_name="tir.if_then_else"
)

tvm.target.datatype.register_op(
    tvm.target.datatype.lower_call_pure_extern,
    "Call",
    "llvm",
    "float_st",
    intrinsic_name="tir.call_pure_extern",
)

tvm.target.datatype.register_op(
    tvm.target.datatype.create_lower_func({128: "FloatstMul"}),
    "Mul",
    "llvm",
    "float_st",
)
tvm.target.datatype.register_op(
    tvm.target.datatype.create_lower_func({128: "FloatstDiv"}),
    "Div",
    "llvm",
    "float_st",
)

tvm.target.datatype.register_op(
    tvm.target.datatype.create_lower_func({128: "FloatstSqrt"}),
    "Call",
    "llvm",
    "float_st",
    intrinsic_name="tir.sqrt",
)

tvm.target.datatype.register_op(
    tvm.target.datatype.create_lower_func({128: "FloatstSub"}),
    "Sub",
    "llvm",
    "float_st",
)

tvm.target.datatype.register_op(
    tvm.target.datatype.create_lower_func({128: "FloatstExp"}),
    "Call",
    "llvm",
    "float_st",
    intrinsic_name="tir.exp",
)

tvm.target.datatype.register_op(
    tvm.target.datatype.create_lower_func({128: "FloatstMax"}),
    "Max",
    "llvm",
    "float_st",
)

tvm.target.datatype.register_min_func(
    tvm.target.datatype.create_min_lower_func({128: "MinFloatst"}, "float_st"),
    "float_st",
)



# Vectorization is not implemented with custom datatypes.
with tvm.transform.PassContext(config={"tir.disable_vectorize": True}):
    result_myfloat = ex.evaluate(expr)(input_st, **params)
    result_myfloat = convert_ndarray(src_dtype, result_myfloat).asnumpy()
    # print first 10 elements
    print(result_myfloat.flatten()[:10])

but when I try to run it I get the following error:

data types custom[float_st]128 and float32do not match in BroadcastRel

data types custom[float_st]128 and float32do not match in BroadcastRel

note: run with TVM_BACKTRACE=1 environment variable to display a backtrace.

Do you have any idea? I uploaded the onnx model as well as float_st.so for reproducing the error.https://file.io/R2dLjpdoSZWG

I should mention that I tested my data type wrapper(float_st.so in the code) with bring_your_own_datatypes.py which runs the mobilenet model and it worked.