ONNX Runtime for WebをVue.js+WebGL環境で試す

Microsoftから ONNX Runtime for Web (ORT Web) なるものが9月2日にリリースされました。

ONNX (Open Neural Network Exchange) について

ONNX (Open Neural Network Exchange) は機械学習のモデルフォーマットの一つです。機械学習フレームワークはTensowflowやPyTorch、MXNetやCaffe2などたくさんありますが、ONNXフォーマットを使えばそれらのフレームワーク間において相互運用が可能になります。共通で使えるファイル形式ということです。ONNXについての歴史や開発背景はWikipediaの説明を読むのが一番わかりやすいかと思います。Open Neural Network Exchange – Wikipedia

このブログでも2年以上前ですがOpenCVのONNX対応について紹介しているので、OpenCVで画像処理と機械学習をやってみたい人は参考までに。

ONNX Runtime for Web (ORT Web)

ONNX Runtimeはハイパフォーマンスのクロスプラットフォーム対応の推論エンジンです。ONNX Runtime自体は昔からありましたが、今回MicrosoftがリリースしたORT WebはWebブラウザで動作するランタイムとして使うことができます。クライアントサイドでONNXフォーマットのモデルを使った計算(推論処理)が可能となります。Webブラウザで動くランタイムとしては onnx.js というものがありましたがこちらは今後利用非推奨となるようです。これからは替わりにORT Webを使いましょう。

使い方

Microsoft公式にチュートリアル(Quick Start)があります。

既存の速いランタイムを使う場合と比べてNode.jsで動かすのはあまりメリットが無いので(JSで書けるくらい)、今回はWebブラウザで動くものを試します。HTMLのscript tagに全ての実装を埋め込むチュートリアルはわかりやすいですが、昨今のフロントエンド開発においては実践向きではないので必然的にbundler版を試すことになります。とは言ってもbundler版もwebpackを使っているだけでそんなに変わらないようです。いずれにせよチュートリアル通りに動かせば特に躓くところは無さそうなので今回はもう少し現実的ににVue.jsと併せて使ってみました。まぁ2,3年後もVueが活発に使われているか分かりませんが。

環境

Vue.js: 3.2.11
ONNX Runtime for Web (onnxruntime-web): 1.8.0
Webブラウザ: Chrome 93.0, Firefox 92.0
実装: JavaScript (async/awaitを使うためES2017以降)

Vue.jsでONNX Runtime for Webを使う

今のVue.jsはvue-cliがあるので、webpackを直接使う機会は減っているかとは思います。ここでもvue-cliのvue-cli-serviceを使って動作確認やデプロイをします。

# Vue.jsプロジェクトの生成 (説明を簡単にするためVue 3用のプリセットを利用)
$ vue create ort-sample
Vue CLI v4.5.13
? Please pick a preset:
  Default ([Vue 2] babel, eslint)
❯ Default (Vue 3) ([Vue 3] babel, eslint)
  Manually select features

# ONNX Runtime for Web (onnxruntime-web)のインストール
$ cd ort-sample
$ yarn add onnxruntime-web

bundler版チュートリアルのJavaScript実装をVue.js用に移植します。vue-cliでプロジェクトを作ると最初からあるsrc/components/HelloWorld.vueをHelloONNX.vueにリネーム、補足情報を少し追加して移植します。




# CSS部分は省略

onnxruntime-webのモジュールとして、InferenceSessionTensor が提供されています。今のところ2つだけなので覚えることは少ないですが、Tensorにはリッチな機能は備わっていないので大半の処理は自前で実装する必要があります。公式チュートリアルでは以下の簡単な行列計算(3×4と4×3の行列積)のモデルを通すだけなので簡単です。モデルファイル本体(model.onnx)はチュートリアルに付属しています。また、InferenceSession のAPIはPromiseを返すので async/await を付けて出力を受け取ります。

(Netronでmodel.onnxを可視化)

上記モデル構造については実装上で少し意識する必要があります。具体的にはモデルの入力と出力層の名前を知る必要がありますが、それは InferenceSession.inputNames/outputNames プロパティで取得できます。ここでは入力層が a,b で出力層が c となっています。

そろそろ動作確認したいところですが、実行前に事前準備をしておきます。モデルファイルとwasmファイルの配置です。チュートリアルではCopyPluginを使っているのですが、たいしたことをしていないので適切な場所にファイルを配置して、あとはvue-cli-serviceに任せます。具体的にはpublicディレクトリ以下に該当ファイルを置くだけでビルド時にコピーしてくれるようです。

$ tree -L 2 public
public
├── index.html
├── js
│   ├── ort-wasm-threaded.wasm
│   └── ort-wasm.wasm
└── model.onnx

wasmファイルはnode_modules/onnxruntime-web/dist/以下にありますのでコピーしてpublic/js/に、モデルファイル(model.onnx)をpublic/直下に配置したら動作確認してみます。

