I tried using take but it doesn’t work.
I think what I need here is take_along_axis, right?
https://numpy.org/doc/stable/reference/generated/numpy.take_along_axis.html
I tried using take but it doesn’t work.
I think what I need here is take_along_axis, right?
https://numpy.org/doc/stable/reference/generated/numpy.take_along_axis.html