Hello, to get started I have an example code of intrinsic function like this:
from __future__ import absolute_import, print_function
import tvm
from tvm import te
import numpy as np
from tvm.topi.utils import get_const_tuple
ctx = tvm.context("cpu", 0)
M = 8
factor = 4
A = te.placeholder((M,), name='A')
B = te.placeholder((M,), name='B')
C = te.compute((M,), lambda i: A[i] + B[i], name='C')
s = te.create_schedule(C.op)
x, = C.op.axis
xo, xi = s[C].split(x, factor=factor)
print(tvm.lower(s, [A, B, C], simple_mode=True))
dtype = A.dtype
def intrin_add(m):
a = te.placeholder((m,), name='a')
b = te.placeholder((m,), name='b')
c = te.compute((m,), lambda i: a[i] + b[i], name='c')
d = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=dtype), ctx)
Ab = tvm.tir.decl_buffer(a.shape, a.dtype,
name="A",
offset_factor=1,
strides=[1])
Bb = tvm.tir.decl_buffer(b.shape, b.dtype,
name="B",
offset_factor=1,
strides=[1])
Cb = tvm.tir.decl_buffer(c.shape, c.dtype,
name="C",
offset_factor=1,
strides=[1])
def intrin_func(ins, outs):
aa, bb = ins
cc = outs[0]
return tvm.tir.call_extern('int32', 'add', aa.access_ptr('r'), bb.access_ptr('r'), cc.access_ptr('w'), m)
return te.decl_tensor_intrin(c.op, intrin_func, binds={a: Ab, b: Bb, c: Cb})
add = intrin_add(factor)
s[C].tensorize(xi, add)
print(tvm.lower(s, [A, B, C], simple_mode=True))
func = tvm.build(s, [A, B, C], target="llvm -mcpu=core-avx2", name="add")
a = np.random.uniform(size=get_const_tuple(A.shape)).astype(dtype)
b = np.random.uniform(size=get_const_tuple(B.shape)).astype(dtype)
c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=dtype), ctx)
func(tvm.nd.array(a, ctx), tvm.nd.array(b, ctx), c)
np.testing.assert_allclose(c.asnumpy(), np.add(a, b), rtol=1e-3)
while the extern C function looks like:
extern "C" int add(const float* a, const float* b, float* c, int M) {
for (int i = 0; i < M; i++) {
c[i] = a[i] + b[i];
}
return 0;
}
Now I wanna pass a constant input array dd
to the intrinsic function so that it looks like:
def intrin_func(ins, outs):
aa, bb = ins
cc = outs[0]
##
# compute dd here
##
return tvm.tir.call_extern('int32', 'add',
aa.access_ptr('r'),
bb.access_ptr('r'),
cc.access_ptr('w'),
dd,
m)
and correspondingly the C function’s API becomes
extern "C" int add(const float* a, const float* b, float* c, const float* d, int M)
I have tried multiple ways, e.g. passing a list / ndarray to the function, etc, but they all failed. It seems like the intrinsic function requires inputs to be string
or primexpr
and I didn’t see any types that is compatible with a list or an array.
Any one can help? Thanks!