Cast te.compute's dtype without astype?

I use the code below and found that matmul.dtype is uint8. But I want it to be int16 or int32 to avoid data overflow. I wonder if there is a way to do this.I have tried “astype”,but this force cast will use a lot of time. So is there any way to set matmul.dtype by myself rather than autoinfer to save time?

A = te.placeholder((N, L), name="A", dtype="uint8")
B = te.placeholder((L, M), name="B", dtype="uint8")
k = te.reduce_axis((0, L), name="k")
matmul = te.compute(
    (N, M),
    lambda i, j: te.sum(A[i, k] * B[k, j], axis=k),
    name="matmul",
    attrs={"layout_free_placeholders": [B]},  # enable automatic layout transform for tensor B
)

@yzhliu @tqchen can you guys help?much thanks

you can apply astype on A and B: lambda i, j: te.sum(A[i, k].astype('int32') * B[k, j].astype('int32'), axis=k)

if your hardware has mixed precision operations you can tensorize during schedule (see example)

1 Like