FoldScaleAxis fails on quantized mobileBERT with Out of stack
due to recursive graph traversal used. Would it be possible for the author(s) to refactor the pass to non-recursive implementation?
There was a discussion regarding this:
Below is (rather artificial) demonstrator. If ails with FoldScaleAxis pass which could be added to tests/python/relay/test_pass_fold_scale_axis.py file. The number of cast operations in my case is about ~8000 to get OutOfStack exception.
def test_fold_bwd_stress(): “”“Simple stress testcase.”""
def before(x, conv_weight, out_bias, out_scale, in_channels, channels):
args = [x, conv_weight, out_bias]
out_bias = relay.expand_dims(out_bias, axis=1, num_newaxis=2)
y = relay.nn.conv2d(
x,
conv_weight,
channels=channels,
kernel_size=(3, 3),
padding=(1, 1),
data_layout="NCHW",
kernel_layout="OIHW",
)
y = relay.add(y, out_bias)
y = relay.nn.relu(y)
y = relay.multiply(y, out_scale)
for i in range(1, 9999):
y = relay.cast(y, dtype="int32")
return relay.Function(args, y)
def check(shape, in_channels, channels):
x = relay.var("x", shape=shape)
weight = relay.var("weight")
out_bias = relay.var("out_bias", shape=(channels,))
out_scale = relay.const(_get_positive_scale((channels, 1, 1)))
y1 = before(x, weight, out_bias, out_scale, in_channels, channels)
y1 = run_opt_pass(y1, transform.InferType())
type_dict = {x.name_hint: x.checked_type for x in y1.params}
weight = relay.var("weight", type_dict["weight"])
y1_folded = run_opt_pass(y1, transform.BackwardFoldScaleAxis())
check((2, 4, 10, 10), 4, 8)