Incorrect gradient with `relay.transform.gradient`

relay seems to return an incorrect gradient if we roundtrip from ANormal to GraphNormal to ANormal forms. The error disappears if I change the order of the good and bad transforms, because the original model seems to be modified inplace. Maybe this here is related: https://github.com/apache/incubator-tvm/issues/6624

from tvm import relay
import tvm

N = 50
x = relay.var('x', shape=(N,))
z = relay.sum(x)
func = relay.Function([x], z)
mod = tvm.IRModule()
mod['func'] = func
grad = relay.transform.gradient(mod['func'], mod)
mod['main'] = grad

print("### Bad:")
mod_bad = tvm.transform.Sequential([
    relay.transform.ToANormalForm(),
    relay.transform.ToGraphNormalForm(),
    relay.transform.ToANormalForm(),
    relay.transform.PartialEvaluate(),
    relay.transform.DeadCodeElimination(),
    relay.transform.ToGraphNormalForm(),
])(mod)

with tvm.transform.PassContext(opt_level=3):
    mod_bad, params = relay.optimize(mod_bad, 'llvm')

print(mod_bad.astext())

print("### Good:")
mod_good = tvm.transform.Sequential([
    relay.transform.PartialEvaluate(),
    relay.transform.DeadCodeElimination(),
    relay.transform.ToGraphNormalForm(),
])(mod)

with tvm.transform.PassContext(opt_level=3):
    mod_good, params = relay.optimize(mod_good, 'llvm')

print(mod_good.astext())

The output is

### Bad:
#[version = "0.0.5"]
def @main(%x: Tensor[(50), float32]) -> (float32, (Tensor[(50), float32],)) {
  %0 = fn (%p0: Tensor[(50), float32], Primitive=1) -> float32 {
    sum(%p0) /* ty=float32 */
  };
  %1 = %0(%x) /* ty=float32 */;
  %2 = fn (%p01: Tensor[(50), float32], Primitive=1) -> Tensor[(50), float32] {
    zeros_like(%p01) /* ty=Tensor[(50), float32] */
  };
  %3 = %2(%x) /* ty=Tensor[(50), float32] */;
  %4 = (%3,);
  (%1, %4)
}

### Good:
#[version = "0.0.5"]
def @main(%x: Tensor[(50), float32]) -> (float32, (Tensor[(50), float32],)) {
  %0 = fn (%p0: Tensor[(50), float32], Primitive=1) -> float32 {
    sum(%p0) /* ty=float32 */
  };
  %1 = %0(%x) /* ty=float32 */;
  %6 = fn (%p01: Tensor[(50), float32], %p1: float32, Primitive=1) -> Tensor[(50), float32] {
    %2 = zeros_like(%p01) /* ty=Tensor[(50), float32] */;
    %3 = ones_like(%p1) /* ty=float32 */;
    %4 = expand_dims(%3, axis=0) /* ty=Tensor[(1), float32] */;
    %5 = broadcast_to_like(%4, %p01) /* ty=Tensor[(50), float32] */;
    add(%2, %5) /* ty=Tensor[(50), float32] */
  };
  %7 = %6(%x, %1) /* ty=Tensor[(50), float32] */;
  %8 = (%7,);
  (%1, %8)
}

I am using a very recent commit from master (21002cd094c34716b1e07a63ed76f53dadd60e23), on arch linux.

Hi, can you try disabling the DeadCodeElimination pass in both cases? This may be an issue with references being removed, which is a current bug.

See https://github.com/apache/incubator-tvm/issues/6803 for info.

@altanh That’s not it. The output doesn’t change if I remove the dead code elimination. And I just rebuilt tvm, the issue is still there on the current master. (Had to add a mod = relay.transform.InferType()(mod) before the gradient though)

1 Like

I looked into it some further and it seems that ToGraphNormalForm might also be unsound with references… Can you try disabling that also? (Although I’m guessing you want to keep that pass.) You can try using gradient with mode="first_order" if you want to keep the ToGNF pass, since first order will not create references. Hope this helps!