この記事では配列の中で、最大の値をもった要素のインデックスを返す関数、np.argmaxを紹介します。
np.ndarrayの最大の要素を返すnp.sum関数、最小の要素を返すnp.min関数に対応する形で、それらの要素のインデックスを返すnp.argmaxとnp.argminがNumPyには用意されています。
機械学習ではソフトマックス関数の出力値をargmaxしたりしますね。
意外に活躍の機会が多いargmax関数、是非この記事で覚えましょう。
※ この記事のコードはPython 3.7, Ubuntu 18.04で動作確認しました。
np.argmaxの使い方
※この記事のコードは、jupyter notebookやjuputer labを使って書かれています。
コードを試すときは是非これらを使ってみてください。
np.argmax(一次元配列)
np.argmaxは最大値のindexを返す関数です。
使い方は非常に簡単なので、実際にコードを見てみましょう。
# コード In [1]: import numpy as np
# コード In [2]: a = np.arange(10) np.random.shuffle(a) a
# 出力結果 Out [2]: array([3, 6, 0, 1, 2, 5, 9, 4, 8, 7])
ここで、0~9までの要素を持った配列aを作り、これをシャッフルしました。
(※シャッフルした結果はここに書かれているものとは違うものになると思います。)
では、np.argmaxでindexを確認してみましょう。
# コード In [3]: np.argmax(a)
# 出力結果 Out [3]: 6
この配列では9が最大値ですね。
9は0スタートのindexで6番目にあるので、np.argmaxで正しい値が取れていることがわかります。
np.argmax(多次元配列)
多次元配列に対しても同様に最大値のindexを取ることができます。
また、他のNumPy関数と同様に、axisパラメータを指定する事もできます。
# コード In [4]: b = np.reshape(a, (2,5)) b
# 出力結果 Out [4]: array([[3, 6, 0, 1, 2], [5, 9, 4, 8, 7]])
# コード In [5]: res1 = np.argmax(b) res2 = np.argmax(b, axis=0) res3 = np.argmax(b, axis=1) print(res1,"\n", res2,"\n", res3)
# 出力結果 [5]: 6 [1 1 1 1 1] [1 1]
axisを指定しないときは、多次元配列をflattenした(一次元配列に直した)状態での最大値indexを返します。
axisを指定すると、列ごとや行ごとに最大値のindexを返します。
この辺のaxisに関する使い方は、np.maxのような関数と同様の使い方です。
N番目に大きい要素のindexを見つける方法
np.argmaxでは、N番目に大きい要素のindexを見つけることはできません。
ではどうするのかというと、np.sortとnp.whereをあわせて使うと同様の結果になります。
np.sortは昇順に要素を並べる関数で、np.whereは、引数として指定した条件に合った要素のindexを返す関数です。
# コード In [6]: N = 2 np.where(a==np.sort(a)[-N])
# 出力結果 Out [6]: (array([8]),)
np.argminの使い方
np.argminの使い方はnp.argmaxと同じです。
違いは最小値のindexを取ってくるというところだけです。
まずは一次元配列のargminです。
# コード In [7]: np.argmin(a)
# 出力結果 Out [7]: 2
次に多次元配列のargminです。
# コード In [8]: res1 = np.argmin(b) res2 = np.argmin(b, axis=0) res3 = np.argmin(b, axis=1) print(res1,"\n", res2,"\n", res3)
# 出力結果 [8]: 2 [0 0 0 0 0] [2 2]
まとめ
この記事では、配列の最大要素のインデックスを返すnp.argmaxと、最小要素のインデックスを返すnp.argminについて紹介しました。
機械学習の実装で非常に役に立つ関数です。
是非とも覚えて使いこなしてください!