PyTorch ランチャーの構成方法

PyTorch ランチャーは、精度チェッカーツール内でモデルを簡単に起動するためにサポートされるラッパーの 1 つです。このランチャーを使用すると、推論バックエンドに PyTorch* フレームワークを使用してモデルを実行できます。

PyTorch ランチャーを有効にするには、構成ファイルの launchers セクションに framework: pytorch を追加し、次のパラメーターを指定する必要があります。

  • device - 推論に使用するデバイス (cpucuda など) を指定します。

  • module- ロード用の PyTorch ネットワーク・モジュール。

  • checkpoint - 事前トレーニングされたモデルのチェックポイント (オプション)。

  • python_path - 現在の Python 環境でネットワーク・モジュールを表示する PYTHONPATH の付録 (オプション)。

  • module_args - ネットワーク・モジュールの位置引数のリスト (オプション)。

  • module_kwargs - ネットワーク・モジュールのキーワード引数を表す辞書 (key: valuekey は引数名、value は引数の値)。

  • adapter - 生の出力がデータセットの表現にどのように変換されるかという問題に対処するため、一部のアダプターはフレームワークに固有であることがあります。アダプターの使用方法の詳細な説明は、こちらでご覧いただけます。

  • batch - 実行中のモデルのバッチサイズ (オプション、デフォルトは 1)。

次に、モデルに複数の入力がある場合は、特定のパラメーター inputs でそれらを指定する必要があります。各入力の説明には次の情報が含まれている必要があります。

  • name - ネットワーク内のレイヤー名を入力します

  • type - 入力値のタイプに応じて、入力ポリシーに影響します。次のオプションが利用できます。

    • CONST_INPUT - 入力は、構成で提供される定数で埋められます。value の提供も必要です。

    • IMAGE_INFO - 入力形状の情報をレイヤーに設定する特定のキー (高速 RCNN ベースのトポロジーで使用)。value は実行時に計算されるため、指定の必要はありません。形式値は [H, W, S] 形式の N 要素からなるリストです。N はバッチサイズ、H - 元の画像の高さ、W - 元の画像の幅、S - 元の画像のスケール (デフォルトは 1) です。

    • ORIG_IMAGE_INFO - 前処理前の元の画像サイズの情報を設定する特定のキー。

    • PROCESSED_IMAGE_INFO - 前処理後の入力サイズの情報を設定する特定のキー。

    • SCALE_FACTOR - 画像スケール係数の情報を設定する特定のキーは [SCALE_Y, SCALE_X] として定義されます。ここで、SCALE_Y = <resized_image_height>/<original_image_heightSCALE_X = <resized_image_width> / <original_image_width>

    • IGNORE_INPUT - 評価中は空のままにしておく必要がある入力。

    • INPUT - メイン・データ・ストリームのネットワーク入力 (画像など)。複数のデータ入力がある場合、特定の value でどのデータを提供するかを指定する値として識別子の正規表現を提供する必要があります。

  • shape - 入力レイヤーの形状は、バッチサイズを除くすべてのサイズのカンマで区切って記述されます。

    オプションで、モデルが非標準のデータレイアウト (PyTorch のデフォルトレイアウトは NCHW) と精度でトレーニングされた場合のレイアウトを決定できます (サポートされている精度: FP32 - float、FP16 - signed shot、U8 - unsigned char、U16 - unsigned short int、I8 - signed char、I16 - short int、I32 - int、I64 - long int)。

モデルに複数の出力がある場合は、オプション output_names を使用してアダプターで値を取得できるように、構成でそれらの名前を指定する必要もあります。

PyTorch ランチャー構成例 (torchvision から AlexNet モデルを実行する方法を示します):

launchers:
  - framework: pytorch
    device: CPU
    module: torchvision.models.alexnet

    module_kwargs:
      pretrained: True

    adapter: classification