Distil-Whisper と OpenVINO を使用した自動音声認識

この Jupyter ノートブックは、ローカルへのインストール後にのみ起動できます。

GitHub

Distil-Whisper は、OpenAI による Whisper モデルの蒸留版です。Distil-Whisper は、論文「Robust Knowledge Distillation via Large-Scale Pseudo Labelling」で提案されています。著者によると、Distil-Whisper は Whisper と比較して、パラメーターが 50% 少ない状態で数倍高速に実行され、分布外評価データで 1% 以内の単語エラー率 (WER) を実現できるということです。

Whisper は、トランスフォーマー・ベースのエンコーダー/デコーダーモデルであり、シーケンスツーシーケンス・モデルとも呼ばれます。オーディオ・スペクトログラム機能のシーケンスをテキストトークンのシーケンスにマッピングします。まず、生のオーディオ入力は、特徴抽出器の動作によって log-Mel スペクトログラムに変換されます。次に、トランスフォーマー・エンコーダーはスペクトログラムをエンコードして、エンコーダーの隠し状態のシーケンスを形成します。最後に、デコーダーは、以前のトークンとエンコーダーの隠れ状態の両方を条件として、テキストトークンを自己回帰的に予測します。

下の図でモデルのアーキテクチャーを確認できます。

whisper_architecture.svg

whisper_architecture.svg

このチュートリアルでは、OpenVINO を使用して Distil-Whisper を実行する方法について説明します。Hugging Face Transformers ライブラリーの事前トレーニング済みモデルを使用します。ユーザー・エクスペリエンスを簡素化するために、Hugging Face の Optimum ライブラリーを使用してモデルを OpenVINO™ IR 形式に変換します。OpenVINO Distil-Whisper モデルのパフォーマンスをさらに向上させるため、NNCF からの INT8 トレーニング後の量子化が適用されます。

目次

必要条件

%pip install -q "transformers>=4.35" onnx "git+https://github.com/huggingface/optimum-intel.git" --extra-index-url https://download.pytorch.org/whl/cpu
%pip install -q "openvino>=2023.2.0" datasets  "gradio>=4.0" "librosa" "soundfile"
%pip install -q "nncf>=2.6.0" "jiwer"

PyTorch モデルのロード

AutoModelForSpeechSeq2Seq.from_pretrained メソッドは、Transformers ライブラリーを使用して PyTorch Whisper モデルの初期化に使用されます。このチュートリアルでは、デフォルトで distil-whisper/distil-large-v2 モデルを例として使用します。モデルは最初の実行時に一度ダウンロードされますが、このプロセスには時間がかかる場合があります。

また、distil-whisper/distil-medium.endistil-whisper/distil-small.en など、Distil-Whisper hugging face コレクションから他のモデルを選択することもできます。オリジナルの Whisper アーキテクチャーのモデルも入手可能です。詳細はこちらをご覧ください。

このモデルの使用では、前処理と後処理が重要です。AutoProcessor 初期化に使用されるクラス WhisperProcessor は、モデルのオーディオ入力データを準備し、それをメルスペクトログラムに変換し、トークナイザーを使用して予測された出力 token_ids を文字列にデコードする役割を担います。

import ipywidgets as widgets

model_ids = {
    "Distil-Whisper": [
        "distil-whisper/distil-large-v2",
        "distil-whisper/distil-medium.en",
        "distil-whisper/distil-small.en"
    ],
    "Whisper": [
        "openai/whisper-large-v3",
        "openai/whisper-large-v2",
        "openai/whisper-large",
        "openai/whisper-medium",
        "openai/whisper-small",
        "openai/whisper-base",
        "openai/whisper-tiny",
        "openai/whisper-medium.en",
        "openai/whisper-small.en",
        "openai/whisper-base.en",
        "openai/whisper-tiny.en",
    ]
}

model_type = widgets.Dropdown(
    options=model_ids.keys(),
    value="Distil-Whisper",
    description="Model type:",
    disabled=False,
)

model_type
model_id = widgets.Dropdown(
    options=model_ids[model_type.value],
    value=model_ids[model_type.value][0],
    description="Model:",
    disabled=False,
)

model_id
from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq

