https://tvm.apache.org/docs/how_to/work_with_schedules/tensorize.html I am trying to reporduce it.
from __future__ import absolute_import, print_function
import tvm
from tvm import te
import tvm.testing
import numpy as np
N, M, L = 1024, 512, 64
A = te.placeholder((N, L), name="A")
B = te.placeholder((M, L), name="B")
k = te.reduce_axis((0, L), name="k")
C = te.compute((N, M), lambda i, j: te.sum(A[i, k] * B[j, k], axis=k), name="C")
s = te.create_schedule(C.op)
# print(tvm.lower(s, [A, B, C], simple_mode=True))
factor = 16
x, y = C.op.axis
(z,) = C.op.reduce_axis
yo, yi = s[C].split(y, factor=factor)
s[C].reorder(x, yo, yi, z)
# print(tvm.lower(s, [A, B, C], simple_mode=True))
def intrin_gemv(m, l):
a = te.placeholder((l,), name="a")
b = te.placeholder((m, l), name="b")
k = te.reduce_axis((0, l), name="k")
c = te.compute((m,), lambda i: te.sum(a[k] * b[i, k], axis=k), name="c")
Ab = tvm.tir.decl_buffer(a.shape, a.dtype, name="A", offset_factor=1, strides=[1])
Bb = tvm.tir.decl_buffer(b.shape, b.dtype, name="B", offset_factor=1, strides=[te.var("s1"), 1])
Cb = tvm.tir.decl_buffer(c.shape, c.dtype, name="C", offset_factor=1, strides=[1])
def intrin_func(ins, outs):
ib = tvm.tir.ir_builder.create()
aa, bb = ins
cc = outs[0]
ib.emit(
tvm.tir.call_extern(
"int32",
"gemv_update",
cc.access_ptr("w"),
aa.access_ptr("r"),
bb.access_ptr("r"),
m,
l,
bb.strides[0],
)
)
return ib.get()
return te.decl_tensor_intrin(c.op, intrin_func, binds={a: Ab, b: Bb, c: Cb})
gemv = intrin_gemv(factor, L)
s[C].tensorize(yi, gemv)
# print(tvm.lower(s, [A, B, C], simple_mode=True))
def gemv_impl():
cc_code = """
extern "C" int gemv_update(float *cc, float *aa, float *bb, int m, int l, int stride) {
for (int i = 0; i < m; ++i) {
for (int j = 0; j < l; ++j) {
cc[i] += aa[j] * bb[i * stride + j];
}
}
return 0;
}
"""
from tvm.contrib import utils, clang
temp = utils.tempdir()
ll_path = temp.relpath("temp.ll")
# Create LLVM ir from c source code
ll_code = clang.create_llvm(cc_code, output=ll_path)
return ll_code
s[C].pragma(x, "import_llvm", gemv_impl())
# print(tvm.lower(s, [A, B, C], simple_mode=True))
func = tvm.build(s, [A, B, C], target="llvm", name="gemv")
from tvm.topi.utils import get_const_tuple
dtype = A.dtype
dev = tvm.device("cpu", 0)
a = np.random.uniform(size=get_const_tuple(A.shape)).astype(dtype)
b = np.random.uniform(size=get_const_tuple(B.shape)).astype(dtype)
c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=dtype), dev)
func(tvm.nd.array(a, dev), tvm.nd.array(b, dev), c)
tvm.testing.assert_allclose(c.numpy(), np.dot(a, b.T), rtol=1e-3)
But it shows:
Traceback (most recent call last):
File "try.py", line 76, in <module>
func = tvm.build(s, [A, B, C], target="llvm", name="gemv")
File "/Users/linj/Desktop/tvm/python/tvm/driver/build_module.py", line 294, in build
rt_mod_host = _driver_ffi.tir_to_runtime(annotated_mods, target_host)
File "/Users/linj/Desktop/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 239, in __call__
raise_last_ffi_error()
File "/Users/linj/Desktop/tvm/python/tvm/_ffi/base.py", line 481, in raise_last_ffi_error
raise py_err
tvm._ffi.base.TVMError: Traceback (most recent call last):
File "/Users/linj/Desktop/tvm/src/target/llvm/llvm_instance.cc", line 176
error: expected type
define i32 @gemv_update(ptr noundef %0, ptr noundef %1, ptr noundef %2, i32 noundef %3, i32 noundef %4, i32 noundef %5) #0 {