Expressing nested reduce operations

I’m having trouble expressing a nested reduce operation.

The input is a (N, C, H, W) tensor, and I want to run a window around each pixel, where I sum up the euclidean distance between the center pixel and other pixels in the window. I’ve expressed this as follows:


X = te.placeholder((batch, in_channel, in_size_h, in_size_w), name="X") #input
rc = te.reduce_axis((0, in_channel), name="rc")
ry = te.reduce_axis((0, kernel), name="ry")
rx = te.reduce_axis((0, kernel), name="rx")

def reduce_window(nn, yy, xx):
    _y = yy * stride
    _x = xx * stride
    _ry = _y + (ry - kernel_center)
    _rx = _x + (rx - kernel_center)
    return te.sum(
        te.sum(
            (X[nn, rc, _y, _x] - X[nn, rc, _ry, _rx])*(X[nn, rc, _y, _x] - X[nn, rc, _ry, _rx]),
            axis=rc
        ),
        where=te.all(_rx >= 0, _ry >= 0, _rx < in_size_w, _ry < in_size_h),
        axis=[ry, rx]
    )

Y = te.compute(
    (batch, in_size_h, in_size_w),
    reduce_window,
    name="Y",
)

Running tvm.build(s, [X, Y], "cuda") on this fails with the error File "/[...]/tvm/include/tvm/tir/expr_functor.h", line 155 TVMError: Do not have a default for tir.Reduce.

Running tvm.lower(...) on the code returns an IR that has a reduce(…) in the inner most statement (where I expected it to be turned into a loop):

primfn(X_1: handle, Y_1: handle) -> ()
  attr = {"global_symbol": "main", "tir.noalias": True}
  buffers = {Y: Buffer(Y_2: Pointer(float32), float32, [256, 14, 14], []),
             X: Buffer(X_2: Pointer(float32), float32, [256, 256, 14, 14], [])}
  buffer_map = {X_1: X, Y_1: Y} {
  attr [IterVar(blockIdx.x: int32, (nullptr), "ThreadIndex", "blockIdx.x")] "thread_extent" = 256;
  attr [IterVar(threadIdx.x: int32, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 14;
  for (xx: int32, 0, 14) {
    Y_2[(((blockIdx.x*196) + (threadIdx.x*14)) + xx)] = 0f32
    for (ry: int32, 0, 3) {
      for (rx: int32, 0, 3) {
        if ((((1 <= (xx + rx)) && (1 <= (threadIdx.x + ry))) && ((xx + rx) < 15)) && ((threadIdx.x + ry) < 15)) {
          Y_2[(((blockIdx.x*196) + (threadIdx.x*14)) + xx)] = ((float32*)Y_2[(((blockIdx.x*196) + (threadIdx.x*14)) + xx)] + reduce(meta[tir.CommReducer][0], [(((float32*)X_2[((((blockIdx.x*50176) + (rc: int32*196)) + (threadIdx.x*14)) + xx)] - (float32*)X_2[(((((((blockIdx.x*50176) + (rc*196)) + (threadIdx.x*14)) + (ry*14)) + xx) + rx) - 15)])*((float32*)X_2[((((blockIdx.x*50176) + (rc*196)) + (threadIdx.x*14)) + xx)] - (float32*)X_2[(((((((blockIdx.x*50176) + (rc*196)) + (threadIdx.x*14)) + (ry*14)) + xx) + rx) - 15)]))], [IterVar(rc, [0:256], "CommReduce", "")], 0, []))
        }
      }
    }
  }
}

An older post said that nested reductions are not supported with Tensor Expressions. Is this still true? How would you go about expressing this computation? Are there other ways of expressing this?