[Relax] How to optimize matrix multiplication in Relax

I noticed that Dietcode uses dynamic batched matrix multiplication to measure the performance of automatic optimization.

Here is my code:

def build_bm(tensor_a, tensor_b):
    bb = relax.BlockBuilder()

    with bb.function("bm", [tensor_a, tensor_b]):
        gv0 = bb.emit_te(tvm.topi.nn.batch_matmul, tensor_a, tensor_b)
        gv1 = bb.emit_te(topi.nn.relu, gv0)
        bb.emit_func_output(gv1)

    mod = bb.get()
    return mod


if __name__ == "__main__":
    # symbolic dimensions
    m, k = tir.Var("m", "int64"), tir.Var("k", "int64")
    # create data and weight variables
    tensor_a = relax.Var("tensor_a", relax.TensorStructInfo([192, m, k], "float32"))
    tensor_b = relax.Var("tensor_b", relax.TensorStructInfo([192, k, m], "float32"))

    mod = build_bm(tensor_a, tensor_b)
    target = tvm.target.Target("llvm", host="llvm")
    ex = relax.build(mod, target)

    vm = relax.VirtualMachine(ex, tvm.cpu(), profile=True)

    data = tvm.nd.array((np.random.rand(192, 100, 64).astype(np.float32)))
    weight = tvm.nd.array((np.random.rand(192, 64, 100).astype(np.float32)))
    
    print(vm.profile("bm", data, weight))

In relax, is there anyway to optimize dynamic batched matrix multiplication?

BTW, I’m trying to run this script https://github.com/apache/tvm/blob/unity/apps/relax_examples/e2e_auto_tir.py that using meta_schedule to tune the resnet model, but get the following error:

  input_name: input0
  input_shape: [1, 3, 224, 224]
  input_dtype: float32
INFO:tvm.meta_schedule.runner.local_runner:LocalRunner: max_workers = 1
Traceback (most recent call last):
  File "e2e_auto_tir.py", line 253, in <module>
    main()
  File "e2e_auto_tir.py", line 194, in main
    db = ms.relax_integration.tune_relax(
  File "/root/work/tvm-unity/python/tvm/meta_schedule/relax_integration.py", line 236, in tune_relax
    all_tasks = extract_tasks(mod, target, params, module_equality=module_equality)
  File "/root/work/tvm-unity/python/tvm/meta_schedule/relax_integration.py", line 98, in extract_tasks
    mod = BindParams("main", params)(mod)
  File "/root/work/tvm-unity/python/tvm/ir/transform.py", line 238, in __call__
    return _ffi_transform_api.RunPass(self, mod)
  File "/root/work/tvm-unity/python/tvm/_ffi/_ctypes/packed_func.py", line 239, in __call__
    raise_last_ffi_error()
  File "/root/work/tvm-unity/python/tvm/_ffi/base.py", line 476, in raise_last_ffi_error
    raise py_err
  File "/root/work/tvm-unity/src/relax/transform/bind_params.cc", line 206, in operator()
    [=](IRModule mod, PassContext pc) { return BindParam(std::move(mod), func_name, params); };
  File "/root/work/tvm-unity/src/relax/transform/bind_params.cc", line 185, in tvm::relax::BindParam(tvm::IRModule, tvm::runtime::String, tvm::runtime::Map<tvm::runtime::ObjectRef, tvm::runtime::ObjectRef, void, void>)
    Function f_after_bind = FunctionBindParams(GetRef<Function>(relax_f), bind_params);
  File "/root/work/tvm-unity/src/relax/transform/bind_params.cc", line 163, in tvm::relax::FunctionBindParams(tvm::relax::Function, tvm::runtime::Map<tvm::runtime::ObjectRef, tvm::runtime::ObjectRef, void, void> const&)
    auto [bind_dict, symbolic_var_map] = NormalizeBindings(func, untyped_params);
  File "/root/work/tvm-unity/src/relax/transform/bind_params.cc", line 143, in tvm::relax::NormalizeBindings(tvm::relax::Function const&, tvm::runtime::Map<tvm::runtime::ObjectRef, tvm::runtime::ObjectRef, void, void> const&)
    relax_var_remap.Set(normalize_key(key), normalize_value(value));
  File "/root/work/tvm-unity/src/relax/transform/bind_params.cc", line 104, in operator()
    CHECK(it != string_lookup.end())
tvm._ffi.base.TVMError: Traceback (most recent call last):
  4: operator()
        at /root/work/tvm-unity/src/relax/transform/bind_params.cc:206
  3: tvm::relax::BindParam(tvm::IRModule, tvm::runtime::String, tvm::runtime::Map<tvm::runtime::ObjectRef, tvm::runtime::ObjectRef, void, void>)
        at /root/work/tvm-unity/src/relax/transform/bind_params.cc:185
  2: tvm::relax::FunctionBindParams(tvm::relax::Function, tvm::runtime::Map<tvm::runtime::ObjectRef, tvm::runtime::ObjectRef, void, void> const&)
        at /root/work/tvm-unity/src/relax/transform/bind_params.cc:163
  1: tvm::relax::NormalizeBindings(tvm::relax::Function const&, tvm::runtime::Map<tvm::runtime::ObjectRef, tvm::runtime::ObjectRef, void, void> const&)
        at /root/work/tvm-unity/src/relax/transform/bind_params.cc:143
  0: operator()
        at /root/work/tvm-unity/src/relax/transform/bind_params.cc:104
  File "/root/work/tvm-unity/src/relax/transform/bind_params.cc", line 104
TVMError: Check failed: (it != string_lookup.end()) is false: Function does not have parameter with name "aten::add__7.num_batches_tracked".  Function parameters are named ["input0"]

Anyone know how to solve it? Thank you.

I used the following command to run the script:

python e2e_auto_tir.py --workload='resnet_18' --input-shape=[1,3,224,224] --target='llvm' --work-dir=$pwd --num-trials=3

Hi @Elam ,

Were you able to solve this issue? I am having similar errors using the same script.

the default autotuning may have some issues with dyn shape, we introduced a recent module called dlight which helps to resolve some of that in llm usecases