Check failed when use relay.build

TVM version 0.8.0
device : android-amd64-pc
Hi, I am trying to generate lib from relay.Function but I met a problem while generating.

  1: tvm::runtime::TypedPackedFunc<tvm::runtime::Array<tvm::te::Tensor, void> (tvm::Attrs const&, tvm::runtime::Array<tvm::te::Tensor, void> const&, tvm::Type const&)>::AssignTypedLambda<tvm::runtime::Array<tvm::te::Tensor, void> (*)(tvm::Attrs const&, tvm::runtime::Array<tvm::te::Tensor, void> const&, tvm::Type const&)>(tvm::runtime::Array<tvm::te::Tensor, void> (*)(tvm::Attrs const&, tvm::runtime::Array<tvm::te::Tensor, void> const&, tvm::Type const&))::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*) const
  0: tvm::runtime::Array<tvm::te::Tensor, void> tvm::relay::Pool2DGradCompute<tvm::relay::AvgPool2DAttrs, (tvm::topi::nn::PoolType)0>(tvm::Attrs const&, tvm::runtime::Array<tvm::te::Tensor, void> const&, tvm::Type const&)
  File "/home/timmy/Documents/research/tvm/src/relay/op/nn/pooling.cc", line 855
TVMError: 
---------------------------------------------------------------
An error occurred during the execution of TVM.
For more information, please see: https://tvm.apache.org/docs/errors.html
---------------------------------------------------------------
  Check failed: (param != nullptr) is false: 

Why this happens?
I found the c++ code but it seems that the params of Pool2DGradCompute is null.
Here is the code that generates the error

import numpy as np
import os

import tvm
from tvm import relay, auto_scheduler
from tvm.relay import data_dep_optimization as ddo
import tvm.relay.testing
from tvm.contrib import graph_executor
from tvm.contrib.utils import tempdir
target = tvm.target.Target("llvm -mtriple=arm64-linux-android")

network = "mobilenet"
use_sparse = False
batch_size = 1
layout = "NHWC"
dtype = "float32"


def load_module():
    graph_path = "./ir_zoos/mbv2_float/bias_only-1x3x128x128.ir"
    param_path = "./ir_zoos/mbv2_float/weights.params"
    SEMVER = '#[version = "0.0.5"]\n'
    with open(graph_path,"r") as fp:
        code = fp.read()
    expr = tvm.parser.parse_expr(SEMVER + code)
    # mod = tvm.IRModule()
    # mod["main"] = expr  
    mod = tvm.IRModule.from_expr(expr)
    with open(param_path,"rb") as fp:
        bin = fp.read()
    param = dict(tvm.runtime.load_param_dict(bin))
    new_param = dict()
    for keys in param.keys():
        # module.set_input("v"+keys,param[keys])
        print("v"+keys)
        new_param["v"+keys] = param[keys]
    return mod,new_param

def local_run(mod,param,size,input_name):
    target = tvm.target.Target("llvm -mtriple=x86_64-pc-linux-gnu")
    lib = relay.build(mod, target=target)
    dev = tvm.cpu()
    import numpy as np
    from tvm.contrib import graph_executor
    data_tvm = tvm.nd.array(np.random.uniform(size = size).astype("float32"))
    module =  graph_executor.GraphModule(lib["default"](dev))
    module.set_input(input_name,data_tvm)
    data_tvm1 = tvm.nd.array(np.random.uniform(size = (1,10)).astype("float32"))
    module.set_input("label",data_tvm1)
    module.run()
    output = module.get_output(0)
    print(output)
def main():
    mod,params = load_module()
    # print(params)
    size = (1,3,128,128)
    input_name = "input"
    local_run(mod,params,size,input_name)
    # remote_run(size)
    # tune(mod,params,target)
main()

and I guess here is the function that cause this error

%185 = reshape(%184, newshape=[1, 1280, 1, 1]) /* ty=Tensor[(1, 1280, 1, 1), float32] */;
  %186 = less(%183, %164) /* ty=Tensor[(1, 1280, 4, 4), bool] */;
  %187 = zeros(shape=[1, 1280, 4, 4], dtype="float32") /* ty=Tensor[(1, 1280, 4, 4), float32] */;
  %188 = nn.avg_pool2d_grad(%185, %165, pool_size=[4, 4], dilation=[]) /* ty=Tensor[(1, 1280, 4, 4), float32] */;

However, if I only generate the avg_pool2d_grad relay function. There is no error. Like this below. How?

import tvm.relay
import numpy as np
from tvm import relay
import tvm
target = tvm.target.Target("llvm -mtriple=x86_64-pc-linux-gnu")
x = tvm.relay.var("x", tvm.relay.TensorType([1,1,1,1], dtype="float32"))
y = tvm.relay.var("y", tvm.relay.TensorType([1,1,4,4], dtype="float32"))
expr = relay.nn.avg_pool2d_grad(x,y,(4,4))

func = tvm.relay.Function([x,y], expr)
print(func)
mod = tvm.IRModule.from_expr(func)
lib = relay.build(mod,target)
print("successfully built")