The issue of data consistency about "whileloop"

I tried to write a control flow program with relay, but I found a data consistency problem in my program. I don’t know what caused it. Here is the code using relay and pytorch.









from tvm import relay

from tvm.relay.frontend.common import infer_shape,infer_type

import numpy as np

from tvm.relay.scope_builder import ScopeBuilder

import tvm

from tvm.relay.loops import while_loop

def pytorch_test(num_iterations=2):

     import torch

     scores = np.load(“scores.npy”)

     log_nu = np.load(“log_nu.npy”)

     log_mu = np.load(“log_mu.npy”)

     log_mu = torch.from_numpy(log_mu)

     log_nu = torch.from_numpy(log_nu)

      scores = torch.from_numpy(scores)

     u, v = torch.zeros_like(log_mu), torch.zeros_like(log_nu)

     for _ in range(num_iterations):

          u = log_mu - torch.logsumexp(scores + u.unsqueeze(1), dim=2)

           v = log_nu - torch.logsumexp(scores + v.unsqueeze(2), dim=1)

      print(“torch:”,scores + u.unsqueeze(2) + v.unsqueeze(1))





def log_sinkhorn_normalization_while_loop(scores, log_mu, log_nu,num_iterations):

      u_var = relay.var(“u”)

     v_var = relay.var(“v”)

      i = relay.var(“i”)

     def cond(i,u_var,v_var):

            return i<num_iterations

     def body(i,u_var,v_var):

            new_num_iterations_var = i + relay.const(1,“int32”)

            new_u_var = log_mu - relay.logsumexp(scores+relay.expand_dims(v_var, axis=1),axis=2)

           new_v_var = log_nu - relay.logsumexp(scores+relay.expand_dims(u_var, axis=2),axis=1)

           return new_num_iterations_var,new_u_var,new_v_var

     loop = while_loop(cond, [i,u_var,v_var], body)

     u, v = relay.zeros_like(log_mu), relay.zeros_like(log_nu)

     uv = loop(relay.const(0,“int32”),u,v)

     u = relay.expand_dims(relay.TupleGetItem(uv,1), 2)

     v = relay.expand_dims(relay.TupleGetItem(uv,2), 1)

      return scores + v + u

def relay_test(num_iterations=2):

     scores = relay.var(“scores”,shape=[2000,65,65],dtype=“float32”)

     log_mu = relay.var(“log_mu”,shape=[2000,65],dtype=“float32”)

     log_nu = relay.var(“log_nu”,shape=[2000,65],dtype=“float32”)

     num_iterations=relay.const(num_iterations,“int32”)

     func = log_sinkhorn_normalization_while_loop(scores, log_mu, log_nu,num_iterations)

     mod = tvm.IRModule({})

     mod = mod.from_expr(func)

     scores = np.load(“scores.npy”)

     log_nu = np.load(“log_nu.npy”)

     log_mu = np.load(“log_mu.npy”)

     out = relay.create_executor(“vm”, device=tvm.cpu(0), target=“llvm”, mod=mod).evaluate()(*[scores,log_mu,log_nu])

     print(“relay:”,out)







num_iterations=10

scores = np.random.rand(2000,65,65).astype(“float32”)

np.save(“scores.npy”,scores)

log_nu = np.random.rand(2000,65).astype(“float32”)

np.save(“log_nu.npy”,log_nu)

log_mu = np.random.rand(2000,65).astype(“float32”)

np.save(“log_mu.npy”,log_mu)

relay_test(num_iterations=num_iterations)

pytorch_test(num_iterations=num_iterations)’