I am starting my work with a simple compute graph as follows:
class NNModule(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, v):
x = nn.softmax(x)
x = nn.matmul(x, v)
return x
after invoking NNModule().export_tvm(...)
, i get the following module (the shape of softmax
is [5,5]
, and the shape of matmul
is [5,5] * [5,10] -> [5,10]
):
# from tvm.script import ir as I
# from tvm.script import relax as R
@I.ir_module
class Module:
@R.function
def forward(x: R.Tensor((5, 5), dtype="float32"), v: R.Tensor((5, 10), dtype="float32")) -> R.Tensor((5, 10), dtype="float32"):
R.func_attr({"num_input": 2})
with R.dataflow():
softmax: R.Tensor((5, 5), dtype="float32") = R.nn.softmax(x, axis=-1)
matmul: R.Tensor((5, 10), dtype="float32") = R.matmul(softmax, v, out_dtype="void")
gv: R.Tensor((5, 10), dtype="float32") = matmul
R.output(gv)
return gv
after relax.transform.LegalizeOps()
, i get:
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R
@I.ir_module
class Module:
@T.prim_func(private=True)
def matmul(softmax: T.Buffer((T.int64(5), T.int64(5)), "float32"), v: T.Buffer((T.int64(5), T.int64(10)), "float32"), matmul: T.Buffer((T.int64(5), T.int64(10)), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for i0, i1, k in T.grid(T.int64(5), T.int64(10), T.int64(5)):
with T.block("matmul"):
v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
T.reads(softmax[v_i0, v_k], v[v_k, v_i1])
T.writes(matmul[v_i0, v_i1])
with T.init():
matmul[v_i0, v_i1] = T.float32(0.0)
matmul[v_i0, v_i1] = matmul[v_i0, v_i1] + softmax[v_i0, v_k] * v[v_k, v_i1]
@T.prim_func(private=True)
def softmax(x: T.Buffer((T.int64(5), T.int64(5)), "float32"), T_softmax_norm: T.Buffer((T.int64(5), T.int64(5)), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
T_softmax_maxelem = T.alloc_buffer((T.int64(5),))
T_softmax_exp = T.alloc_buffer((T.int64(5), T.int64(5)))
T_softmax_expsum = T.alloc_buffer((T.int64(5),))
for i0, k in T.grid(T.int64(5), T.int64(5)):
with T.block("T_softmax_maxelem"):
v_i0, v_k = T.axis.remap("SR", [i0, k])
T.reads(x[v_i0, v_k])
T.writes(T_softmax_maxelem[v_i0])
with T.init():
T_softmax_maxelem[v_i0] = T.float32(-340282346638528859811704183484516925440.0)
T_softmax_maxelem[v_i0] = T.max(T_softmax_maxelem[v_i0], x[v_i0, v_k])
for i0, i1 in T.grid(T.int64(5), T.int64(5)):
with T.block("T_softmax_exp"):
v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
T.reads(x[v_i0, v_i1], T_softmax_maxelem[v_i0])
T.writes(T_softmax_exp[v_i0, v_i1])
T_softmax_exp[v_i0, v_i1] = T.exp(x[v_i0, v_i1] - T_softmax_maxelem[v_i0])
for i0, k in T.grid(T.int64(5), T.int64(5)):
with T.block("T_softmax_expsum"):
v_i0, v_k = T.axis.remap("SR", [i0, k])
T.reads(T_softmax_exp[v_i0, v_k])
T.writes(T_softmax_expsum[v_i0])
with T.init():
T_softmax_expsum[v_i0] = T.float32(0.0)
T_softmax_expsum[v_i0] = T_softmax_expsum[v_i0] + T_softmax_exp[v_i0, v_k]
for i0, i1 in T.grid(T.int64(5), T.int64(5)):
with T.block("T_softmax_norm"):
v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
T.reads(T_softmax_exp[v_i0, v_i1], T_softmax_expsum[v_i0])
T.writes(T_softmax_norm[v_i0, v_i1])
T.block_attr({"axis": 1})
T_softmax_norm[v_i0, v_i1] = T_softmax_exp[v_i0, v_i1] / T_softmax_expsum[v_i0]
@R.function
def forward(x: R.Tensor((5, 5), dtype="float32"), v: R.Tensor((5, 10), dtype="float32")) -> R.Tensor((5, 10), dtype="float32"):
R.func_attr({"num_input": 2})
cls = Module
with R.dataflow():
softmax = R.call_tir(cls.softmax, (x,), out_sinfo=R.Tensor((5, 5), dtype="float32"))
matmul = R.call_tir(cls.matmul, (softmax, v), out_sinfo=R.Tensor((5, 10), dtype="float32"))
gv: R.Tensor((5, 10), dtype="float32") = matmul
R.output(gv)
return gv
my question is can i fuse this softmax
and matmul
into one function and just keep their origin computations? what i want maybe like this:
...
def fused_softmax_matmul(x, v):
# origin softmax computation, just like what shows above
for ...
for ...
...
# following origin matmul computation, just like what shows above
for ...
for ...
for ...
...
with R.dataflow():
fused_softmax_matmul = R.call_tir(cls.fused_softmax_matmul, (x, v), ...)
R.output(fused_softmax_matmul)
return fused_softmax_matmul
...
I tried FuseOpsByPattern
, but i can just extract the patterns, without getting the actual computation functions. I am new to TVM so i cannot achieve this. Is it possible or how can i do this?