I used Hybrid to wirte the tf.where function like bellow: def where_compute(A): “”" Parameters ---------- A : tvm.Tensor with shape [m,]
Returns
-------
Output : tvm.Tensor
with shape [m,1]
"""
m = A.shape[0]
j = 0
res = output_tensor((m,1), "int64")
for i in range(m):
if A[i]:
res[j,0] = i
j = j + 1
for k in range(j,m):
res[k,0] = -1
j = j + 1
return res
The IR generated is like this:
Can we transform the j(0) to j so we can keep it as a scalar?