I’m trying to work with old the GEMM tutorial. One section describe flatting access to the B matrix, as follows:
packedB = te.compute(
(N / bn, K, bn), lambda bigN, k, littleN: B[k, bigN * bn + littleN], name="packedB"
)
C = te.compute(
(M, N),
lambda m, n: te.sum(A[m, k] * packedB[n // bn, k, tvm.tir.indexmod(n, bn)], axis=k),
name="C",
)
M, N, K, and bn are fixed values. This results in the following error:
---> 94 C = te.compute(
95 (M, N),
96 lambda m, n: te.sum(A[m, k] * packedB[n // bn, k, tvm.tir.indexmod(n, bn)], axis=k),
97 name="C",
98 )
...
TypeError: Mismatched type on argument #1 when calling: `tir.ProducerLoad(0: tir.DataProducer, 1: Array<PrimExpr>, 2: Span) -> tir.ProducerLoad`. Expected `Array<PrimExpr>` but got `Array[index 1: tir.LoopRV]`
Any ideas how to fix it?