今回はオンライン機械学習アルゴリズムとして知られている AROW (Adaptive Regularization of Weight Vectors) を試してみました。内容的には以下のエントリーの続きになりますが、今回からタイトル文言を少し変えようと思います。TypeScript入門という段階はそろそろ脱したかなと思うのと、TypeScriptよりも直接JavaScriptで書く量の方が増えてきたためです。
前回のエントリー内で、次はブースティング系アルゴリズムを実装してみたいと書いたのですが、オンライン機械学習 (機械学習プロフェッショナルシリーズ)を読んでいたらこの分野への興味が強くなってしまったので、ちょっと寄り道。
AROW (Adaptive Regularization of Weight Vectors)
AROWのアルゴリズムは以下のようになっています(元論文から引用)。
AROWは Confidence Weighted Learning (CW) というアルゴリズムを改善したものです。CWでは”現在の訓練データを常に正しく分類する”という条件で最適化するので、訓練データにノイズの入っていると上手く学習できないという欠点がありますが、AROWではこれまでの分布(パラメータ)に近い分布を探しつつ、各特徴の確信度(Confidence)を更新毎に上げるという条件(正則化項として加える)も併せて考慮して最適化するため、ノイズが多いデータでもCWと比較して分類精度が高くなるという特徴があります。パラメータも一つだけというのも嬉しいです。実装上は計算量の削減の為に共分散行列の非対角項を0にして対角項だけ計算することが多いそうですが、これでも精度はほとんど落ちないようです。
Node.js Stream APIで学習データの読み込み
今回からはWebブラウザではなくNode.js上で実行します。理由はデータセットの効率的な読み込み処理の為にStream APIを使いたかった為です。LIBSVMフォーマットのデータファイルをStream APIで読み込むには例えば以下のように書くことができます。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 |
var fs = require('fs'); var readline = require('readline'); var loadLIBSVM = function (filePath, callback) { var options = { encoding: 'utf8', highWaterMark: 256 }; var stream = fs.createReadStream(filePath, options); var rl = readline.createInterface(stream, null); rl.on('line', function (line) { var fields = line.split(/[\s:]/); var label = fields[0].charAt(0) == '+' ? +1 : -1; var x = []; var len = fields.length; for (var i = 1; i < len; i += 2) { var index = parseInt(fields[i]) - 1; var value = parseFloat(fields[i + 1]); var element = { index: index, value: value }; x.push(element); } callback(x, label); }); rl.on('close', function () { console.log('load complete'); }); }; |
これで一つのサンプルが取り出される度にコールバック関数が呼ばれます。これを応用して入力をファイルではなくHTTPのストリームを指定して、例えばJSON形式で一つのサンプルを渡してもらうインタフェースにすると、WebAPIで学習データをストリームで受け取り続けることができるのでオンライン学習と相性が良さそうです。ただ、上記のように単一のファイルを読み込む場合はオーバーヘッドが大きいので、通常の非同期読み込みの手順を使う方が良いと思います。
検証
AROW本体やデータ読み込み用モジュール、動作確認用のテストコード一式はこれまで通りGitHubに置いておきます。今回からはTypeScriptとJavaScriptの両方のコードを上げておきます。
今回はLIBSVM Data: Classificationの news20.binary データセットで分類精度や収束速度などを確認してみます。
1 2 3 4 5 |
news20.binary Preprocessing: Each instance has unit length. # of classes: 2 # of data: 19,996 # of features: 1,355,191 |
事前にデータをシャッフルして15000例の訓練データと4996例のテストデータに分けています。今回作ったLIBSVMデータファイル読み込みモジュールは、非同期によるストリーム読み込みだけでなく同期読み込みもできるようにしてあります。動作確認には後者を使った方が楽です。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 |
var featureSize = 1355191; var r = 0.1; // ハイパーパラメータ var clf = new AROW(featureSize, r); // 非同期ストリーム読み込みによる学習は以下のように書く DataLoader.read('news20_train', function(x, label) { clf.update(x, label); }); // 同期読み込みによる学習とテスト var trainData = new DataLoader('news20_train'); var testData = new DataLoader('news20_test'); console.log('load complete.'); console.log(trainData.size, testData.size); var maxIter = 5; for (var i = 0; i < maxIter; i++) { trainData.data.forEach(function (e) { clf.update(e.x, e.label); }); var error = 0; testData.data.forEach(function (e) { var predLabel = clf.predict(e.x); if (predLabel != e.label) { error++; } }); console.log('iteration: ' + (i + 1) + ', error rate = ' + (error / testData.size)); } |
1 2 3 4 5 6 7 |
load complete. 15000 4996 iteration: 1, error rate = 0.02822257806244996 iteration: 2, error rate = 0.027622097678142513 iteration: 3, error rate = 0.027822257806244997 iteration: 4, error rate = 0.027822257806244997 iteration: 5, error rate = 0.027822257806244997 |
パラメータは適当に決めたのですが、分類精度(正答率)は約97.2%となりました。注目すべきは収束速度で、データを一巡しただけでほぼ収束しているように見えます。速いですね。
線形SVMとも比較してみます。まだJavaScriptでSVMの実装はしていないので、ここではscikit-learnの実装(sklearn.svm.LinearSVC)を使いました。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 |
import numpy as np from sklearn.datasets import load_svmlight_file from sklearn.svm import LinearSVC from sklearn.grid_search import GridSearchCV from sklearn.metrics import accuracy_score, classification_report X_train, y_train = load_svmlight_file('news20_train') ## 学習データ X_test, y_test = load_svmlight_file('news20_test') ## テストデータ estimator = LinearSVC() param_grid = [{'C':np.logspace(-3, 1, 20),},] clf = GridSearchCV(estimator, param_grid, cv=5, scoring='accuracy', n_jobs=3) clf.fit(X_train, y_train) print(clf.best_estimator_) pred = clf.predict(X_test) print('## accuracy: %s' % (accuracy_score(y_test, pred),)) cr = classification_report(y_test, pred) print(cr) |
* 実行結果
1 2 3 4 5 6 7 8 9 10 11 |
LinearSVC(C=3.7926901907322459, class_weight=None, dual=True, fit_intercept=True, intercept_scaling=1, loss='squared_hinge', max_iter=1000, multi_class='ovr', penalty='l2', random_state=None, tol=0.0001, verbose=0) ## accuracy: 0.964792958592 precision recall f1-score support -1.0 0.97 0.96 0.96 2474 1.0 0.96 0.97 0.97 2522 avg / total 0.96 0.96 0.96 4996 |
パラメータはグリッドサーチで決定、分類精度は約96.5%で、AROWと同等あるいは少し悪いくらい。妥当な結果と言えるでしょうか。
おわりに
今回はオンライン機械学習アルゴリズムのひとつであるAROWを試してみました。最近はディープラーニング系のニューラルネットワークのプログラムを趣味でも書くことが増えてきたのですけど、それらに比べてオンライン機械学習のアルゴリズムは実装が楽で収束も速いので嬉しいです。それに加えてデータを蓄積しておかなくて良いので、近年のWebサービスで扱われるような断続的なストリームデータとも相性が良いと思います。この後はAROWより後発のアルゴリズムであるSCWなども試しつつ、Node.js上でのより効率的な処理の書き方なども併せて調べていくつもりです。
参考
オンライン機械学習 (機械学習プロフェッショナルシリーズ)
AROW (Adaptive Regularization of Weight Vectors) (論文PDF)
SCW (Soft Confidence Weighted Learning) (論文PDF)
CW, AROW, and SCW
[機械学習] AROWのコードを書いてみた
Node.js v6.3.1 Documentation
Node.js の Stream API で「データの流れ」を扱う方法