DFPatternRewrite is unable to detect depthwise convolution

I’m using TVM’s DFPatternRewrite class to identify a pattern comprising of IR mentioned below:

  %0 = nn.conv2d(%data, %weight1, padding=[1, 1, 1, 1], groups=304, channels=304, kernel_size=[3, 3]) /* ty=Tensor[(1, 304, 128, 128), float32] */;
  %1 = nn.bias_add(%0, %bias) /* ty=Tensor[(1, 304, 128, 128), float32] */;
  %2 = nn.relu(%1) /* ty=Tensor[(1, 304, 128, 128), float32] */;
  %3 = nn.conv2d(%2, %weight2, padding=[0, 0, 0, 0], channels=256, kernel_size=[1, 1]) /* ty=Tensor[(1, 256, 128, 128), float32] */;
  %4 = nn.bias_add(%3, %bias1) /* ty=Tensor[(1, 256, 128, 128), float32] */;
  nn.relu(%4) /* ty=Tensor[(1, 256, 128, 128), float32] */

I’m able to detect conv2d using DFPatternRewrite but it’s unable to detect depthwise convolution in the pattern.

Mentioned below is the C++ source code that I have used:

class CLS: public DFPatternRewrite {
  public:
   CLS() {
     x_ = IsWildcard();
     const1_ = IsWildcard();
     const2_ = IsWildcard();
     const3_ = IsWildcard();
     const4_ = IsWildcard();
  
     Map<String, ObjectRef> attrs;
     attrs.Set("groups", Integer(static_cast<int>(304)));
     // Also tried with groups = 1 but the pattern isn't getting recognised
     // attrs.Set("groups", Integer(static_cast<int>(1)));
  
     auto conv2d = IsOp("nn.conv2d").HasAttr(attrs);
     auto maybedepthwise = IsOp("nn.conv2d");
     auto biasadd = IsOp("nn.bias_add");
     auto relu = IsOp("nn.relu");
     pattern_ = relu({biasadd({conv2d({relu({biasadd({maybedepthwise({x_, const1_}), const2_})}), const3_}), const4_})});
   }

  Expr Callback(const Expr& pre, const Expr& post,
                const Map<DFPattern, Array<Expr>>& node_map) const override {
         std::cout << "depthwise conv2d detected " << std::endl;
         return post;
      }

The callback function isn’t being hit.

Sharing my test case:

import tvm
from tvm import relay
data = relay.var("data", relay.TensorType((1, 304, 128, 128), "float32"))
weight1 = relay.var("weight1", relay.TensorType((304, 1, 3, 3), "float32"))
bias = relay.var("bias", shape=(304,))
conv = relay.nn.conv2d(data=data, weight=weight1, kernel_size=(3, 3),
        channels=304, groups=304,  padding=(1, 1, 1, 1))
bias_add = relay.nn.bias_add(conv, bias)
relu = relay.nn.relu(bias_add)

weight2 = relay.var("weight2", relay.TensorType((256, 304, 1, 1), "float32"))
conv = relay.nn.conv2d(data=relu, weight=weight2, kernel_size=(1, 1),
        channels=256, padding=(0, 0, 0, 0))
bias2 = relay.var("bias", shape=(256,))
bias_add = relay.nn.bias_add(conv, bias2)
relu = relay.nn.relu(bias_add)
func = relay.Function([data, weight1, weight2, bias, bias2], relu)

mod = tvm.IRModule()
mod["main"] = func

mod = relay.transform.CLS()(mod)

Any help would be appreciated on this. Thank you!

Taking a look! Sorry for the delay.

Ah, that’s a bit subtle. You are wanting to match on the attributes of the call to the innermost conv2d, not on the attributes of the conv2d operator itself. The DFPatternMatcher machinery helpfully interprets HasAttr either way.

This works for me:

class TestRewriter : public DFPatternRewrite {
 public:
  TestRewriter() {
    x_ = IsWildcard();
    const1_ = IsWildcard();
    const2_ = IsWildcard();
    const3_ = IsWildcard();
    const4_ = IsWildcard();

    auto biasadd = IsOp("nn.bias_add");
    auto relu = IsOp("nn.relu");
    auto conv2d = IsOp("nn.conv2d");

    Map<String, ObjectRef> attrs;
    attrs.Set("groups", Integer(304));
    auto maybedepthwise = conv2d({x_, const1_}).HasAttr(attrs);

    pattern_ = relu({biasadd(
        {conv2d({relu({biasadd({maybedepthwise, const2_})}), const3_}), const4_})});
  }

  Expr Callback(const Expr& pre, const Expr& post,
                const Map<DFPattern, Array<Expr>>& node_map) const override {
    LOG(INFO) << "depthwise conv2d detected ";
    auto attrs = runtime::make_object<InitOpAttrs>();
    attrs->shape = Array<Integer>({Integer(1), Integer(256), Integer(128), Integer(128)});
    attrs->dtype = DataType::Float(32);
    return Call(Op::Get("zeros"), {}, Attrs(attrs));
  }

  DFPattern x_, const1_, const2_, const3_, const4_;
};

I’ve sent this in as a unit test in https://github.com/apache/tvm/pull/10533.

2 Likes

Thank you @mbs-octoml and @mbrookhart for your quick help on this!

I would like to ask one more question: In the present case, we’re detecting dw conv by checking groups attribute to be equal to 304. However, for different dwconv operators, group value can be different. Is there a generic way to detect dw conv using DFPatternMatcher ?

Asking this because if I try to change the groups value to something other than 304, the pattern isn’t getting detected.

    attrs.Set("groups", Integer(304));

No, depthwise by definition is groups==channels, to know channels, we need to specify the input shape. I would recommend leaving the groups unspecified in the pattern for the depthwise conv, and then checking that groups matches channels in the rewrite callback

1 Like

Thanks @mbrookhart!

While going through TVM code, I found this function to be useful IsDepthwiseConv2D.

It’s doing what you said(checking groups==channels) and also checking if batch==1

2 Likes