I have question that is it possible for relay op to return a scalar? I have tried some methods and both of them failed… My pytorch code is as follows:
import torch
@torch.jit.script
def max_item(x, y):
if x.item() > y.item():
r = x
else:
r = y
return r
def bar(x, y, z):
return max_item(x, y) + z
a = torch.randn([1])
b = torch.randn([1])
c = torch.randn([1])
print(“a.item():”, a.item())
print(“b.item():”, b.item())
print(“c.item():”, c.item())
traced_graph = torch.jit.trace(bar, (a, b, c))
print(traced_graph.graph_for(a, b, c))
JIT Graph:
graph(%x : Tensor,
%y.1 : Tensor,
%z : Tensor):
%3 : int = prim::Constantvalue=1 # torch_script.py:13:0
%4 : Scalar = aten::item(%x) # torch_script.py:5:7
%5 : Scalar = aten::item(%y.1) # torch_script.py:5:18
%6 : bool = aten::gt(%4, %5) # torch_script.py:5:7
%y : Tensor = prim::If(%6) # torch_script.py:5:4
block0():
→ (%x)
block1():
→ (%y.1)
%12 : Tensor = aten::add(%y, %z, %3) # torch_script.py:13:0
return (%12)