TensorFlow GNMT モデルの変換

危険

ここで説明されているコードは非推奨になりました。従来のソリューションの適用を避けるため使用しないでください。下位互換性を確保するためにしばらく保持されますが、最新のアプリケーションでは使用してはなりません

このガイドでは、非推奨となった変換方法について説明します。新しい推奨方法に関するガイドは、Python チュートリアルに記載されています。

このチュートリアルでは、Google ニューラル機械翻訳 (GNMT) モデルを中間表現 (IR) に変換する方法について説明します。

TensorFlow GNMT モデル実装の公開バージョンが GitHub で入手できます。このチュートリアルでは、GNMT モデルを TensorFlow ニューラル機械翻訳 (NMT) リポジトリーから IR に変換する方法について説明します。

パッチファイルの作成

モデルを変換する前に、リポジトリー用のパッチファイルを作成します。このパッチは、推論グラフのダンプを有効にするコマンドライン引数をフレームワーク・オプションに追加してフレームワーク・コードを変更します。

  1. 書き込み可能なディレクトリーに移動し、GNMT_inference.patch ファイルを作成します。

  2. 次の差分コードをファイルにコピーします。

    diff --git a/nmt/inference.py b/nmt/inference.py
    index 2cbef07..e185490 100644
    --- a/nmt/inference.py
    +++ b/nmt/inference.py
    @@ -17,9 +17,11 @@
    from __future__ import print_function
    
    import codecs
    +import os
    import time
    
    import tensorflow as tf
    +from tensorflow.python.framework import graph_io
    
    from . import attention_model
    from . import gnmt_model
    @@ -105,6 +107,29 @@ def start_sess_and_load_model(infer_model, ckpt_path):
       return sess, loaded_infer_model
    
    
    +def inference_dump_graph(ckpt_path, path_to_dump, hparams, scope=None):
    +    model_creator = get_model_creator(hparams)
    +    infer_model = model_helper.create_infer_model(model_creator, hparams, scope)
    +    sess = tf.Session(
    +        graph=infer_model.graph, config=utils.get_config_proto())
    +    with infer_model.graph.as_default():
    +        loaded_infer_model = model_helper.load_model(
    +            infer_model.model, ckpt_path, sess, "infer")
    +    utils.print_out("Dumping inference graph to {}".format(path_to_dump))
    +    loaded_infer_model.saver.save(
    +        sess,
    +        os.path.join(path_to_dump + 'inference_GNMT_graph')
    +        )
    +    utils.print_out("Dumping done!")
    +
    +    output_node_name = 'index_to_string_Lookup'
    +    utils.print_out("Freezing GNMT graph with output node {}...".format(output_node_name))
    +    frozen = tf.graph_util.convert_variables_to_constants(sess, sess.graph_def,
    +                                                          [output_node_name])
    +    graph_io.write_graph(frozen, '.', os.path.join(path_to_dump, 'frozen_GNMT_inference_graph.pb'), as_text=False)
    +    utils.print_out("Freezing done. Freezed model frozen_GNMT_inference_graph.pb saved to {}".format(path_to_dump))
    +
    +
    def inference(ckpt_path,
                   inference_input_file,
                   inference_output_file,
    diff --git a/nmt/nmt.py b/nmt/nmt.py
    index f5823d8..a733748 100644
    --- a/nmt/nmt.py
    +++ b/nmt/nmt.py
    @@ -310,6 +310,13 @@ def add_arguments(parser):
       parser.add_argument("--num_intra_threads", type=int, default=0,
                         help="number of intra_op_parallelism_threads")
    
    +  # Special argument for inference model dumping without inference
    +  parser.add_argument("--dump_inference_model", type="bool", nargs="?",
    +                      const=True, default=False,
    +                      help="Argument for dump inference graph for specified trained ckpt")
    +
    +  parser.add_argument("--path_to_dump", type=str, default="",
    +                      help="Path to dump inference graph.")
    
    def create_hparams(flags):
       """Create training hparams."""
    @@ -396,6 +403,9 @@ def create_hparams(flags):
          language_model=flags.language_model,
          num_intra_threads=flags.num_intra_threads,
          num_inter_threads=flags.num_inter_threads,
    +
    +      dump_inference_model=flags.dump_inference_model,
    +      path_to_dump=flags.path_to_dump,
       )
    
    
    @@ -613,7 +623,7 @@ def create_or_load_hparams(
       return hparams
    
    
    -def run_main(flags, default_hparams, train_fn, inference_fn, target_session=""):
    +def run_main(flags, default_hparams, train_fn, inference_fn, inference_dump, target_session=""):
       """Run main."""
       # Job
       jobid = flags.jobid
    @@ -653,8 +663,26 @@ def run_main(flags, default_hparams, train_fn, inference_fn, target_session=""):
             out_dir, default_hparams, flags.hparams_path,
             save_hparams=(jobid == 0))
    
    -  ## Train/Decode
    -  if flags.inference_input_file:
    +  #  Dumping inference model
    +  if flags.dump_inference_model:
    +      # Inference indices
    +      hparams.inference_indices = None
    +      if flags.inference_list:
    +          (hparams.inference_indices) = (
    +              [int(token) for token in flags.inference_list.split(",")])
    +
    +      # Ckpt
    +      ckpt = flags.ckpt
    +      if not ckpt:
    +          ckpt = tf.train.latest_checkpoint(out_dir)
    +
    +      # Path to dump graph
    +      assert flags.path_to_dump != "", "Please, specify path_to_dump model."
    +      path_to_dump = flags.path_to_dump
    +      if not tf.gfile.Exists(path_to_dump): tf.gfile.MakeDirs(path_to_dump)
    +
    +      inference_dump(ckpt, path_to_dump, hparams)
    +  elif flags.inference_input_file:
       # Inference output directory
       trans_file = flags.inference_output_file
       assert trans_file
    @@ -693,7 +721,8 @@ def main(unused_argv):
       default_hparams = create_hparams(FLAGS)
       train_fn = train.train
       inference_fn = inference.inference
    -  run_main(FLAGS, default_hparams, train_fn, inference_fn)
    +  inference_dump = inference.inference_dump_graph
    +  run_main(FLAGS, default_hparams, train_fn, inference_fn, inference_dump)
    
    
    if __name__ == "__main__":
    
  3. ファイルを保存して閉じます。

