InferType question

import tvm
from tvm import relay

def get_mod():
    sb = relay.ScopeBuilder()
    mod = tvm.IRModule()
    f = relay.GlobalVar("f")
    main = relay.GlobalVar("main")

    # Recursive function f
    ti32 = relay.scalar_type("int32")
    n = relay.var("n", ti32)
    x = relay.var("x", shape=(1, 100), dtype="float32")
    with sb.if_scope(relay.equal(n, relay.const(0, ti32))):
        sb.ret(relay.zeros(shape=(1,100), dtype="float32"))
    with sb.else_scope():
        sb.ret(f(relay.subtract(n, relay.const(1, ti32)), relay.tanh(x)))
    mod[f] = relay.Function([n, x], sb.get())


    # mod = tvm.relay.transform.InferType()(mod) # Uncommenting this line works


    n1 = relay.var("n1", ti32)
    y = relay.var("y", shape=(1, 100), dtype="float32")
    out = f(n1, y)
    mod[main] = relay.Function([n1, y], out)
    return mod

mod = get_mod()
print(mod)
mod = tvm.relay.transform.InferType()(mod)
print(mod)

If I have main function call recursive function f, it fails InferType. If I uncomment the intermediate InferType, it works. Trying to understand if this is missing InferType feature, or something fundamental about how InferType is done.

@zhiics @jroesch @masahi @yzhliu

I tried something simpler. main calling f which just does tanh. This also seems to fail w/o intermediate InferType


def get_mod1():
    f1 = relay.GlobalVar("f1")

    # f1 just does tanh
    x = relay.var("x", shape=(1, 100))
    tanh = relay.tanh(x)
    out = tanh
    mod = tvm.IRModule()
    mod[f1] = relay.Function([x], out)

    # mod = tvm.relay.transform.InferType()(mod)


    # main calls f1
    y = relay.var("y", shape=(1, 100))
    out = f1(y)
    main = relay.GlobalVar("main")
    mod[main] = relay.Function([y], out)
    return mod

mod = get_mod1()
print(mod)
mod = tvm.relay.transform.InferType()(mod)
print(mod)