I followed the instructions as you suggested but I get following errors.
Please find my code here as well.
import torch
from tvm.topi import topi
import tvm
from tvm import te
from tvm.contrib import dlpack
def _codegen_function(name):
d1 = te.var('d1') # D1 -> # of rows of first matrix
d2 = te.var('d2') # D2 -> # of columns of first matrix
bsz = te.var('bsz') # bsz and d3 can be variables without impact on performance
d3 = te.var('d3')
A = te.placeholder((bsz, d1, d3), name='A', dtype='float32')
B = te.placeholder((bsz, d2, d3), name='B', dtype='float32')
R = topi.nn.batch_matmul(A, B)
s = te.create_schedule(R.op)
return tvm.build(s, [A, B, R], name=name, target = 'cuda')
if __name__ == "__main__":
bsz = 12
d1 = 2048
d2 = 1024
d3 = 64
bmm1 = _codegen_function('bmm1')
bmm1_pytorch = dlpack.to_pytorch_func(bmm1) # wrap it as a pytorch function
A = torch.randn(bsz, d1, d3, device='cuda')
B = torch.randn(bsz, d2, d3, device='cuda')
R = B.new_empty(bsz, d1, d2) # allocate memory for the result tensor
bmm1_pytorch(A, B, R)
Error:
TVMError: Traceback (most recent call last):
10: TVMFuncCall
9: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::runtime::Module (tvm::runtime::Map<tvm::Target, tvm::IRModule, void, void> const&, tvm::Target)>::AssignTypedLambda<tvm::$_5>(tvm::$_5, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
8: tvm::TIRToRuntime(tvm::runtime::Map<tvm::Target, tvm::IRModule, void, void> const&, tvm::Target const&)
7: tvm::SplitMixedModule(tvm::IRModule, tvm::Target const&, tvm::Target const&)
6: tvm::ApplyPasses(tvm::IRModule, tvm::transform::Sequential)
5: tvm::transform::Pass::operator()(tvm::IRModule) const
4: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
3: tvm::transform::SequentialNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
2: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
1: tvm::transform::ModulePassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
0: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::IRModule, tvm::transform::PassContext)>::AssignTypedLambda<tvm::tir::transform::VerifyMemory()::$_0>(tvm::tir::transform::VerifyMemory()::$_0)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
Did you forget to bind?
Variable `B` is directly accessed by host memory (it is not contained in a thread environment or in the function arguments.
Variable `A` is directly accessed by host memory (it is not contained in a thread environment or in the function arguments.
Variable `T_batch_matmul_NT` is directly accessed by host memory (it is not contained in a thread environment or in the function arguments.
Variable `T_batch_matmul_NT` is directly accessed by host memory (it is not contained in a thread environment or in the function arguments.
Variable `T_batch_matmul_NT` is directly accessed by host memory (it is not contained in a thread environment or in the function arguments.
File "/home/gakolhe/tvm/src/tir/analysis/verify_memory.cc", line 214
RuntimeError: Memory verification failed with the following errors:
PrimFunc([A, B, T_batch_matmul_NT]) attrs={"from_legacy_te_schedule": (bool)1, "global_symbol": "bmm1", "tir.noalias": (bool)1, "target": cuda -keys=cuda,gpu -arch=sm_60 -max_num_threads=1024 -thread_warp_size=32} {
for (b, 0, {batch|batch>=0}) {
for (i, 0, d1) {
for (j, 0, d2) {
T_batch_matmul_NT[(((b*stride) + (i*stride)) + (j*stride))] = 0f
for (k, 0, d3) {
T_batch_matmul_NT[(((b*stride) + (i*stride)) + (j*stride))] = (T_batch_matmul_NT[(((b*stride) + (i*stride)) + (j*stride))] + (A[(((b*stride) + (i*stride)) + (k*stride))]*B[(((b*stride) + (j*stride)) + (k*stride))]))
}
}
}
}
}