Inferbound error(domain already inferred) of split op

Hi all:

Recently I met infer bound error on split op:

TVMError: Check failed: match: iter_var(blockIdx.x, , blockIdx.x) domain already inferred, cannot prove their extents are the same floordiv(((({any_dim|any_dim>=0}*{any_dim|any_dim>=0})*({any_dim|any_dim>=0} - (floordiv({any_dim|any_dim>=0}, 4)*3))) + 511), 512) vs floordiv(((({any_dim|any_dim>=0}*{any_dim|any_dim>=0})*floordiv({any_dim|any_dim>=0}, 4)) + 511), 512)
Error during compile function

And this error can be reproduced by this code snippet:

import tvm
from tvm import relay
import numpy as np

def test_split():
    
    # input var
    data = relay.var("data", tvm.ir.TensorType(shape = (relay.Any(), relay.Any(), relay.Any()), dtype = "float32"))
    data_1 = relay.var("data_1", tvm.ir.TensorType(shape = (relay.Any(), relay.Any(), relay.Any()), dtype = "float32"))
    
    split_out = relay.op.split(data, 4, 2)

    out = tvm.relay.op.multiply(split_out[0], data_1)
    out_1 = tvm.relay.op.multiply(split_out[1], data_1)
    
    mod = tvm.IRModule()
    mod["main"] = relay.Function([data, data_1], tvm.relay.expr.Tuple([out, out_1]))


    # tvm setting
    target = tvm.target.Target("cuda -libs=cublas,cudnn")
    ctx = tvm.gpu(0)

    ex = relay.create_executor("vm", mod=mod, ctx=ctx, target=target)
    
    # inputs data
    data_np = np.random.randint(256, size=(1,1,256)).astype(np.float32)
    data_1_np = np.random.randint(64, size=(1,1,64)).astype(np.float32)

    result_tvm = ex.evaluate()(data=data_np, data_1=data_1_np)
    print("=======tvm result=========")
    print(result_tvm)

if __name__ == "__main__":
    test_split()

It seems that the itervar was inferred twice, but the range->extent can not be proved the same.

Could anyone give me some advices on this issue, why tvm cannot correctly infer bound on this case?

Best regards!

@junrushao Please take a look :slight_smile: :blush: