FoldScaleAxis fails on quantized mobileBERT

FoldScaleAxis fails on quantized mobileBERT with Out of stack due to recursive graph traversal used. Would it be possible for the author(s) to refactor the pass to non-recursive implementation? There was a discussion regarding this:

Below is (rather artificial) demonstrator. If ails with FoldScaleAxis pass which could be added to tests/python/relay/ file. The number of cast operations in my case is about ~8000 to get OutOfStack exception.

def test_fold_bwd_stress(): “”“Simple stress testcase.”""

def before(x, conv_weight, out_bias, out_scale, in_channels, channels):
    args = [x, conv_weight, out_bias]
    out_bias = relay.expand_dims(out_bias, axis=1, num_newaxis=2)
    y = relay.nn.conv2d(
        kernel_size=(3, 3),
        padding=(1, 1),
    y = relay.add(y, out_bias)
    y = relay.nn.relu(y)
    y = relay.multiply(y, out_scale)

    for i in range(1, 9999):
        y = relay.cast(y, dtype="int32")

    return relay.Function(args, y)

def check(shape, in_channels, channels):
    x = relay.var("x", shape=shape)
    weight = relay.var("weight")
    out_bias = relay.var("out_bias", shape=(channels,))
    out_scale = relay.const(_get_positive_scale((channels, 1, 1)))

    y1 = before(x, weight, out_bias, out_scale, in_channels, channels)
    y1 = run_opt_pass(y1, transform.InferType())
    type_dict = {x.name_hint: x.checked_type for x in y1.params}
    weight = relay.var("weight", type_dict["weight"])
    y1_folded = run_opt_pass(y1, transform.BackwardFoldScaleAxis())

check((2, 4, 10, 10), 4, 8)