Hi everyone! I am playing with TensorIR, and the simple code below produces an unexpected result, which is also non-deterministic. The intended result should be [128, 128, …]: 128 x 1. What leads to this weird behavior?
from tvm.script import tir as T
import tvm
import numpy as np
@T.prim_func
def func(A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128,), "float32"]) -> None:
for i in T.serial(128):
for k in T.parallel(128):
with T.block("B"):
vi, vk = T.axis.remap("SR", [i, k])
T.reads(B[vi], A[vi, vk])
T.writes(B[vi])
with T.init():
B[vi] = T.float32(0)
B[vi] = B[vi] + A[vi, vk]
if __name__ == '__main__':
A = tvm.nd.array(np.ones((128, 128), dtype=np.float32))
B = tvm.nd.array(np.zeros((128,), dtype=np.float32))
mod = tvm.build(func)
mod(A, B)
print(B)
The output:
[128. 62. 88. 62. 66. 88. 62. 84. 62. 88. 84. 84. 88. 106.
66. 84. 84. 84. 62. 84. 62. 66. 106. 84. 106. 84. 110. 128.
84. 106. 66. 84. 84. 84. 84. 62. 84. 84. 110. 88. 66. 84.
66. 106. 66. 106. 106. 106. 40. 106. 66. 84. 88. 84. 88. 88.
62. 62. 84. 84. 84. 66. 66. 106. 84. 66. 66. 62. 66. 84.
88. 106. 66. 84. 106. 84. 84. 84. 66. 106. 84. 106. 84. 88.
84. 106. 84. 84. 44. 84. 44. 66. 62. 84. 88. 84. 88. 106.
84. 84. 66. 66. 62. 84. 62. 84. 62. 84. 84. 88. 84. 84.
84. 84. 128. 88. 106. 106. 88. 84. 62. 84. 62. 84. 62. 62.
88. 62.]