I use the code below and found that matmul.dtype is uint8. But I want it to be int16 or int32 to avoid data overflow. I wonder if there is a way to do this.I have tried “astype”,but this force cast will use a lot of time. So is there any way to set matmul.dtype by myself rather than autoinfer to save time?
A = te.placeholder((N, L), name="A", dtype="uint8")
B = te.placeholder((L, M), name="B", dtype="uint8")
k = te.reduce_axis((0, L), name="k")
matmul = te.compute(
(N, M),
lambda i, j: te.sum(A[i, k] * B[k, j], axis=k),
name="matmul",
attrs={"layout_free_placeholders": [B]}, # enable automatic layout transform for tensor B
)