relay seems to return an incorrect gradient if we roundtrip from ANormal to GraphNormal to ANormal forms. The error disappears if I change the order of the good and bad transforms, because the original model seems to be modified inplace. Maybe this here is related: https://github.com/apache/incubator-tvm/issues/6624
from tvm import relay
import tvm
N = 50
x = relay.var('x', shape=(N,))
z = relay.sum(x)
func = relay.Function([x], z)
mod = tvm.IRModule()
mod['func'] = func
grad = relay.transform.gradient(mod['func'], mod)
mod['main'] = grad
print("### Bad:")
mod_bad = tvm.transform.Sequential([
relay.transform.ToANormalForm(),
relay.transform.ToGraphNormalForm(),
relay.transform.ToANormalForm(),
relay.transform.PartialEvaluate(),
relay.transform.DeadCodeElimination(),
relay.transform.ToGraphNormalForm(),
])(mod)
with tvm.transform.PassContext(opt_level=3):
mod_bad, params = relay.optimize(mod_bad, 'llvm')
print(mod_bad.astext())
print("### Good:")
mod_good = tvm.transform.Sequential([
relay.transform.PartialEvaluate(),
relay.transform.DeadCodeElimination(),
relay.transform.ToGraphNormalForm(),
])(mod)
with tvm.transform.PassContext(opt_level=3):
mod_good, params = relay.optimize(mod_good, 'llvm')
print(mod_good.astext())
The output is
### Bad:
#[version = "0.0.5"]
def @main(%x: Tensor[(50), float32]) -> (float32, (Tensor[(50), float32],)) {
%0 = fn (%p0: Tensor[(50), float32], Primitive=1) -> float32 {
sum(%p0) /* ty=float32 */
};
%1 = %0(%x) /* ty=float32 */;
%2 = fn (%p01: Tensor[(50), float32], Primitive=1) -> Tensor[(50), float32] {
zeros_like(%p01) /* ty=Tensor[(50), float32] */
};
%3 = %2(%x) /* ty=Tensor[(50), float32] */;
%4 = (%3,);
(%1, %4)
}
### Good:
#[version = "0.0.5"]
def @main(%x: Tensor[(50), float32]) -> (float32, (Tensor[(50), float32],)) {
%0 = fn (%p0: Tensor[(50), float32], Primitive=1) -> float32 {
sum(%p0) /* ty=float32 */
};
%1 = %0(%x) /* ty=float32 */;
%6 = fn (%p01: Tensor[(50), float32], %p1: float32, Primitive=1) -> Tensor[(50), float32] {
%2 = zeros_like(%p01) /* ty=Tensor[(50), float32] */;
%3 = ones_like(%p1) /* ty=float32 */;
%4 = expand_dims(%3, axis=0) /* ty=Tensor[(1), float32] */;
%5 = broadcast_to_like(%4, %p01) /* ty=Tensor[(50), float32] */;
add(%2, %5) /* ty=Tensor[(50), float32] */
};
%7 = %6(%x, %1) /* ty=Tensor[(50), float32] */;
%8 = (%7,);
(%1, %8)
}
I am using a very recent commit from master (21002cd094c34716b1e07a63ed76f53dadd60e23), on arch linux.