Wrong Impl of relay.nn.ncross_entropy_with_logits

CrossEntropy should be equal to NLLLoss(LogSoftmax(pred), label), but in current relay’s implentation, it only contains the nllloss part but without softmax.

# prepare the data
tx = torch.randn(1, 10)
tty = torch.zeros(1, 10)
tty[:, 1] = 1
ty = tty.argmax(-1).long()

# PyTorch's results
F.cross_entropy(tx, ty).item()
# 3.778721332550049

F.nll_loss(F.log_softmax(tx, dim=-1), ty).item()
# 3.778721332550049

TVM’s results

x = relay.var("x", shape=[1, 10], dtype="float32")
y = relay.var("y", shape=[1, 10], dtype="float32")
z = relay.nn.cross_entropy_with_logits(x, y)
fn = relay.Function([x, y], z)
mod = tvm.IRModule.from_expr(fn)
mod = relay.transform.InferType()(mod)
lib = relay.build(mod, target="llvm")
g = graph_executor.GraphModule(lib["default"](tvm.cpu(0)))


ttx = tx

dx = ttx.numpy()
dy = tty.numpy()

g.set_input("x", dx)
g.set_input("y", dy)

g.run()
g.get_output(0)
# <tvm.nd.NDArray shape=(), cpu(0)>
#  array(0.45164683, dtype=float32)

After manually adding the log_softmax, the result matches:

ttx = F.log_softmax(tx)
# ttx = tx

dx = ttx.numpy()
dy = tty.numpy()

g.set_input("x", dx)
g.set_input("y", dy)

g.run()
g.get_output(0)
# <tvm.nd.NDArray shape=(), cpu(0)>
# array(3.7787213, dtype=float32)