量子化対応トレーニング (QAT)¶
はじめに¶
量子化対応トレーニングは、モデルを量子化し、微調整を適用して量子化による精度低下を回復する一般的な方法です。実際、これは最も正確な量子化方法です。このドキュメントでは、ニューラル・ネットワーク圧縮フレームワーク (NNCF) から QAT を適用して 8 ビットの量子化モデルを取得する方法について説明します。Python プログラミングの知識があり、ソース DL フレームワークのモデルのトレーニング・コードに精通していることを前提としています。
NNCF QAT の使用¶
ここでは、NNCF の QAT を PyTorch または TensorFlow 2 で作成されたトレーニング・スクリプトに統合する手順を示します。
注
現在、NNCF for TensorFlow 2 は、Keras Sequential API または Functional API を使用して作成されたモデルの最適化をサポートしています。
1. NNCF API をインポート¶
このステップでは、トレーニング・スクリプトの先頭に NNCF 関連のインポート文を追加します。
import torch
import nncf # Important - should be imported right after torch
from nncf import NNCFConfig
from nncf.torch import create_compressed_model, register_default_init_args
import tensorflow as tf
from nncf import NNCFConfig
from nncf.tensorflow import create_compressed_model, register_default_init_args
2. NNCF 構成の作成¶
ここでは、モデル関連のパラメーター ("input_info"
セクション) と最適化メソッドのパラメーター ("compression"
セクション) で構成される NNCF 構成を定義する必要があります。収束を高速化するため、DL フレームワークに固有のデータセット・オブジェクトを登録することも推奨されます。これはモデル作成ステップで量子化パラメーターを初期化する際に使用されます。
nncf_config_dict = {
"input_info": {"sample_size": [1, 3, 224, 224]}, # input shape required for model tracing
"compression": {
"algorithm": "quantization", # 8-bit quantization with default settings
},
}
nncf_config = NNCFConfig.from_dict(nncf_config_dict)
nncf_config = register_default_init_args(nncf_config, train_loader) # train_loader is an instance of torch.utils.data.DataLoader
nncf_config_dict = {
"input_info": {"sample_size": [1, 3, 224, 224]}, # input shape required for model tracing
"compression": {
"algorithm": "quantization", # 8-bit quantization with default settings
},
}
nncf_config = NNCFConfig.from_dict(nncf_config_dict)
nncf_config = register_default_init_args(nncf_config, train_dataset, batch_size=1) # train_dataset is an instance of tf.data.Dataset
3. 最適化の適用¶
このステップでは、前のステップで定義した構成を使用して、create_compressed_model()
で元のモデル・オブジェクトをラップする必要があります。このメソッドは、圧縮コントローラーと、元のモデルと同じように使用できるラップされたモデルを返します。モデルが対応する一連の変換を実行し、最適化に必要な追加の操作を含めることができるように、このステップで最適化メソッドが適用されることに注意してください。QAT の場合、圧縮コントローラー・オブジェクトはモデルのエクスポートに使用され、オプションで以下に示すように分散トレーニングにも使用されます。
model = TorchModel() # instance of torch.nn.Module
compression_ctrl, model = create_compressed_model(model, nncf_config)
model = KerasModel() # instance of the tensorflow.keras.Model
compression_ctrl, model = create_compressed_model(model, nncf_config)
4. モデルの微調整¶
このステップでは、ベースライン・モデルに対して適用した方法でモデルに微調整を加えることを前提としています。QAT の場合、10e-5 などの小さな学習率でモデルを数エポック・トレーニングする必要があります。原則として、このステップはスキップできます。これは、トレーニング後の最適化がモデルに適用されるためです。
... # fine-tuning preparations, e.g. dataset, loss, optimizer setup, etc.
# tune quantized model for 5 epochs as the baseline
for epoch in range(0, 5):
compression_ctrl.scheduler.epoch_step() # Epoch control API
for i, data in enumerate(train_loader):
compression_ctrl.scheduler.step() # Training iteration control API
... # training loop body
... # fine-tuning preparations, e.g. dataset, loss, optimizer setup, etc.
# create compression callbacks to control optimization parameters and dump compression statistics
compression_callbacks = create_compression_callbacks(compression_ctrl, log_dir="./compression_log")
# tune quantized model for 5 epochs the same way as the baseline
model.fit(train_dataset, epochs=5, callbacks=compression_callbacks)
5. マルチ GPU 分散トレーニング¶
マルチ GPU 分散トレーニング (DataParallel ではない) の場合、微調整の前に compression_ctrl.distributed()
を呼び出す必要があります。これにより、分散モードで機能するためいくつかの調整を行えるように最適化メソッドに通知されます。
compression_ctrl.distributed() # call it before the training loop
compression_ctrl.distributed() # call it before the training
6. 量子化モデルのエクスポート¶
微調整が終了したら、量子化モデルを対応する形式にエクスポートして、さらに推論を行うことができます: PyTorch の場合は ONNX、TensorFlow 2 の場合は凍結グラフ。
compression_ctrl.export_model("compressed_model.onnx")
compression_ctrl.export_model("compressed_model.pb") #export to Frozen Graph
注
重みの精度は、モデルを OpenVINO 中間表現に変換するステップの後でのみ INT8 になります。その形式でのみ、モデルのフットプリントの削減が期待できます。
これらは、NNCF の QAT メソッドを適用する基本的な手順です。ただし、状況によっては、トレーニング中にモデルのチェックポイントを保存/復元する必要があります。NNCF は元のモデルを独自のオブジェクトでラップするため、これらのニーズに対応する API を提供します。
7. (オプション) チェックポイントの保存¶
モデルのチェックポイントを保存するには、次の API を使用します。
checkpoint = {
'state_dict': model.state_dict(),
'compression_state': compression_ctrl.get_compression_state(),
... # the rest of the user-defined objects to save
}
torch.save(checkpoint, path_to_checkpoint)
from nncf.tensorflow.utils.state import TFCompressionState
from nncf.tensorflow.callbacks.checkpoint_callback import CheckpointManagerCallback
checkpoint = tf.train.Checkpoint(model=model,
compression_state=TFCompressionState(compression_ctrl),
... # the rest of the user-defined objects to save
)
callbacks = []
callbacks.append(CheckpointManagerCallback(checkpoint, path_to_checkpoint))
...
model.fit(..., callbacks=callbacks)
8. (オプション) チェックポイントから復元¶
チェックポイントからモデルを復元するには、次の API を使用します。
resuming_checkpoint = torch.load(path_to_checkpoint)
compression_state = resuming_checkpoint['compression_state']
compression_ctrl, model = create_compressed_model(model, nncf_config, compression_state=compression_state)
state_dict = resuming_checkpoint['state_dict']
model.load_state_dict(state_dict)
from nncf.tensorflow.utils.state import TFCompressionStateLoader
checkpoint = tf.train.Checkpoint(compression_state=TFCompressionStateLoader())
checkpoint.restore(path_to_checkpoint)
compression_state = checkpoint.compression_state.state
compression_ctrl, model = create_compressed_model(model, nncf_config, compression_state)
checkpoint = tf.train.Checkpoint(model=model,
...)
checkpoint.restore(path_to_checkpoint)
NNCF でのチェックポイントの保存/復元の詳細については、次のドキュメントを参照してください。