I’m using TVM’s DFPatternRewrite class to identify a pattern comprising of IR mentioned below:
%0 = nn.conv2d(%data, %weight1, padding=[1, 1, 1, 1], groups=304, channels=304, kernel_size=[3, 3]) /* ty=Tensor[(1, 304, 128, 128), float32] */;
%1 = nn.bias_add(%0, %bias) /* ty=Tensor[(1, 304, 128, 128), float32] */;
%2 = nn.relu(%1) /* ty=Tensor[(1, 304, 128, 128), float32] */;
%3 = nn.conv2d(%2, %weight2, padding=[0, 0, 0, 0], channels=256, kernel_size=[1, 1]) /* ty=Tensor[(1, 256, 128, 128), float32] */;
%4 = nn.bias_add(%3, %bias1) /* ty=Tensor[(1, 256, 128, 128), float32] */;
nn.relu(%4) /* ty=Tensor[(1, 256, 128, 128), float32] */
I’m able to detect conv2d using DFPatternRewrite but it’s unable to detect depthwise convolution in the pattern.
Mentioned below is the C++ source code that I have used:
class CLS: public DFPatternRewrite {
public:
CLS() {
x_ = IsWildcard();
const1_ = IsWildcard();
const2_ = IsWildcard();
const3_ = IsWildcard();
const4_ = IsWildcard();
Map<String, ObjectRef> attrs;
attrs.Set("groups", Integer(static_cast<int>(304)));
// Also tried with groups = 1 but the pattern isn't getting recognised
// attrs.Set("groups", Integer(static_cast<int>(1)));
auto conv2d = IsOp("nn.conv2d").HasAttr(attrs);
auto maybedepthwise = IsOp("nn.conv2d");
auto biasadd = IsOp("nn.bias_add");
auto relu = IsOp("nn.relu");
pattern_ = relu({biasadd({conv2d({relu({biasadd({maybedepthwise({x_, const1_}), const2_})}), const3_}), const4_})});
}
Expr Callback(const Expr& pre, const Expr& post,
const Map<DFPattern, Array<Expr>>& node_map) const override {
std::cout << "depthwise conv2d detected " << std::endl;
return post;
}
The callback function isn’t being hit.
Sharing my test case:
import tvm
from tvm import relay
data = relay.var("data", relay.TensorType((1, 304, 128, 128), "float32"))
weight1 = relay.var("weight1", relay.TensorType((304, 1, 3, 3), "float32"))
bias = relay.var("bias", shape=(304,))
conv = relay.nn.conv2d(data=data, weight=weight1, kernel_size=(3, 3),
channels=304, groups=304, padding=(1, 1, 1, 1))
bias_add = relay.nn.bias_add(conv, bias)
relu = relay.nn.relu(bias_add)
weight2 = relay.var("weight2", relay.TensorType((256, 304, 1, 1), "float32"))
conv = relay.nn.conv2d(data=relu, weight=weight2, kernel_size=(1, 1),
channels=256, padding=(0, 0, 0, 0))
bias2 = relay.var("bias", shape=(256,))
bias_add = relay.nn.bias_add(conv, bias2)
relu = relay.nn.relu(bias_add)
func = relay.Function([data, weight1, weight2, bias, bias2], relu)
mod = tvm.IRModule()
mod["main"] = func
mod = relay.transform.CLS()(mod)
Any help would be appreciated on this. Thank you!