vfdff
June 17, 2025, 1:07am
1
In llvm, we have a combine pass to transform some operators to another operator , sush as mull + add → mla/madd.
so I make a simple demo to test it in tvm(tbe), and find it is not done in tir passes (common pass for tvm and tbe), does TVM have similar pass to combine the operators?
PS:I don’t want it is been done in graph as such operator may be the part of some operator , which are not independent operator.
(py310) /data/zhongyunde/source/test/dsl # cat testMadd.py
# import tvm
import te.lang.cce
from tbe import tvm
from tbe import dsl
from tbe.common.platform import set_current_compile_soc_info
set_current_compile_soc_info("Ascend310P3")
shape = (280000,280000)
dtype = "float16"
data = tvm.placeholder(shape, name="data", dtype=dtype)
with tvm.target.cce():
const_val = tvm.const(-1, "float32")
mul_val = te.lang.cce.vmuls(data, const_val)
res = te.lang.cce.vadds(mul_val, const_val)
sch = dsl.auto_schedule(res)
print (res)
print (sch)
config = { "name" : "abs_28_28_float16", "tensor_list" : [data,res]}
dsl.build(sch, config)
vfdff
June 17, 2025, 2:22am
2
Base on tvm 0.18, it generate vfmadd213ps instruction (fused multiply-add ) in disassemble function vector_mul_add_compute_, so tvm can transform such optimization for custom operator defined with te, but I still don’t know which pass is responsible fot this optimization ? .
382 00000000000016e0 <vector_mul_add_compute_>:
...
399 174e: 62 72 5d 48 a8 21 vfmadd213ps (%rcx),%zmm4,%zmm12
400 1754: 62 72 55 48 a8 69 01 vfmadd213ps 0x40(%rcx),%zmm5,%zmm13
...
test for tvm: python vector_matmul_add5.py
(tvm0.18_py310_zyd) root@j00595921debug2-cc95c9977-q752v:/home/zhongyunde# cat vector_matmul_add5.py
import tvm
from tvm import te
import numpy as np
def vector_mul_add(n, dtype="float32"):
"""
Element-wise vector fused multiply-add (FMA):
result[i] = a[i] * b[i] + c[i] for i in 0..n-1
"""
a = te.placeholder((n,), dtype=dtype, name="a") # Input vector a
b = te.placeholder((n,), dtype=dtype, name="b") # Input vector b
c = te.placeholder((n,), dtype=dtype, name="c") # Input vector c
# Element-wise multiplication: mul[i] = a[i] * b[i]
mul = te.compute((n,), lambda i: a[i] * b[i], name="mul")
# Element-wise addition: result[i] = mul[i] + c[i]
result = te.compute((n,), lambda i: mul[i] + c[i], name="result")
return [a, b, c, mul, result]
# Main program
n = 128 # Vector length (should be multiple of 16 for AVX512 full utilization)
a, b, c, mul, result = vector_mul_add(n) # Unpack tensors from function
s = te.create_schedule(result.op) # Create schedule for the computation graph
# Get operation objects for scheduling
mul_op = mul.op
# Apply vectorization for AVX512-FMA instructions:
# - Maps to 16-wide float32 operations using 'vfmadd231ps'
s[mul_op].vectorize(mul_op.axis[0]) # Vectorize multiplication
s[result].vectorize(result.op.axis[0]) # Vectorize addition
# Compile with AVX512-FMA support for Skylake-AVX512 architecture
# -mcpu=skylake-avx512 enables:
# * 512-bit vector registers (ZMM0-ZMM31)
# * FMA3 instructions for fused multiply-add
# * 16 single-precision floats processed per instruction
target = "llvm -mcpu=skylake-avx512"
with tvm.transform.PassContext(opt_level=3):
func = tvm.build(s, [a, b, c, result], target, name="vector_mul_add")
# Export optimized binary as a shared library
lib_path = "vector_mul_add.so"
func.export_library(lib_path)
print(f"Optimized binary with AVX512-FMA support exported to: {lib_path}")
vfdff
June 24, 2025, 2:08am
3
find a RFC link maybe related, so now it is expected done in graph phase?
Basic Block Normal Form
vfdff
June 25, 2025, 11:49am
4
Now I find the function MakeFMA in src/tir/transforms/lower_intrin.cc, maybe here suits to address the combine .
353 PrimExpr MakeFMA(const PrimExpr& a, const PrimExpr& b, const PrimExpr& c, const AddNode* op) {
354 // emit fma instruction: a * b + c
355 PrimExpr lhs = SwapBroadcastCast(a);
356 PrimExpr rhs = SwapBroadcastCast(b);
357
358 if (fma_ != nullptr && op->dtype.is_float()) {
359 PrimExpr r = fma_(Call(op->dtype, builtin::fma(), {lhs, rhs, c}));
360 if (r.defined()) return this->VisitExpr(r);
361 } else {
362 if (!lhs.same_as(a) || !rhs.same_as(b)) {
363 PrimExpr mul = this->VisitExpr(Mul(lhs, rhs));
364 return Add(mul, this->VisitExpr(c));
365 }
366 }
367 return IRMutatorWithAnalyzer::VisitExpr_(op);
368 }