JAXライクなfunctorchで機械学習を速くする – part 1

PyTorch 1.11からβ版として追加された functorch と呼ばれる機能を試してみました。PyTorch 1.9くらいのときから試験版として本体に組み込まれて提供されていましたが、どうやらfunctorchという別モジュールに切り出して提供されるようになったようです。

functorchとは

PyTorch公式サイトには以下のように説明されています。

functorch is a library that adds composable function transforms to PyTorch. It aims to provide composable vmap (vectorization) and autodiff transforms that work with PyTorch modules and PyTorch autograd with good eager-mode performance.

  • computing per-sample-gradients (or other per-sample quantities)
  • running ensembles of models on a single machine
  • efficiently batching together tasks in the inner-loop of MAML
  • efficiently computing Jacobians and Hessians
  • efficiently computing batched Jacobians and Hessians

PyTorch単体では記述が面倒だったサンプル毎の勾配計算やモデルアンサンブルの計算等を効率良く実行することができます。どうやら自動微分とベクトル化が目玉機能のようです。

functorchはGoogle JAXをインスパイアしておりAPIの見た目もだいたいJAXと同じです。もちろんデータの取り回しはJAX(jax.numpy)のDeviceArrayではなくtorch.TensorとなるのでPyTorch用に最適化されています。

最初は公式サイトのサンプルを動かしつつ応用的なコードも書いていきますが、このエントリーではfunctorchの全ての機能は説明できないので、少しずつ分かる範囲だけ試していこうと思います。今回は基本機能を知る準備編(part 1)ですね。処理速度などの非機能面については次回以降のエントリーで紹介します。

導入

Python開発環境がある程度整っていればインストールは簡単です。Google Colab無料版でも動くので環境が用意出来ない人はColabを使うといいでしょう。

functorchは独立したモジュールなのでpipで別途インストールします。

インストールできているかどうか動作確認します。

↑のサンプルコードの意味は以降で説明するのでとりあえずvmapgradが読み込みできていればOKです。

使い方

functorchの説明には”composable function transforms”という表現が頻出しますが、コード上では高階関数を作って運用します。ここからは関数型に頭を切り替えて使っていきましょう。

grad (gradient computation)

functorch has auto-differentiation transforms (grad(f) returns a function that computes the gradient of f)

functorch.gradはわかりやすいのでまずこちらから試します。gradは関数の自動微分をやってくれるAPIで、冒頭の公式説明にもあったようにサンプル毎の勾配計算にも利用できますし、後述のfunctorch.vmapと組み合わせて使うことで機械学習処理のパフォーマンスを向上させることができます。

まずおさらいとして、PyTorch単体で自動微分する場合、backward()後にtorch.Tensorのgrad属性を参照することで微分係数を得ることができます。

上記コードのdx_func関数をgradを使って書くと以下のように書けます。gradは新しい関数を作って返すので、それを即時に適用することができます。

↑の例では適当に遊んでみただけですが、こういうIIFE(Immediately Invoked Function Expression: 即時実行関数式)な書き方は昔のJavaScriptでよく使われてましたね。Pythonだとあまり見かけないかもしれませんが、関数型言語っぽい書き味に慣れていきましょう。

さて、導入時の動作確認用サンプルコードを改めて確認してみると、今なら簡単に理解できるかと思います。

ちなみにtorch.allclose関数は第一引数と第二引数の値がほぼ等しいかどうかをチェックする関数で、NumPyにも同名の関数があります。
* torch.allclose — PyTorch 1.11.0 documentation

もちろん多変数関数の自動微分もできます。gradのargnumsパラメータで微分対象の変数を指定します。

vmap (auto-vectorization)

a vectorization/batching transform (vmap(f) returns a function that computes f over batches of inputs), and others

functorch.vmapはコードを大きく変更することなく関数を自動ベクトル化(auto-vectorization)して並列処理することでパフォーマンスを向上させます。機械学習だと特にバッチ学習/推論の際に有用です。内部実装的にはコア部分がC++で書かれており、過去のエントリーでも何度か紹介しているATen(テンソル演算用C++モジュール)もがっつり使っているようです。

ちなみにATenが環境毎にどうやって並列処理しているかについては以下の関数で確認できます。

