JavaScriptで機械学習の実装 3 AROW

今回はオンライン機械学習アルゴリズムとして知られている AROW (Adaptive Regularization of Weight Vectors) を試してみました。内容的には以下のエントリーの続きになりますが、今回からタイトル文言を少し変えようと思います。TypeScript入門という段階はそろそろ脱したかなと思うのと、TypeScriptよりも直接JavaScriptで書く量の方が増えてきたためです。

前回のエントリー内で、次はブースティング系アルゴリズムを実装してみたいと書いたのですが、オンライン機械学習 (機械学習プロフェッショナルシリーズ)を読んでいたらこの分野への興味が強くなってしまったので、ちょっと寄り道。

AROW (Adaptive Regularization of Weight Vectors)

AROWのアルゴリズムは以下のようになっています(元論文から引用)。

arow_algorithm
AROWは Confidence Weighted Learning (CW) というアルゴリズムを改善したものです。CWでは”現在の訓練データを常に正しく分類する”という条件で最適化するので、訓練データにノイズの入っていると上手く学習できないという欠点がありますが、AROWではこれまでの分布(パラメータ)に近い分布を探しつつ、各特徴の確信度(Confidence)を更新毎に上げるという条件(正則化項として加える)も併せて考慮して最適化するため、ノイズが多いデータでもCWと比較して分類精度が高くなるという特徴があります。パラメータも一つだけというのも嬉しいです。実装上は計算量の削減の為に共分散行列の非対角項を0にして対角項だけ計算することが多いそうですが、これでも精度はほとんど落ちないようです。

Node.js Stream APIで学習データの読み込み

今回からはWebブラウザではなくNode.js上で実行します。理由はデータセットの効率的な読み込み処理の為にStream APIを使いたかった為です。LIBSVMフォーマットのデータファイルをStream APIで読み込むには例えば以下のように書くことができます。

これで一つのサンプルが取り出される度にコールバック関数が呼ばれます。これを応用して入力をファイルではなくHTTPのストリームを指定して、例えばJSON形式で一つのサンプルを渡してもらうインタフェースにすると、WebAPIで学習データをストリームで受け取り続けることができるのでオンライン学習と相性が良さそうです。ただ、上記のように単一のファイルを読み込む場合はオーバーヘッドが大きいので、通常の非同期読み込みの手順を使う方が良いと思います。

検証

AROW本体やデータ読み込み用モジュール、動作確認用のテストコード一式はこれまで通りGitHubに置いておきます。今回からはTypeScriptとJavaScriptの両方のコードを上げておきます。

今回はLIBSVM Data: Classificationの news20.binary データセットで分類精度や収束速度などを確認してみます。

事前にデータをシャッフルして15000例の訓練データと4996例のテストデータに分けています。今回作ったLIBSVMデータファイル読み込みモジュールは、非同期によるストリーム読み込みだけでなく同期読み込みもできるようにしてあります。動作確認には後者を使った方が楽です。

パラメータは適当に決めたのですが、分類精度(正答率)は約97.2%となりました。注目すべきは収束速度で、データを一巡しただけでほぼ収束しているように見えます。速いですね。

線形SVMとも比較してみます。まだJavaScriptでSVMの実装はしていないので、ここではscikit-learnの実装(sklearn.svm.LinearSVC)を使いました。

* 実行結果

パラメータはグリッドサーチで決定、分類精度は約96.5%で、AROWと同等あるいは少し悪いくらい。妥当な結果と言えるでしょうか。

おわりに

今回はオンライン機械学習アルゴリズムのひとつであるAROWを試してみました。最近はディープラーニング系のニューラルネットワークのプログラムを趣味でも書くことが増えてきたのですけど、それらに比べてオンライン機械学習のアルゴリズムは実装が楽で収束も速いので嬉しいです。それに加えてデータを蓄積しておかなくて良いので、近年のWebサービスで扱われるような断続的なストリームデータとも相性が良いと思います。この後はAROWより後発のアルゴリズムであるSCWなども試しつつ、Node.js上でのより効率的な処理の書き方なども併せて調べていくつもりです。

参考

オンライン機械学習 (機械学習プロフェッショナルシリーズ)
AROW (Adaptive Regularization of Weight Vectors) (論文PDF)
SCW (Soft Confidence Weighted Learning) (論文PDF)
CW, AROW, and SCW
[機械学習] AROWのコードを書いてみた
Node.js v6.3.1 Documentation
Node.js の Stream API で「データの流れ」を扱う方法

あわせて読む:

コメントを残す

メールアドレスが公開されることはありません。 が付いている欄は必須項目です