JAXライクなfunctorchで機械学習を速くする – part 2

今回も引き続きfunctorchを使っていろいろ試してみます。前回のエントリーはfunctorchの基本機能を紹介しました。今回はfunctorchによる機械学習のユースケースについて考えてみたいと思います。gradvmapなどの基本機能の説明は前回のエントリーを参照してください。

いろいろ試してみて気付いたのですが、結論から言うと、functorchは適切に使えば性能を発揮できます。ただし、従来のオブジェクト指向や手続き型の処理を関数型っぽく書き換えるための準備と再設計が必要です。特に機械学習の文脈で内部状態を持たないようにするのは実際には難しいので、そこはfunctorchがある程度は上手くカバーしてくれますが、全てを置き換えるのではなく有効な箇所を少しずつ置き換えていく形で試してみるのが良いと思いました。

環境

Google Colab Pro
CUDA Version: 11.2
Tesla T4 / P100-PCIE 16GB VRAM
torch-1.11.0+cu102
functorch-0.1.1

前回はあえてCPUインスタンスを使いましたが今回は素直にGPUインスタンスを使います。現在はColab ProだとP100かT4が割り当てられるようですが、Pro+だとA100が当たることもあるので羨ましいです。

ニューラルネットワークの高速化

公式のチュートリアルを多少改変、補足や注意事項を加えて説明します。

ベースとなるモデルは以下のようなシンプルなCNNを定義しておきます。イメージとしてはAlexNetっぽい構造のモデルをCIFAR-10用に簡略化したものにしています。注意点として、Dropout層のようなランダム要素を含むレイヤーは意図的に一旦除いています。

学習データは適当に作ります。もちろんtorchvisionからCIFAR-10の実データを取ってきて使ってもOKです。バッチサイズは環境に合わせて適宜調整できますが、ベンチマーク目的なので小さくしすぎないように注意してください。またCIFAR-10用のCNNなのでデータの形状はモデルの入力に合うようにしておきます。

PyTorchで従来の学習のステップを書くときに、以下のようなコードがよく出てくるかと思います。

loss.backward()でミニバッチ毎の勾配の平均を求めていますが、functorchを効果的に適用するためにミニバッチ単位ではなくサンプル毎の勾配を求めるように修正します。まずは、functorchを使わない場合は以下のように書けます。

model(sample)の入力sampleは(N,C,H,W)形式なので、その前にunsqueeze(0)でバッチ分の次元を追加しています。

make_functional

次にfunctorchを使って書き直します。gradとvmapは前回紹介しましたが、ここでもう一つ新しい機能が登場します。

make_functional(model) はニューラルネットワークのモデルと状態(パラメータ)を分離する関数です。これを通すことで上記のfmodelは純粋関数となり、参照透過性を持ちます(同じ入力なら常に同じ出力)。functorchでは以下のようなFunctionalModuleクラスのオブジェクトとして扱われます。

注意点として、ここで分離されたパラメータオブジェクト(torch.nn.parameter.Parameter)には勾配情報を保存しておく必要はありませんので無効にしておきます。エラーが出るわけではありませんが、関数型のアプローチでは内部状態を更新するような処理は行いません。

ちなみにこのエントリー執筆時にはまだ使えませんでしたが、make_functionalにdisable_autograd_trackingパラメータというものが付くそうですので最新版ではこれを使いましょう。requires_gradをいちいち変更する手間が省けます。

また、make_functional_with_buffer という亜種もあります。モデル内部に任意領域としてのバッファを持っていることを知らない人もいるかもしれませんが、たまに便利なので実際にはこちらを使うことをおすすめします。

bufferはどういうときに使うのかは以下のフォーラムスレッドを参考に。

話を戻して、前述のcompute_grad関数をFunctionalModuleを使って書き換えます。モデルと分離されたパラメータを忘れずに持ってきましょう。

次に前述のcompute_sample_grads関数をfunctorchのgradとvmapを使って書き換えると以下のようにシンプルになり、vmapによりサンプル毎の計算は並列処理されます。

ここまでの処理においてfunctorchの利用有無で動作確認しておきます。サンプル毎の勾配の値が同じかどうか確認するテストコードになっています。

アサートエラーがでなければ問題ないです。

注意点

vmapを適用する関数について、特にランダム要素の有無に注意する必要があります。適用する関数は参照透過性が必要であるため、同じ入力でも実行の度に結果が変わるような関数を扱うことはできません(副作用については後述)。例えば今回定義したCNNでは意図的にDropout層を除いていますが、そのままDropout層を入れても以下のようなエラーが出ます。エラーメッセージがわかりやすいのですぐにミスに気付けますね。

しかし、機械学習において乱数的な振る舞いはいろんな所で必要になってくるのでfunctorchではそのためのオプションが用意されています。

randomness (str) – Specifies whether the randomness in this vmap should be the same or different across batches. If ‘different’, the randomness for each batch will be different. If ‘same’, the randomness will be the same across batches. If ‘error’, any calls to random functions will error. Default: ‘error’. WARNING: this flag only applies to random PyTorch operations and does not apply to Python’s random module or numpy randomness.

もしランダム要素を含む関数を扱う場合はvmapのrandomnessオプションを使えば上記エラーは回避できます。ただし、”same”を付けると全てのバッチで同じ値となるので留意しておきます。

