Question

I have a numpy array and I'd like to take the argmax over all axes except the first. I have (I think) a solution, but I wonder whether there's a better way to do it.

import numpy as np

def argmax(array):
    ## Argmax along all axes except the first (ie axis 0)
    last_axis = len(array.shape) - 1
    if last_axis == 0:
        return tuple(range(array.size))
    if last_axis == 1:
        return (range(array.shape[0]), list(np.argmax(array, axis=1)))
    index_array = np.argmax(array, axis=last_axis)
    smaller_array = np.amax(array, axis=last_axis)
    assert index_array.shape == smaller_array.shape
    argmax_smaller_array = argmax(smaller_array)
    return argmax_smaller_array + (list(index_array[argmax_smaller_array]), )

Some examples:

a = np.arange(12).reshape((6, 2))
a[5, 0] = 22
argmax(a)
a[argmax(a)]

b = np.arange(18).reshape((3, 3, 2))
b[0, 0, 0] = 55
b[argmax(b)]
np.all(b[argmax(b)] == np.array([np.max(b[0]), np.max(b[1]), np.max(b[2])]))  # True

I'm just starting out with numpy and I wonder whether there's an easier way to do this. Am I rewriting something that already exists?

Was it helpful?

Solution

Your method seems ok, but you're computing a lot of intermediate results you don't need. You can do something like:

import numpy as np

def argmax(array):
    shape = array.shape
    array = array.reshape((shape[0], -1))
    ravelmax = np.argmax(array, axis=1)
    return (np.arange(shape[0]),) + np.unravel_index(ravelmax, shape[1:])
Licensed under: CC-BY-SA with attribution
Not affiliated with StackOverflow
scroll top