Distil-Whisper と OpenVINO を使用した自動音声認識#

この Jupyter ノートブックは、ローカルへのインストール後にのみ起動できます。

GitHub

Distil-Whisper は、OpenAI による Whisper モデルの精製版です。Distil-Whisper は、論文 Robust Knowledge Distillation via Large-Scale Pseudo Labelling で提案されています。著者によると、Distil-Whisper は Whisper と比較して、パラメーターが 50% 少ない状態で数倍高速に実行され、分布外評価データで 1% 以内の単語エラー率 (WER) を実現できるということです。

Whisper は、トランスフォーマー・ベースのエンコーダー/デコーダーモデルであり、シーケンスツーシーケンス・モデルとも呼ばれます。オーディオ・スペクトログラム機能のシーケンスをテキストトークンのシーケンスにマッピングします。まず、生のオーディオ入力は、特徴抽出器の動作によって log-Mel スペクトログラムに変換されます。次に、トランスフォーマー・エンコーダーはスペクトログラムをエンコードして、エンコーダーの隠し状態のシーケンスを形成します。最後に、デコーダーは、以前のトークンとエンコーダーの隠れ状態の両方を条件として、テキストトークンを自己回帰的に予測します。

下の図でモデルのアーキテクチャーを確認できます:

whisper_architecture.svg

whisper_architecture.svg#

このチュートリアルでは、OpenVINO を使用して Distil-Whisper を実行する方法について説明します。Hugging Face トランスフォーマー・ライブラリーの事前トレーニング済みモデルを使用します。ユーザー・エクスペリエンスを簡素化するために、Hugging Face の Optimum ライブラリーを使用してモデルを OpenVINO™ IR 形式に変換します。OpenVINO Distil-Whisper モデルのパフォーマンスをさらに向上させるため、NNCF からの INT8 トレーニング後の量子化が適用されます。

目次:

必要条件#

%pip install -q "transformers>=4.35" "torch>=2.1" onnx "git+https://github.com/huggingface/optimum-intel.git" "peft==0.6.2" --extra-index-url https://download.pytorch.org/whl/cpu 
%pip install -q "openvino>=2023.2.0" datasets "gradio>=4.0" "librosa" "soundfile" 
%pip install -q "nncf>=2.6.0" "jiwer"

PyTorch モデルのロード#

AutoModelForSpeechSeq2Seq.from_pretrained メソッドは、Transformers ライブラリーを使用して PyTorch Whisper モデルの初期化に使用されます。このチュートリアルでは、デフォルトで distil-whisper/distil-large-v2 モデルを例として使用します。モデルは最初の実行時に 1 度ダウンロードされますが、このプロセスには時間がかかる場合があります。

また、distil-whisper/distil-medium.endistil-whisper/distil-small.en など、Distil-Whisper hugging face コレクションから他のモデルを選択することもできます。オリジナルの Whisper アーキテクチャーのモデルも入手可能です。詳細はこちらをご覧ください。

このモデルの使用では、前処理と後処理が重要です。初期化に使用されるクラス WhisperProcessor は、モデルのオーディオ入力データを準備し、それをメルスペクトログラムに変換し、トークナイザーを使用して予測された出力 token_ids を文字列にデコードする役割を担います。

import ipywidgets as widgets 

model_ids = { 
    "Distil-Whisper": [ 
        "distil-whisper/distil-large-v2", 
        "distil-whisper/distil-medium.en", 
        "distil-whisper/distil-small.en", 
    ], 
    "Whisper": [ 
        "openai/whisper-large-v3", 
        "openai/whisper-large-v2", 
        "openai/whisper-large", 
        "openai/whisper-medium", 
        "openai/whisper-small", 
        "openai/whisper-base", 
        "openai/whisper-tiny", 
        "openai/whisper-medium.en", 
        "openai/whisper-small.en", 
        "openai/whisper-base.en", 
        "openai/whisper-tiny.en", 
    ], 
} 

model_type = widgets.Dropdown( 
    options=model_ids.keys(), 
    value="Distil-Whisper", 
    description="Model type:", 
    disabled=False, 
) 

model_type
model_id = widgets.Dropdown( 
    options=model_ids[model_type.value], 
    value=model_ids[model_type.value][0], 
    description="Model:", 
    disabled=False, 
) 

model_id
from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq 

processor = AutoProcessor.from_pretrained(model_id.value) 

