Hi,
my question is about vectorizing the input to intrinsic (tensorized) functions.
I’m trying to write a TVM schedule to sum a 64 element vector. Specifically, I want the sum reduction to be split into the steps described by the following Python code:
import numpy as np
def special_sum(input, n):
acc = np.zeros(16)
for ii in range(0,n,16):
acc += input[ii:ii+16]
res = sum(acc)
return res
n = 64
input = np.ones(n)
result = special_sum(input, n)
print(result)
It first reduces the 64 element vector (“input”) to a 16 element vector of partial sums (“acc”). It then reduces the 16 element vector to a scalar sum (“sum”).
The following TVM code describes this calculation schedule (the first summation step is vectorized, and the final summation to scalar tensorized):
#! /usr/bin/env python3
import tvm
from tvm import te
def intrin_sum(l):
a = te.placeholder((l,), name="a")
k = te.reduce_axis((0, l), name = "k")
c = te.compute((1,), lambda i: te.sum(a[k], axis=k), name="c")
Ab = tvm.tir.decl_buffer(a.shape, a.dtype, name="A", offset_factor=1, strides=[1])
Cb = tvm.tir.decl_buffer(c.shape, c.dtype, name="C", offset_factor=1, strides=[1])
def intrin_func(ins, outs):
aa, = ins
cc = outs[0]
def _body():
ib = tvm.tir.ir_builder.create()
ib.emit(
tvm.tir.call_extern(
"int32",
"sum",
cc.access_ptr("w"),
aa.access_ptr("r")
)
)
return ib.get()
def _reduce_reset():
ib = tvm.tir.ir_builder.create()
ib.emit(tvm.tir.call_extern("int32", "sum_reset", cc.access_ptr("w")))
return ib.get()
def _reduce_update():
return _body()
return _body(), _reduce_reset(), _reduce_update()
return te.decl_tensor_intrin(c.op, intrin_func, binds={a: Ab, c: Cb})
def define_comp(n):
assert n%16 == 0, "input vector length must be multiple of 16"
# 64x1
A = te.placeholder((n,), name="a")
# 16x1
k0 = te.reduce_axis((0, n//16), name="k0")
B = te.compute((16,), lambda i: te.sum(A[k0*16+i], axis=k0), name="b")
# 1x1
k1 = te.reduce_axis((0, 16), name="k1")
C = te.compute((1,), lambda i: te.sum(B[k1], axis=k1), name="c")
s = te.create_schedule(C.op)
return s, (A, B, C)
def define_schedule(n):
s, (A, B, C) = define_comp(n)
# Vectorize partial summation
s[B].vectorize(s[B].op.axis[0])
# Tensorize final summation to scalar
intrin = intrin_sum(16)
s[C].tensorize(s[C].op.reduce_axis[0], intrin)
return s, (A, C)
def main():
s, args = define_schedule(64)
print(tvm.lower(s, args, simple_mode=True))
if __name__ == "__main__":
main()
It results in the following IR:
@main = primfn(a_1: handle, c_1: handle) -> ()
attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}
buffers = {a: Buffer(a_2: Pointer(float32), float32, [64], []),
c: Buffer(c_2: Pointer(float32), float32, [1], [])}
buffer_map = {a_1: a, c_1: c}
preflattened_buffer_map = {a_1: a_3: Buffer(a_2, float32, [64], []), c_1: c_3: Buffer(c_2, float32, [1], [])} {
allocate(b: Pointer(global float32), float32, [16]), storage_scope = global {
b_1: Buffer(b, float32, [16], [], align=64)[ramp(0, 1, 16)] = broadcast(0f32, 16)
for (k0: int32, 0, 4) {
b_1[ramp(0, 1, 16)] = (b_1[ramp(0, 1, 16)] + a[ramp((k0*16), 1, 16)])
}
@tir.call_extern("sum", @tir.tvm_access_ptr(@tir.type_annotation(, dtype=float32), c_2, 0, 1, 2, dtype=handle), @tir.tvm_access_ptr(@tir.type_annotation(, dtype=float32), b, 0, 16, 1, dtype=handle), dtype=int32)
}
}
TVM generates the tensor b (the second argument in the extern call to sum) as a vector of 16 scalars. I want the tensor b to be a vector variable with 16 lanes and length 1. I think this should be possible to achieve by declaring the second argument to the instrinsic funtion (aa.access_ptr(“r”)) as a vector. I am, however, not sure how to do that.
I tried to change the declaration of the second argument to
aa.access_ptr("r", content_lanes=16)
but that resulted in the following error message:
Check failed: index_lanes == value_dtype.lanes() (1 vs. 16) :
Could it be related to the tensorization being along a reduction axis (I know it is not possible to vectorize along a reduction axis)?
How do I make the input argument to a tensorized function vectorized?
Does anyone have ideas?
Thanks!