PyTorch RNN-T モデルの変換

危険

ここで説明されているコードは非推奨になりました。従来のソリューションの適用を避けるため使用しないでください。下位互換性を確保するためにしばらく保持されますが、最新のアプリケーションでは使用してはなりません

このガイドでは、非推奨となった変換方法について説明します。新しい推奨方法に関するガイドは、Python チュートリアルに記載されています。

このガイドでは、MLCommons リポジトリーからの RNN-T モデルの変換について説明します。IR に変換する前に、以下の手順に従って PyTorch モデルを ONNX にエクスポートします。

ステップ 1. MLCommons リポジトリー (リビジョン r1.0) から RNN-T PyTorch 実装のクローンを作成します。完全なリポジトリーを使用せずに RNN-T モデルのみを取得する浅いクローンを作成します。すでに完全なリポジトリーがある場合、これをスキップして ステップ 2 に進みます。

git clone -b r1.0 -n https://github.com/mlcommons/inference rnnt_for_openvino --depth 1
cd rnnt_for_openvino
git checkout HEAD speech_recognition/rnnt

ステップ 2. MLCommons 推論リポジトリーの完全なクローンがすでに存在する場合、IR への変換が行われる事前トレーニング済みの PyTorch モデル用のフォルダーを作成します。ステップ 5 では、完全なクローンへのパスも指定する必要があります。浅いクローンがある場合は、この手順をスキップしてください。

mkdir rnnt_for_openvino
cd rnnt_for_openvino

ステップ 3. こちらから、PyTorch 実装用の事前トレーニングされた重みをダウンロードします。UNIX のようなシステムでは wget を使用できます。

wget https://zenodo.org/record/3662521/files/DistributedDataParallel_1576581068.9962234-epoch-100.pt

リンクは speech_recoginitin/rnnt サブフォルダー内の setup.sh から取得されました。ガイドに従った場合と全く同じ重みが得られます。

ステップ 4. 必要な Python パッケージをインストールします。

pip3 install torch toml

ステップ 5. 以下のスクリプトを使用して、RNN-T モデルを ONNX にエクスポートします。以下のコードを export_rnnt_to_onnx.py ファイルにコピーし、現在のディレクトリー rnnt_for_openvino で実行します。

MLCommons 推論リポジトリーの完全なクローンがすでにある場合は、mlcommons_inference_path 変数を指定する必要があります。

import toml
import torch
import sys


def load_and_migrate_checkpoint(ckpt_path):
    checkpoint = torch.load(ckpt_path, map_location="cpu")
    migrated_state_dict = {}
    for key, value in checkpoint['state_dict'].items():
        key = key.replace("joint_net", "joint.net")
        migrated_state_dict[key] = value
    del migrated_state_dict["audio_preprocessor.featurizer.fb"]
    del migrated_state_dict["audio_preprocessor.featurizer.window"]
    return migrated_state_dict


mlcommons_inference_path = './'  # specify relative path for MLCommons inferene
checkpoint_path = 'DistributedDataParallel_1576581068.9962234-epoch-100.pt'
config_toml = 'speech_recognition/rnnt/pytorch/configs/rnnt.toml'
config = toml.load(config_toml)
rnnt_vocab = config['labels']['labels']
sys.path.insert(0, mlcommons_inference_path + 'speech_recognition/rnnt/pytorch')

from model_separable_rnnt import RNNT

model = RNNT(config['rnnt'], len(rnnt_vocab) + 1, feature_config=config['input_eval'])
model.load_state_dict(load_and_migrate_checkpoint(checkpoint_path))

seq_length, batch_size, feature_length = 157, 1, 240
inp = torch.randn([seq_length, batch_size, feature_length])
feature_length = torch.LongTensor([seq_length])
x_padded, x_lens = model.encoder(inp, feature_length)
torch.onnx.export(model.encoder, (inp, feature_length), "rnnt_encoder.onnx", opset_version=12,
                  input_names=['input', 'feature_length'], output_names=['x_padded', 'x_lens'],
                  dynamic_axes={'input': {0: 'seq_len', 1: 'batch'}})

symbol = torch.LongTensor([[20]])
hidden = torch.randn([2, batch_size, 320]), torch.randn([2, batch_size, 320])
g, hidden = model.prediction.forward(symbol, hidden)
torch.onnx.export(model.prediction, (symbol, hidden), "rnnt_prediction.onnx", opset_version=12,
                  input_names=['symbol', 'hidden_in_1', 'hidden_in_2'],
                  output_names=['g', 'hidden_out_1', 'hidden_out_2'],
                  dynamic_axes={'symbol': {0: 'batch'}, 'hidden_in_1': {1: 'batch'}, 'hidden_in_2': {1: 'batch'}})

f = torch.randn([batch_size, 1, 1024])
model.joint.forward(f, g)
torch.onnx.export(model.joint, (f, g), "rnnt_joint.onnx", opset_version=12,
                  input_names=['0', '1'], output_names=['result'], dynamic_axes={'0': {0: 'batch'}, '1': {0: 'batch'}})
python3 export_rnnt_to_onnx.py

この手順が完了すると、ファイル rnnt_encoder.onnxrnnt_prediction.onnx、および rnnt_joint.onnx が現在のディレクトリーに保存されます。

ステップ 6. 変換コマンドを実行します。

mo --input_model rnnt_encoder.onnx --input "input[157,1,240],feature_length->157"
mo --input_model rnnt_prediction.onnx --input "symbol[1,1],hidden_in_1[2,1,320],hidden_in_2[2,1,320]"
mo --input_model rnnt_joint.onnx --input "0[1,1,1024],1[1,1,320]"

シーケンス長 = 157 のハードコードされた値は MLCommons から取得されたものですが、IR への変換によりネットワークの再形成の可能性が維持されます。入力形状は、変換中または推論中に手動で任意の値に変更できます。