processor = AutoProcessor.from_pretrained(model_id.value)

pt_model = AutoModelForSpeechSeq2Seq.from_pretrained(model_id.value)
pt_model.eval();
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.

入力サンプルを準備

プロセッサーは、numpy 配列形式のオーディオデータとオーディオ・サンプリング・レートに関する情報を期待し、予測を行う input_features テンソルを返します。オーディオから numpy 形式への変換は、Hugging Face データセットの実装によって処理されます。

from datasets import load_dataset

def extract_input_features(sample):
    input_features = processor(
        sample["audio"]["array"],
        sampling_rate=sample["audio"]["sampling_rate"],
        return_tensors="pt",
    ).input_features
    return input_features

dataset = load_dataset(
    "hf-internal-testing/librispeech_asr_dummy", "clean", split="validation"
)
sample = dataset[0]
input_features = extract_input_features(sample)

モデルの推論を実行

音声認識を実行するには、モデルの生成インターフェースを使用できます。生成が完了したら、processor.batch_decode を使用して、予測された token_ids をテキスト転写にデコードできます。

import IPython.display as ipd

predicted_ids = pt_model.generate(input_features)
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)

display(ipd.Audio(sample["audio"]["array"], rate=sample["audio"]["sampling_rate"]))
print(f"Reference: {sample['text']}")
print(f"Result: {transcription[0]}")
Reference: MISTER QUILTER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL
Result:  Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.

Optimum ライブラリーを使用して OpenVINO モデルをロード

Hugging Face Optimum API は、Hugging Face Transformers ライブラリーのモデルを OpenVINO™ IR 形式に変換および量子化できる高レベル API です。詳細については、Hugging Face Optimum のドキュメントを参照してください。

Optimum Intel を使用すると、Hugging Face Hub から最適化されたモデルをロードし、Hugging Face API を使用して OpenVINO ランタイムで推論を実行するパイプラインを作成できます。Optimum 推論モデルは、Hugging Face Transformers モデルと API の互換性があります。つまり、AutoModelForXxx クラスを対応する OVModelForXxx クラスに置き換えるだけで済みます。

以下は distil-whisper モデルの例です。

-from transformers import AutoModelForSpeechSeq2Seq
+from optimum.intel.openvino import OVModelForSpeechSeq2Seq
from transformers import AutoTokenizer, pipeline

model_id = "distil-whisper/distil-large-v2"
-model = AutoModelForSpeechSeq2Seq.from_pretrained(model_id)
+model = OVModelForSpeechSeq2Seq.from_pretrained(model_id, export=True)

モデルクラスの初期化は、from_pretrained メソッドの呼び出しから始まります。Transformers モデルをダウンロードして変換する場合は、パラメーター export=True を追加する必要があります。save_pretrained メソッドを使用して、変換されたモデルを次回に使用するため保存できます。トークナイザーとプロセッサーは、OpenVINO モデルとも互換性のあるモデルとともに配布されます。つまり、初期化された初期プロセッサーを再利用できるということです。

from pathlib import Path
from optimum.intel.openvino import OVModelForSpeechSeq2Seq

model_path = Path(model_id.value.replace('/', '_'))
ov_config = {"CACHE_DIR": ""}

if not model_path.exists():
    ov_model = OVModelForSpeechSeq2Seq.from_pretrained(
        model_id.value, ov_config=ov_config, export=True, compile=False, load_in_8bit=False
    )
    ov_model.half()
    ov_model.save_pretrained(model_path)
else:
    ov_model = OVModelForSpeechSeq2Seq.from_pretrained(
        model_path, ov_config=ov_config, compile=False
    )
INFO:nncf:NNCF initialized successfully. Supported frameworks detected: torch, onnx, openvino

推論デバイスの選択

import openvino as ov
import ipywidgets as widgets

core = ov.Core()

device = widgets.Dropdown(
    options=core.available_devices + ["AUTO"],
    value="AUTO",
    description="Device:",
    disabled=False,
)

device
Dropdown(description='Device:', index=4, options=('CPU', 'GPU.0', 'GPU.1', 'GPU.2', 'AUTO'), value='AUTO')

OpenVINO モデルをコンパイル

