x.argmax(0)
gives the indexes along the 1st axis for the maximum values. Use np.indices
to generate the indices for the other axis.
x = np.array([[[1,2],[0,1]],[[3,4],[6,7]]])
x.argmax(0)
array([[1, 1],
[1, 1]])
a1, a2 = np.indices((2,2))
(x.argmax(0),a1,a2)
(array([[1, 1],
[1, 1]]),
array([[0, 0],
[1, 1]]),
array([[0, 1],
[0, 1]]))
x[x.argmax(0),a1,a2]
array([[3, 4],
[6, 7]])
x[a1,x.argmax(1),a2]
array([[1, 2],
[6, 7]])
x[a1,a2,x.argmax(2)]
array([[2, 1],
[4, 7]])
If x
has other dimensions, generate a1
, and a2
appropriately.
The official documentation does not say much about how to use argmax
, but earlier SO threads have discussed it. I got this general idea from Using numpy.argmax() on multidimensional arrays