機械学習手法「SVM(サポートベクトルマシン)」でクラス分類に挑戦

Deep Learningが流行る前に大流行していた機械学習手法のSVM(サポートベクトルマシン)をご存知ですか?

高速で、少ないデータでも良い性能が期待でき、データ解析の実務でも使える分類アルゴリズムだと言えます。ここではSVMについて、あまり詳しく深入りせずに、初心者がscikit-learnを使って試せるようになるまでをサポートします!

この記事では

  • SVMとは?
  • SVMの使い方は?
  • SVMのチューニング方法

について解説しますので、「SVMって何?使ってみたい!」と思った方は是非読み進めてください!

この記事のコードは、Python 3.7、Scikit-learn 0.19で動作確認しました。

目次

SVMとは

SVM(サポートベクトルマシン/サポートベクターマシン)は機械学習モデルの一種で、非常に強力なアルゴリズムです。教師あり学習で、分類や回帰に使われます。

Deep Learningが流行る前の世代ではその非常に高い汎化性能と使いやすさから、本当に広い分野で使われていました。

マージン最大化という考え方、カーネル法という非線形への拡張などが肝ですが、scikit-learnを使うことで簡単に使うことができるようになりました。詳しく解説を知りたい方は、以下のPDFが非常に丁寧でおすすめです(要:大学数学)

マージン最大化について概要を知りたい方は以下のページがわかりやすいのでおすすめです。

また、侍エンジニアのマンツーマンレッスンでも、SVMについてわかりやすい解説ができます!気になる方は上のリンクをチェック!

SVMの使い方

目的

SVMによるクラス分類について勉強していきましょう。そのために、ここではdigitsという手書き数字認識データセットを使い、クラス分類タスクを行います。

0~9までの数字を手書きした8×8ピクセルのデータが収められています。ここではSVMを使ってこの画像データからラベル(0~9の数字)を予測します。

準備

まずはライブラリをimportしましょう。

from sklearn import datasets
import numpy as np
import matplotlib.pyplot as plt
from sklearn.svm import SVC
from sklearn.svm import SVR
from sklearn.model_selection import train_test_split,GridSearchCV

SVCがここでの主役です。SVMでクラス分類を行うためのクラスですね。もしもSVMで回帰を行いたい場合はその下のSVRをimportしてください。

次はデータの読み込みです。

mnist = datasets.load_digits()
train_images, test_images, train_labels, test_labels 
    = train_test_split(mnist.data, mnist.target, test_size=0.2)

データの分割(教師データやテストデータを作る)には、普通np.arrayのスライスを使います。シャッフルしてからスライスで40%や20%のデータをテストデータとするのですが、skleranのtrain_test-split関数を使うことで楽ができます。

今読み込んだデータを見てみましょう。

plt.figure(figsize=(15,15))
for i in range(25):
    plt.subplot(5,5,i+1)
    plt.xticks([])
    plt.yticks([])
    plt.imshow(train_images[i].reshape((8,8)), cmap=plt.cm.binary)
    plt.xlabel(train_labels[i])

モデルの学習

clf = SVC(verbose=True, random_state=2525)
clf.fit(train_images, train_labels)

[出力結果]

[LibSVM]
SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0,
  decision_function_shape='ovr', degree=3, gamma='auto', kernel='rbf',
  max_iter=-1, probability=False, random_state=2525, shrinking=True,
  tol=0.001, verbose=True)

モデルの作成から学習までの流れもskleranなら2stepで簡単です。実験に再現性をもたせたい場合は、random_stateに適当な値を渡してください。このrandom_stateを指定しないと、適当な乱数シードが使われるので再現が難しくなります。

また、sklearnのSVMでは、有名なSVMライブラリであるLibSVMが使われます。これはCで書かれているので、非常に速いのが特徴です。

モデルの評価

predicted_labels = clf.predict(test_images)
predicted_labels

[出力結果]

array([3, 0, 3, 3, 3, 3, 3, 3, 1, 3, 0, 3, 1, 4, 4, 0, 3, 3, 3, 3, 4, 6,
       3, 7, 9, 8, 6, 4, 3, 3, 3, 3, 3, 2, 3, 2, 0, 3, 0, 1, 1, 3, 0, 3,
       2, 6, 3, 3, 3, 0, 3, 3, 8, 3, 3, 3, 3, 3, 4, 3, 3, 3, 3, 3, 3, 4,
       6, 3, 6, 3, 3, 3, 1, 7, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 0, 3,
       6, 3, 9, 3, 0, 3, 3, 1, 1, 3, 0, 3, 3, 3, 3, 3, 3, 3, 8, 7, 7, 3,
       3, 3, 3, 3, 3, 1, 8, 3, 3, 5, 3, 6, 3, 9, 3, 6, 3, 3, 3, 3, 3, 8,
       3, 3, 6, 3, 3, 3, 1, 3, 1, 9, 3, 3, 8, 7, 3, 2, 1, 3, 4, 3, 3, 3,
       3, 3, 3, 3, 3, 3, 3, 3, 4, 3, 3, 3, 7, 9, 3, 3, 3, 3, 3, 3, 0, 3,
       3, 1, 3, 3, 3, 1, 3, 5, 6, 3, 3, 3, 1, 3, 3, 2, 3, 3, 3, 3, 3, 4,
       3, 1, 7, 3, 3, 3, 3, 3, 0, 1, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 1, 3,
       0, 3, 3, 3, 3, 9, 4, 8, 6, 3, 0, 3, 6, 6, 7, 3, 3, 7, 3, 6, 3, 3,
       3, 3, 3, 3, 3, 3, 7, 3, 1, 3, 9, 3, 3, 3, 3, 6, 3, 3, 3, 0, 6, 3,
       2, 4, 3, 5, 3, 3, 3, 3, 3, 4, 3, 6, 6, 3, 3, 3, 2, 3, 1, 7, 3, 3,
       0, 2, 5, 4, 3, 3, 3, 0, 3, 9, 3, 3, 9, 2, 1, 0, 3, 3, 3, 3, 3, 3,
       3, 3, 3, 3, 3, 3, 6, 4, 3, 9, 2, 4, 3, 6, 3, 3, 3, 4, 3, 3, 3, 6,
       4, 1, 3, 3, 3, 3, 3, 3, 9, 3, 3, 3, 8, 3, 4, 6, 3, 3, 3, 3, 3, 6,
       4, 3, 7, 3, 3, 3, 4, 3])