ov_model.to(device.value)
ov_model.compile()
Compiling the encoder to AUTO ...
Compiling the decoder to AUTO ...
Compiling the decoder to AUTO ...

OpenVINO モデル推論を実行

predicted_ids = ov_model.generate(input_features)
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)

display(ipd.Audio(sample["audio"]["array"], rate=sample["audio"]["sampling_rate"]))
print(f"Reference: {sample['text']}")
print(f"Result: {transcription[0]}")
/home/nsavel/venvs/ov_notebooks_tmp/lib/python3.8/site-packages/optimum/intel/openvino/modeling_seq2seq.py:457: FutureWarning: shared_memory is deprecated and will be removed in 2024.0. Value of shared_memory is going to override share_inputs value. Please use only share_inputs explicitly.
  last_hidden_state = torch.from_numpy(self.request(inputs, shared_memory=True)["last_hidden_state"]).to(
/home/nsavel/venvs/ov_notebooks_tmp/lib/python3.8/site-packages/optimum/intel/openvino/modeling_seq2seq.py:538: FutureWarning: shared_memory is deprecated and will be removed in 2024.0. Value of shared_memory is going to override share_inputs value. Please use only share_inputs explicitly.
  self.request.start_async(inputs, shared_memory=True)
Reference: MISTER QUILTER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL
Result:  Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.

PyTorch と OpenVINO のパフォーマンスを比較

import time
import numpy as np
from tqdm.notebook import tqdm


def measure_perf(model, sample, n=10):
    timers = []
    input_features = extract_input_features(sample)
    for _ in tqdm(range(n), desc="Measuring performance"):
        start = time.perf_counter()
        model.generate(input_features)
        end = time.perf_counter()
        timers.append(end - start)
    return np.median(timers)
perf_torch = measure_perf(pt_model, sample)
perf_ov = measure_perf(ov_model, sample)
Measuring performance:   0%|          | 0/10 [00:00<?, ?it/s]
Measuring performance:   0%|          | 0/10 [00:00<?, ?it/s]
print(f"Mean torch {model_id.value} generation time: {perf_torch:.3f}s")
print(f"Mean openvino {model_id.value} generation time: {perf_ov:.3f}s")
print(f"Performance {model_id.value} openvino speedup: {perf_torch / perf_ov:.3f}")
Mean torch distil-large-v2 generation time: 3.064s
Mean openvino distil-large-v2 generation time: 1.819s
Performance distil-large-v2 openvino speedup: 1.684

load_in_8bit

OpenAI Whisper と比較

Hugging Face パイプラインを使用した OpenVINO モデルの使用

オリジナルの PyTorch モデルと同様に、OpenVINO モデルも automatic-speech-recognition (自動音声認識) 用の Hugging Face パイプライン・インターフェイスと互換性があります。パイプラインは長い音声の書き起こしに使用できます。Distil-Whisper はチャンク・アルゴリズムを使用して、長い形式のオーディオファイルを転記します。実際には、チャンク化された長い形式のアルゴリズムは、Whisper の論文で OpenAI が提案した順次アルゴリズムよりも 9 倍高速です。チャンク化を有効にするには、chunk_length_s パラメーターをパイプラインに渡します。Distil-Whisper の場合、チャンクの長さは 15 秒が最適です。バッチ処理を有効にするには、引数 batch_size を渡します。

from transformers import pipeline

ov_model.generation_config = pt_model.generation_config

pipe = pipeline(
    "automatic-speech-recognition",
    model=ov_model,
    tokenizer=processor.tokenizer,
    feature_extractor=processor.feature_extractor,
    max_new_tokens=128,
    chunk_length_s=15,
    batch_size=16,
)
The model 'OVModelForWhisper' is not supported for automatic-speech-recognition. Supported models are ['Pop2PianoForConditionalGeneration', 'SeamlessM4TForSpeechToText', 'SpeechEncoderDecoderModel', 'Speech2TextForConditionalGeneration', 'SpeechT5ForSpeechToText', 'WhisperForConditionalGeneration', 'Data2VecAudioForCTC', 'HubertForCTC', 'MCTCTForCTC', 'SEWForCTC', 'SEWDForCTC', 'UniSpeechForCTC', 'UniSpeechSatForCTC', 'Wav2Vec2ForCTC', 'Wav2Vec2ConformerForCTC', 'WavLMForCTC'].
dataset = load_dataset("distil-whisper/librispeech_long", "clean", split="validation")
sample_long = dataset[0]


def format_timestamp(seconds: float):
    """
    format time in srt-file expected format
    """
    assert seconds >= 0, "non-negative timestamp expected"
    milliseconds = round(seconds * 1000.0)

    hours = milliseconds // 3_600_000
    milliseconds -= hours * 3_600_000

    minutes = milliseconds // 60_000
    milliseconds -= minutes * 60_000

    seconds = milliseconds // 1_000
    milliseconds -= seconds * 1_000

    return (
        f"{hours}:" if hours > 0 else "00:"
    ) + f"{minutes:02d}:{seconds:02d},{milliseconds:03d}"


def prepare_srt(transcription):
    """
    Format transcription into srt file format
    """
    segment_lines = []
    for idx, segment in enumerate(transcription["chunks"]):
        segment_lines.append(str(idx + 1) + "\n")
        timestamps = segment["timestamp"]
        time_start = format_timestamp(timestamps[0])
        time_end = format_timestamp(timestamps[1])
        time_str = f"{time_start} --> {time_end}\n"
        segment_lines.append(time_str)
        segment_lines.append(segment["text"] + "\n\n")
    return segment_lines

return_timestamps 引数により、処理された各チャンクに関連付けられた音声の開始と終了のタイムスタンプを取得できます。音声の分離やビデオ字幕の生成などのタスクに役立つ可能性があります。この例では、一般的な字幕形式の 1 つである SRT 形式で出力フォーマットを提供します。

result = pipe(sample_long["audio"].copy(), return_timestamps=True)
srt_lines = prepare_srt(result)

display(
    ipd.Audio(sample_long["audio"]["array"], rate=sample_long["audio"]["sampling_rate"])
)
print("".join(srt_lines))
1
00:00:00,000 --> 00:00:06,560
 Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.

2
00:00:06,560 --> 00:00:11,280
 Nor is Mr. Quilter's manner less interesting than his matter.

3
00:00:11,280 --> 00:00:16,840
 He tells us that at this festive season of the year, with Christmas and roast beef looming

4
00:00:16,840 --> 00:00:23,760
 before us, similes drawn from eating and its results occur most readily to the mind.

5
00:00:23,760 --> 00:00:29,360
 He has grave doubts whether Sir Frederick Leighton's work is really Greek after all, and

6
00:00:29,360 --> 00:00:33,640
 can discover in it but little of Rocky Ithaca.

7
00:00:33,640 --> 00:00:39,760
 Lennel's pictures are a sort of upgards and Adam paintings, and Mason's exquisite

8
00:00:39,760 --> 00:00:44,720
 idles are as national as a jingo poem.

9
00:00:44,720 --> 00:00:50,320
 Mr. Burkett Foster's landscapes smile at one much in the same way that Mr. Carker used

10
00:00:50,320 --> 00:00:52,920
 to flash his teeth.

11
00:00:52,920 --> 00:00:58,680
 And Mr. John Collier gives his sitter a cheerful slap on the back, before he says, like

12
00:00:58,680 --> 00:01:01,120
 a shampooer and a Turkish bath,

13
00:01:01,120 --> 00:01:02,000
 Next man!

量子化

NNCF は、モデルグラフに量子化レイヤーを追加し、トレーニング・データセットのサブセットを使用してこれらの追加の量子化レイヤーのパラメーターを初期化することで、トレーニング後の量子化を可能にします。このフレームワークは、元のトレーニング・コードへの変更が最小限になるように設計されています。

最適化プロセスには次の手順が含まれます。

  1. 量子化用のキャリブレーション・データセットを作成します。

  2. nncf.quantize を実行して、量子化されたエンコーダーおよびデコーダーモデルを取得します。

  3. openvino.save_model 関数を使用して INT8 モデルをシリアル化します。

注: 量子化は時間とメモリーを消費する操作です。以下の量子化コードの実行には時間がかかる場合があります。

Distil-Whisper 量子化を行うかどうかを以下で選択してください。

to_quantize = widgets.Checkbox(
    value=True,
    description='Quantization',
    disabled=False,
)

to_quantize
Checkbox(value=True, description='Quantization')
# Fetch notebook_utils module
import urllib.request

urllib.request.urlretrieve(
    url='https://raw.githubusercontent.com/openvinotoolkit/openvino_notebooks/main/notebooks/utils/skip_kernel_extension.py',
    filename='skip_kernel_extension.py'
)

%load_ext skip_kernel_extension

キャリブレーション・データセットの準備

最初のステップは、量子化のキャリブレーション・データセットを準備することです。whisper エンコーダーとデコーダーを別々に量子化するため、各モデルのキャリブレーション・データセットを準備する必要があります。モデル入力をインターセプトしてリストに収集する InferRequestWrapper クラスをインポートします。次に、少量のオーディオサンプルに対してモデル推論を実行します。一般的に、キャリブレーション・データセットのサイズを大きくすると、量子化の品質が向上します。

%%skip not $to_quantize.value

from itertools import islice
from optimum.intel.openvino.quantization import InferRequestWrapper


def collect_calibration_dataset(ov_model: OVModelForSpeechSeq2Seq, calibration_dataset_size: int):
    # Overwrite model request properties, saving the original ones for restoring later
    original_encoder_request = ov_model.encoder.request
    original_decoder_with_past_request = ov_model.decoder_with_past.request
    encoder_calibration_data = []
    decoder_calibration_data = []
    ov_model.encoder.request = InferRequestWrapper(original_encoder_request, encoder_calibration_data)
    ov_model.decoder_with_past.request = InferRequestWrapper(original_decoder_with_past_request,
                                                             decoder_calibration_data)

    calibration_dataset = load_dataset("librispeech_asr", "clean", split="validation", streaming=True)
    for sample in tqdm(islice(calibration_dataset, calibration_dataset_size), desc="Collecting calibration data",
                       total=calibration_dataset_size):
        input_features = extract_input_features(sample)
        ov_model.generate(input_features)

    ov_model.encoder.request = original_encoder_request
    ov_model.decoder_with_past.request = original_decoder_with_past_request

    return encoder_calibration_data, decoder_calibration_data

Distil-Whisper エンコーダーとデコーダーのモデルを量子化

以下では、Distil-Whisper エンコーダーおよびデコーダー (履歴付きモデル) で nncf.quantize を呼び出す quantize 関数を実行します。全体の推論時間に占める割合が無視できるため、第 1 ステップのデコーダーは量子化しません。

%%skip not $to_quantize.value

import gc
import shutil
import nncf

CALIBRATION_DATASET_SIZE = 50
quantized_model_path = Path(f"{model_path}_quantized")


def quantize(ov_model: OVModelForSpeechSeq2Seq, calibration_dataset_size: int):
    if not quantized_model_path.exists():
        encoder_calibration_data, decoder_calibration_data = collect_calibration_dataset(
            ov_model, calibration_dataset_size
        )
        print("Quantizing encoder")
        quantized_encoder = nncf.quantize(
            ov_model.encoder.model,
            nncf.Dataset(encoder_calibration_data),
            subset_size=len(encoder_calibration_data),
            model_type=nncf.ModelType.TRANSFORMER,
            # Smooth Quant algorithm reduces activation quantization error; optimal alpha value was obtained through grid search
            advanced_parameters=nncf.AdvancedQuantizationParameters(smooth_quant_alpha=0.50)
        )
        ov.save_model(quantized_encoder, quantized_model_path / "openvino_encoder_model.xml")
        del quantized_encoder
        del encoder_calibration_data
        gc.collect()

        print("Quantizing decoder with past")
        quantized_decoder_with_past = nncf.quantize(
            ov_model.decoder_with_past.model,
            nncf.Dataset(decoder_calibration_data),
            subset_size=len(decoder_calibration_data),
            model_type=nncf.ModelType.TRANSFORMER,
            # Smooth Quant algorithm reduces activation quantization error; optimal alpha value was obtained through grid search
            advanced_parameters=nncf.AdvancedQuantizationParameters(smooth_quant_alpha=0.95)
        )
        ov.save_model(quantized_decoder_with_past, quantized_model_path / "openvino_decoder_with_past_model.xml")
        del quantized_decoder_with_past
        del decoder_calibration_data
        gc.collect()

        # Copy the config file and the first-step-decoder manually
        shutil.copy(model_path / "config.json", quantized_model_path / "config.json")
        shutil.copy(model_path / "openvino_decoder_model.xml", quantized_model_path / "openvino_decoder_model.xml")
        shutil.copy(model_path / "openvino_decoder_model.bin", quantized_model_path / "openvino_decoder_model.bin")

    quantized_ov_model = OVModelForSpeechSeq2Seq.from_pretrained(quantized_model_path, ov_config=ov_config, compile=False)
    quantized_ov_model.to(device.value)
    quantized_ov_model.compile()
    return quantized_ov_model


ov_quantized_model = quantize(ov_model, CALIBRATION_DATASET_SIZE)
Collecting calibration data:   0%|          | 0/10 [00:00<?, ?it/s]
Quantizing encoder
Statistics collection: 100%|████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:15<00:00,  1.55s/it]
Applying Smooth Quant: 100%|██████████████████████████████████████████████████████████████████████████████████████████████| 128/128 [00:10<00:00, 12.24it/s]
INFO:nncf:96 ignored nodes was found by name in the NNCFGraph
Statistics collection: 100%|████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:29<00:00,  2.99s/it]
Applying Fast Bias correction: 100%|██████████████████████████████████████████████████████████████████████████████████████| 162/162 [00:21<00:00,  7.60it/s]
Quantizing decoder with past
Statistics collection: 100%|██████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:04<00:00, 85.63it/s]
Applying Smooth Quant: 100%|████████████████████████████████████████████████████████████████████████████████████████████████| 12/12 [00:00<00:00, 16.09it/s]
INFO:nncf:12 ignored nodes was found by name in the NNCFGraph
Statistics collection: 100%|██████████████████████████████████████████████████████████████████████████████████████████████| 390/390 [00:07<00:00, 52.93it/s]
Applying Fast Bias correction: 100%|████████████████████████████████████████████████████████████████████████████████████████| 14/14 [00:00<00:00, 18.50it/s]
Compiling the encoder to AUTO ...
Compiling the decoder to AUTO ...
Compiling the decoder to AUTO ...