GNMT モデルを IR に変換

TensorFlow* バージョン 1.13.0 以下をインストールします。

ステップ 1. この GitHub リポジトリーのクローンを作成し、コミットをチェックアウトします。

  1. NMT リポジトリーのクローンを作成します。

    git clone https://github.com/tensorflow/nmt.git
    
  2. 必要なコミットを確認してください。

    git checkout b278487980832417ad8ac701c672b5c3dc7fa553
    

ステップ 2. トレーニングされたモデルを取得します。次の 2 つのオプションがあります。

  • NMT フレームワークを使用して、GNMT wmt16_gnmt_4_layer.json または wmt16_gnmt_8_layer.json 構成ファイルを使用してモデルをトレーニングします。

  • NMT リポジトリーで提供される事前トレーニング済みチェックポイントは使用しないでください。これらは旧式で、現在のリポジトリーのバージョンとは互換性がない可能性があります。

このチュートリアルでは、wmt16_gnmt_4_layer.json 構成のトレーニング済み GNMT モデル (ドイツ語から英語への翻訳) を前提としています。

ステップ 3. 推論グラフを作成します。

OpenVINO は、モデルが推論のみに使用されることを前提としています。したがって、モデルを IR に変換する前に、トレーニング・グラフを推論グラフに変換する必要があります。GNMT モデルでは、トレーニング・グラフと推論グラフには異なるデコーダーがあります。トレーニング・グラフは greedy 検索デコード・アルゴリズムを使用し、推論グラフは beam 検索デコード・アルゴリズムを使用します。

  1. GNMT_inference.patch パッチをリポジトリーに適用します。ない場合は、パッチファイル手順を作成をします。

    git apply /path/to/patch/GNMT_inference.patch
    
  2. NMT フレームワークを実行して推論モデルをダンプします。

    python -m nmt.nmt
       --src=de
       --tgt=en
       --ckpt=/path/to/ckpt/translate.ckpt
       --hparams_path=/path/to/repository/nmt/nmt/standard_hparams/wmt16_gnmt_4_layer.json
       --vocab_prefix=/path/to/vocab/vocab.bpe.32000
       --out_dir=""
       --dump_inference_model
       --infer_mode beam_search
       --path_to_dump /path/to/dump/model/
    

異なるチェックポイントを使用する場合、srctgtckpthparams_path、および vocab_prefix パラメーターに対応する値を使用します。推論チェックポイント inference_GNMT_graph と凍結された推論グラフ frozen_GNMT_inference_graph.pb/path/to/dump/model/ フォルダーにあります。

vocab.bpe.32000 を生成するには、nmt/scripts/wmt16_en_de.sh スクリプトを実行します。チェックポイント・グラフの埋め込みレイヤーとボキャブラリー (ソースとターゲットの両方) 間にサイズ不一致の問題が発生した場合は、nmt.py ファイルの extend_hparams 関数の 508 行目以降 (src_vocab_size 変数と tgt_vocab_size 変数) に次のコードを必ず追加してください。

