Removing IR node from Relay function

Hi,

I have a relay function that looks like this:

*Function fn (%input: Tensor[(1, 2, 3, 4), float32]) -> Tensor[(1, 2, 3, 4), float32] {*
  •  %0 = qnn.quantize(%input, 0.0374455f /* ty=float32 */, -71 /* ty=int32 */, out_dtype="uint8") /* ty=Tensor[(1, 2, 3, 4), uint8] */;*
    
  •  %1 = qnn.add(%0, meta[relay.Constant][0] /* ty=Tensor[(1, 2, 3, 4), uint8] */, 1f /* ty=float32 */, 0 /* ty=int32 */, 1f /* ty=float32 */, 0 /* ty=int32 */,  1f /* ty=float32 */, 0 /* ty=int32 */) /* ty=Tensor[(1, 2, 3, 4), uint8] */;*
    
  •  qnn.dequantize(%1, 0.284927f /* ty=float32 */, -93 /* ty=int32 */) /* ty=Tensor[(1, 2, 3, 4), float32] */*
    
  • }*

However, I do some preprocessing I don’t need quantize and dequant node. To do this, I walk this graph calling Rewrite_ from MixedModeMutator.

  1. I can remove dequantize node without any issue as I return the callNode from dequant, which in this case would be add.
  2. I also use the quantization details from quant and propagate to add, and change type according to my need.

However, I am having an issue to remove quantize node.

When I walk the graph, first node I encounter is quantize. At this time, I just setup a state var. When Rewrite_ gets to add, I replace add and return a new modified expression (of add) that looks like this:

*expression free_var %input: Tensor[(1, 2, 3, 4), int8];*
  • qnn.add(%input, meta[relay.Constant][0] /* ty=Tensor[(1, 2, 3, 4), int8] /, 1f / ty=float32 /, 0 / ty=int32 /, 1f / ty=float32 /, 0 / ty=int32 /, 1f / ty=float32 /, 0 / ty=int32 /) / ty=Tensor[(1, 2, 3, 4), int8] /

That is what I want. At this time, I am just changing th input to add, and modifying required types. I am relaying on deletion of quantize node (supposedly no longer use used add) to be deleted as an used node. However, it does not get deleted. However, after I am done with VisitExpr, I see the function body as following:

*free_var %input: Tensor[(1, 2, 3, 4), int8];*
  • %0 = qnn.quantize(%input, 0.0374455f /* ty=float32 /, -71 / ty=int32 /, out_dtype=“uint8”) / ty=Tensor[(1, 2, 3, 4), uint8] /;
  • qnn.add(%0, meta[relay.Constant][0] /* ty=Tensor[(1, 2, 3, 4), int8] /, 1f / ty=float32 /, 0 / ty=int32 /, 1f / ty=float32 /, 0 / ty=int32 /, 1f / ty=float32 /, 0 / ty=int32 /) / ty=Tensor[(1, 2, 3, 4), int8] /

How do I remove

  • %0 = qnn.quantize(%input, 0.0374455f /* ty=float32 /, -71 / ty=int32 /, out_dtype=“uint8”) / ty=Tensor[(1, 2, 3, 4), uint8] /; from the graph?