量子化モデルの推論を実行

オリジナルモデルと量子化モデルの転写結果を比較してみましょう。

%%skip not $to_quantize.value

dataset = load_dataset(
    "hf-internal-testing/librispeech_asr_dummy", "clean", split="validation"
)
sample = dataset[0]
input_features = extract_input_features(sample)

predicted_ids = ov_model.generate(input_features)
transcription_original = processor.batch_decode(predicted_ids, skip_special_tokens=True)

predicted_ids = ov_quantized_model.generate(input_features)
transcription_quantized = processor.batch_decode(predicted_ids, skip_special_tokens=True)

display(ipd.Audio(sample["audio"]["array"], rate=sample["audio"]["sampling_rate"]))
print(f"Original : {transcription_original[0]}")
print(f"Quantized: {transcription_quantized[0]}")
Original :  Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.
Quantized:  Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.

結果は同じです !

元のモデルと量子化されたモデルのパフォーマンスと精度を比較

最後に、精度とパフォーマンスの観点から、元の Distil-Whisper モデルと量子化された Distil-Whisper モデルを比較します。

精度を測定するために、1 - WER をメトリックとして使用します。WER は Word Error Rate (単語誤り率) の略です。

推論時間を測定するときは、エンコーダーとデコーダーの過去のモデル転送と、モデル全体の推論を個別に測定します。