pt_model = AutoModelForSpeechSeq2Seq.from_pretrained(model_id.value) 
pt_model.eval();
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.

入力サンプルを準備#

プロセッサーは、numpy 配列形式のオーディオデータとオーディオ・サンプリング・レートに関する情報を期待し、予測を行う input_features テンソルを返します。オーディオから numpy 形式への変換は、Hugging Face データセットの実装によって処理されます。

from datasets import load_dataset 

def extract_input_features(sample): 
    input_features = processor( 
        sample["audio"]["array"], 
        sampling_rate=sample["audio"]["sampling_rate"], 
        return_tensors="pt", 
    ).input_features 
    return input_features 

dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation", trust_remote_code=True) 
sample = dataset[0] 
input_features = extract_input_features(sample)

モデルの推論を実行#

音声認識を実行するには、モデルの生成インターフェイスを使用できます。生成が完了したら、processor.batch_decode を使用して、予測された token_ids をテキスト転写にデコードできます。

import IPython.display as ipd 

predicted_ids = pt_model.generate(input_features) 
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True) 

display(ipd.Audio(sample["audio"]["array"], rate=sample["audio"]["sampling_rate"])) 
print(f"Reference: {sample['text']}") 
print(f"Result: {transcription[0]}")
Reference: MISTER QUILTER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL 
Result: Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.

Optimum ライブラリーを使用して OpenVINO モデルをロード#

Hugging Face Optimum API は、Hugging Face Transformers ライブラリーのモデルを OpenVINO™ IR 形式に変換および量子化できる高レベル API です。詳細については、Hugging Face Optimum のドキュメントを参照してください。

Optimum Intel を使用すると、Hugging Face ハブ から最適化されたモデルをロードし、Hugging Face API を使用して OpenVINO ランタイムで推論を実行するパイプラインを作成できます。Optimum 推論モデルは、Hugging Face Transformers モデルと API の互換性があります。つまり、AutoModelForXxx クラスを対応する OVModelForXxx クラスに置き換えるだけで済みます。

以下は distil-whisper モデルの例です

-from transformers import AutoModelForSpeechSeq2Seq 
+from optimum.intel.openvino import OVModelForSpeechSeq2Seq 
from transformers import AutoTokenizer, pipeline 

model_id = "distil-whisper/distil-large-v2" 
-model = AutoModelForSpeechSeq2Seq.from_pretrained(model_id) 
+model = OVModelForSpeechSeq2Seq.from_pretrained(model_id, export=True)

モデルクラスの初期化は、from_pretrained メソッドの呼び出しから始まります。Transformers モデルをダウンロードして変換する場合は、パラメーター export=True を追加する必要があります。save_pretrained メソッドを使用して、変換されたモデルを次回に使用するため保存できます。トークナイザーとプロセッサーは、OpenVINO モデルとも互換性のあるモデルとともに配布されます。つまり、初期化された初期プロセッサーを再利用できるということです。

from pathlib import Path 
from optimum.intel.openvino import OVModelForSpeechSeq2Seq 

model_path = Path(model_id.value.replace("/", "_")) 
ov_config = {"CACHE_DIR": ""} 

if not model_path.exists(): 
    ov_model = OVModelForSpeechSeq2Seq.from_pretrained( 
        model_id.value, 
        ov_config=ov_config, 
        export=True, 
        compile=False, 
        load_in_8bit=False, 
    ) 
    ov_model.half() 
    ov_model.save_pretrained(model_path) 
else: 
    ov_model = OVModelForSpeechSeq2Seq.from_pretrained(model_path, ov_config=ov_config, compile=False)
INFO:nncf:NNCF initialized successfully. Supported frameworks detected: torch, onnx, openvino

推論デバイスの選択#

import openvino as ov 
import ipywidgets as widgets 

core = ov.Core() 

device = widgets.Dropdown( 
    options=core.available_devices + ["AUTO"], 
    value="AUTO", 
    description="Device:", 
    disabled=False, 
) 

device
Dropdown(description='Device:', index=4, options=('CPU', 'GPU.0', 'GPU.1', 'GPU.2', 'AUTO'), value='AUTO')

OpenVINO モデルをコンパイル#

ov_model.to(device.value) 
ov_model.compile()
Compiling the encoder to AUTO ...
Compiling the decoder to AUTO ...
Compiling the decoder to AUTO ...

