Which pass is expected to fused or combined operator in TVM

  • In llvm, we have a combine pass to transform some operators to another operator , sush as mull + add → mla/madd. so I make a simple demo to test it in tvm(tbe), and find it is not done in tir passes (common pass for tvm and tbe), does TVM have similar pass to combine the operators?

PS:I don’t want it is been done in graph as such operator may be the part of some operator , which are not independent operator.

  • test:
(py310) /data/zhongyunde/source/test/dsl # cat testMadd.py 
# import tvm
import te.lang.cce
from tbe import tvm
from tbe import dsl
from tbe.common.platform import set_current_compile_soc_info


set_current_compile_soc_info("Ascend310P3")

shape = (280000,280000)
dtype = "float16"

data = tvm.placeholder(shape, name="data", dtype=dtype)

with tvm.target.cce():
    const_val = tvm.const(-1, "float32")
    mul_val = te.lang.cce.vmuls(data, const_val)
    res = te.lang.cce.vadds(mul_val, const_val)
    sch = dsl.auto_schedule(res)
    print (res)
    print (sch)

config = { "name" : "abs_28_28_float16", "tensor_list" : [data,res]}

dsl.build(sch, config)
  • Base on tvm 0.18, it generate vfmadd213ps instruction (fused multiply-add) in disassemble function vector_mul_add_compute_, so tvm can transform such optimization for custom operator defined with te, but I still don’t know which pass is responsible fot this optimization ? .
382 00000000000016e0 <vector_mul_add_compute_>:
...
399     174e:       62 72 5d 48 a8 21       vfmadd213ps (%rcx),%zmm4,%zmm12
400     1754:       62 72 55 48 a8 69 01    vfmadd213ps 0x40(%rcx),%zmm5,%zmm13
...
  • test for tvm: python vector_matmul_add5.py
(tvm0.18_py310_zyd) root@j00595921debug2-cc95c9977-q752v:/home/zhongyunde# cat vector_matmul_add5.py 
import tvm
from tvm import te
import numpy as np

def vector_mul_add(n, dtype="float32"):
    """
    Element-wise vector fused multiply-add (FMA):
    result[i] = a[i] * b[i] + c[i] for i in 0..n-1
    """
    a = te.placeholder((n,), dtype=dtype, name="a")  # Input vector a
    b = te.placeholder((n,), dtype=dtype, name="b")  # Input vector b
    c = te.placeholder((n,), dtype=dtype, name="c")  # Input vector c
    
    # Element-wise multiplication: mul[i] = a[i] * b[i]
    mul = te.compute((n,), lambda i: a[i] * b[i], name="mul")
    
    # Element-wise addition: result[i] = mul[i] + c[i]
    result = te.compute((n,), lambda i: mul[i] + c[i], name="result")
    
    return [a, b, c, mul, result]

# Main program
n = 128  # Vector length (should be multiple of 16 for AVX512 full utilization)
a, b, c, mul, result = vector_mul_add(n)  # Unpack tensors from function
s = te.create_schedule(result.op)  # Create schedule for the computation graph

# Get operation objects for scheduling
mul_op = mul.op

# Apply vectorization for AVX512-FMA instructions:
# - Maps to 16-wide float32 operations using 'vfmadd231ps'
s[mul_op].vectorize(mul_op.axis[0])  # Vectorize multiplication
s[result].vectorize(result.op.axis[0])  # Vectorize addition

# Compile with AVX512-FMA support for Skylake-AVX512 architecture
# -mcpu=skylake-avx512 enables:
#   * 512-bit vector registers (ZMM0-ZMM31)
#   * FMA3 instructions for fused multiply-add
#   * 16 single-precision floats processed per instruction
target = "llvm -mcpu=skylake-avx512"
with tvm.transform.PassContext(opt_level=3):
    func = tvm.build(s, [a, b, c, result], target, name="vector_mul_add")

# Export optimized binary as a shared library
lib_path = "vector_mul_add.so"
func.export_library(lib_path)
print(f"Optimized binary with AVX512-FMA support exported to: {lib_path}")

find a RFC link maybe related, so now it is expected done in graph phase?

Basic Block Normal Form

Now I find the function MakeFMA in src/tir/transforms/lower_intrin.cc, maybe here suits to address the combine .

353   PrimExpr MakeFMA(const PrimExpr& a, const PrimExpr& b, const PrimExpr& c, const AddNode* op) {
354     // emit fma instruction: a * b + c
355     PrimExpr lhs = SwapBroadcastCast(a);
356     PrimExpr rhs = SwapBroadcastCast(b);
357 
358     if (fma_ != nullptr && op->dtype.is_float()) {
359       PrimExpr r = fma_(Call(op->dtype, builtin::fma(), {lhs, rhs, c}));
360       if (r.defined()) return this->VisitExpr(r);
361     } else {
362       if (!lhs.same_as(a) || !rhs.same_as(b)) {
363         PrimExpr mul = this->VisitExpr(Mul(lhs, rhs));
364         return Add(mul, this->VisitExpr(c));
365       }
366     }
367     return IRMutatorWithAnalyzer::VisitExpr_(op);
368   }