%%skip not $to_quantize.value

import time
from contextlib import contextmanager
from jiwer import wer, wer_standardize


TEST_DATASET_SIZE = 50
MEASURE_TIME = False

@contextmanager
def time_measurement():
    global MEASURE_TIME
    try:
        MEASURE_TIME = True
        yield
    finally:
        MEASURE_TIME = False

def time_fn(obj, fn_name, time_list):
    original_fn = getattr(obj, fn_name)

    def wrapper(*args, **kwargs):
        if not MEASURE_TIME:
            return original_fn(*args, **kwargs)
        start_time = time.perf_counter()
        result = original_fn(*args, **kwargs)
        end_time = time.perf_counter()
        time_list.append(end_time - start_time)
        return result

    setattr(obj, fn_name, wrapper)

def calculate_transcription_time_and_accuracy(ov_model, test_samples):
    encoder_infer_times = []
    decoder_with_past_infer_times = []
    whole_infer_times = []
    time_fn(ov_model, "generate", whole_infer_times)
    time_fn(ov_model.encoder, "forward", encoder_infer_times)
    time_fn(ov_model.decoder_with_past, "forward", decoder_with_past_infer_times)

    ground_truths = []
    predictions = []
    for data_item in tqdm(test_samples, desc="Measuring performance and accuracy"):
        input_features = extract_input_features(data_item)

        with time_measurement():
            predicted_ids = ov_model.generate(input_features)
        transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)

        ground_truths.append(data_item["text"])
        predictions.append(transcription[0])

    word_accuracy = (1 - wer(ground_truths, predictions, reference_transform=wer_standardize,
                             hypothesis_transform=wer_standardize)) * 100
    mean_whole_infer_time = sum(whole_infer_times)
    mean_encoder_infer_time = sum(encoder_infer_times)
    mean_decoder_with_time_infer_time = sum(decoder_with_past_infer_times)
    return word_accuracy, (mean_whole_infer_time, mean_encoder_infer_time, mean_decoder_with_time_infer_time)

