Motivation:
Existing TensorIntrin
has “reduce_init” and “reduce_update” properties to support the tensorization of reduce_axis == 0 and reduce_axis > 0 specifically, which is already well suited for many cases. However, the support for activation fusion is still missing, because it lacks of facilities to handle “reduce_axis = reduce_dim - 1” case.
Suppose we want to tenorize a matmul_with_relu op, the code would look like this:
def intrin_gemm(m, n, k):
a = te.placeholder((m, k), name="a")
b = te.placeholder((k, n), name="b")
k_axis = te.reduce_axis((0, k), name="k")
c = te.compute((m, n), lambda i, j: te.sum(a[i, k_axis] * b[k_axis, j], axis=k_axis), name="c")
a_buffer = tvm.tir.decl_buffer(a.shape, a.dtype, name="a_buffer", offset_factor=1, strides=[te.var("s1"), 1])
b_buffer = tvm.tir.decl_buffer(b.shape, b.dtype, name="b_buffer", offset_factor=1, strides=[te.var("s2"), 1])
c_buffer = tvm.tir.decl_buffer(c.shape, c.dtype, name="c_buffer", offset_factor=1, strides=[te.var("s3"), 1])
def intrin_func(ins, outs):
def _body():
ib = tvm.tir.ir_builder.create()
ib.emit(
tvm.tir.call_packed("tvm.contrib.cblas.matmul", ins[0], ins[1], outs[0])
)
return ib.get()
def _update():
ib = tvm.tir.ir_builder.create()
ib.emit(
tvm.tir.call_packed("tvm.contrib.cblas.matmul_with_relu", ins[0], ins[1], outs[0])
)
return ib.get()
return _body(), None, _update()
return te.decl_tensor_intrin(c.op, intrin_func, binds={a: a_buffer, b: b_buffer, c: c_buffer})
And the corresponding TIR generated is:
for (k.outer: int32, 0, 4) {
if @tir.likely((0 < k.outer), dtype=bool) {
@tir.tvm_call_packed("tvm.contrib.cblas.matmul_with_relu", @tir.tvm_stack_make_array(A_2, @tir.tvm_stack_make_shape(16, 4, dtype=handle), @tir.tvm_stack_make_shape(16, 1, dtype=handle), 2, 0f32, (k.outer*4), dtype=handle), @tir.tvm_stack_make_array(B_2, @tir.tvm_stack_make_shape(4, 16, dtype=handle), @tir.tvm_stack_make_shape(16, 1, dtype=handle), 2, 0f32, (k.outer*64), dtype=handle), @tir.tvm_stack_make_array(C_2, @tir.tvm_stack_make_shape(16, 16, dtype=handle), @tir.tvm_stack_make_shape(16, 1, dtype=handle), 2, 0f32, 0, dtype=handle), dtype=int32)
} else {
@tir.tvm_call_packed("tvm.contrib.cblas.matmul", @tir.tvm_stack_make_array(A_2, @tir.tvm_stack_make_shape(16, 4, dtype=handle), @tir.tvm_stack_make_shape(16, 1, dtype=handle), 2, 0f32, (k.outer*4), dtype=handle), @tir.tvm_stack_make_array(B_2, @tir.tvm_stack_make_shape(4, 16, dtype=handle), @tir.tvm_stack_make_shape(16, 1, dtype=handle), 2, 0f32, (k.outer*64), dtype=handle), @tir.tvm_stack_make_array(C_2, @tir.tvm_stack_make_shape(16, 16, dtype=handle), @tir.tvm_stack_make_shape(16, 1, dtype=handle), 2, 0f32, 0, dtype=handle), dtype=int32)
}
}
Sadly, the above TIR is actually WRONG, because it invokes “matmul_with_relu” at every k.outer != 0 iteration. We only want “matmul_with_relu” be invoked at the k.outer == 3 iteration, but it’s not possible with existing TensorIntrin implementation.
Proposal
My proposal is to add a “reduce_last” property to TensorIntrin
, which represents the last iteration of reduce axis. The API would like this:
def intrin_func(ins, outs):
def _body():
ib = tvm.tir.ir_builder.create()
ib.emit(
tvm.tir.call_packed("tvm.contrib.cblas.matmul_with_bias", ins[0], ins[1], outs[0])
)
return ib.get()
def _update():
ib = tvm.tir.ir_builder.create()
ib.emit(
tvm.tir.call_packed("tvm.contrib.cblas.matmul", ins[0], ins[1], outs[0])
)
return ib.get()
# Handle the last iteration of reduce axis, by generating statement for "reduce_last".
def _last():
ib = tvm.tir.ir_builder.create()
ib.emit(
tvm.tir.call_packed("tvm.contrib.cblas.matmul_with_relu", ins[0], ins[1], outs[0])
)
return ib.get()
return _body(), None, _update(), _last()
return te.decl_tensor_intrin(c.op, intrin_func, binds={a: a_buffer, b: b_buffer, c: c_buffer})
And the TIR generated would be like:
for (k.outer: int32, 0, 4) {
if @tir.likely((k.outer < 3), dtype=bool) {
if @tir.likely((0 < k.outer), dtype=bool) {
@tir.tvm_call_packed("tvm.contrib.cblas.matmul", @tir.tvm_stack_make_array(A_2, @tir.tvm_stack_make_shape(16, 4, dtype=handle), @tir.tvm_stack_make_shape(16, 1, dtype=handle), 2, 0f32, (k.outer*4), dtype=handle), @tir.tvm_stack_make_array(B_2, @tir.tvm_stack_make_shape(4, 16, dtype=handle), @tir.tvm_stack_make_shape(16, 1, dtype=handle), 2, 0f32, (k.outer*64), dtype=handle), @tir.tvm_stack_make_array(C_2, @tir.tvm_stack_make_shape(16, 16, dtype=handle), @tir.tvm_stack_make_shape(16, 1, dtype=handle), 2, 0f32, 0, dtype=handle), dtype=int32)
} else {
@tir.tvm_call_packed("tvm.contrib.cblas.matmul", @tir.tvm_stack_make_array(A_2, @tir.tvm_stack_make_shape(16, 4, dtype=handle), @tir.tvm_stack_make_shape(16, 1, dtype=handle), 2, 0f32, (k.outer*4), dtype=handle), @tir.tvm_stack_make_array(B_2, @tir.tvm_stack_make_shape(4, 16, dtype=handle), @tir.tvm_stack_make_shape(16, 1, dtype=handle), 2, 0f32, (k.outer*64), dtype=handle), @tir.tvm_stack_make_array(C_2, @tir.tvm_stack_make_shape(16, 16, dtype=handle), @tir.tvm_stack_make_shape(16, 1, dtype=handle), 2, 0f32, 0, dtype=handle), dtype=int32)
}
} else {
@tir.tvm_call_packed("tvm.contrib.cblas.matmul_with_relu", @tir.tvm_stack_make_array(A_2, @tir.tvm_stack_make_shape(16, 4, dtype=handle), @tir.tvm_stack_make_shape(16, 1, dtype=handle), 2, 0f32, (k.outer*4), dtype=handle), @tir.tvm_stack_make_array(B_2, @tir.tvm_stack_make_shape(4, 16, dtype=handle), @tir.tvm_stack_make_shape(16, 1, dtype=handle), 2, 0f32, (k.outer*64), dtype=handle), @tir.tvm_stack_make_array(C_2, @tir.tvm_stack_make_shape(16, 16, dtype=handle), @tir.tvm_stack_make_shape(16, 1, dtype=handle), 2, 0f32, 0, dtype=handle), dtype=int32)
}
}
Thanks, comments are appreciated.
cc @tqchen .