OpenVINO モデル推論を実行#

predicted_ids = ov_model.generate(input_features) 
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True) 

display(ipd.Audio(sample["audio"]["array"], rate=sample["audio"]["sampling_rate"])) 
print(f"Reference: {sample['text']}") 
print(f"Result: {transcription[0]}")
/home/nsavel/venvs/ov_notebooks_tmp/lib/python3.8/site-packages/optimum/intel/openvino/modeling_seq2seq.py:457: FutureWarning: shared_memory is deprecated and will be removed in 2024.0.Value of shared_memory is going to override share_inputs value.Please use only share_inputs explicitly. 
  last_hidden_state = torch.from_numpy(self.request(inputs, shared_memory=True)["last_hidden_state"]).to( 
/home/nsavel/venvs/ov_notebooks_tmp/lib/python3.8/site-packages/optimum/intel/openvino/modeling_seq2seq.py:538: FutureWarning: shared_memory is deprecated and will be removed in 2024.0.Value of shared_memory is going to override share_inputs value.Please use only share_inputs explicitly. 
  self.request.start_async(inputs, shared_memory=True)
Reference: MISTER QUILTER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL 
Result: Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.

PyTorch と OpenVINO のパフォーマンスを比較#

import time 
import numpy as np 
from tqdm.notebook import tqdm 

def measure_perf(model, sample, n=10): 
    timers = [] 
    input_features = extract_input_features(sample) 
    for _ in tqdm(range(n), desc="Measuring performance"): 
        start = time.perf_counter() 
        model.generate(input_features) 
        end = time.perf_counter() 
        timers.append(end - start) 
    return np.median(timers)
perf_torch = measure_perf(pt_model, sample) 
perf_ov = measure_perf(ov_model, sample)
Measuring performance: 0%|          | 0/10 [00:00<?, ?it/s]
Measuring performance: 0%|          | 0/10 [00:00<?, ?it/s]
print(f"Mean torch {model_id.value} generation time: {perf_torch:.3f}s") 
print(f"Mean openvino {model_id.value} generation time: {perf_ov:.3f}s") 
print(f"Performance {model_id.value} openvino speedup: {perf_torch / perf_ov:.3f}")
Mean torch distil-large-v2 generation time: 3.064s 
Mean openvino distil-large-v2 generation time: 1.819s 
Performance distil-large-v2 openvino speedup: 1.684

HuggingFace パイプラインを使用した OpenVINO モデル#

オリジナルの PyTorch モデルと同様に、OpenVINO モデルも automatic-speech-recognition (自動音声認識) 用の HuggingFace パイプライン・インターフェイスと互換性があります。パイプラインは長い音声の書き起こしに使用できます。Distil-Whisper はチャンク・アルゴリズムを使用して、長い形式のオーディオファイルを転記します。実際には、チャンク化された長い形式のアルゴリズムは、Whisper の論文で OpenAI が提案した順次アルゴリズムよりも 9 倍高速です。チャンク化を有効にするには、chunk_length_s パラメーターをパイプラインに渡します。Distil-Whisper の場合、チャンクの長さは 15 秒が最適です。バッチ処理を有効にするには、引数 batch_size を渡します。

from transformers import pipeline 

ov_model.generation_config = pt_model.generation_config 

pipe = pipeline( 
    "automatic-speech-recognition", 
    model=ov_model, 
    tokenizer=processor.tokenizer, 
    feature_extractor=processor.feature_extractor, 
    max_new_tokens=128, 
    chunk_length_s=15, 
    batch_size=16, 
)
The model 'OVModelForWhisper' is not supported for automatic-speech-recognition. Supported models are ['Pop2PianoForConditionalGeneration', 'SeamlessM4TForSpeechToText', 'SpeechEncoderDecoderModel', 'Speech2TextForConditionalGeneration', 'SpeechT5ForSpeechToText', 'WhisperForConditionalGeneration', 'Data2VecAudioForCTC', 'HubertForCTC', 'MCTCTForCTC', 'SEWForCTC', 'SEWDForCTC', 'UniSpeechForCTC', 'UniSpeechSatForCTC', 'Wav2Vec2ForCTC', 'Wav2Vec2ConformerForCTC', 'WavLMForCTC'].
dataset = load_dataset("distil-whisper/librispeech_long", "clean", split="validation", 
trust_remote_code=True) 
sample_long = dataset[0] 

def format_timestamp(seconds: float): 
    """ 
    format time in srt-file expected format 
    """ 
    assert seconds >= 0, "non-negative timestamp expected" 
    milliseconds = round(seconds * 1000.0) 

    hours = milliseconds // 3_600_000 
    milliseconds -= hours * 3_600_000 

    minutes = milliseconds // 60_000 
    milliseconds -= minutes * 60_000 

    seconds = milliseconds // 1_000 
    milliseconds -= seconds * 1_000 

    return (f"{hours}:" if hours > 0 else "00:")+ f"{minutes:02d}:{seconds:02d},{milliseconds:03d}" 

def prepare_srt(transcription): 
    """ 
    Format transcription into srt file format 
    """ 
    segment_lines = [] 
    for idx, segment in enumerate(transcription["chunks"]): 
        segment_lines.append(str(idx + 1) + "\n") 
        timestamps = segment["timestamp"] 
        time_start = format_timestamp(timestamps[0]) 
        time_end = format_timestamp(timestamps[1]) 
        time_str = f"{time_start} --> {time_end}\n" 
        segment_lines.append(time_str) 
        segment_lines.append(segment["text"] + "\n\n") 
    return segment_lines

return_timestamps 引数により、処理された各チャンクに関連付けられた音声の開始と終了のタイムスタンプを取得できます。音声の分離やビデオ字幕の生成などのタスクに役立つ可能性があります。この例では、一般的な字幕形式の 1 つである SRT 形式で出力フォーマットを提供します。

result = pipe(sample_long["audio"].copy(), return_timestamps=True)
srt_lines = prepare_srt(result) 

display(ipd.Audio(sample_long["audio"]["array"], rate=sample_long["audio"]["sampling_rate"])) 
print("".join(srt_lines))
1 
00:00:00,000 --> 00:00:06,560 
 Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.

2 
00:00:06,560 --> 00:00:11,280 
 Nor is Mr. Quilter's manner less interesting than his matter.

3 
00:00:11,280 --> 00:00:16,840 
 He tells us that at this festive season of the year, with Christmas and roast beef looming 

4 
00:00:16,840 --> 00:00:23,760 
 before us, similes drawn from eating and its results occur most readily to the mind.

5 
00:00:23,760 --> 00:00:29,360 
 He has grave doubts whether Sir Frederick Leighton's work is really Greek after all, and 

6 
00:00:29,360 --> 00:00:33,640 
 can discover in it but little of Rocky Ithaca.

7 
00:00:33,640 --> 00:00:39,760 
 Lennel's pictures are a sort of upgards and Adam paintings, and Mason's exquisite 

8 
00:00:39,760 --> 00:00:44,720 
 idles are as national as a jingo poem.

9 
00:00:44,720 --> 00:00:50,320 
 Mr. Burkett Foster's landscapes smile at one much in the same way that Mr. Carker used 

10 
00:00:50,320 --> 00:00:52,920 
 to flash his teeth.11 
00:00:52,920 --> 00:00:58,680 
 And Mr. John Collier gives his sitter a cheerful slap on the back, before he says, like 

12 
00:00:58,680 --> 00:01:01,120 
 a shampooer and a Turkish bath, 

13 
00:01:01,120 --> 00:01:02,000 
 Next man!

量子化#

NNCF は、モデルグラフに量子化レイヤーを追加し、トレーニング・データセットのサブセットを使用してこれらの追加の量子化レイヤーのパラメーターを初期化することで、トレーニング後の量子化を可能にします。このフレームワークは、元のトレーニング・コードへの変更が最小限になるように設計されています。

最適化プロセスには次の手順が含まれます:

  1. 量子化用のキャリブレーション・データセットを作成します。

  2. nncf.quantize を実行して、量子化されたエンコーダーおよびデコーダーモデルを取得します。

  3. openvino.save_model 関数を使用して INT8 モデルをシリアル化します。

: 量子化は時間とメモリーを消費する操作です。以下の量子化コードの実行には時間がかかる場合があります。

Distil-Whisper 量子化を行うかどうかを以下で選択してください。

to_quantize = widgets.Checkbox( 
    value=True, 
    description="Quantization", 
    disabled=False, 
) 

to_quantize
Checkbox(value=True, description='Quantization')
# `skip_kernel_extension` モジュールを取得 
import requests 

r = requests.get( 

url="https://raw.githubusercontent.com/openvinotoolkit/openvino_notebooks/latest/utils/skip_kernel_extension.py", 
) 
open("skip_kernel_extension.py", "w").write(r.text) 

%load_ext skip_kernel_extension

キャリブレーション・データセットの準備#

最初のステップは、量子化のキャリブレーション・データセットを準備することです。whisper エンコーダーとデコーダーを別々に量子化するため、各モデルのキャリブレーション・データセットを準備する必要があります。モデル入力をインターセプトしてリストに収集する InferRequestWrapper クラスをインポートします。次に、少量のオーディオサンプルに対してモデル推論を実行します。一般的に、キャリブレーション・データセットのサイズを大きくすると、量子化の品質が向上します。

%%skip not $to_quantize.value 

from itertools import islice 
from optimum.intel.openvino.quantization import InferRequestWrapper 

def collect_calibration_dataset(ov_model: OVModelForSpeechSeq2Seq, calibration_dataset_size: int):
    # モデル要求のプロパティーを上書きし、後で復元できるように元のプロパティーを保存
    encoder_calibration_data = [] 
    decoder_calibration_data = [] 
    ov_model.encoder.request = InferRequestWrapper(ov_model.encoder.request, encoder_calibration_data, apply_caching=True) 
    ov_model.decoder_with_past.request = InferRequestWrapper(ov_model.decoder_with_past.request, decoder_calibration_data, apply_caching=True) 

    try: 
        calibration_dataset = load_dataset("openslr/librispeech_asr", "clean", split="validation", streaming=True, trust_remote_code=True) 
        for sample in tqdm(islice(calibration_dataset, calibration_dataset_size), desc="Collecting calibration data", total=calibration_dataset_size): 
            input_features = extract_input_features(sample) 
            ov_model.generate(input_features) 
    finally: 
        ov_model.encoder.request = ov_model.encoder.request.request 
        ov_model.decoder_with_past.request = ov_model.decoder_with_past.request.request 

    return encoder_calibration_data, decoder_calibration_data

Distil-Whisper エンコーダーとデコーダーのモデルを量子化#

以下では、Distil-Whisper エンコーダーおよびデコーダー (履歴付きモデル) で nncf.quantize を呼び出す quantize 関数を実行します。全体の推論時間に占める割合が無視できるため、第 1 ステップのデコーダーは量子化しません。

%%skip not $to_quantize.value 

import gc 
import shutil 
import nncf 

CALIBRATION_DATASET_SIZE = 50 
quantized_model_path = Path(f"{model_path}_quantized") 

def quantize(ov_model: OVModelForSpeechSeq2Seq, calibration_dataset_size: int): 
    if not quantized_model_path.exists(): 
        encoder_calibration_data, decoder_calibration_data = collect_calibration_dataset( 
            ov_model, calibration_dataset_size 
        ) 
        print("Quantizing encoder") 
        quantized_encoder = nncf.quantize( 
            ov_model.encoder.model, 
            nncf.Dataset(encoder_calibration_data), 
            subset_size=len(encoder_calibration_data), 
            model_type=nncf.ModelType.TRANSFORMER, 
            # スムーズ・クォンタム・アルゴリズムは活性化量子化誤差を削減し、グリッドサーチを通じて最適なアルファ値を取得 
advanced_parameters=nncf.AdvancedQuantizationParameters(smooth_quant_alpha=0.50) 
        ) 
        ov.save_model(quantized_encoder, quantized_model_path / "openvino_encoder_model.xml") 
        del quantized_encoder 
        del encoder_calibration_data 
        gc.collect() 

        print("Quantizing decoder with past") 
        quantized_decoder_with_past = nncf.quantize( 
            ov_model.decoder_with_past.model, 
            nncf.Dataset(decoder_calibration_data), 
            subset_size=len(decoder_calibration_data), 
            model_type=nncf.ModelType.TRANSFORMER, 
            # スムーズ・クォンタム・アルゴリズムは活性化量子化誤差を削減し、グリッドサーチを通じて最適なアルファ値を取得
advanced_parameters=nncf.AdvancedQuantizationParameters(smooth_quant_alpha=0.95) 
        ) 
        ov.save_model(quantized_decoder_with_past, quantized_model_path / "openvino_decoder_with_past_model.xml") 
        del quantized_decoder_with_past 
        del decoder_calibration_data 
        gc.collect() 

        # 設定ファイルと first-step-decoder を手動でコピー 
        shutil.copy(model_path / "config.json", quantized_model_path / "config.json") 
        shutil.copy(model_path / "openvino_decoder_model.xml", quantized_model_path / "openvino_decoder_model.xml") 
        shutil.copy(model_path / "openvino_decoder_model.bin", quantized_model_path / "openvino_decoder_model.bin") 

    quantized_ov_model = OVModelForSpeechSeq2Seq.from_pretrained(quantized_model_path, ov_config=ov_config, compile=False) 
    quantized_ov_model.to(device.value) 
    quantized_ov_model.compile() 
    return quantized_ov_model 

ov_quantized_model = quantize(ov_model, CALIBRATION_DATASET_SIZE)
Collecting calibration data: 0%|          | 0/10 [00:00<?, ?it/s]
Quantizing encoder
Statistics collection: 100%|
██████████████████████████████████████████████████████████████████████████████████████████████| 10/10 
[00:15<00:00, 1.55s/it] Applying Smooth Quant: 100%|
██████████████████████████████████████████████████████████████████████████████████████████████| 128/128 
[00:10<00:00, 12.24it/s]
INFO:nncf:96 ignored nodes was found by name in the NNCFGraph
Statistics collection: 100%|
████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 
[00:29<00:00, 2.99s/it] 
Applying Fast Bias correction: 100%|██████████████████████████████████████████████████████████████████████████████████████| 162/162 [00:21<00:00, 7.60it/s]
Quantizing decoder with past
Statistics collection: 100%|
██████████████████████████████████████████████████████████████████████████████████████████████| 390/390 
[00:04<00:00, 85.63it/s] 
Applying Smooth Quant: 100%|
████████████████████████████████████████████████████████████████████████████████████████████████| 12/12 
[00:00<00:00, 16.09it/s]
INFO:nncf:12 ignored nodes was found by name in the NNCFGraph
Statistics collection: 100%|
██████████████████████████████████████████████████████████████████████████████████████████████| 390/390 
[00:07<00:00, 52.93it/s] 
Applying Fast Bias correction: 100%|
████████████████████████████████████████████████████████████████████████████████████████| 14/14 
[00:00<00:00, 18.50it/s] 
Compiling the encoder to AUTO ...
Compiling the decoder to AUTO ...
Compiling the decoder to AUTO ...

量子化モデルの推論を実行#

オリジナルモデルと量子化モデルの転写結果を比較してみましょう。

%%skip not $to_quantize.value 

dataset = load_dataset( 
    "hf-internal-testing/librispeech_asr_dummy", "clean", split="validation", trust_remote_code=True 
) 
sample = dataset[0] 
input_features = extract_input_features(sample) 

predicted_ids = ov_model.generate(input_features) 
transcription_original = processor.batch_decode(predicted_ids, skip_special_tokens=True) 

predicted_ids = ov_quantized_model.generate(input_features) 
transcription_quantized = processor.batch_decode(predicted_ids, skip_special_tokens=True) 

display(ipd.Audio(sample["audio"]["array"], rate=sample["audio"]["sampling_rate"])) 
print(f"Original : {transcription_original[0]}") 
print(f"Quantized: {transcription_quantized[0]}")

Original : Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel. 
Quantized: Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.

結果は同じです !

元のモデルと量子化されたモデルのパフォーマンスと精度を比較#

最後に、精度とパフォーマンスの観点から、元の Distil-Whisper モデルと量子化された Distil-Whisper モデルを比較します。

精度を測定するために、1 - WER をメトリックとして使用します。WER は Word Error Rate (単語誤り率) の略です。

推論時間を測定するときは、エンコーダーとデコーダーの過去のモデル転送と、モデル全体の推論を個別に測定します。

%%skip not $to_quantize.value 

import time 
from contextlib import contextmanager 
from jiwer import wer, wer_standardize 

TEST_DATASET_SIZE = 50 
MEASURE_TIME = False 

@contextmanager 
def time_measurement(): 
    global MEASURE_TIME 
    try: MEASURE_TIME = True 
        yield 
    finally: MEASURE_TIME = False 

def time_fn(obj, fn_name, time_list): 
    original_fn = getattr(obj, fn_name) 

    def wrapper(*args, **kwargs): 
        if not MEASURE_TIME: 
            return original_fn(\*args, \*\*kwargs) 
        start_time = time.perf_counter() 
        result = original_fn(\*args, \*\*kwargs) 
        end_time = time.perf_counter() 
        time_list.append(end_time - start_time) 
        return result 

    setattr(obj, fn_name, wrapper) 

def calculate_transcription_time_and_accuracy(ov_model, test_samples): 
    encoder_infer_times = [] 
    decoder_with_past_infer_times = [] 
    whole_infer_times = [] 
    time_fn(ov_model, "generate", whole_infer_times) 
    time_fn(ov_model.encoder, "forward", encoder_infer_times) 
    time_fn(ov_model.decoder_with_past, "forward", decoder_with_past_infer_times) 

    ground_truths = [] 
    predictions = [] 
    for data_item in tqdm(test_samples, desc="Measuring performance and accuracy"): 
        input_features = extract_input_features(data_item) 

        with time_measurement(): 
            predicted_ids = ov_model.generate(input_features) 
        transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True) 

        ground_truths.append(data_item["text"]) 
        predictions.append(transcription[0]) 

    word_accuracy = (1 - wer(ground_truths, predictions, reference_transform=wer_standardize, hypothesis_transform=wer_standardize)) * 100 
    mean_whole_infer_time = sum(whole_infer_times) 
    mean_encoder_infer_time = sum(encoder_infer_times) 
    mean_decoder_with_time_infer_time = sum(decoder_with_past_infer_times) 
    return word_accuracy, (mean_whole_infer_time, mean_encoder_infer_time, mean_decoder_with_time_infer_time) 

