The dtype of my original model is float32. I want to convert some operators, such as bias_ add, to bfloat16 for calculation. So, it is necessary to convert bias that is ConstantNode, from float32 to bfloat16. But after conversion in my pass, runtime will check failed:
Check failed: ret == 0 (-1 vs. 0) : Assert fail: (((tir.tvm_struct_get(arg1, 0, 5) == (uint8)1) && (tir.tvm_struct_get(arg1, 0, 6) == (uint8)16)) && (tir.tvm_struct_get(arg1, 0, 7) == (uint16)1)), arg1.dtype is expected to be uint16
the following IR as an example:
def @main(%x1: Tensor[(2), float32]) -> Tensor[(2), float32] { %2 = fn (%p0: Tensor[(2), float32], %p1: Tensor[(2), bfloat16], Primitive=1) -> Tensor[(2), float32] { %0 = cast(%p0, dtype="bfloat16") /* ty=Tensor[(2), bfloat16] */; %1 = add(%0, %p1) /* ty=Tensor[(2), bfloat16] */; cast(%1, dtype="float32") /* ty=Tensor[(2), float32] */ }; %2(%x1, meta[relay.Constant][0] /* ty=Tensor[(2), bfloat16] */) /* ty=Tensor[(2), float32] */ }
Can not the dtype of ConstantNode be bfloat16?
Thanks!