I’m trying to use scan operation in my work, there are two steps:
- scan operation (multiple states following instructions here)
- transpose matrix
However, the following message popped up, does it cause some problem?
[16:42:54] /home/jojo6174/tvm-installation/tvm/src/te/schedule/bound.cc:119: not in feed graph consumer = scan(scan, 0x372e260)
The followings are codes:
from __future__ import absolute_import, print_function
import timeit
import torch
import time
import numpy as np
import tvm
import tvm.testing
from tvm import te
def scan_example():
m = te.var("m")
n = te.var("n")
X = te.placeholder((m, n), name="X")
s_state1 = te.placeholder((m, n))
s_state2 = te.placeholder((m, n))
s_init1 = te.compute((1, n), lambda _, i: X[0, i])
s_init2 = te.compute((1, n), lambda _, i: X[0, i])
s_update1 = te.compute((m, n), lambda t, i: s_state2[t-1, i] * 2, name="s1")
s_update2 = te.compute((m, n), lambda t, i: s_state1[t, i] + s_state2[t-1, i] + X[t, i], name="s2")
s_scan1, s_scan2 = tvm.te.scan([s_init1, s_init2], [s_update1, s_update2], [s_state1, s_state2] , inputs=[X])
T = te.compute((n, m), lambda i, j: s_scan1[j, i])
s = te.create_schedule(T.op)
nf = 32
block_x = te.thread_axis('blockIdx.x')
thread_x = te.thread_axis('threadIdx.x')
block_y = te.thread_axis('blockIdx.y')
thread_y = te.thread_axis('threadIdx.y')
xo, xi = s[s_init1].split(s_init1.op.axis[1], factor=nf)
s[s_init1].bind(xo, block_x)
s[s_init1].bind(xi, thread_x)
xo, xi = s[s_init2].split(s_init2.op.axis[1], factor=nf)
s[s_init2].bind(xo, block_x)
s[s_init2].bind(xi, thread_x)
xo, xi = s[s_update1].split(s_update1.op.axis[1], factor=nf)
s[s_update1].bind(xo, block_x)
s[s_update1].bind(xi, thread_x)
xo, xi = s[s_update2].split(s_update2.op.axis[1], factor=nf)
s[s_update2].bind(xo, block_x)
s[s_update2].bind(xi, thread_x)
xo, xi = s[T].split(T.op.axis[1], factor=nf)
s[T].bind(xo, block_y)
s[T].bind(xi, thread_y)
fscan = tvm.build(s, [X, T], "cuda", name="myscan")
ctx = tvm.gpu(0)
n = 10
m = 10
a_np = np.arange(n*m).reshape(n, m).astype(s_scan1.dtype)
a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(np.zeros((n, m), dtype=s_scan1.dtype), ctx)
fscan(a, b)
print(b.asnumpy())