Dealing with aten::copy_

Considering the following simple torch script:

@torch.jit.script
def fn(x):
  a = torch.zeros([1,2])
  a[0,0] = torch.matmul(x,x)

  return a

The output relay graph will be:

graph(%x.1 : Tensor):
  %1 : bool = prim::Constant[value=0]()
  %2 : NoneType = prim::Constant()
  %3 : int = prim::Constant[value=1]() # /tmp/ipykernel_2640894/1952481818.py:3:19
  %4 : int = prim::Constant[value=2]() # /tmp/ipykernel_2640894/1952481818.py:3:21
  %5 : int = prim::Constant[value=0]() # /tmp/ipykernel_2640894/1952481818.py:4:4
  %6 : int[] = prim::ListConstruct(%3, %4)
  %a.1 : Tensor = aten::zeros(%6, %2, %2, %2, %2) # /tmp/ipykernel_2640894/1952481818.py:3:6
  %8 : Tensor = aten::matmul(%x.1, %x.1) # /tmp/ipykernel_2640894/1952481818.py:4:11
  %9 : Tensor = aten::select(%a.1, %5, %5) # /tmp/ipykernel_2640894/1952481818.py:4:2
  %10 : Tensor = aten::select(%9, %5, %5) # /tmp/ipykernel_2640894/1952481818.py:4:2
  %11 : Tensor = aten::copy_(%10, %8, %1) # /tmp/ipykernel_2640894/1952481818.py:4:2
  return (%a.1)

As you can see we get aten::copy_() which is not supported by TVM. How should the code be altered to avoid the aten::copy?

I managed to fix the problem by using torch.stack() instead.