[RFC][Tensorize] Add "reduce_last" property for TensorIntrin to support activation fusion

Motivation:

Existing TensorIntrin has “reduce_init” and “reduce_update” properties to support the tensorization of reduce_axis == 0 and reduce_axis > 0 specifically, which is already well suited for many cases. However, the support for activation fusion is still missing, because it lacks of facilities to handle “reduce_axis = reduce_dim - 1” case.

Suppose we want to tenorize a matmul_with_relu op, the code would look like this:

def intrin_gemm(m, n, k):
  a = te.placeholder((m, k), name="a")
  b = te.placeholder((k, n), name="b")
  k_axis = te.reduce_axis((0, k), name="k")
  c = te.compute((m, n), lambda i, j: te.sum(a[i, k_axis] * b[k_axis, j], axis=k_axis), name="c")
  a_buffer = tvm.tir.decl_buffer(a.shape, a.dtype, name="a_buffer", offset_factor=1, strides=[te.var("s1"), 1])
  b_buffer = tvm.tir.decl_buffer(b.shape, b.dtype, name="b_buffer", offset_factor=1, strides=[te.var("s2"), 1])
  c_buffer = tvm.tir.decl_buffer(c.shape, c.dtype, name="c_buffer", offset_factor=1, strides=[te.var("s3"), 1])

  def intrin_func(ins, outs):

    def _body():
      ib = tvm.tir.ir_builder.create()
      ib.emit(
        tvm.tir.call_packed("tvm.contrib.cblas.matmul", ins[0], ins[1], outs[0])
      )

      return ib.get()

    def _update():
      ib = tvm.tir.ir_builder.create()
      ib.emit(
        tvm.tir.call_packed("tvm.contrib.cblas.matmul_with_relu", ins[0], ins[1], outs[0])
      )

      return ib.get()

    return _body(), None, _update()

  return te.decl_tensor_intrin(c.op, intrin_func, binds={a: a_buffer, b: b_buffer, c: c_buffer})

And the corresponding TIR generated is:

for (k.outer: int32, 0, 4) {
    if @tir.likely((0 < k.outer), dtype=bool) {
      @tir.tvm_call_packed("tvm.contrib.cblas.matmul_with_relu", @tir.tvm_stack_make_array(A_2, @tir.tvm_stack_make_shape(16, 4, dtype=handle), @tir.tvm_stack_make_shape(16, 1, dtype=handle), 2, 0f32, (k.outer*4), dtype=handle), @tir.tvm_stack_make_array(B_2, @tir.tvm_stack_make_shape(4, 16, dtype=handle), @tir.tvm_stack_make_shape(16, 1, dtype=handle), 2, 0f32, (k.outer*64), dtype=handle), @tir.tvm_stack_make_array(C_2, @tir.tvm_stack_make_shape(16, 16, dtype=handle), @tir.tvm_stack_make_shape(16, 1, dtype=handle), 2, 0f32, 0, dtype=handle), dtype=int32)
    } else {
      @tir.tvm_call_packed("tvm.contrib.cblas.matmul", @tir.tvm_stack_make_array(A_2, @tir.tvm_stack_make_shape(16, 4, dtype=handle), @tir.tvm_stack_make_shape(16, 1, dtype=handle), 2, 0f32, (k.outer*4), dtype=handle), @tir.tvm_stack_make_array(B_2, @tir.tvm_stack_make_shape(4, 16, dtype=handle), @tir.tvm_stack_make_shape(16, 1, dtype=handle), 2, 0f32, (k.outer*64), dtype=handle), @tir.tvm_stack_make_array(C_2, @tir.tvm_stack_make_shape(16, 16, dtype=handle), @tir.tvm_stack_make_shape(16, 1, dtype=handle), 2, 0f32, 0, dtype=handle), dtype=int32)
    }
  }

Sadly, the above TIR is actually WRONG, because it invokes “matmul_with_relu” at every k.outer != 0 iteration. We only want “matmul_with_relu” be invoked at the k.outer == 3 iteration, but it’s not possible with existing TensorIntrin implementation.

Proposal

My proposal is to add a “reduce_last” property to TensorIntrin, which represents the last iteration of reduce axis. The API would like this:

def intrin_func(ins, outs):

    def _body():
      ib = tvm.tir.ir_builder.create()
      ib.emit(
        tvm.tir.call_packed("tvm.contrib.cblas.matmul_with_bias", ins[0], ins[1], outs[0])
      )

      return ib.get()

    def _update():
      ib = tvm.tir.ir_builder.create()
      ib.emit(
        tvm.tir.call_packed("tvm.contrib.cblas.matmul", ins[0], ins[1], outs[0])
      )

      return ib.get()

    # Handle the last iteration of reduce axis, by generating statement for "reduce_last".
    def _last():
      ib = tvm.tir.ir_builder.create()
      ib.emit(
        tvm.tir.call_packed("tvm.contrib.cblas.matmul_with_relu", ins[0], ins[1], outs[0])
      )

      return ib.get()
      

    return _body(), None, _update(), _last()

  return te.decl_tensor_intrin(c.op, intrin_func, binds={a: a_buffer, b: b_buffer, c: c_buffer})

And the TIR generated would be like:

