Using Numpy’s argmax()
A simple overview of using an often-misunderstood yet useful function in Python: Numpy's argmax(). Read the what, the how, and the why of argmax() here.
And argmax() returns... 9?!? Yep, that's correct!
Explaining argmax()
Recall the 5 Ws: who, what, when, where, and why.
When approaching a question, framing it with the proper 'W' can mean the difference between getting the answer you are looking for and absolute confusion.
Consider the following:
- When is your name?
- Where do you do for a living?
- Why street do you live on?
These questions make very little sense as they are posed, and so the answer one is looking for might not come as quickly as hoped.
Similar confusion can arise when one is writing Python code and uses Numpy's argmax()
function. argmax
is useful when working with matrices, or multidimensional arrays, of any number of dimensions, and searching for the maximum value. However, and rather importantly, argmax
returns the indices of the maximum values along an axis, as opposed to the maximum values themselves.
argmax
will tell you the where, not the what. This is a critical point that is often misunderstood by those using the Numpy library, and one which can lead to frustration.
Understanding argmax()
Let's have a look at how Numpy's argmax
is designed to work.
According to the Numpy documentation, argmax
"[r]eturns the indices of the maximum values along an axis."
This means that actual maximum values are not being returned, just the positions of those maximum values.
Its important parameters include the input array from which to locate the maximum values of a particular axis and return its position, along with a particular axis (which is optional). One can additionally optionally pass an output array, and a boolean value to retain any reduced axes in the results. Regarding the optional axis argument, if none is specified, the returned index is into the flattened input array.
argmax
returns an array of indices into the original array, the dimensions of which depend on the function's input.
Why argmax()
?
According to Jason Brownlee in an article on Machine Learning Mastery:
Argmax is most commonly used in machine learning for finding the class with the largest predicted probability.
[...]
The most common situation for using argmax that you will encounter in applied machine learning is in finding the index of an array that results in the largest value.
If we consider the probabilities of predictions of class membership, the argmax
function can determine the index position into an array which contains the max value and, hence, the highest probability prediction. This is clearly useful for machine learning.
Let's have a look at how argmax
works in practice.
Using argmax()
Let's have a look at argmax
in action.
Single Dimension
After importing Numpy in the usual manner, we create a single dimension array and pass it to argmax
. Look at the array and code to figure out what you think the output should be.
import numpy as np a = [[ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 ]] max_idx = np.argmax(a) print(max_idx)
9
Does this output match what you expected?
If you recall that argmax
is returning an index as oppposed to a value, and that indexing begins at 0, we see that we have been returned index 9 — which is actually the tenth position — which holds the value '10', the max value in the array.
Make sense?
Multiple Dimensions
Let's have a look at how argmax
works with an array of multiple dimensions.
import numpy as np a = [[ 1, 2, 3, 4, 5 ], [ 6, 7, 8, 9, 10 ]] max_idx = np.argmax(a) print(max_idx)
9
Now we have a 2-dimensional array. We did not pass an axis parameter, and so the default behavior of argmax
is to flatten the multidimensional array into a single dimension, and return the index of the maximum value into this flattened array.
As such, the result is is the same as the single dimension array in the previous example, and it should now be apparent as to why.
Specifying an Axis
But what if we specify an axis?
First, recall that axis 0 refers to rows while axis 1 refers to columns. Let's see what happens when we pass axis=0
to argmax
.
import numpy as np a = [[ 1, 2, 3, 4, 5 ], [ 6, 7, 8, 9, 10 ]] max_idx = np.argmax(a, axis=0) print(max_idx)
array([1, 1, 1, 1, 1)
What's this all about?
We specified axis=0
, and so argmax
will be returning the maximum value along the rows of the multidimensional array.
What is the maximum value? As established previously, for this array it is '10'.
Where is the maximum value? It is in the last column of the second row.
Since we have already specified that we are interested in knowing what row it is in (axis=0
), for each column in the array Numpy is reporting the row in which this maximum value appears. Recalling that indexing begins with 0, we can see that, for each of the columns in this array, the maximum value occurs in the row with the index of 1, or the second row. Looking to our code to see where the value '10' is located, it appears as though argmax
is correct.
What would you expect the output to be if we instead specified axis=1
? I will eave this as an exercise for the reader.
Summary
In this post we learned what argmax
is, why we would use it, and have covered several examples using Numpy's argmax
function.
You should now know how argmax
works.
Matthew Mayo (@mattmayo13) is a Data Scientist and the Editor-in-Chief of KDnuggets, the seminal online Data Science and Machine Learning resource. His interests lie in natural language processing, algorithm design and optimization, unsupervised learning, neural networks, and automated approaches to machine learning. Matthew holds a Master's degree in computer science and a graduate diploma in data mining. He can be reached at editor1 at kdnuggets[dot]com.