JavaScriptで機械学習の実装 5 Gradient Boosting

少し間が空いてしまいましたけど、今回はGradient Boosting (勾配ブースティング)と呼ばれる機械学習アルゴリズムを試してみます。機械学習関連のコンペでも大人気の手法ですね。かなり昔(2011年)に決定木をJavaScriptで実装したことはあるので勾配ブースティングも併せて学んでおこうと思います。

Gradient Boosting (勾配ブースティング)

ブースティングについてWikipediaを引用すると、

ブースティング(英: Boosting)とは、教師あり学習を実行するための機械学習メタアルゴリズムの一種。ブースティングは、Michael Kearns の提示した「一連の弱い学習機をまとめることで強い学習機を生成できるか?」という疑問に基づいている[1]。弱い学習機は、真の分類と若干の相関のある分類器と定義される。- ブースティング – Wikipedia

ブースティングは逐次的(弱学習器を1つずつ順番に構築)に弱学習器を構築していく手法で、新しい弱学習器を構築する際にはそれまでに構築された全ての弱学習器の予測結果を利用するという特徴があります。Gradient Boosting (勾配ブースティング)では正解値と弱学習器の予測値との差を負の勾配と見なして、それを最小化するように逐次的に弱学習器を学習させていきます。弱学習器として決定木が用いられるため Gradient Boosting Decision Tree (GBDT) と呼ばれます。ちなみにOpenCVに顔検出機能はAdaBoostと呼ばれるブースティングアルゴリズムを利用していますね。AdaBoostは学生時代に特にお世話になりました。

今回は回帰問題を扱うので弱学習器には回帰木を用います。JavaScript(ES2015)での学習処理部分の実装を抜粋しますが、ここを見るだけでも勾配ブースティングの特徴がよく表れているのではないでしょうか。

検証

データセットはBoston housingを用い、人口 1 人当たりの犯罪発生数やNOxの濃度などの情報からボストン市の住宅価格を予測します。

今回もJavaScript(ES2015)で実装したのですが、いつものようにscikit-learn風なインタフェースにしていて、パラメータのデフォルト値はscikit-learnのGBDT実装(sklearn.ensemble.GradientBoostingRegressor)を参考にしました。もちろん外からパラメータ変更もできます。また、ファイルからのデータ読み込み処理は省略してデータセット自体も1つのモジュールとして埋め込んでいます。ソースコードもいつものようにGitHubに置いてあるので興味があれば。

クライアント側では以下のように使います。

* 出力結果例

モデル評価用にRMSEやR^2スコア(決定係数)計算用のモジュールも併せて作りました。まずは残差プロットでざっくりとモデルの性能を眺めてみます。最近はGoogleスプレッドシートでグラフを描くことが増えてきました。さすがにExcelよりは機能不足ですけど簡単なグラフならすぐ作成できるので便利に使っています。

良い感じに残差が均一に分散しています。それなりに上手く学習できているようです。

定量的な確認も少しだけ。データをシャッフルして3回計測したRMSEとR^2スコア(決定係数)の平均値を載せます。GBDTの学習率は0.1で固定して、回帰木の本数(ブースティングの繰り返し回数)と深さを変更しました。

RMSE R^2 score (決定係数)
回帰木の本数: 100, 深さ: 3 3.503 0.817
回帰木の本数: 200, 深さ: 4 3.156 0.851

ある程度想定した通りの結果が出ているので問題なさそうです。パラメータ調整についてはまだ理論を理解しきれていない部分もあるので適当なのですが、いろいろパラメータを変えて試してみた所、回帰木は深くしすぎてもダメみたいです。3 ~ 5くらいがちょうど良いんでしょうか。

おわりに

勾配ブースティングは以前からずっと作ってみたいと思っていたので、今回実装する機会を得ることができて良かったです。ところでロシアのYandexが7月18日にCatBoostというすばらしい名前の勾配ブースティング実装を公開しました。今はこちらを使っていろいろ勉強中です。まだ日本語のドキュメント類は無いみたいなので情報がまとまったらCatBoostに関するエントリーを書くかもしれません。またよろしくおねがいします。

参考

あわせて読む:

コメントを残す

メールアドレスが公開されることはありません。 が付いている欄は必須項目です