src_vocab_size -= 1
tgt_vocab_size -= 1

ステップ 4. モデルを IR に変換します。

mo
--input_model /path/to/dump/model/frozen_GNMT_inference_graph.pb
--input "IteratorGetNext:1{i32}[1],IteratorGetNext:0{i32}[1,50],dynamic_seq2seq/hash_table_Lookup_1:0[1]->[2],dynamic_seq2seq/hash_table_Lookup:0[1]->[1]"
--output dynamic_seq2seq/decoder/decoder/GatherTree
--output_dir /path/to/output/IR/

OpenVINO™ は IteratorGetNext および LookupTableFindV2 操作をサポートしていないため、--input および --output オプションを使用した入力および出力のカットが必要です。

入力のカット:

  • IteratorGetNext 操作はデータセットに対して反復されます。これは出力ポートによってカットされます: ポート 0 には形状 [batch_size, max_sequence_length] のデータテンソルが含まれ、ポート 1 には形状 [batch_size] のすべてのバッチの sequence_length が含まれます。

  • LookupTableFindV2 操作 (グラフ内の dynamic_seq2seq/hash_table_Lookup_1 および dynamic_seq2seq/hash_table_Lookup ノード) は定数値でカットされます)。

出力のカット:

  • LookupTableFindV2 操作は出力から切り取られ、動的な dynamic_seq2seq/decoder/decoder/GatherTree ノードが新しい終了点として扱われます。

モデルカットの詳細については、モデルの一部を切り取るガイドを参照してください。

GNMT モデルを使用

このステップでは、モデルが中間表現に変換されていることを前提とします。

モデルの入力:

  • IteratorGetNext/placeholder_out_port_0 形状 [batch_size, max_sequence_length] の入力には、batch_size でデコードされた入力文が含まれます。すべての文は、ボキャブラリー内の文要素のインデックスと同じ方法でデコードされ、eos (文末記号) のインデックスが埋め込まれます。文の長さが max_sequence_length 未満の場合、残りの要素は eos トークンのインデックスで埋められます。

  • IteratorGetNext/placeholder_out_port_1 形状 [batch_size] の入力には、最初の入力からのすべての文のシーケンス長が含まれます。例えば、max_sequence_length = 50batch_size = 1 で、文に要素が 30 個しかない場合、IteratorGetNext/placeholder_out_port_1 の入力テンソルは [30] でなければなりません。

モデルの出力:

  • dynamic_seq2seq/decoder/decoder/GatherTree 形状 [max_sequence_length * 2, batch, beam_size] のテンソルであり、入力からのすべての文に対する beam_size の最適な翻訳が含まれます (ボキャブラリー内の単語のインデックスとしてもデコードされます)。

TensorFlow のテンソルの形状は異なる場合があります: max_sequence_length * 2 の代わりに、それより小さい任意の値にすることができます。これは、OpenVINO が出力の動的形状をサポートしていない一方で、TensorFlow は、eos シンボルを生成する際にデコード反復を停止できるためです。

GNMT IR を実行

  1. ベンチマーク・アプリの場合:

    benchmark_app -m <path to the generated GNMT IR> -d CPU
    
  2. OpenVINO ランタイム Python API の場合:

    例を実行する前に、GNMT .xml および .bin ファイルへのパスを MODEL_PATHWEIGHTS_PATH に挿入し、入力データに従って input_data_tensorseq_lengths テンソルを入力します。

    from openvino.inference_engine import IENetwork, IECore
    
    MODEL_PATH = '/path/to/IR/frozen_GNMT_inference_graph.xml'
    WEIGHTS_PATH = '/path/to/IR/frozen_GNMT_inference_graph.bin'
    
    # Creating network
    net = IENetwork(
       model=MODEL_PATH,
       weights=WEIGHTS_PATH)
    
    # Creating input data
    input_data = {'IteratorGetNext/placeholder_out_port_0': input_data_tensor,
                'IteratorGetNext/placeholder_out_port_1': seq_lengths}
    
    # Creating plugin and loading extensions
    ie = IECore()
    ie.add_extension(extension_path="libcpu_extension.so", device_name="CPU")
    
    # Loading network
    exec_net = ie.load_network(network=net, device_name="CPU")
    
    # Run inference
    result_ie = exec_net.infer(input_data)
    

Python API の詳細については、OpenVINO ランタイム Python API ガイドを参照してください。