[RELAY PASS][CONSTANT FOLDING] Question on conv2d + bias(constant) + mul(constant) folding

Hi, I have met one scenario as bellow:

%60 = fn (%p08: Tensor[(1, 1280, 720, 20), float32], %p17: Tensor[(3, 3, 20, 20), float32], %p27: Tensor[(1, 1, 1, 20), float32], %p34: Tensor[(1, 1), float32], %p42: Tensor[(1, 1280, 720, 20), float32], hash=“ef71cf679b24179e”, data_layout=“NHWC”, kernel_layout=“HWIO”, Primitive=1, out_layout="") → Tensor[(1, 1280, 720, 20), float32] {

%11 = nn.conv2d(%p08, %p17, padding=[1, 1, 1, 1], kernel_size=[3, 3], data_layout=“NHWC”, kernel_layout=“HWIO”) /* ty=Tensor[(1, 1280, 720, 20), float32] */;

%12 = add(%11, %p27) /* ty=Tensor[(1, 1280, 720, 20), float32] */;
%13 = multiply(%12, %p34) /* ty=Tensor[(1, 1280, 720, 20), float32] */;
add(%13, %p42) /* ty=Tensor[(1, 1280, 720, 20), float32] */

};

%61 = %60(%59, meta[relay.Constant][31] /* ty=Tensor[(3, 3, 20, 20), float32] /, meta[relay.Constant][32] / ty=Tensor[(1, 1, 1, 20), float32] /, meta[relay.Constant][33] / ty=Tensor[(1, 1), float32] */, %57);

after playing around the passes, I found that multiply can not be folded in to conv and add. Is there any possibility that this mul can be folded?

thank you

There is a pass FoldScaleAxis does this optimization, does it work in this case?

sounds like not working. I assumed it should work somehow like conv2d + batch_norm. So I make a small example just conv2D with element wise mul. But it didn’t work as conv2d + batch_norm. From the IR log, it sounds like LowTE is not triggered somehow ?

/src/ir/transform.cc:412: Executing module pass : InferType with opt level: 0 /src/relay/transforms/type_infer.cc:823: tvm::relay::transform::InferType /src/ir/transform.cc:412: Executing module pass : InferType with opt level: 0 /src/relay/transforms/type_infer.cc:823: tvm::relay::transform::InferType /src/ir/transform.cc:412: Executing module pass : InferType with opt level: 0 /src/relay/transforms/type_infer.cc:823: tvm::relay::transform::InferType /src/relay/ir/transform.cc:133: Executing function pass : SimplifyInference with opt level: 0 /src/ir/transform.cc:412: Executing module pass : InferType with opt level: 0 /src/relay/transforms/type_infer.cc:823: tvm::relay::transform::InferType /src/relay/ir/transform.cc:133: Executing function pass : FoldConstant with opt level: 2 /src/ir/transform.cc:412: Executing module pass : InferType with opt level: 0 /src/relay/transforms/type_infer.cc:823: tvm::relay::transform::InferType /src/ir/transform.cc:412: Executing module pass : InferType with opt level: 0 /src/relay/transforms/type_infer.cc:823: tvm::relay::transform::InferType /src/relay/ir/transform.cc:133: Executing function pass : BackwardFoldScaleAxis with opt level: 3 /src/ir/transform.cc:412: Executing module pass : InferType with opt level: 0 /src/relay/transforms/type_infer.cc:823: tvm::relay::transform::InferType /src/ir/transform.cc:412: Executing module pass : InferType with opt level: 0 /src/relay/transforms/type_infer.cc:823: tvm::relay::transform::InferType /src/relay/ir/transform.cc:133: Executing function pass : ForwardFoldScaleAxis with opt level: 3 /src/ir/transform.cc:412: Executing module pass : InferType with opt level: 0 /src/relay/transforms/type_infer.cc:823: tvm::relay::transform::InferType /src/relay/ir/transform.cc:133: Executing function pass : FoldConstant with opt level: 2 /src/ir/transform.cc:412: Executing module pass : InferType with opt level: 0 /src/relay/transforms/type_infer.cc:823: tvm::relay::transform::InferType


fn (%data: Tensor[(1, 3, 224, 224), float32], %weight, %gamma: Tensor[(1), float32]) { %0 = nn.conv2d(%data, %weight, padding=[1, 1, 1, 1], channels=16, kernel_size=[3, 3]); multiply(%0, %gamma) }

def @main(%data: Tensor[(1, 3, 224, 224), float32]) → Tensor[(1, 16, 224, 224), float32] { %0 = nn.conv2d(%data, meta[relay.Constant][0], padding=[1, 1, 1, 1], channels=16, kernel_size=[3, 3]); multiply(%0, meta[relay.Constant][1]) } output:

def @main(%data: Tensor[(1, 3, 224, 224), float32]) → Tensor[(1, 16, 224, 224), float32] { %0 = nn.conv2d(%data, meta[relay.Constant][0] /* ty=Tensor[(16, 3, 3, 3), float32] /, padding=[1, 1, 1, 1], channels=16, kernel_size=[3, 3]) / ty=Tensor[(1, 16, 224, 224), float32] /; multiply(%0, meta[relay.Constant][1] / ty=Tensor[(1), float32] /) / ty=Tensor[(1, 16, 224, 224), float32] */ }

1 Like

Same issue here, any progress on this problem?

This issue should be solved by [Relay][transform][SimplifyExpr] simplify adjacent muls and adds with constants by yangulei · Pull Request #13213 · apache/tvm (github.com)