I was trying to define argmax using comm_reducer, and I couldn’t find any way to define a comm_reducer that just computes the index of the maximum value, and does not compute the value as well.
For example, Taking the comm_reducer for argmax as defined in tests/python/integration/test_reduce.py
:
def fcombine(x, y):
lhs = tvm.tir.Select((x[1] >= y[1]), x[0], y[0])
rhs = tvm.tir.Select((x[1] >= y[1]), x[1], y[1])
return lhs, rhs
def fidentity(t0, t1):
return tvm.tir.const(-1, t0), tvm.te.min_value(t1)
argmax = te.comm_reducer(fcombine,fidentity, name="argmax")
m = te.size_var("m")
n = te.size_var("n")
idx = te.placeholder((m, n), name="idx", dtype="int32")
val = te.placeholder((m, n), name="val", dtype="float32")
k = te.reduce_axis((0, n), "k")
T0, T1 = te.compute((m,), lambda i: argmax((idx[i, k], val[i, k]), axis=k), name="T")
s = te.create_schedule(T0.op)
irm = tvm.lower(s, [T0, T1, idx, val])
This generates the below IR
primfn(T_2: handle, T_3: handle, idx_1: handle, val_1: handle) -> ()
attr = {"global_symbol": "main", "tir.noalias": True}
buffers = {val: Buffer(val_2: Pointer(float32), float32, [m: int32, n: int32], [stride: int32, stride_1: int32], type=auto_broadcast),
idx: Buffer(idx_2: Pointer(int32), int32, [m, n], [stride_2: int32, stride_3: int32], type=auto_broadcast),
T_1: Buffer(T_4: Pointer(float32), float32, [m], [stride_4: int32], type=auto_broadcast),
T: Buffer(T_5: Pointer(int32), int32, [m], [stride_5: int32], type=auto_broadcast)}
buffer_map = {T_2: T, T_3: T_1, idx_1: idx, val_1: val} {
for (i: int32, 0, m) {
T_5[(i*stride_5)] = -1
T_4[(i*stride_4)] = -3.40282e+38f32
for (k: int32, 0, n) {
T_5[(i*stride_5)] = @tir.if_then_else(((float32*)T_4[(i*stride_4)] < (float32*)val_2[((i*stride) + (k*stride_1))]), (int32*)idx_2[((i*stride_2) + (k*stride_3))], (int32*)T_5[(i*stride_5)], dtype=int32)
T_4[(i*stride_4)] = max((float32*)T_4[(i*stride_4)], (float32*)val_2[((i*stride) + (k*stride_1))])
}
}
}
Since the argmax function has to take both the index and the value, the output of comm_reducer is also expected to be 2 Tensors, which is the max_index (T_4
in the above IR) and the max_value (T_5
in the above IR)
I see that even though the documentation of argmax in topi mentions that it only returns the indices of max values along the axis, the IR generated for that also computes the value and then ignores it.
What I wanted to ask was whether it is possible to define a comm_reducer for Argmax that does not compute the maximum value, and thus both the memory allocated for max value and the extra computation can be avoided.
Thanks in advance