PyTorch BERT-NER モデルの変換

危険

ここで説明されているコードは非推奨になりました。従来のソリューションの適用を避けるため使用しないでください。下位互換性を確保するためにしばらく保持されますが、最新のアプリケーションでは使用してはなりません

このガイドでは、非推奨となった変換方法について説明します。新しい推奨方法に関するガイドは、Python チュートリアルに記載されています。

この記事の目的は、PyTorch BERT-NER モデルを OpenVINO IR に変換する段階的なガイドを提示することです。最初に、モデルをダウンロードして ONNX に変換する必要があります。

モデルをダウンロードして ONNX に変換

事前トレーニングされたモデルをダウンロードするか、モデルを自身でトレーニングするには、BERT-NER モデル・リポジトリーの手順を参照してください。設定ファイルを含むモデルは out_base ディレクトリーに保存されます。

モデルを ONNX 形式に変換するには、モデル・リポジトリーのルート・ディレクトリーで次のスクリプトを作成して実行します。事前トレーニング済みモデルをダウンロードする場合は、スクリプトを実行するため bert.py をダウンロードする必要があります。手順変換は commit-SHA: e5be564156f194f1becb0d82aeaf6e762d9eb9ed を使用してテストされました。

import torch

from bert import Ner

ner = Ner("out_base")

input_ids, input_mask, segment_ids, valid_positions = ner.preprocess('Steve went to Paris')
input_ids = torch.tensor([input_ids], dtype=torch.long, device=ner.device)
input_mask = torch.tensor([input_mask], dtype=torch.long, device=ner.device)
segment_ids = torch.tensor([segment_ids], dtype=torch.long, device=ner.device)
valid_ids = torch.tensor([valid_positions], dtype=torch.long, device=ner.device)

ner_model, tknizr, model_config = ner.load_model("out_base")

with torch.no_grad():
    logits = ner_model(input_ids, segment_ids, input_mask, valid_ids)
torch.onnx.export(ner_model,
                  (input_ids, segment_ids, input_mask, valid_ids),
                  "bert-ner.onnx",
                  input_names=['input_ids', 'segment_ids', 'input_mask', 'valid_ids'],
                  output_names=['output'],
                  dynamic_axes={
                      "input_ids": {0: "batch_size"},
                      "segment_ids": {0: "batch_size"},
                      "input_mask": {0: "batch_size"},
                      "valid_ids": {0: "batch_size"},
                      "output": {0: "output"}
                  },
                  opset_version=11,
                  )

このスクリプトは ONNX モデルファイル bert-ner.onnx を生成します。

ONNX BERT-NER モデルから IR への変換

mo --input_model bert-ner.onnx --input "input_mask[1,128],segment_ids[1,128],input_ids[1,128]"

ここで、1batch_size128sequence_length です。