You can follow the example in vta to overwrite the implementation for a specific target. https://github.com/apache/incubator-tvm/blob/master/vta/python/vta/top/op.py#L63
Thank you very much, I have another problem. I want to implement BN op with te.extern, like cuDNN. And I don’t want to unpack BN. Is there any exmaple for this or how should I do? Please help me @haichen
I have created new strategy for BN, as follows:
@override_native_generic_func("batch_norm_strategy")
def batch_norm_strategy(attrs, inputs, out_type, target):
"""batch_norm ssnpu strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_batch_norm(topi.nn.batch_norm),
wrap_topi_schedule(topi.generic.schedule_injective),
name="batch_norm.generic")
return strategy
@batch_norm_strategy.register("ssnpu")
def batch_norm_strategy_ssnpu(attrs, inputs, out_type, target):
"""batch_norm ssnpu strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_batch_norm(topi.ssnpu.batch_norm_vp),
wrap_topi_schedule(topi.ssnpu.schedule_batch_norm_vp),
name="batch_norm.ssnpu",
plevel=15)
return strategy
def batch_norm_vp(data, gamma, beta, mean, variance):
"""Batch_norm operator on ssnpu"""
return te.extern(
data.shape,
[data, gamma, beta, mean, variance],
lambda ins, outs: tvm.tir.call_packed(
"tvm.contrib.ssnpu.batch_norm.forward",
ins[0], ins[1], ins[2], ins[3], ins[4], outs[0]), dtype=data.dtype, name="C")
but I got some probleam when I run the net
tvm._ffi.base.TVMError: Traceback (most recent call last):
[bt] (8) /home/wda/tvm_0622/build/libtvm.so(tvm::runtime::GraphRuntime::Run()+0x79) [0x7fc3ce6ed8b7]
[bt] (7) /home/wda/tvm_0622/build/libtvm.so(std::function<void ()>::operator()() const+0x32) [0x7fc3cdb83460]
[bt] (6) /home/wda/tvm_0622/build/libtvm.so(+0x2862878) [0x7fc3ce6f3878]
[bt] (5) /home/wda/tvm_0622/build/libtvm.so(+0x285ff21) [0x7fc3ce6f0f21]
[bt] (4) /home/wda/tvm_0622/build/libtvm.so(tvm::runtime::PackedFunc::CallPacked(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*) const+0x30) [0x7fc3cda28c4e]
[bt] (3) /home/wda/tvm_0622/build/libtvm.so(std::function<void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)>::operator()(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*) const+0x5a) [0x7fc3cd9d6d48]
[bt] (2) /home/wda/tvm_0622/build/libtvm.so(+0x27ebdc6) [0x7fc3ce67cdc6]
[bt] (1) /home/wda/tvm_0622/build/libtvm.so(+0x27eaa3d) [0x7fc3ce67ba3d]
[bt] (0) /home/wda/tvm_0622/build/libtvm.so(dmlc::LogMessageFatal::~LogMessageFatal()+0x37) [0x7fc3cd96d191]
File "/home/wda/tvm_0622/src/runtime/library_module.cc", line 78
TVMError: Check failed: ret == 0 (-1 vs. 0) : Assert fail: (num_args == 6), fused_nn_batch_norm: num_args should be 6
The problem has been solved. My batch normal’s compute is wrong.
return te.extern(
#TODO data.shape[1] maybe need modify
[data.shape, data.shape[1], data.shape[1]],
[data, gamma, beta, mean, variance],
lambda ins, outs: tvm.tir.call_packed(
"tvm.contrib.ssnpu.batch_norm.forward",
ins[0], ins[1], ins[2], ins[3], ins[4], outs[0], outs[1], outs[2]), dtype=[dtype, dtype, dtype], name="C")