I noticed that Dietcode uses dynamic batched matrix multiplication to measure the performance of automatic optimization.
Here is my code:
def build_bm(tensor_a, tensor_b):
bb = relax.BlockBuilder()
with bb.function("bm", [tensor_a, tensor_b]):
gv0 = bb.emit_te(tvm.topi.nn.batch_matmul, tensor_a, tensor_b)
gv1 = bb.emit_te(topi.nn.relu, gv0)
bb.emit_func_output(gv1)
mod = bb.get()
return mod
if __name__ == "__main__":
# symbolic dimensions
m, k = tir.Var("m", "int64"), tir.Var("k", "int64")
# create data and weight variables
tensor_a = relax.Var("tensor_a", relax.TensorStructInfo([192, m, k], "float32"))
tensor_b = relax.Var("tensor_b", relax.TensorStructInfo([192, k, m], "float32"))
mod = build_bm(tensor_a, tensor_b)
target = tvm.target.Target("llvm", host="llvm")
ex = relax.build(mod, target)
vm = relax.VirtualMachine(ex, tvm.cpu(), profile=True)
data = tvm.nd.array((np.random.rand(192, 100, 64).astype(np.float32)))
weight = tvm.nd.array((np.random.rand(192, 64, 100).astype(np.float32)))
print(vm.profile("bm", data, weight))
In relax, is there anyway to optimize dynamic batched matrix multiplication?
BTW, I’m trying to run this script https://github.com/apache/tvm/blob/unity/apps/relax_examples/e2e_auto_tir.py that using meta_schedule to tune the resnet model, but get the following error:
input_name: input0
input_shape: [1, 3, 224, 224]
input_dtype: float32
INFO:tvm.meta_schedule.runner.local_runner:LocalRunner: max_workers = 1
Traceback (most recent call last):
File "e2e_auto_tir.py", line 253, in <module>
main()
File "e2e_auto_tir.py", line 194, in main
db = ms.relax_integration.tune_relax(
File "/root/work/tvm-unity/python/tvm/meta_schedule/relax_integration.py", line 236, in tune_relax
all_tasks = extract_tasks(mod, target, params, module_equality=module_equality)
File "/root/work/tvm-unity/python/tvm/meta_schedule/relax_integration.py", line 98, in extract_tasks
mod = BindParams("main", params)(mod)
File "/root/work/tvm-unity/python/tvm/ir/transform.py", line 238, in __call__
return _ffi_transform_api.RunPass(self, mod)
File "/root/work/tvm-unity/python/tvm/_ffi/_ctypes/packed_func.py", line 239, in __call__
raise_last_ffi_error()
File "/root/work/tvm-unity/python/tvm/_ffi/base.py", line 476, in raise_last_ffi_error
raise py_err
File "/root/work/tvm-unity/src/relax/transform/bind_params.cc", line 206, in operator()
[=](IRModule mod, PassContext pc) { return BindParam(std::move(mod), func_name, params); };
File "/root/work/tvm-unity/src/relax/transform/bind_params.cc", line 185, in tvm::relax::BindParam(tvm::IRModule, tvm::runtime::String, tvm::runtime::Map<tvm::runtime::ObjectRef, tvm::runtime::ObjectRef, void, void>)
Function f_after_bind = FunctionBindParams(GetRef<Function>(relax_f), bind_params);
File "/root/work/tvm-unity/src/relax/transform/bind_params.cc", line 163, in tvm::relax::FunctionBindParams(tvm::relax::Function, tvm::runtime::Map<tvm::runtime::ObjectRef, tvm::runtime::ObjectRef, void, void> const&)
auto [bind_dict, symbolic_var_map] = NormalizeBindings(func, untyped_params);
File "/root/work/tvm-unity/src/relax/transform/bind_params.cc", line 143, in tvm::relax::NormalizeBindings(tvm::relax::Function const&, tvm::runtime::Map<tvm::runtime::ObjectRef, tvm::runtime::ObjectRef, void, void> const&)
relax_var_remap.Set(normalize_key(key), normalize_value(value));
File "/root/work/tvm-unity/src/relax/transform/bind_params.cc", line 104, in operator()
CHECK(it != string_lookup.end())
tvm._ffi.base.TVMError: Traceback (most recent call last):
4: operator()
at /root/work/tvm-unity/src/relax/transform/bind_params.cc:206
3: tvm::relax::BindParam(tvm::IRModule, tvm::runtime::String, tvm::runtime::Map<tvm::runtime::ObjectRef, tvm::runtime::ObjectRef, void, void>)
at /root/work/tvm-unity/src/relax/transform/bind_params.cc:185
2: tvm::relax::FunctionBindParams(tvm::relax::Function, tvm::runtime::Map<tvm::runtime::ObjectRef, tvm::runtime::ObjectRef, void, void> const&)
at /root/work/tvm-unity/src/relax/transform/bind_params.cc:163
1: tvm::relax::NormalizeBindings(tvm::relax::Function const&, tvm::runtime::Map<tvm::runtime::ObjectRef, tvm::runtime::ObjectRef, void, void> const&)
at /root/work/tvm-unity/src/relax/transform/bind_params.cc:143
0: operator()
at /root/work/tvm-unity/src/relax/transform/bind_params.cc:104
File "/root/work/tvm-unity/src/relax/transform/bind_params.cc", line 104
TVMError: Check failed: (it != string_lookup.end()) is false: Function does not have parameter with name "aten::add__7.num_batches_tracked". Function parameters are named ["input0"]
Anyone know how to solve it? Thank you.