ここではGoogle Colab無料版を使っているので↑のような情報が得られました。OpenMPが有効で、スレッド並列数2で動作しているようです。

以下の公式サンプルコードを見てみます。

単純な線形モデルを定義し、vmapによってサンプル毎にモデルを適用しています。これをあえてvmapを使わずに書くと以下のようにforループが出現してしまい関数型っぽくなくなります。リスト内包表記を使うと多少シンプルに見えますが処理が効率化されるわけではありません。

ここまでは特に利点が見えにくいかもしれませんが、以下のようにvmapgradと組み合わせると、機械学習でよく使われる処理をスマートに書くことができ、なおかつ内部的にはvmapにより効率的に並列処理されます。

compute_loss関数ではモデルを適用してMSE Lossを計算、これをgradで自動微分対象として勾配を得る関数(1)を生成します。さらにvmapでサンプル毎に関数(1)を適用する関数(2)を生成、inputsに対してその関数(2)を即時適用しています。in_dimsパラメータでマッピングする次元を指定できるので、ここでは入力と同じ3要素のタプルで指定します。weightsのようにマッピングさせない場合はNoneを指定すればOKです。in_dimsパラメータはJAXだと別の名前(in_axes)ですが機能は同じはずです。functorchはJAX-likeなAPIを謳っていますが、PyTorchだとdimという単語がよく使われるのでそちらに合わせたのでしょう。

in_dimsがあればout_dimsもあります。頭の中でどうマッピングされるかイメージして指定しましょう。

vmapの挙動を見ると当然ではありますが入力するtensorの形状には注意します。あと、常用しないとは思いますが形状さえ合っていれば(マッピング可能であれば)ネストしても動作します。

↑の例は大袈裟気味に書きましたが、vmapやgrad対象とする関数は普通にdefで定義すると関数定義がスコープ内に残ってしまうので、無名関数として作って即時適用する形に慣れる方が良さそうです。

ここまでに紹介したパラメータとシンプルな入力データを使って最後におさらいします。

複数の入力ベクトルに対してvmapで並列処理したい場合は、

↑の例が分かれば、gradvmapの使い方はある程度理解できたのではないでしょうか。関数型に頭を切り替えられていれば意外とすんなり理解できるので単に慣れの問題だと思います。その他、gradvmapを使う際の細かい注意点や制限については以下の公式ページを参照してください。余裕があれば次回エントリーでいくつかピックアップして紹介するかもしれません。

おわりに

JAXは以前から興味があったにもかかわらずほとんど使ってなかったのですが、functorchの登場によって使うモチベーションが上がりました。functorchも最初は使いにくく感じるかもしれませんが、慣れればスラスラ書けるようになって面白いです。今回はfunctorchの基本機能となるgradvmapを主に紹介する準備編でした。functorchの真価はPyTorchの文脈上で機械学習処理を高速化することですので、次回はより実用的な例や他のAPIも試してみたいと思います。

また、functorch自体はまだβ版であり、APIはユーザーからのフィードバックなどを経て変わる可能性があるとのことです。今回のエントリー内での使い方も将来使えなくなるかもしれないのでその際はご了承ください。

ちなみに、GoogleにはTensorFlowがあるのになぜFlax(JAX)、Traxなど複数のフレームワークを作っているのかについてはQuoraにディスカッションがあったので見てみると面白いです。Google(DeepMind含む)内のそれぞれの組織がフレームワークをボトムアップで作っているだけということですが、巨大な組織だとよくある話かなと思います。確かにGoogleはチャットアプリもたくさん作ってるし、今後生き残った者に投資していくのでしょう。

参考

Tutorial: Writing JAX-like code in PyTorch with functorch – Simone Scardapane
FUNCTORCH | RICHARD ZOU & HORACE HE – YouTube
JAXによるスケーラブルな機械学習 – ZOZO TECH BLOG
jaxのautogradをpytorchのautogradと比較、単回帰まで(速度比較追加) – HELLO CYBERNETICS
機械学習で楽しむ JAX と NumPyro v0.1.0

あわせて読む:

コメントを残す

メールアドレスが公開されることはありません。 が付いている欄は必須項目です