Considering the following simple torch script:
@torch.jit.script
def fn(x):
a = torch.zeros([1,2])
a[0,0] = torch.matmul(x,x)
return a
The output relay graph will be:
graph(%x.1 : Tensor):
%1 : bool = prim::Constant[value=0]()
%2 : NoneType = prim::Constant()
%3 : int = prim::Constant[value=1]() # /tmp/ipykernel_2640894/1952481818.py:3:19
%4 : int = prim::Constant[value=2]() # /tmp/ipykernel_2640894/1952481818.py:3:21
%5 : int = prim::Constant[value=0]() # /tmp/ipykernel_2640894/1952481818.py:4:4
%6 : int[] = prim::ListConstruct(%3, %4)
%a.1 : Tensor = aten::zeros(%6, %2, %2, %2, %2) # /tmp/ipykernel_2640894/1952481818.py:3:6
%8 : Tensor = aten::matmul(%x.1, %x.1) # /tmp/ipykernel_2640894/1952481818.py:4:11
%9 : Tensor = aten::select(%a.1, %5, %5) # /tmp/ipykernel_2640894/1952481818.py:4:2
%10 : Tensor = aten::select(%9, %5, %5) # /tmp/ipykernel_2640894/1952481818.py:4:2
%11 : Tensor = aten::copy_(%10, %8, %1) # /tmp/ipykernel_2640894/1952481818.py:4:2
return (%a.1)
As you can see we get aten::copy_() which is not supported by TVM. How should the code be altered to avoid the aten::copy?