test_dataset = load_dataset("librispeech_asr", "clean", split="test", streaming=True)
test_dataset = test_dataset.shuffle(seed=42).take(TEST_DATASET_SIZE)
test_samples = [sample for sample in test_dataset]

accuracy_original, times_original = calculate_transcription_time_and_accuracy(ov_model, test_samples)
accuracy_quantized, times_quantized = calculate_transcription_time_and_accuracy(ov_quantized_model, test_samples)
print(f"Encoder performance speedup: {times_original[1] / times_quantized[1]:.3f}")
print(f"Decoder with past performance speedup: {times_original[2] / times_quantized[2]:.3f}")
print(f"Whole pipeline performance speedup: {times_original[0] / times_quantized[0]:.3f}")
print(f"Whisper transcription word accuracy. Original model: {accuracy_original:.2f}%. Quantized model: {accuracy_quantized:.2f}%.")
print(f"Accuracy drop: {accuracy_original - accuracy_quantized:.2f}%.")
Got disconnected from remote data host. Retrying in 5sec [1/20]
Got disconnected from remote data host. Retrying in 5sec [2/20]
Measuring performance and accuracy:   0%|          | 0/50 [00:00<?, ?it/s]
Measuring performance and accuracy:   0%|          | 0/50 [00:00<?, ?it/s]
Encoder performance speedup: 1.751
Decoder with past performance speedup: 1.777
Whole pipeline performance speedup: 1.711
Whisper transcription word accuracy. Original model: 85.29%. Quantized model: 85.29%.
Accuracy drop: 0.00%.

