Hi,
First of all, I am completely new in TVM and thanks in advance for your help.
I have an onnx model that I want to run with my own data type, and I found out here that it’s possible by tvm.
So I got the idea from bring_your_own_datatypes.py
and developed the following code:
import numpy as np
import tvm
from tvm.contrib import graph_runtime
import os
import onnx
from tvm import relay
import ctypes
from tvm.relay.frontend.change_datatype import ChangeDatatype
#register the custom type with TVM
ctypes.CDLL('./float_st.so', ctypes.RTLD_GLOBAL)
tvm.target.datatype.register("float_st", 150)
ctx = tvm.cpu()
# load example onnx model
onnx_model = onnx.load('./moel.onnx')
# convert to relay, needs the onnx model and input layer name and shape
module, params = relay.frontend.from_onnx( onnx_model, {"input_1": (100,128,3)} )
ex = tvm.relay.create_executor("graph", mod=module)
def convert_ndarray(dst_dtype, array):
"""Converts an NDArray into the specified datatype"""
x = relay.var("x", shape=array.shape, dtype=str(array.dtype))
cast = relay.Function([x], x.astype(dst_dtype))
with tvm.transform.PassContext(config={"tir.disable_vectorize": True}):
a = relay.create_executor("graph").evaluate(cast)(array)
return a
tvm.target.datatype.register_op(
tvm.target.datatype.create_lower_func(
{
(32, 128): "FloatToFloatst",
}
),
"Cast",
"llvm",
"float",
"float_st",
)
tvm.target.datatype.register_op(
tvm.target.datatype.create_lower_func({(128, 32): "FloatstToFloat"}),
"Cast",
"llvm",
"float_st",
"float",
)
src_dtype = "float32"
dst_dtype = "custom[float_st]128"
module = relay.transform.InferType()(module)
module = tvm.relay.transform.SimplifyInference()(module)
# Currently, custom datatypes only work if you run simplify_inference beforehand
module = tvm.relay.transform.SimplifyInference()(module)
# Run type inference before changing datatype
module = tvm.relay.transform.InferType()(module)
# Change datatype from float to float_st and re-infer types
cdtype = ChangeDatatype(src_dtype, dst_dtype)
expr = cdtype.visit(module["main"])
module = tvm.relay.transform.InferType()(module)
# We need to convert our input:
data_shape = [100,128,3]
input = np.random.uniform(size=data_shape).astype('float32')
input_st = convert_ndarray(dst_dtype, input)
# We also convert the parameters:
params = {k: convert_ndarray(dst_dtype, v) for k, v in params.items()}
#register all the needed functions:
tvm.target.datatype.register_op(
tvm.target.datatype.create_lower_func({128: "FloatToFloatst"}),
"FloatImm",
"llvm",
"float_st",
)
tvm.target.datatype.register_op(
tvm.target.datatype.lower_ite, "Call", "llvm", "float_st", intrinsic_name="tir.if_then_else"
)
tvm.target.datatype.register_op(
tvm.target.datatype.lower_call_pure_extern,
"Call",
"llvm",
"float_st",
intrinsic_name="tir.call_pure_extern",
)
tvm.target.datatype.register_op(
tvm.target.datatype.create_lower_func({128: "FloatstMul"}),
"Mul",
"llvm",
"float_st",
)
tvm.target.datatype.register_op(
tvm.target.datatype.create_lower_func({128: "FloatstDiv"}),
"Div",
"llvm",
"float_st",
)
tvm.target.datatype.register_op(
tvm.target.datatype.create_lower_func({128: "FloatstSqrt"}),
"Call",
"llvm",
"float_st",
intrinsic_name="tir.sqrt",
)
tvm.target.datatype.register_op(
tvm.target.datatype.create_lower_func({128: "FloatstSub"}),
"Sub",
"llvm",
"float_st",
)
tvm.target.datatype.register_op(
tvm.target.datatype.create_lower_func({128: "FloatstExp"}),
"Call",
"llvm",
"float_st",
intrinsic_name="tir.exp",
)
tvm.target.datatype.register_op(
tvm.target.datatype.create_lower_func({128: "FloatstMax"}),
"Max",
"llvm",
"float_st",
)
tvm.target.datatype.register_min_func(
tvm.target.datatype.create_min_lower_func({128: "MinFloatst"}, "float_st"),
"float_st",
)
# Vectorization is not implemented with custom datatypes.
with tvm.transform.PassContext(config={"tir.disable_vectorize": True}):
result_myfloat = ex.evaluate(expr)(input_st, **params)
result_myfloat = convert_ndarray(src_dtype, result_myfloat).asnumpy()
# print first 10 elements
print(result_myfloat.flatten()[:10])
but when I try to run it I get the following error:
data types custom[float_st]128 and float32do not match in BroadcastRel
data types custom[float_st]128 and float32do not match in BroadcastRel
note: run with
TVM_BACKTRACE=1
environment variable to display a backtrace.
Do you have any idea? I uploaded the onnx
model as well as float_st.so
for reproducing the error.https://file.io/R2dLjpdoSZWG
I should mention that I tested my data type wrapper(float_st.so
in the code) with bring_your_own_datatypes.py
which runs the mobilenet
model and it worked.