test_dataset = load_dataset("openslr/librispeech_asr", "clean", split="test", streaming=True, trust_remote_code=True) 
test_dataset = test_dataset.shuffle(seed=42).take(TEST_DATASET_SIZE) 
test_samples = [sample for sample in test_dataset] 

accuracy_original, times_original = calculate_transcription_time_and_accuracy(ov_model, test_samples) 
accuracy_quantized, times_quantized = calculate_transcription_time_and_accuracy(ov_quantized_model, test_samples) 
print(f"Encoder performance speedup: {times_original[1] / times_quantized[1]:.3f}") 
print(f"Decoder with past performance speedup: {times_original[2] / times_quantized[2]:.3f}") 
print(f"Whole pipeline performance speedup: {times_original[0] / times_quantized[0]:.3f}") 
print(f"Whisper transcription word accuracy. Original model: {accuracy_original:.2f}%.Quantized model: {accuracy_quantized:.2f}%.") 
print(f"Accuracy drop: {accuracy_original - accuracy_quantized:.2f}%.")
Got disconnected from remote data host. Retrying in 5sec [1/20] 
Got disconnected from remote data host. Retrying in 5sec [2/20]
Measuring performance and accuracy: 0%|          | 0/50 [00:00<?, ?it/s]
Measuring performance and accuracy: 0%|          | 0/50 [00:00<?, ?it/s]
Encoder performance speedup: 1.751 
Decoder with past performance speedup: 1.777 
Whole pipeline performance speedup: 1.711 
Whisper transcription word accuracy. Original model: 85.29%. Quantized model: 85.29%. Accuracy drop: 0.00%.

