Relay Transfrom Pass - Replace Conv2D with Bitserial Conv2D Operators

I’m trying to implement a simple transform pass in Relay that replaces all occurrences of the Conv2D operator (“nn.conv2d”) with the binary Conv2D operator (“nn.bitserial_conv2d”). I’m defining my AST traverser as a MixedModeMutator with essentially the following Rewrite_ method:

image

I confirmed that the traverser does visit all Conv2D operators and calls the MakeBinaryConv2D function every time as required. However, in the output module after the transformation pass, I still see almost all occurrences (154 in my test model) of the Conv2D operator with 4 additions of the binary Conv2D operator.

Is there something I’m missing here? My goal is to basically replace an existing CallNode with a newly created CallNode. Isn’t the Rewrite_ function for the CallNode supposed to do that?

1 Like