ベンチマーク

ベンチマークは公式チュートリアルにあるコードをほぼそのまま使いました。PyTorchにベンチマーク用パッケージ(torch.utils.benchmark)があるのを初めて知りました。こちらも環境によって反復回数(n_timeit)を調整します。

Tesla T4/P100-PCIEそれぞれの環境でのベンチマーク結果は以下のようになりました。

Tesla T4は推論処理用アクセラレータと謳っているのでここで使うには適切ではないと思われますが、T4だと300%強、P100で600%ほど高速化されました。torch.utils.benchmarkパッケージで実行される処理については内部実装を深堀りしてはいませんが、公式モジュールですし結果は信用して良いと思います。ここではTimerしか使ってないですが他にも細かくいろいろ設定できるらしいのでまた別エントリーかgist等で紹介します。

アンサンブル学習の高速化

モデルのアンサンブル学習もfunctorchで効率化できます。

今回のエントリーで紹介したサンプル毎の勾配計算の並列化が理解できれば上記チュートリアルもすぐ理解できると思うのでここでは軽く紹介します。vmapによる並列化がモデル単位になっただけです。つまり、functorchの文脈ではモデルオブジェクトは純粋関数として扱えるのでそのままvmapに適用できます。

上記例は同じモデルを10つ並べているだけなので実用として使えるようなものではないですが、functorchによるアンサンブル処理のイメージは掴めるのではないでしょうか。

注意点としては、シングルノードで実行されるのでコンピューティングリソース(特にVRAM)が潤沢である必要があります。マルチノードでの分散学習は準備が面倒でコストがかかるので、もしスペックの高いサーバが一台あればfunctorchでアンサンブル学習の並列化を試してみるのも良いかもしれません。

補足事項など

今回試したベンチマークはあくまで学習処理の一部分(勾配計算)のみ切り出して計測したものです。実際の学習処理ではtorch.optim.Optimizerでのパラメータ更新部分も同様に対応する必要がありますが、実装が煩雑にならないようにするための議論がこちらで行われています。

また、PyTorchでは破壊的な(副作用のある)オペレーションがたくさんありますが、functorchではそれを上手く隠蔽・除去して変換してくれるようです。例えば、torch.nn.ReLU など一部のレイヤーはinplaceパラメータがありますが、これをTrueにしていても特に挙動がおかしくなったり、エラーが発生することはありません。make_functionalが内部で丁寧にパラメータを抽出・分離してくれるようです。

ただし、PyTorchが提供している破壊的な関数(xxx_())は以下のようにvmapを適用しても問題なく破壊的に動作します。意図的にこういう実装はしないとは思いますが、このような破壊的な操作を伴う関数を使う際は注意してください。

あとvmapでは今のところ制御文も使えません。

これは上記issuesで議論中のようです。JAXにはjax.lax.cond関数があるのでこちらを使うようですが、functorchでもたぶんすぐ対応されると思います。

終わりに

冒頭でも書いたように、functorchで効率的に機械学習を行うには関数型の考え方に切り替える必要があるようです。単に make_functional でモデルと状態を切り分けるだけでは、デバッグが少々楽になる気がするだけで効率化には繋がりません。バグをなるべく防ぎつつ高速化したい場合のアプローチとしては、なんとなく以下のように取り組むと良さそうです。

  1. モデルと状態(パラメータ)を分離させる -> make_functional / make_functional_with_buffers
  2. 微分対象の処理単位を小さくする(モジュラー性) -> grad
  3. 参照透過性を備えた2.の処理単位で並列化 -> vmap

こう書いてみると簡単ですが、これまでバッチ単位で処理することに慣れていた人達にとっては抵抗感があるのではないでしょうか。例えばBatch Normalizationの扱いはどうなるの、とか。それからあんまり調子に乗ってvmapしてると CUDA out of memory. が頻発するのでリソース調整をより慎重に行う必要があります。今回はモデルアンサンブルの例を一応挙げましたが、実際に行うのは現実的ではないかも知れません。いろんな種類のモデルをvmapに適用できる形に全て変換するのは大変だし、シングルノードで学習するのはGPUリソースがまず足りないでしょう。

また、モデル学習処理全体を効率化したい場合は、例えばDataLoaderの取り扱いやGPU周りの設定、TorchScriptの利用なども含めて満遍なく対応することになります。それらと併せてfunctorchも適切に使えば学習処理のコアな部分を集中的に高速化できるので、がんばってチャレンジする価値はあるかもしれません。個人的にはオプティマイザー周りの処理(パラメータの更新)がもっと綺麗に書けるようになったら是非採用したいと思っています。

Papers With CodeのTrending Researchとか見てると、明らかにPyTorchでのリファレンス実装が多いようです。高速化のためにJAX/Flax移植をがんばったり、マルチノード/マルチGPUの分散学習環境を準備するよりも、functorchの利用で十分高速化するのなら低コストで済むので助かりますね。

最後に注意点として、functorchは2022/06現在まだβ版となっています。APIも今後仕様が変わる可能性があるので利用の際は注意しましょう。

参考:

あわせて読む:

コメントを残す

メールアドレスが公開されることはありません。 が付いている欄は必須項目です