量子化により、精度が大幅に低下することなくモデルの推論時間が大幅に改善されました。

インタラクティブなデモ#

また、Gradio インターフェイスを使用したインタラクティブなデモも提供しており、固有のオーディオデータでモデルの機能をテストしたり (アップロードボタンを使用)、マイクを使用して録音したりできます。なお、Distil-Whisper は現在、英語の音声認識のみで利用可能です。多言語サポートは後日提供される予定です。

from transformers.pipelines.audio_utils import ffmpeg_read 
import gradio as gr 

r = requests.get("https://huggingface.co/spaces/distil-whisper/whisper-vs-distil-whisper/resolve/main/assets/example_1.wav") 

with open("example_1.wav", "wb") as f: 
    f.write(r.content) 

BATCH_SIZE = 16 
MAX_AUDIO_MINS = 30 # 最大音声入力 (分) 

generate_kwargs = {"language": "en", "task": "transcribe"} if not model_id.value.endswith(".en") else {} 
ov_pipe = pipeline( 
    "automatic-speech-recognition", 
    model=ov_model, 
    tokenizer=processor.tokenizer, 
    feature_extractor=processor.feature_extractor, 
    max_new_tokens=128, 
    chunk_length_s=15, 
    generate_kwargs=generate_kwargs, 
) 
ov_pipe_forward = ov_pipe._forward 

