前回はONNX Runtime for Web (ORT Web)をVue.jsアプリケーションで使ってみました。
公式チュートリアルは単純な行列計算でしたので、今回はもう少し実践的に自前のモデルを使った画像認識を試してみたいと思います。
環境
前回の環境とほぼ同じですが、実装言語はTypeScriptに変更しました。数値計算系の部分は型があるとデバッグしやすいので。
Vue.js: 3.2.11
ONNX Runtime for Web (onnxruntime-web): 1.8.0
Webブラウザ: Chrome 93.0, Firefox 92.0
実装: TypeScript 4.1.6
ONNXモデルファイルの準備
ONNXフォーマットのモデルの作り方は前回の記事を参照してください。PyTorchで作ったモデルをONNXフォーマットに変換したものを準備します。
モデルアーキテクチャ: AlexNet
データセット: CIFAR-10 (32x32px画像、10カテゴリ)
画像認識のモデルとしては入門レベルのシンプルなものです。ちなみに精度(accuracy)は86%くらいでした。CIFAR-10データセットは32×32ピクセルの小さな画像なので人間には少々見難いですが、処理が軽いので動作確認には適しています。
実装
既存のONNX Runtimeで実行するのと基本的には同様の手順で動きますが、ORT Webのインタフェースは必要最低限のものしか提供されていないので細かい計算は自前で行う必要があります。今回は記事は主にHTML Canvas固有の処理の説明になります。
まず実装の全容を以下に載せます。ORT WebとVue.js以外の外部モジュールは使用していません。ここでは1ファイルに全容を収めるためにVue.jsの単一ファイルコンポーネントとして実装していますが、計算処理部分は別モジュールに切り出すのが適切かと思います。1ファイルなのでgistにも一応貼っておきます。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 |
<template> <div class="hello"> <h1>{{ msg }}</h1> <canvas width="32" height="32" ref="canvas"></canvas> <button type="button" @click="inference">inference</button> <span>{{ infoLabel }}</span> </div> </template> <script lang="ts"> import { defineComponent } from 'vue'; import { InferenceSession, Tensor } from 'onnxruntime-web'; interface DataType { modelPath: string, imagePath: string, imageData: ImageData | null, ctx: CanvasRenderingContext2D | null, session: InferenceSession | null, infoLabel: string } export default defineComponent({ name: 'HelloONNX', props: { msg: String, }, data(): DataType { return { modelPath: 'cifar10_net.onnx', // ONNXモデルファイル名 imagePath: require("@/assets/cat9.png"), // テスト画像を埋め込み imageData: null, ctx: null, session: null, infoLabel: "" }; }, async mounted() { const option = {executionProviders: ['wasm', 'webgl']}; this.session = await InferenceSession.create(this.modelPath, option); this.infoLabel = "loading model complete." const image = new Image(); image.src = this.imagePath; const isCanvas = (x: any): x is HTMLCanvasElement => x instanceof HTMLCanvasElement; image.onload = () => { // Canvas要素に画像ファイルを貼り、画像データを取得する const ref = this.$refs; if(!isCanvas(ref.canvas)) return; this.ctx = ref.canvas.getContext("2d"); if(this.ctx == null) return; const [w, h] = [this.ctx.canvas.width, this.ctx.canvas.height]; this.ctx.drawImage(image, 0, 0, w, h); this.imageData = this.ctx.getImageData(0, 0, w, h); }; }, methods: { async inference(): Promise<void> { if(this.session == null || this.imageData == null) return; const { data, width, height } = this.imageData; // 入力データの正規化(標準化) const processed = this.normalize((data as Uint8ClampedArray), width, height); // ORT Web用にデータ変換して、推論処理を実行 const tensor = new Tensor("float32", processed, [1, 3, width, height]); const feed = { input: tensor }; const result = await this.session.run(feed); // 推論結果(カテゴリと信頼度)を取得 const predicted = this.softmax((result.output.data as Float32Array)); // 信頼度が一番高いカテゴリと信頼度を画面表示 this.infoLabel = this.getClass(predicted).toString(); }, normalize(src: Uint8ClampedArray, width: number, height: number): Float32Array { const dst = new Float32Array(width * height * 3); const transforms = [[0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]]; const step = width * height; for(let y = 0; y < height; y++) { for(let x = 0; x < width; x++) { const [di, si] = [y * width + x, (y * width + x) * 4]; // 各チャンネルで平均を引いて標準偏差で割る(標準化) // さらに RGBARGBARGBA... から RRR...GGG...BBB... の順にデータを詰め替え dst[di] = ((src[si + 0] / 255) - transforms[0][0]) / transforms[1][0]; dst[di + step] = ((src[si + 1] / 255) - transforms[0][1]) / transforms[1][1]; dst[di + step * 2] = ((src[si + 2] / 255) - transforms[0][2]) / transforms[1][2]; } } return dst; }, softmax(data: Float32Array): Float32Array { const max = Math.max(...data); const d = data.map(y => Math.exp(y - max)).reduce((a, b) => a + b); return data.map((value, index) => Math.exp(value - max) / d); }, getClass(data: Float32Array): [string, number] { const classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']; const maxProb = Math.max(...data); return [classes[data.indexOf(maxProb)], maxProb]; } } }); </script> ... CSSパートは省略 |
TypeScriptのstrict type checkingを有効にしているので余計な処理が入っていますが、処理の本体は inference 関数以下に全て書いています。
処理の流れ
- HTML Canvas要素に画像ファイルを読み込んで、画像データを取得 (mounted)
- 画像データの標準化 (normalize)
- ORT Webで推論処理を実行 (InferenceSession.run)
- 推論結果を整形して表示 (softmax, getClass)
画像データ取得部分はORT Web固有の処理などは無いので難しい部分はないかと思います。次のデータの標準化についてですが、これを前処理として忘れずに行っておきます。具体的には画像の各チャンネル毎に、データセットの全ての画像から算出した平均値を引いて標準偏差で割ってやります。モデルは以下のように作っているので平均や標準偏差も同じ値を使いました。
1 2 3 4 5 6 7 8 9 10 11 12 13 |
# PyTorchでモデルを作った時のtransformの設定 transform = { 'train': transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]), 'val': transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) } |
ここで、ORT WebのTensorデータを作る際の注意点があります。まず、Canvas要素のImageData
配列は、RGBARGBARGBA… の順で一次元に画像データが詰め込まれていますが、ORT WebのTensor
に変換する際には RRR…GGG…BBB… の順で詰め込む必要があります。配列の大きさも異なり、ImageData
はA(アルファチャンネル)を含み、Tensor
は含みません。つまり32×32ピクセルの画像だとImageData
は32x32x4=4096で、Tensor
は32x32x3=3072となります。上記実装内ではそれぞれの配列のインデックス計算をして標準化したデータを詰め替えていることになります。わかりにくいですが計算コストを抑えるために同じループ内で処理しています。
ちなみにJavaScript(TypeScript)ではPythonのNumPy ndarrayの操作をシミュレートする ndarray モジュールがあるのでそれを使うのも良いと思います。今回の計算程度でndarrayを使うとコード量自体は逆に増えてしまいますが見た目は読みやすくはなります。
残りのsoftmax
関数は文字通りSoftmax関数の計算を行い、モデル(ニューラルネットワーク)の出力結果をカテゴリ毎の確率(信頼度)に変換しています。それからgetClass
関数で一番確率が高いカテゴリを返して処理完了となります。上記実装内の predicted 変数に各カテゴリ毎の確率が入っているので、例えば確率トップ5を出したいなら中身をソートして表示すれば良いだけです。今回はテスト画像を32×32ピクセルのものを予め用意していますが、サイズの違う画像を入力にする場合はリサイズ処理も前処理として必要になります。
今回はVue.jsのアプリケーションとして実装しているので表示は以下のようになります。以下は32×32ピクセルの猫の画像を入力にして正しく認識されていることがわかります(約98%の確率でcatであると推論)。
おわりに
今回はORT Webを使った少し実践的なサンプルを作ってみました。ONNXモデルファイルを自前で作る前に、ONNX Model Zooにたくさん転がっているのでまずはそちらを使って遊ぶのが良いです。やはりブラウザで動くのは楽しいですね。
一方で前回の記事にも書きましたが、ORT Webのユーザー層のスキルセットがマッチしないかもという懸念は残ります。上記の実装は少しだけ機械学習の事前知識がないと難しいと感じるかもしれません。ORT Webのインタフェースは必要最低限なので、それを補完するライブラリがあると便利そうです。ORT Webはまだリリースされて間もないので、これからのアップデートに期待したいと思います。
- 今回実装したVueコンポーネント: Image Classification using ONNX Runtime for Web (ORT Web)