import tvm
from tvm import relay
def get_mod():
sb = relay.ScopeBuilder()
mod = tvm.IRModule()
f = relay.GlobalVar("f")
main = relay.GlobalVar("main")
# Recursive function f
ti32 = relay.scalar_type("int32")
n = relay.var("n", ti32)
x = relay.var("x", shape=(1, 100), dtype="float32")
with sb.if_scope(relay.equal(n, relay.const(0, ti32))):
sb.ret(relay.zeros(shape=(1,100), dtype="float32"))
with sb.else_scope():
sb.ret(f(relay.subtract(n, relay.const(1, ti32)), relay.tanh(x)))
mod[f] = relay.Function([n, x], sb.get())
# mod = tvm.relay.transform.InferType()(mod) # Uncommenting this line works
n1 = relay.var("n1", ti32)
y = relay.var("y", shape=(1, 100), dtype="float32")
out = f(n1, y)
mod[main] = relay.Function([n1, y], out)
return mod
mod = get_mod()
print(mod)
mod = tvm.relay.transform.InferType()(mod)
print(mod)
If I have main function call recursive function f
, it fails InferType. If I uncomment the intermediate InferType, it works. Trying to understand if this is missing InferType feature, or something fundamental about how InferType is done.