Not in feed graph consumer message when using scan operation

I’m trying to use scan operation in my work, there are two steps:

  1. scan operation (multiple states following instructions here)
  2. transpose matrix

However, the following message popped up, does it cause some problem?

[16:42:54] /home/jojo6174/tvm-installation/tvm/src/te/schedule/bound.cc:119: not in feed graph consumer = scan(scan, 0x372e260)

The followings are codes:

from __future__ import absolute_import, print_function

import timeit
import torch
import time
import numpy as np
import tvm
import tvm.testing
from tvm import te

def scan_example():
    m = te.var("m")
    n = te.var("n")

    X = te.placeholder((m, n), name="X")
    s_state1 = te.placeholder((m, n))
    s_state2 = te.placeholder((m, n))

    s_init1 = te.compute((1, n), lambda _, i: X[0, i])
    s_init2 = te.compute((1, n), lambda _, i: X[0, i])

    s_update1 = te.compute((m, n), lambda t, i: s_state2[t-1, i] * 2, name="s1")
    s_update2 = te.compute((m, n), lambda t, i: s_state1[t, i] + s_state2[t-1, i] + X[t, i], name="s2")

    s_scan1, s_scan2 = tvm.te.scan([s_init1, s_init2], [s_update1, s_update2], [s_state1, s_state2] , inputs=[X])

    T = te.compute((n, m), lambda i, j: s_scan1[j, i])

    s = te.create_schedule(T.op)

    nf = 32

    block_x = te.thread_axis('blockIdx.x')
    thread_x = te.thread_axis('threadIdx.x')

    block_y = te.thread_axis('blockIdx.y')
    thread_y = te.thread_axis('threadIdx.y')

    xo, xi = s[s_init1].split(s_init1.op.axis[1], factor=nf)
    s[s_init1].bind(xo, block_x)
    s[s_init1].bind(xi, thread_x)

    xo, xi = s[s_init2].split(s_init2.op.axis[1], factor=nf)
    s[s_init2].bind(xo, block_x)
    s[s_init2].bind(xi, thread_x)

    xo, xi = s[s_update1].split(s_update1.op.axis[1], factor=nf)
    s[s_update1].bind(xo, block_x)
    s[s_update1].bind(xi, thread_x)

    xo, xi = s[s_update2].split(s_update2.op.axis[1], factor=nf)
    s[s_update2].bind(xo, block_x)
    s[s_update2].bind(xi, thread_x)

    xo, xi = s[T].split(T.op.axis[1], factor=nf)
    s[T].bind(xo, block_y)
    s[T].bind(xi, thread_y)

    fscan = tvm.build(s, [X, T], "cuda", name="myscan")
    ctx = tvm.gpu(0)
    n = 10
    m = 10

    a_np = np.arange(n*m).reshape(n, m).astype(s_scan1.dtype)
    a = tvm.nd.array(a_np, ctx)
    b = tvm.nd.array(np.zeros((n, m), dtype=s_scan1.dtype), ctx)
    fscan(a, b)
    print(b.asnumpy())
2 Likes

@jojo6174 I also encountered a similar problem, did you solve it? @FrozenGene I want to know how this problem is caused.

tvm_tutorial/tvm/src/te/schedule/bound.cc:119: not in feed graph consumer = compute(p0_red_temp.repl, body=[T.reduce(T.comm_reducer(lambda argmax_lhs_0, argmax_lhs_1, argmax_rhs_0, argmax_rhs_1: (T.Select(argmax_lhs_1 > argmax_rhs_1 or argmax_lhs_1 == argmax_rhs_1 and argmax_lhs_0 < argmax_rhs_0, argmax_lhs_0, argmax_rhs_0), T.Select(argmax_lhs_1 > argmax_rhs_1, argmax_lhs_1, argmax_rhs_1)), [-1, T.float32(-3.4028234663852886e+38)]), source=[p0_red_temp.rf.v0[k3_inner_v, ax0, ax1, ax2], p0_red_temp.rf.v1[k3_inner_v, ax0, ax1, ax2]], init=[], axis=[T.iter_var(k3_inner_v, T.Range(0, 32), "CommReduce", "")], condition=T.bool(True), value_index=0), T.reduce(T.comm_reducer(lambda argmax_lhs_0, argmax_lhs_1, argmax_rhs_0, argmax_rhs_1: (T.Select(argmax_lhs_1 > argmax_rhs_1 or argmax_lhs_1 == argmax_rhs_1 and argmax_lhs_0 < argmax_rhs_0, argmax_lhs_0, argmax_rhs_0), T.Select(argmax_lhs_1 > argmax_rhs_1, argmax_lhs_1, argmax_rhs_1)), [-1, T.float32(-3.4028234663852886e+38)]), source=[p0_red_temp.rf.v0[k3_inner_v, ax0, ax1, ax2], p0_red_temp.rf.v1[k3_inner_v, ax0, ax1, ax2]], init=[], axis=[T.iter_var(k3_inner_v, T.Range(0, 32), "CommReduce", "")], condition=T.bool(True), value_index=1)], axis=[T.iter_var(ax0, T.Range(0, 12), "DataPar", ""), T.iter_var(ax1, T.Range(0, 160), "DataPar", ""), T.iter_var(ax2, T.Range(0, 328), "DataPar", "")], reduce_axis=[T.iter_var(k3_inner_v, T.Range(0, 32), "CommReduce", "")], tag=, attrs={})