if to_quantize.value: 
    ov_quantized_model.generation_config = ov_model.generation_config 
    ov_quantized_pipe = pipeline( 
        "automatic-speech-recognition", 
        model=ov_quantized_model, 
        tokenizer=processor.tokenizer, 
        feature_extractor=processor.feature_extractor, 
        max_new_tokens=128, 
        chunk_length_s=15, 
        generate_kwargs=generate_kwargs, 
    ) 
    ov_quantized_pipe_forward = ov_quantized_pipe._forward 

def transcribe(inputs, quantized=False): 
    pipe = ov_quantized_pipe if quantized else ov_pipe 
    pipe_forward = ov_quantized_pipe_forward if quantized else ov_pipe_forward 

    if inputs is None: 
        raise gr.Error("No audio file submitted! Please record or upload an audio file before submitting your request.") 

    with open(inputs, "rb") as f: 
        inputs = f.read() 

    inputs = ffmpeg_read(inputs, pipe.feature_extractor.sampling_rate) 
    audio_length_mins = len(inputs) / pipe.feature_extractor.sampling_rate / 60 

    if audio_length_mins > MAX_AUDIO_MINS: 
        raise gr.Error( 
            f"To ensure fair usage of the Space, the maximum audio length permitted is {MAX_AUDIO_MINS} minutes." 
            f"Got an audio of length {round(audio_length_mins, 3)} minutes."         ) 

    inputs = {"array": inputs, "sampling_rate": pipe.feature_extractor.sampling_rate} 

    def _forward_ov_time(*args, **kwargs): 
        global ov_time 
        start_time = time.time() 
        result = pipe_forward(*args, **kwargs) 
        ov_time = time.time() - start_time 
        ov_time = round(ov_time, 2) 
        return result 

    pipe._forward = _forward_ov_time 
    ov_text = pipe(inputs.copy(), batch_size=BATCH_SIZE)["text"] 
    return ov_text, ov_time 

