Autotuning arithmetic expression rewrite? Limitation of auto_scheduler?

I’m working on a kernel that computes the equivalent of:

C[n, ci, co, h, w] = sum(axis=[kh, kw], A[n, ci, h, w] * B[n, co, kh, kw]) -- 2D convolution without reduction of input channel
D[ci, co, h, w] = reduce(axis=n, C)
E[co, ci, h, w] = transpose(axis=[ci, co], D)

This can be computed in steps (1):

conv = te.compute(
    [batch, in_channel, out_channel, kernel_h, kernel_w],
    lambda nn, ci, co, kh, kw: te.sum(
        A[
            nn,
            ci,
            kh * stride_h + roh,
            kw * stride_w + row
        ].astype(out_dtype)
        * B[nn, co, roh, row].astype(out_dtype),
        axis=[roh, row],
    )
)

conv_reduced_batch = te.compute(
    [in_channel, out_channel, kernel_h, kernel_w],
    lambda ci, co, kh, kw: te.sum(
        conv[rn, ci, co, kh, kw].astype(out_dtype),
        axis=[rn]
    )
)

result = te.compute(
    [out_channel, in_channel, kernel_h, kernel_w],
    lambda co, ci, kh, kw: conv_reduced_batch[ci, co, kh, kw].astype(out_dtype)
)

or in a single compute (2):

rn = te.reduce_axis((0, batch), name="rn")
roh = te.reduce_axis((0, out_height), name="roh")
row = te.reduce_axis((0, out_width), name="row")

result = te.compute(
    [out_channel, in_channel, kernel_h, kernel_w],
    lambda co, ci, kh, kw: te.sum(
        A[
            rn,
            ci,
            kh * stride_h + roh,
            kw * stride_w + row
        ].astype(out_dtype)
        * B[rn, co, roh, row].astype(out_dtype),
        axis=[rn, roh, row],
    )
)

I was expecting (2) to be faster then (1) after autotuning for a CUDA target, but it turns out the reverse is happening, (2) is about 80% slower than (1). Why could this be happening? Is it a limitation of auto_scheduler’s ability to rewrite tensor expressions or am I missing something?

I have not tested these compute declarations before, so I am not sure about the reasons either.

I am only sure that they are different inputs for auto-scheduler. They are not equivalent so auto-scheduler will generate different sketches for them. The final performance depends on a lot of things such as the layout and rules in the auto-scheduler.

Thanks for the answer, that makes sense. Sketches are generated based on schedule rewriting rules (inline, tiling, unroll, vectorization, etc) and not on arithmetic expression rewrites, correct?

Auto-scheduler analyzes the arithmetic expression and performs the schedule rewriting. However, the analysis is heuristic-based, so the two forms in your posts lead to different analysis results, although they are equivalent mathematically.

Thanks for the extra clarification.