機械学習モデルのインスタンス.predictを行うことで、テストデータに対して予測ラベルを計算できます。この予測ラベルがどのくらいあっているのかも確かめてみましょう。

print(f"acc: {clf.score(test_images, test_labels)}")
# acc: 0.4388888888888889

正答率43%、これはかなり酷い数字ですね。

予測結果を可視化してみます。

plt.figure(figsize=(15,15))

# 先頭から25枚テストデータを可視化
for i in range(25):
    
    # 画像を作成
    plt.subplot(5,5,i+1)
    plt.xticks([])
    plt.yticks([])
    plt.imshow(test_images[i].reshape((8,8)), cmap=plt.cm.binary)
    
    # 今プロットを作っている画像データの予測ラベルと正解ラベルをセット
    predicted_label = predicted_labels[i]
    true_label      = test_labels[i]
    
    # 予測ラベルが正解なら緑、不正解なら赤色を使う
    if predicted_label == true_label:
        color = 'green' # True label color
    else:
        color = 'red'   # False label color
    plt.xlabel("{} True({})".format(predicted_label, 
                                  true_label),
                                  color=color)

 

緑がテストラベルと予測ラベルが同じだったもの、赤が間違っていたものです。

この酷い結果を次のハイパーパラメータ探索でどこまで向上させられるか、確認してみましょう。

パラメータチューニング

ここでは簡単に、グリッドサーチを使って「いいハイパーパラメータを見つける」探索を行います。search_paramsが探索空間です。

ここではkernelとCについて探索します。

search_params = [
    {
        "kernel"          : ["rbf","linear","sigmoid"],
        "C"               : [10**i for i in range(-10,10)],
        "random_state"    : [2525],
    }
]
gs = GridSearchCV(SVC(), 
                  search_params, 
                  cv = 3,
                  verbose=True, 
                  n_jobs=-1)
gs.fit(train_images, train_labels)

print(gs.best_estimator_)

[出力結果]

Fitting 3 folds for each of 60 candidates, totalling 180 fits
[Parallel(n_jobs=-1)]: Done  18 tasks      | elapsed:    0.3s
SVC(C=0.01, cache_size=200, class_weight=None, coef0=0.0,
  decision_function_shape='ovr', degree=3, gamma='auto', kernel='linear',
  max_iter=-1, probability=False, random_state=2525, shrinking=True,
  tol=0.001, verbose=False)
[Parallel(n_jobs=-1)]: Done 180 out of 180 | elapsed:    1.8s finished

GridSearchCVのcv引数は、クロスバリデーションの際にデータを何個に分割するかを決めます。また、verbose=Trueとすることでログを表示してくれるので、時間がかかりそうな探索を行う場合はTrueにすることを忘れないように!

最後に、n_jobsで並列に動作させます。例えば、8コア16スレッドのCPUを使っているなら、これらのスレッドすべてを使ってGridSearchを行うようにする設定が、n_jobs=-1になります。

print(gs.best_estimator_.score(test_images,test_labels))
predicted_label = gs.best_estimator_.predict(test_images)

# 0.9833333333333333

正答率が98.3%まで大きく向上しました。グリッドサーチでいいハイパーパラメータを見つけた成果が出ています。

ちなみに先程と同じコードで可視化したものが以下になります。

見えている範囲の画像はすべて正しく予測できていますね。Deep Learning・ニューラルネットワークを使わないでも、ここまで高精度にクラス分類ができることがわかりました。

時間的成約や求められている性能を考えて、SVMやDeep Learningを使いこなせるとプロっぽいです。

まとめ

この記事ではSVM(サポートベクトルマシン)について紹介し、これを使ったクラス分類を行いました。sklearnを使っているので簡単に使えましたが、実際に自分で実装すると非常に大変なアルゴリズムです。

また、SVMにはパラメータが膨大にあり、例えば「どのカーネル」を選べばいいのか、などは以外にデータの特性でアタリを付けることができます。

もっと詳しくSVMについて勉強したいときは、侍エンジニアのマンツーマンレッスンを試してみてください!

この記事を書いた人

【プロフィール】
DX認定取得事業者に選定されている株式会社SAMURAIのマーケティング・コミュニケーション部が運営。「質の高いIT教育を、すべての人に」をミッションに、IT・プログラミングを学び始めた初学者の方に向け記事を執筆。
累計指導者数4万5,000名以上のプログラミングスクール「侍エンジニア」、累計登録者数1万8,000人以上のオンライン学習サービス「侍テラコヤ」で扱う教材開発のノウハウ、2013年の創業から運営で得た知見に基づき、記事の執筆だけでなく編集・監修も担当しています。
【専門分野】
IT/Web開発/AI・ロボット開発/インフラ開発/ゲーム開発/AI/Webデザイン

目次