ご覧のとおり、量子化により、精度が大幅に低下することなくモデルの推論時間が大幅に改善されました。

インタラクティブなデモ

また、Gradio インターフェースを使用したインタラクティブなデモも提供しており、固有のオーディオデータでモデルの機能をテストしたり (アップロードボタンを使用)、マイクを使用して録音したりできます。なお、Distil-Whisper は現在、英語の音声認識のみで利用可能です。多言語サポートは後日提供される予定です。

from transformers.pipelines.audio_utils import ffmpeg_read
import gradio as gr
import urllib.request

urllib.request.urlretrieve(
    url="https://huggingface.co/spaces/distil-whisper/whisper-vs-distil-whisper/resolve/main/assets/example_1.wav",
    filename="example_1.wav",
)

BATCH_SIZE = 16
MAX_AUDIO_MINS = 30  # maximum audio input in minutes


generate_kwargs = {"language": "en", "task": "transcribe"} if not model_id.value.endswith(".en") else {}
ov_pipe = pipeline(
    "automatic-speech-recognition",
    model=ov_model,
    tokenizer=processor.tokenizer,
    feature_extractor=processor.feature_extractor,
    max_new_tokens=128,
    chunk_length_s=15,
    generate_kwargs=generate_kwargs,
)
ov_pipe_forward = ov_pipe._forward

