JavaScriptで機械学習の実装 4 SCW

前回 JavaScriptで機械学習の実装 3 AROW に引き続きオンライン学習アルゴリズムを試しています。Node.js勉強中なのでJavaScriptを使ってこつこつ学んでいきましょう。

今回は SCW (Soft Confidence Weighted Learning) と呼ばれるアルゴリズムを扱います。

SCW (Soft Confidence Weighted Learning)

2012年に提案された、CW (Confidence Weighted Learning)と前回紹介したAROW (Adaptive Regularization of Weight Vectors)の特徴を備え持つ手法です。SCWはCWと同様に重みベクトルの各重みが正規分布に従って生成されていると考えます。

SCWのアルゴリズムは以下のようになっています(元論文から引用)。
scw_algorithm

μが信頼している重み、Σが各重みの信頼度・自信(confidence)を表しており、ハイパーパラメータCでどの程度の誤りを許容するか(損失関数を常に0にするという制約を緩める)を調整しています。
scw_update_kld

損失関数は以下のように定義されており、信頼度が低いパラメータは重視されないようになっています。
scw_loss_function

また、αとβはどちらもCWと同様に閉じた形で更新式が表されます。αの計算式は二通りあり、それぞれSCW-IとSCW-IIと呼ばれています。具体的な式については論文などを参照。

SCWの特性として論文の最後で以下の4つが挙げられています。

  1. large margin training
  2. confidence weighting
  3. adaptive margin
  4. capability of handling non-separable data

性能面ではCWやAROWと比べて多くのケースにおいて高精度で効率が良い(収束が速い)とされているようです。

検証

今回はLIBSVM Data: Classificationの a9a データセットで分類精度や収束速度などを確認してみます。a9a は世帯収入が一定以上かどうかの二値、線形分離不可能なデータになっています。

a9a
Preprocessing: The same as a1a. [JP98a]
# of classes: 2
# of data: 32,561 / 16,281 (testing)
# of features: 123 / 123 (testing) 

SCW実装の本体やデータ読み込み用モジュール、動作確認用のテストコード一式はこれまで通りGitHubに置いておきます。今回もAROWと同様にTypeScriptとJavaScriptの両方のコードを上げてあります。

疎ベクトル用に最適化してないナイーブな実装なので、1,355,191次元もある news20.binary データセットを用いたAROWとの比較は難しいです。ということで a9a を使って動作確認をします。

データファイルの読み込み部分は前回作ったLIBSVMフォーマットファイルをStream APIで読み込むモジュールを使い回しました。Node.js環境において以下のような感じで学習とテストを実行できます。
[javascript]
// 特徴量の次元数
var featureSize = 123;
// ハイパーパラメータ
var eta = 0.9;
var C= 1.0;
// SCW-Iで学習する場合は SCW.SCW_I 、SCW-IIで学習する場合は SCW.SCW_II を第4引数で指定
var clf = new SCW(featureSize, eta, C, SCW.SCW_II);

// 非同期ストリーム読み込みによる学習は以下のように書く
DataLoader.read(‘news20_train’, featureSize, function(x, label) {
clf.update(x, label);
});

// 同期読み込みによる学習とテスト
var trainData = new DataLoader(‘news20_train’, featureSize);
var testData = new DataLoader(‘news20_test’, featureSize);
console.log(‘load data complete.’);
console.log(trainData.size, testData.size);

// 学習
trainData.data.forEach(function (e) {
clf.update(e.x, e.label);
});

// テスト
var error = 0;
testData.data.forEach(function (e) {
var predLabel = clf.predict(e.x);
if (predLabel != e.label) {
error++;
}
});
console.log(‘error rate = ‘ + (error / testData.size));
[/javascript]

load data complete.
32561 16281
error rate = 0.15404459185553712

SCWの2つのハイパーパラメータは適当に設定し、データは1巡のみで84.6%の分類精度となりました。AROWの時もそうでしたが、やはりオンライン学習は収束が速くて助かります。

オンライン学習の性質を確認するため、学習データを1件ずつ与えて学習する毎に精度確認を行いました。データを与える順番を変えるためにデータをシャッフルして合計3回試行しています。
result_a9a_scw
a9aなら5,000サンプルくらい与えてやれば十分なようです。

線形SVMのscikit-learn実装(sklearn.svm.LinearSVC)でも試してみました。グリッドサーチして学習した結果、こちらも84%の分類精度となりました。a9a データセットにおいては線形SVMと精度面での違いはそこまでないようです。

LinearSVC(C=0.12742749857031335, class_weight=None, dual=True,
     fit_intercept=True, intercept_scaling=1, loss='squared_hinge',
     max_iter=1000, multi_class='ovr', penalty='l2', random_state=None,
     tol=0.0001, verbose=0)

             precision    recall  f1-score   support

       -1.0       0.88      0.93      0.90     12435
        1.0       0.72      0.59      0.65      3846

avg / total       0.84      0.85      0.84     16281

SCWはAROWよりもパラメータが一つ多いので調整がめんどくさそうに思えますが、適当に設定しても高い精度が出るのでそんなに悩まなくても大丈夫です。とりあえず動作確認はできたと思うので今回はこの辺で。次はそろそろ分散学習アルゴリズム周りを学んでいきたいです。

参考

あわせて読む:

コメントを残す

メールアドレスが公開されることはありません。 が付いている欄は必須項目です