$ yarn serve
 DONE  Compiled successfully in 10864ms                                                                         14:53:59

  App running at:
  - Local:   http://localhost:8080
  - Network: http://{おまえのIPアドレス}:8080

  Note that the development build is not optimized.
  To create a production build, run yarn build.

あとはlocalhostでブラウザから確認します。

ページの見出しタイトルとかモデル構造画像はおまけで表示していますが、肝心の計算の方も動作確認できました。

ここからはチュートリアルには載っていない追加情報となります。

ONNX Runtime for WebをWebGLで使う

ONNX Runtime for WebはデフォルトだとCPU(WebAssembly)が使われるようですがGPU(WebGL)を使うことも出来ます。WebGLを有効にするにはセッション作成時に executionProviders オプションを指定するだけです。

const option =  {executionProviders: ['webgl']};
const session = await InferenceSession.create('model.onnx', option);

そこそこの規模の計算になる場合は基本的にCPU(WebAssembly)よりGPU(WebGL)の方が高速に動作しますので是非覚えておきたいです。対応ブラウザについてですが、現在の主要なWebブラウザはWebGL対応しているのでそこは心配しなくても良さそうです。そもそもWebGL対応していないブラウザだとTyped Arrayも恐らく使えないためWebAssembly版も動作しないと考えられます。

モデルファイル読み込みの注意点

チュートリアルに付属しているモデルファイル(model.onnx)は120バイトですが、実用上のモデルファイルは数十MB、数百MBの大きさになるのでファイル読み込み時に注意が必要です。InferenceSessionにはそのための機能も備わっているので是非使いましょう。

const modelFilePath = './model.onnx';
if (typeof fetch !== 'undefined') {
  const response = await fetch(filePath);
  const buffer = await response.arrayBuffer();
  const option =  {executionProviders: ['webgl']};
  const session = await InferenceSession.create(buffer, option); // データのArrayBufferとオプションを指定
}

ファイル読み込みとセッション生成処理を分けて実行しています。実際のVueアプリケーションにおいては画面初期化処理の裏でファイル読み込みを事前に行っておくのが良いかと思います。fetch APIは主要なブラウザならほぼ実装されているかと思いますが、ファイル取得部分は他の方法を使っても問題ありません。

自前のモデルファイルを使う

ちょうど前回の記事でPyTorchの新しいパッケージング機能を紹介するときにモデルファイルを作ったので流用します。こちらはCNNのAlexNetをCIFAR-10で学習したモデルになります。

ONNXファイルへの変換はPyTorchの機能を使えば簡単です。以下にサンプルコードを載せますが、torch.onnxパッケージにあるexportを使うだけです。細かいオプションは他にもありますがここでは省略しています。

from alexnet import AlexNet
import torch
from torch import nn, onnx

def convert_onnx(model: nn.Module, output_path: str):
    model.eval()
    dummy_input = torch.randn(1, 3, 32, 32, requires_grad=True)
    onnx.export(
        model,
        dummy_input,
        output_path,
        export_params=True,
        input_names=["input"],
        output_names=["output"]
    )

if __name__ == '__main__':
    model_path = './cifar10_net.pth'
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    net = AlexNet()
    net.load_state_dict(torch.load(model_path, map_location=device))
    output_path = './cifar10_net.onnx'
    convert_onnx(net, output_path)

これで cifar10_net.onnx ファイルが生成されていれば成功ですが、念のためにORT Webで読み込んでみると確実だと思います。AlexNetの小さいモデルですが90MB弱ありましたので、前述のテクニックで事前読み込みをしておくのが良さそうです。ちなみにORT Webに限らずモデルファイルを読み込んだ後に数回処理を通してウォーミングアップする手順(定常運用時の処理を速くする目的)がありますがそれは別の機会に紹介します。

おわりに

とりあえず今回はORT Webの触りの部分と補足情報を紹介しましたが。チュートリアルの行列計算では面白くないので、もう少し分かりやすいアプリケーションを作ってみたいと思います。

ORT Webを軽く触ってみて個人的に使うのが難しいと思った点について書くと、ORT Webの推定ユーザー層がデータサイエンス系のエンジニアやリサーチャーの方とは異なるというところかなと思います。チュートリアル程度の内容であればとっつきやすいですが、まともなアプリケーションを作ろうと思ったときに現在のORT Webがフロントエンド層のライブラリとしてユーザーフレンドリーとは言い辛いです。ORT WebのTensorがPythonのNumPyと比較すると機能が貧弱だったり、各種前処理・後処理をサポートするユーティリティが付属していない点については、今後のアップデートで改善されるかもしれませんが、今のところ使い勝手の点では従来のonnx.jsと同等水準のインタフェースなので、Webアプリ作る層のエンジニアにとっては多分使いにくいだろうなという印象です。

逆に言うと、データサイエンスがわかるフロントエンドエンジニアにとっては、ORT Web用の便利ライブラリを作れば周りに貢献できるかもしれません。まずは気軽に触ってみるところから始めてみるのが良いのではないでしょうか。

あわせて読む:

コメントを残す

メールアドレスが公開されることはありません。 が付いている欄は必須項目です