if to_quantize.value:
    ov_quantized_model.generation_config = ov_model.generation_config
    ov_quantized_pipe = pipeline(
        "automatic-speech-recognition",
        model=ov_quantized_model,
        tokenizer=processor.tokenizer,
        feature_extractor=processor.feature_extractor,
        max_new_tokens=128,
        chunk_length_s=15,
        generate_kwargs=generate_kwargs,
    )
    ov_quantized_pipe_forward = ov_quantized_pipe._forward


def transcribe(inputs, quantized=False):
    pipe = ov_quantized_pipe if quantized else ov_pipe
    pipe_forward = ov_quantized_pipe_forward if quantized else ov_pipe_forward

    if inputs is None:
        raise gr.Error(
            "No audio file submitted! Please record or upload an audio file before submitting your request."
        )

    with open(inputs, "rb") as f:
        inputs = f.read()

    inputs = ffmpeg_read(inputs, pipe.feature_extractor.sampling_rate)
    audio_length_mins = len(inputs) / pipe.feature_extractor.sampling_rate / 60

    if audio_length_mins > MAX_AUDIO_MINS:
        raise gr.Error(
            f"To ensure fair usage of the Space, the maximum audio length permitted is {MAX_AUDIO_MINS} minutes."
            f"Got an audio of length {round(audio_length_mins, 3)} minutes."
        )

    inputs = {"array": inputs, "sampling_rate": pipe.feature_extractor.sampling_rate}

    def _forward_ov_time(*args, **kwargs):
        global ov_time
        start_time = time.time()
        result = pipe_forward(*args, **kwargs)
        ov_time = time.time() - start_time
        ov_time = round(ov_time, 2)
        return result

    pipe._forward = _forward_ov_time
    ov_text = pipe(inputs.copy(), batch_size=BATCH_SIZE)["text"]
    return ov_text, ov_time


with gr.Blocks() as demo:
    gr.HTML(
        """
                <div style="text-align: center; max-width: 700px; margin: 0 auto;">
                  <div
                    style="
                      display: inline-flex; align-items: center; gap: 0.8rem; font-size: 1.75rem;
                    "
                  >
                    <h1 style="font-weight: 900; margin-bottom: 7px; line-height: normal;">
                      OpenVINO Distil-Whisper demo
                    </h1>
                  </div>
                </div>
            """
    )
    audio = gr.components.Audio(type="filepath", label="Audio input")
    with gr.Row():
        button = gr.Button("Transcribe")
        if to_quantize.value:
            button_q = gr.Button("Transcribe quantized")
    with gr.Row():
        infer_time = gr.components.Textbox(
            label="OpenVINO Distil-Whisper Transcription Time (s)"
        )
        if to_quantize.value:
            infer_time_q = gr.components.Textbox(
                label="OpenVINO Quantized Distil-Whisper Transcription Time (s)"
            )
    with gr.Row():
        transcription = gr.components.Textbox(
            label="OpenVINO Distil-Whisper Transcription", show_copy_button=True
        )
        if to_quantize.value:
            transcription_q = gr.components.Textbox(
                label="OpenVINO Quantized Distil-Whisper Transcription", show_copy_button=True
            )
    button.click(
        fn=transcribe,
        inputs=audio,
        outputs=[transcription, infer_time],
    )
    if to_quantize.value:
        button_q.click(
            fn=transcribe,
            inputs=[audio, gr.Number(value=1, visible=False)],
            outputs=[transcription_q, infer_time_q],
        )
    gr.Markdown("## Examples")
    gr.Examples(
        [["./example_1.wav"]],
        audio,
        outputs=[transcription, infer_time],
        fn=transcribe,
        cache_examples=False,
    )
# if you are launching remotely, specify server_name and server_port
# demo.launch(server_name='your server name', server_port='server port in int')
# Read more in the docs: https://gradio.app/docs/
try:
    demo.launch(debug=False)
except Exception:
    demo.launch(share=True, debug=False)