with gr.Blocks() as demo: 
    gr.HTML( 
        """ 
            <div style="text-align: center; max-width: 700px; margin: 0 auto;"> 
                <div 
                    style=" 
                        display: inline-flex; align-items: center; gap: 0.8rem; font-size: 1.75rem; 
                    " 
                > 
                    <h1 style="font-weight: 900; margin-bottom: 7px; line-height: normal;"> 
                        OpenVINO Distil-Whisper demo 
                    </h1> 
                </div> 
            </div> 
        """ 
    ) 
    audio = gr.components.Audio(type="filepath", label="Audio input") 
    with gr.Row(): 
        button = gr.Button("Transcribe") 
        if to_quantize.value: 
            button_q = gr.Button("Transcribe quantized") 
    with gr.Row(): 
        infer_time = gr.components.Textbox(label="OpenVINO Distil-Whisper Transcription Time (s)") 
        if to_quantize.value: 
            infer_time_q = gr.components.Textbox(label="OpenVINO Quantized Distil-Whisper Transcription Time (s)") 
    with gr.Row(): 
        transcription = gr.components.Textbox(label="OpenVINO Distil-Whisper Transcription", show_copy_button=True) 
        if to_quantize.value: 
            transcription_q = gr.components.Textbox( 
                label="OpenVINO Quantized Distil-Whisper Transcription", 
                show_copy_button=True, 
            ) 
    button.click( 
        fn=transcribe, 
        inputs=audio, 
        outputs=[transcription, infer_time], 
    ) 
    if to_quantize.value: 
        button_q.click( 
            fn=transcribe, 
            inputs=[audio, gr.Number(value=1, visible=False)], 
            outputs=[transcription_q, infer_time_q], 
        ) 
    gr.Markdown("## Examples") 
    gr.Examples( 
        [["./example_1.wav"]], 
        audio, 
        outputs=[transcription, infer_time], 
        fn=transcribe, 
        cache_examples=False, 
    ) 
# リモートで起動する場合は、server_name と server_port を指定 
# demo.launch(server_name='your server name', server_port='server port in int') 
# 詳細はドキュメントをご覧ください: https://gradio.app/docs/ 
try: 
    demo.launch(debug=False) 
except Exception: 
    demo.launch(share=True, debug=False)