今回も引き続きfunctorchを使っていろいろ試してみます。前回のエントリーはfunctorchの基本機能を紹介しました。今回はfunctorchによる機械学習のユースケースについて考えてみたいと思います。grad
やvmap
などの基本機能の説明は前回のエントリーを参照してください。
いろいろ試してみて気付いたのですが、結論から言うと、functorchは適切に使えば性能を発揮できます。ただし、従来のオブジェクト指向や手続き型の処理を関数型っぽく書き換えるための準備と再設計が必要です。特に機械学習の文脈で内部状態を持たないようにするのは実際には難しいので、そこはfunctorchがある程度は上手くカバーしてくれますが、全てを置き換えるのではなく有効な箇所を少しずつ置き換えていく形で試してみるのが良いと思いました。
環境
Google Colab Pro
CUDA Version: 11.2
Tesla T4 / P100-PCIE 16GB VRAM
torch-1.11.0+cu102
functorch-0.1.1
前回はあえてCPUインスタンスを使いましたが今回は素直にGPUインスタンスを使います。現在はColab ProだとP100かT4が割り当てられるようですが、Pro+だとA100が当たることもあるので羨ましいです。
ニューラルネットワークの高速化
公式のチュートリアルを多少改変、補足や注意事項を加えて説明します。
ベースとなるモデルは以下のようなシンプルなCNNを定義しておきます。イメージとしてはAlexNetっぽい構造のモデルをCIFAR-10用に簡略化したものにしています。注意点として、Dropout層のようなランダム要素を含むレイヤーは意図的に一旦除いています。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 |
## AlexNet like CNN for CIFAR-10 class SimpleCNN(nn.Module): features: nn.Sequential classifier: nn.Sequential flatten: nn.Flatten def __init__(self, num_classes: int=10): super(SimpleCNN, self).__init__() self.features = nn.Sequential( nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1), nn.ReLU(), nn.MaxPool2d(kernel_size=2), nn.Conv2d(64, 192, kernel_size=3, padding=1), nn.ReLU(), nn.MaxPool2d(kernel_size=2), nn.Conv2d(192, 384, kernel_size=3, padding=1), nn.ReLU(), nn.Conv2d(384, 256, kernel_size=3, padding=1), nn.ReLU(), nn.Conv2d(256, 256, kernel_size=3, padding=1), nn.ReLU(), nn.MaxPool2d(kernel_size=2) ) self.flatten = nn.Flatten() self.classifier = nn.Sequential( # nn.Dropout(), ここではDropout層は除いておく、説明は後述 nn.Linear(256 * 2 * 2, 4096), nn.ReLU(), # nn.Dropout(), nn.Linear(4096, 4096), nn.ReLU(), nn.Linear(4096, num_classes), ) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.features(x) x = self.flatten(x) logits = self.classifier(x) return logits ## 損失関数 (クロスエントロピー) def loss_fn(predictions, targets): return F.cross_entropy(predictions, targets) |
学習データは適当に作ります。もちろんtorchvisionからCIFAR-10の実データを取ってきて使ってもOKです。バッチサイズは環境に合わせて適宜調整できますが、ベンチマーク目的なので小さくしすぎないように注意してください。またCIFAR-10用のCNNなのでデータの形状はモデルの入力に合うようにしておきます。
1 2 3 4 5 |
device = "cuda:0" if torch.cuda.is_available() else "cpu" batch_size = 32 data = torch.randn(batch_size, 3, 32, 32, device=device) targets = torch.randint(10, (batch_size,), device=device) model = SimpleCNN().to(device=device) |
PyTorchで従来の学習のステップを書くときに、以下のようなコードがよく出てくるかと思います。
1 2 3 |
predictions = model(data) loss = loss_fn(predictions, targets) ## loss_fnは任意の損失関数 loss.backward() |
loss.backward()でミニバッチ毎の勾配の平均を求めていますが、functorchを効果的に適用するためにミニバッチ単位ではなくサンプル毎の勾配を求めるように修正します。まずは、functorchを使わない場合は以下のように書けます。
1 2 3 4 5 6 7 8 9 10 11 12 13 |
def compute_grad(sample, target): sample = sample.unsqueeze(0) target = target.unsqueeze(0) prediction = model(sample) loss = loss_fn(prediction, target) return torch.autograd.grad(loss, list(model.parameters())) def compute_sample_grads(data, targets): """ manually process each sample with per sample gradient """ sample_grads = [compute_grad(data[i], targets[i]) for i in range(batch_size)] sample_grads = zip(*sample_grads) sample_grads = [torch.stack(shards) for shards in sample_grads] return sample_grads |
model(sample)の入力sampleは(N,C,H,W)形式なので、その前にunsqueeze(0)でバッチ分の次元を追加しています。
make_functional
次にfunctorchを使って書き直します。gradとvmapは前回紹介しましたが、ここでもう一つ新しい機能が登場します。
1 2 |
from functorch import make_functional, grad, vmap fmodel, params = make_functional(model) |
make_functional(model)
はニューラルネットワークのモデルと状態(パラメータ)を分離する関数です。これを通すことで上記のfmodelは純粋関数となり、参照透過性を持ちます(同じ入力なら常に同じ出力)。functorchでは以下のようなFunctionalModule
クラスのオブジェクトとして扱われます。
1 2 3 4 5 6 7 8 9 10 11 |
FunctionalModule( (stateless_model): SimpleCNN( (features): Sequential( (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) (1): ReLU() (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (3): Conv2d(64, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (4): ReLU() (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ... 省略 |
注意点として、ここで分離されたパラメータオブジェクト(torch.nn.parameter.Parameter)には勾配情報を保存しておく必要はありませんので無効にしておきます。エラーが出るわけではありませんが、関数型のアプローチでは内部状態を更新するような処理は行いません。
1 2 |
for param in params: param.requires_grad = False |
ちなみにこのエントリー執筆時にはまだ使えませんでしたが、make_functionalにdisable_autograd_tracking
パラメータというものが付くそうですので最新版ではこれを使いましょう。requires_gradをいちいち変更する手間が省けます。
- Added params_require_grad arg to make_functional* by vfdev-5 · Pull Request #701 · pytorch/functorch
1 2 |
## disable_autograd_tracking make_functional(model, disable_autograd_tracking=True) |
また、make_functional_with_buffer
という亜種もあります。モデル内部に任意領域としてのバッファを持っていることを知らない人もいるかもしれませんが、たまに便利なので実際にはこちらを使うことをおすすめします。
1 2 3 4 5 6 |
## torch.nn.Module.buffersオブジェクトのデータを分離 fmodel, params, buffers = make_functional_with_buffers(model) ## ここではbuffersの中身は無いので空のtupleが返る print(model.buffers(), buffers) <generator object Module.buffers at 0x7f90e0634bd0> () |
bufferはどういうときに使うのかは以下のフォーラムスレッドを参考に。
話を戻して、前述のcompute_grad関数をFunctionalModuleを使って書き換えます。モデルと分離されたパラメータを忘れずに持ってきましょう。
1 2 3 4 5 6 |
def compute_loss_stateless_model(params, sample, target): batch = sample.unsqueeze(0) targets = target.unsqueeze(0) predictions = fmodel(params, batch) loss = loss_fn(predictions, targets) return loss |
次に前述のcompute_sample_grads関数をfunctorchのgradとvmapを使って書き換えると以下のようにシンプルになり、vmapによりサンプル毎の計算は並列処理されます。
1 2 3 4 5 6 |
## サンプル毎の勾配計算 ft_compute_sample_grads = vmap(grad(compute_loss_stateless_model), in_dims=(None, 0, 0)) ## もちろん分けて書いても良い ft_compute_grad = grad(compute_loss_stateless_model) ft_compute_sample_grads = vmap(ft_compute_grad, in_dims=(None, 0, 0)) |
ここまでの処理においてfunctorchの利用有無で動作確認しておきます。サンプル毎の勾配の値が同じかどうか確認するテストコードになっています。
1 2 3 4 5 6 |
per_sample_grads = compute_sample_grads(data, targets) # functorch使わない版 ft_per_sample_grads = ft_compute_sample_grads(params, data, targets) # functorch版 # we can double check that the results using functorch grad and vmap match the results of hand processing each one individually: for per_sample_grad, ft_per_sample_grad in zip(per_sample_grads, ft_per_sample_grads): assert torch.allclose(per_sample_grad, ft_per_sample_grad, atol=3e-3, rtol=1e-5) |
アサートエラーがでなければ問題ないです。
注意点
vmapを適用する関数について、特にランダム要素の有無に注意する必要があります。適用する関数は参照透過性が必要であるため、同じ入力でも実行の度に結果が変わるような関数を扱うことはできません(副作用については後述)。例えば今回定義したCNNでは意図的にDropout層を除いていますが、そのままDropout層を入れても以下のようなエラーが出ます。エラーメッセージがわかりやすいのですぐにミスに気付けますね。
1 |
RuntimeError: vmap: called random operation while in randomness error mode. Please either use the 'same' or 'different' randomness flags on vmap or perform the randomness operation out of vmap |
しかし、機械学習において乱数的な振る舞いはいろんな所で必要になってくるのでfunctorchではそのためのオプションが用意されています。
randomness (str) – Specifies whether the randomness in this vmap should be the same or different across batches. If ‘different’, the randomness for each batch will be different. If ‘same’, the randomness will be the same across batches. If ‘error’, any calls to random functions will error. Default: ‘error’. WARNING: this flag only applies to random PyTorch operations and does not apply to Python’s random module or numpy randomness.
もしランダム要素を含む関数を扱う場合はvmapのrandomnessオプションを使えば上記エラーは回避できます。ただし、”same”を付けると全てのバッチで同じ値となるので留意しておきます。
1 |
ft_compute_sample_grad = vmap(grad(compute_loss_stateless_model), in_dims=(None, 0, 0), randomness="same") |
ベンチマーク
ベンチマークは公式チュートリアルにあるコードをほぼそのまま使いました。PyTorchにベンチマーク用パッケージ(torch.utils.benchmark)があるのを初めて知りました。こちらも環境によって反復回数(n_timeit)を調整します。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 |
## functorchの利用有無で勾配計算を反復実行して速度比較 torch.backends.cudnn.benchmark = True def get_perf(first, first_descriptor, second, second_descriptor): """ takes torch.benchmark objects and compares delta of second vs first. """ second_res = second.times[0] first_res = first.times[0] gain = (first_res-second_res)/first_res if gain < 0: gain *=-1 final_gain = gain*100 print(f" Performance delta: {final_gain:.4f} percent improvement with {first_descriptor} ") from torch.utils.benchmark import Timer without_vmap = Timer(stmt="compute_sample_grads(data, targets)", globals=globals()) with_vmap = Timer(stmt="ft_compute_sample_grads(params, data, targets)", globals=globals()) n_timeit = 100 no_vmap_timing = without_vmap.timeit(n_timeit) with_vmap_timing = with_vmap.timeit(n_timeit) print(f'Per-sample-grads without vmap {no_vmap_timing}') print(f'Per-sample-grads with vmap {with_vmap_timing}') get_perf(with_vmap_timing, "vmap", no_vmap_timing,"no vmap" ) |
Tesla T4/P100-PCIEそれぞれの環境でのベンチマーク結果は以下のようになりました。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 |
## T4 Per-sample-grads without vmap <torch.utils.benchmark.utils.common.Measurement object at 0x7f5191f33850> compute_sample_grads(data, targets) 93.55 ms 1 measurement, 100 runs , 1 thread Per-sample-grads with vmap <torch.utils.benchmark.utils.common.Measurement object at 0x7f51921d0e50> ft_compute_sample_grads(params, data, targets) 21.50 ms 1 measurement, 100 runs , 1 thread Performance delta: 335.1771 percent improvement with vmap ## P100-PCIE Per-sample-grads without vmap <torch.utils.benchmark.utils.common.Measurement object at 0x7f90e169b190> compute_sample_grads(data, targets) 88.88 ms 1 measurement, 100 runs , 1 thread Per-sample-grads with vmap <torch.utils.benchmark.utils.common.Measurement object at 0x7f90e17731d0> ft_compute_sample_grads(params, data, targets) 12.70 ms 1 measurement, 100 runs , 1 thread Performance delta: 600.1182 percent improvement with vmap |
Tesla T4は推論処理用アクセラレータと謳っているのでここで使うには適切ではないと思われますが、T4だと300%強、P100で600%ほど高速化されました。torch.utils.benchmarkパッケージで実行される処理については内部実装を深堀りしてはいませんが、公式モジュールですし結果は信用して良いと思います。ここではTimerしか使ってないですが他にも細かくいろいろ設定できるらしいのでまた別エントリーかgist等で紹介します。
アンサンブル学習の高速化
モデルのアンサンブル学習もfunctorchで効率化できます。
今回のエントリーで紹介したサンプル毎の勾配計算の並列化が理解できれば上記チュートリアルもすぐ理解できると思うのでここでは軽く紹介します。vmapによる並列化がモデル単位になっただけです。つまり、functorchの文脈ではモデルオブジェクトは純粋関数として扱えるのでそのままvmapに適用できます。
1 2 3 4 5 6 7 8 9 10 11 12 |
## アンサンブルのためにモデルを複数作り、データを分割する num_models = 10 models = [SimpleCNN().to(device) for _ in range(num_models)] minibatches = data[:num_models] ## combine_state_for_ensemble でスタックされたパラメータとバッファーに分離 ## https://pytorch.org/functorch/stable/generated/functorch.combine_state_for_ensemble.html from functorch import combine_state_for_ensemble fmodel, params, buffers = combine_state_for_ensemble(models) ## vmapでベクタライズ、即時適用 predictions = vmap(fmodel)(params, buffers, minibatches) |
上記例は同じモデルを10つ並べているだけなので実用として使えるようなものではないですが、functorchによるアンサンブル処理のイメージは掴めるのではないでしょうか。
注意点としては、シングルノードで実行されるのでコンピューティングリソース(特にVRAM)が潤沢である必要があります。マルチノードでの分散学習は準備が面倒でコストがかかるので、もしスペックの高いサーバが一台あればfunctorchでアンサンブル学習の並列化を試してみるのも良いかもしれません。
補足事項など
今回試したベンチマークはあくまで学習処理の一部分(勾配計算)のみ切り出して計測したものです。実際の学習処理ではtorch.optim.Optimizer
でのパラメータ更新部分も同様に対応する必要がありますが、実装が煩雑にならないようにするための議論がこちらで行われています。
また、PyTorchでは破壊的な(副作用のある)オペレーションがたくさんありますが、functorchではそれを上手く隠蔽・除去して変換してくれるようです。例えば、torch.nn.ReLU
など一部のレイヤーはinplaceパラメータがありますが、これをTrueにしていても特に挙動がおかしくなったり、エラーが発生することはありません。make_functional
が内部で丁寧にパラメータを抽出・分離してくれるようです。
ただし、PyTorchが提供している破壊的な関数(xxx_()
)は以下のようにvmapを適用しても問題なく破壊的に動作します。意図的にこういう実装はしないとは思いますが、このような破壊的な操作を伴う関数を使う際は注意してください。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 |
def add(x, y): x.add_(y) return x x = torch.randn(3) y = torch.randn(3) print(x, y) print(vmap(add, in_dims=(0, 0))(x, y)) print(x, y) 動作結果、xの値変更されている tensor([ 0.2734, -0.9181, -0.0404]) tensor([ 0.2881, -0.0075, -0.9145]) tensor([ 0.5615, -0.9256, -0.9549]) tensor([ 0.5615, -0.9256, -0.9549]) tensor([ 0.2881, -0.0075, -0.9145]) |
あとvmapでは今のところ制御文も使えません。
1 2 3 4 5 6 7 8 9 |
def relu(x): if x > 0: return x return 0 x = torch.randn(3) vmap(relu)(x) RuntimeError: vmap: It looks like you're attempting to use a Tensor in some data-dependent control flow. We don't support that yet, please shout over at https://github.com/pytorch/functorch/issues/257 . |
これは上記issuesで議論中のようです。JAXにはjax.lax.cond関数があるのでこちらを使うようですが、functorchでもたぶんすぐ対応されると思います。
終わりに
冒頭でも書いたように、functorchで効率的に機械学習を行うには関数型の考え方に切り替える必要があるようです。単に make_functional
でモデルと状態を切り分けるだけでは、デバッグが少々楽になる気がするだけで効率化には繋がりません。バグをなるべく防ぎつつ高速化したい場合のアプローチとしては、なんとなく以下のように取り組むと良さそうです。
- モデルと状態(パラメータ)を分離させる ->
make_functional / make_functional_with_buffers
- 微分対象の処理単位を小さくする(モジュラー性) ->
grad
- 参照透過性を備えた2.の処理単位で並列化 ->
vmap
こう書いてみると簡単ですが、これまでバッチ単位で処理することに慣れていた人達にとっては抵抗感があるのではないでしょうか。例えばBatch Normalizationの扱いはどうなるの、とか。それからあんまり調子に乗ってvmapしてると CUDA out of memory. が頻発するのでリソース調整をより慎重に行う必要があります。今回はモデルアンサンブルの例を一応挙げましたが、実際に行うのは現実的ではないかも知れません。いろんな種類のモデルをvmapに適用できる形に全て変換するのは大変だし、シングルノードで学習するのはGPUリソースがまず足りないでしょう。
また、モデル学習処理全体を効率化したい場合は、例えばDataLoaderの取り扱いやGPU周りの設定、TorchScriptの利用なども含めて満遍なく対応することになります。それらと併せてfunctorchも適切に使えば学習処理のコアな部分を集中的に高速化できるので、がんばってチャレンジする価値はあるかもしれません。個人的にはオプティマイザー周りの処理(パラメータの更新)がもっと綺麗に書けるようになったら是非採用したいと思っています。
Papers With CodeのTrending Researchとか見てると、明らかにPyTorchでのリファレンス実装が多いようです。高速化のためにJAX/Flax移植をがんばったり、マルチノード/マルチGPUの分散学習環境を準備するよりも、functorchの利用で十分高速化するのなら低コストで済むので助かりますね。
最後に注意点として、functorchは2022/06現在まだβ版となっています。APIも今後仕様が変わる可能性があるので利用の際は注意しましょう。
参考:
- Per-sample-gradients — functorch 0.1.1 documentation
- Model ensembling — functorch 0.1.1 documentation
- Working-with-FuncTorch-An-Introduction – Weights & Biases
- JAXとPyTorch、どっちが速いのか検証してみた – まったり勉強ノート
- JAXによるスケーラブルな機械学習 – ZOZO TECH BLOG