Hi, How to pass a scalar value in TE. I tried these code below:
n = te.const(10)
m = te.const(5)
A = te.placeholder((n, m), name="A")
# scale
scale = te.const(2, "int32")
# B
B = te.compute(A.shape, lambda *i: A(*i) + scale, name="B")
s = te.create_schedule(B.op)
fapi = tvm.lower(s, [A, B], simple_mode=True)
print(fapi)
Print result as below:
@main = primfn(A_1: handle, B_1: handle) -> ()
attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}
buffers = {A: Buffer(A_2: Pointer(float32), float32, [10, 5], []),
B: Buffer(B_2: Pointer(float32), float32, [10, 5], [])}
buffer_map = {A_1: A, B_1: B} {
for (i0: int32, 0, 10) {
for (i1: int32, 0, 5) {
let cse_var_1: int32 = ((i0*5) + i1)
B_3: Buffer(B_2, float32, [50], [])[cse_var_1] = (A_3: Buffer(A_2, float32, [50], [])[cse_var_1] + 2f32)
}
}
}
but failed without scalar params in primFunc params. I dont’t know how to write this form making variable scale as a scalar value in input args of primfn in TE. I hope the scale can be one of input args in primfn.