for (k.outer: int32, 0, 4) {
    if @tir.likely((k.outer < 3), dtype=bool) {
      if @tir.likely((0 < k.outer), dtype=bool) {
        @tir.tvm_call_packed("tvm.contrib.cblas.matmul", @tir.tvm_stack_make_array(A_2, @tir.tvm_stack_make_shape(16, 4, dtype=handle), @tir.tvm_stack_make_shape(16, 1, dtype=handle), 2, 0f32, (k.outer*4), dtype=handle), @tir.tvm_stack_make_array(B_2, @tir.tvm_stack_make_shape(4, 16, dtype=handle), @tir.tvm_stack_make_shape(16, 1, dtype=handle), 2, 0f32, (k.outer*64), dtype=handle), @tir.tvm_stack_make_array(C_2, @tir.tvm_stack_make_shape(16, 16, dtype=handle), @tir.tvm_stack_make_shape(16, 1, dtype=handle), 2, 0f32, 0, dtype=handle), dtype=int32)
      } else {
        @tir.tvm_call_packed("tvm.contrib.cblas.matmul", @tir.tvm_stack_make_array(A_2, @tir.tvm_stack_make_shape(16, 4, dtype=handle), @tir.tvm_stack_make_shape(16, 1, dtype=handle), 2, 0f32, (k.outer*4), dtype=handle), @tir.tvm_stack_make_array(B_2, @tir.tvm_stack_make_shape(4, 16, dtype=handle), @tir.tvm_stack_make_shape(16, 1, dtype=handle), 2, 0f32, (k.outer*64), dtype=handle), @tir.tvm_stack_make_array(C_2, @tir.tvm_stack_make_shape(16, 16, dtype=handle), @tir.tvm_stack_make_shape(16, 1, dtype=handle), 2, 0f32, 0, dtype=handle), dtype=int32)
      }
    } else {
      @tir.tvm_call_packed("tvm.contrib.cblas.matmul_with_relu", @tir.tvm_stack_make_array(A_2, @tir.tvm_stack_make_shape(16, 4, dtype=handle), @tir.tvm_stack_make_shape(16, 1, dtype=handle), 2, 0f32, (k.outer*4), dtype=handle), @tir.tvm_stack_make_array(B_2, @tir.tvm_stack_make_shape(4, 16, dtype=handle), @tir.tvm_stack_make_shape(16, 1, dtype=handle), 2, 0f32, (k.outer*64), dtype=handle), @tir.tvm_stack_make_array(C_2, @tir.tvm_stack_make_shape(16, 16, dtype=handle), @tir.tvm_stack_make_shape(16, 1, dtype=handle), 2, 0f32, 0, dtype=handle), dtype=int32)
    }
  }

Thanks, comments are appreciated.

cc @tqchen .

Thanks @zhuwenxi for the RFC. I can see the need to fuse activations after reduction. Regarding to the proposed reduce_last, my question is that how should we declare the computation. Right now the assumption of the reduction is that it contains init and update steps. From the example provided,
c = te.compute((m, n), lambda i, j: te.sum(a[i, k_axis] * b[k_axis, j], axis=k_axis), name="c")
this doesn’t match the intrin body that contains relu in the last reduction. Alternatively, we can use IR builder to build the reduction loop, the add the activation in the last iteration.

FYI, the new TensorIR will allow more flexible tensorization ops (the tensorization schedule primitive will be upstreamed soon). Hopefully it can address these limitations of tensorization in TE.

From the example provided, c = te.compute((m, n), lambda i, j: te.sum(a[i, k_axis] * b[k_axis, j], axis=k_axis), name="c") this doesn’t match the intrin body that contains relu in the last reduction

Yeah, but that’s becuase TE has a restriction that reduction must be presented at the top level of compute, otherwise the compilation would fail.

we can use IR builder to build the reduction loop, the add the activation in the last iteration.

Exactly, IR builder could help represent the reduction loop, if-else clauses and activations. However tir generated by IR builder could not be auto-tuned currently. From this point of view, I agree we should expect TensorIR and MetaScheduler to solve the problem fundamentally.

Make sense to me. Thank you, @vinx13 .

Yeah, but that’s becuase TE has a restriction that reduction must be presented at the top level of compute, otherwise the compilation would fail.

I understand this is indeed the limitation of reduction. The use case in this example is what we previous ignored - we assume we should use another loop to perform the activation (or in CUDA case this can be fused to the shared-> global phase).

Exactly, IR builder could help represent the reduction loop, if-else clauses and activations. However tir generated by IR builder could not be auto-tuned currently.

While using IR builder to writer the whole kernel is not tunable, we can write only the tensor intrin part with IR builder while keeping the outer loops tunable. You can use autotvm cfg to get the current factor of split, and then use it to declare the tensor intrin on the fly (taking the reduction loop length as argument).

Indeed this would add complexity to the schedule, I’m suggesting an alternative way that minimizes changes and prevents breaking to the reduction semantic

we can write only the tensor intrin part with IR builder while keeping the outer loops tunable

Yes, also I think that’s exactly what the “block statement” of TensorIR is designed for, to isolate the untunable part.

Since we do have alternative way to address the activation fusion problem, I agree we should not break the reduction semantic. We can close this discussion now, thank you!

1 Like