MMS: OpenVINO™ で音声テクノロジーを 1000 以上の言語に拡張#

この Jupyter ノートブックは、ローカルへのインストール後にのみ起動できます。

GitHub

大規模多言語スピーチ (MMS) プロジェクトでは、1,100 を超える言語 (従来の 10 倍以上) をサポートする単一の多言語音声認識モデル、4,000 を超える言語 (従来の 40 倍) を識別できる言語識別モデル、1,400 を超える言語をサポートする事前トレーニング済みモデル、および 1,100 を超える言語のテキスト読み上げモデルを構築することで、音声テクノロジーを約 100 言語から 1,000 を超える言語に拡張します。

MMS モデルは、Scaling Speech Technology to 1,000+ Languages で提案されました。モデルとコードは元々ここでリリースされています。

MMS プロジェクトにはさまざまなオープンソース・モデルがあります: 自動音声認識 (ASR)、言語識別 (LID)、および音声合成 (TTS)。これについて簡単な図を以下に示します。

LID と ASR フロー

LID と ASR フロー#

このノートブックでは、ASR と LID を検討します。LID モデルを使用して言語を識別し、言語固有の ASR モデルを使用してそれを認識します。モデルの推論速度を向上させるため、追加のモデル量子化ステップが採用されています。ノートブックの最後には、Gradio ベースのインタラクティブ・デモがあります。

目次:

必要条件#

%pip install -q --upgrade pip 
%pip install -q "transformers>=4.33.1" "torch>=2.1" "openvino>=2023.1.0" "numpy>=1.21.0" "nncf>=2.9.0" 
%pip install -q --extra-index-url https://download.pytorch.org/whl/cpu torch "datasets>=2.14.6" accelerate soundfile librosa "gradio>=4.19" jiwer
from pathlib import Path 

import torch 

import openvino as ov

サンプル音声を準備#

オーディオファイルを読み取り、オーディオデータを処理します。オーディオデータが 16000 kHz でサンプリングされていることを確認してください。この例では、多言語 LibriSpeech (MLS) データセットのストリーミング可能なバージョンを使用します。7 つの言語の例をサポートしています: 'german', 'dutch', 'french', 'spanish', 'italian', 'portuguese', 'polish'。いずれかを選択してください。

import ipywidgets as widgets 

SAMPLE_LANG = widgets.Dropdown( 
    options=["german", "dutch", "french", "spanish", "italian", "portuguese", "polish"], 
    value="german", 
    description="Dataset language:", 
    disabled=False, 
) 

SAMPLE_LANG
Dropdown(description='Dataset language:', options=('german', 'dutch', 'french', 'spanish', 'italian', 'portugu…

データセット全体をダウンロードしない場合は、streaming=True を指定します。

from datasets import load_dataset 

mls_dataset = load_dataset("facebook/multilingual_librispeech", SAMPLE_LANG.value, split="test", streaming=True, trust_remote_code=True) 
mls_dataset = iter(mls_dataset) # 反復可能にする 

example = next(mls_dataset) # 1 つのサンプルを取得

例には辞書構造があります。音声データとテキストの書き起こしが含まれています。

print(example) # look at structure
{'file': None, 'audio': {'path': '1054_1599_000000.flac', 'array': array([-0.00131226, -0.00152588, -0.00134277, ..., 0.00411987, 0.00308228, -0.00015259]), 'sampling_rate': 16000}, 'text': 'mein sechster sohn scheint wenigstens auf den ersten blick der tiefsinnigste von allen ein kopfhänger und doch ein schwätzer deshalb kommt man ihm nicht leicht bei ist er am unterliegen so verfällt er in unbesiegbare traurigkeit', 'speaker_id': 1054, 'chapter_id': 1599, 'id': '1054_1599_000000'}
import IPython.display as ipd 

print(example["text"]) 
ipd.Audio(example["audio"]["array"], rate=16_000)
mein sechster sohn scheint wenigstens auf den ersten blick der tiefsinnigste von allen ein kopfhänger und doch ein schwätzer deshalb kommt man ihm nicht leicht bei ist er am unterliegen so verfällt er in unbesiegbare traurigkeit

言語識別 (LID)#

事前学習済みモデルとプロセッサーをダウンロード#

認識できる言語の数に応じて、126、256、512、1024、2048、4017 などさまざまな LID モデルが用意されています。ここでは 126 を使用します。

from transformers import Wav2Vec2ForSequenceClassification, AutoFeatureExtractor 

model_id = "facebook/mms-lid-126" 

lid_processor = AutoFeatureExtractor.from_pretrained(model_id) 
lid_model = Wav2Vec2ForSequenceClassification.from_pretrained(model_id)

元のモデルを使用して推論を実行#

inputs = lid_processor(example["audio"]["array"], sampling_rate=16_000, return_tensors="pt") 

with torch.no_grad(): 
    outputs = lid_model(**inputs).logits 

lang_id = torch.argmax(outputs, dim=-1)[0].item() 
detected_lang = lid_model.config.id2label[lang_id] 
print(detected_lang)
deu

OpenVINO IR モデルに変換して推論を実行#

OpenVINO を使用して推論を実行するデバイスをドロップダウン・リストから選択します。

core = ov.Core() 

device = widgets.Dropdown( 
    options=core.available_devices + ["AUTO"], 
    value="AUTO", 
    description="Device:", 
    disabled=False, 
) 

device
Dropdown(description='Device:', index=1, options=('CPU', 'AUTO'), value='AUTO')

モデルを OpenVINO 形式に変換してコンパイルします

MAX_SEQ_LENGTH = 30480 

lid_model_xml_path = Path("models/ov_lid_model.xml") 

def get_lid_model(model_path, compiled=True): 
    input_values = torch.zeros([1, MAX_SEQ_LENGTH], dtype=torch.float) 

    if not model_path.exists() and model_path == lid_model_xml_path: 
        lid_model_xml_path.parent.mkdir(parents=True, exist_ok=True) 
        converted_model = ov.convert_model(lid_model, example_input={"input_values": input_values}) 
        ov.save_model(converted_model, lid_model_xml_path) 
        if not compiled: 
            return converted_model 
    if compiled: 
        return core.compile_model(model_path, device_name=device.value) 
    return core.read_model(model_path) 

compiled_lid_model = get_lid_model(lid_model_xml_path)
/home/nsavel/venvs/ov_notebooks_tmp/lib/python3.8/site-packages/transformers/models/wav2vec2/modeling_wav2vec2.py:595: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect.We can't record the data flow of Python values, so this value will be treated as a constant in the future.This means that the trace might not generalize to other inputs! 
  if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): /home/nsavel/venvs/ov_notebooks_tmp/lib/python3.8/site-packages/transformers/models/wav2vec2/modeling_wav2vec2.py:634: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect.We can't record the data flow of Python values, so this value will be treated as a constant in the future.This means that the trace might not generalize to other inputs! 
  if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):

これで推論を実行できるようになりました。

def detect_language(compiled_model, audio_data): 
    inputs = lid_processor(audio_data, sampling_rate=16_000, return_tensors="pt") 

    outputs = compiled_model(inputs["input_values"])[0] 

    lang_id = torch.argmax(torch.from_numpy(outputs), dim=-1)[0].item() 
    detected_lang = lid_model.config.id2label[lang_id] 

    return detected_lang
detect_language(compiled_lid_model, example["audio"]["array"])
'deu'

別の言語を確認してみましょう。

SAMPLE_LANG = widgets.Dropdown( 
    options=["german", "dutch", "french", "spanish", "italian", "portuguese", "polish"], 
    value="french", 
    description="Dataset language:", 
    disabled=False, 
) 

SAMPLE_LANG
Dropdown(description='Dataset language:', index=2, options=('german', 'dutch', 'french', 'spanish', 'italian',…
mls_dataset = load_dataset("facebook/multilingual_librispeech", SAMPLE_LANG.value, split="test", streaming=True, trust_remote_code=True) 
mls_dataset = iter(mls_dataset) 

example = next(mls_dataset) 
print(example["text"]) 
ipd.Audio(example["audio"]["array"], rate=16_000)
grisé par ce parfum il fit des vers en l'honneur de l'humble fleur des bois et il les récita tout haut à ses pieds une violette l'entendit elle crut qu'il ne parlait que pour elle
language_id = detect_language(compiled_lid_model, example["audio"]["array"]) 
print(language_id)
fra

自動音声認識 (ASR)#

事前学習済みモデルとプロセッサーをダウンロード#

事前学習済みモデルとプロセッサーをダウンロードします。デフォルトでは、MMS は英語のアダプター重みを読み込みます。別の言語のアダプターの重みをロードする場合、target_lang=<your-chosen-target-lang>ignore_mismatched_sizes=True を必ず指定してください。指定された言語の語彙に応じて言語モデルヘッドのサイズを変更できるようにするには、ignore_mismatched_sizes=True キーワードを渡す必要があります。同様に、プロセッサーにも同じターゲット言語をロードする必要があります。サポート言語は後から変更することもできます。

from transformers import Wav2Vec2ForCTC, AutoProcessor 

model_id = "facebook/mms-1b-all" 

asr_processor = AutoProcessor.from_pretrained(model_id) 
asr_model = Wav2Vec2ForCTC.from_pretrained(model_id)

サポートされるすべての言語を確認できます:

asr_processor.tokenizer.vocab.keys()
dict_keys(['abi', 'abk', 'abp', 'aca', 'acd', 'ace', 'acf', 'ach', 'acn', 'acr', 'acu', 'ade', 'adh', 'adj', 'adx', 'aeu', 'afr', 'agd', 'agg', 'agn', 'agr', 'agu', 'agx', 'aha', 'ahk', 'aia', 'aka', 'akb', 'ake', 'akp', 'alj', 'alp', 'alt', 'alz', 'ame', 'amf', 'amh', 'ami', 'amk', 'ann', 'any', 'aoz', 'apb', 'apr', 'ara', 'arl', 'asa', 'asg', 'asm', 'ast', 'ata', 'atb', 'atg', 'ati', 'atq', 'ava', 'avn', 'avu', 'awa', 'awb', 'ayo', 'ayr', 'ayz', 'azb', 'azg', 'azj-script_cyrillic', 'azj-script_latin', 'azz', 'bak', 'bam', 'ban', 'bao', 'bas', 'bav', 'bba', 'bbb', 'bbc', 'bbo', 'bcc-script_arabic', 'bcc-script_latin', 'bcl', 'bcw', 'bdg', 'bdh', 'bdq', 'bdu', 'bdv', 'beh', 'bel', 'bem', 'ben', 'bep', 'bex', 'bfa', 'bfo', 'bfy', 'bfz', 'bgc', 'bgq', 'bgr', 'bgt', 'bgw', 'bha', 'bht', 'bhz', 'bib', 'bim', 'bis', 'biv', 'bjr', 'bjv', 'bjw', 'bjz', 'bkd', 'bkv', 'blh', 'blt', 'blx', 'blz', 'bmq', 'bmr', 'bmu', 'bmv', 'bng', 'bno', 'bnp', 'boa', 'bod', 'boj', 'bom', 'bor', 'bos', 'bov', 'box', 'bpr', 'bps', 'bqc', 'bqi', 'bqj', 'bqp', 'bre', 'bru', 'bsc', 'bsq', 'bss', 'btd', 'bts', 'btt', 'btx', 'bud', 'bul', 'bus', 'bvc', 'bvz', 'bwq', 'bwu', 'byr', 'bzh', 'bzi', 'bzj', 'caa', 'cab', 'cac-dialect_sanmateoixtatan', 'cac-dialect_sansebastiancoatan', 'cak-dialect_central', 'cak-dialect_santamariadejesus', 'cak-dialect_santodomingoxenacoj', 'cak-dialect_southcentral', 'cak-dialect_western', 'cak-dialect_yepocapa', 'cap', 'car', 'cas', 'cat', 'cax', 'cbc', 'cbi', 'cbr', 'cbs', 'cbt', 'cbu', 'cbv', 'cce', 'cco', 'cdj', 'ceb', 'ceg', 'cek', 'ces', 'cfm', 'cgc', 'che', 'chf', 'chv', 'chz', 'cjo', 'cjp', 'cjs', 'ckb', 'cko', 'ckt', 'cla', 'cle', 'cly', 'cme', 'cmn-script_simplified', 'cmo-script_khmer', 'cmo-script_latin', 'cmr', 'cnh', 'cni', 'cnl', 'cnt', 'coe', 'cof', 'cok', 'con', 'cot', 'cou', 'cpa', 'cpb', 'cpu', 'crh', 'crk-script_latin', 'crk-script_syllabics', 'crn', 'crq', 'crs', 'crt', 'csk', 'cso', 'ctd', 'ctg', 'cto', 'ctu', 'cuc', 'cui', 'cuk', 'cul', 'cwa', 'cwe', 'cwt', 'cya', 'cym', 'daa', 'dah', 'dan', 'dar', 'dbj', 'dbq', 'ddn', 'ded', 'des', 'deu', 'dga', 'dgi', 'dgk', 'dgo', 'dgr', 'dhi', 'did', 'dig', 'dik', 'dip', 'div', 'djk', 'dnj-dialect_blowowest', 'dnj-dialect_gweetaawueast', 'dnt', 'dnw', 'dop', 'dos', 'dsh', 'dso', 'dtp', 'dts', 'dug', 'dwr', 'dyi', 'dyo', 'dyu', 'dzo', 'eip', 'eka', 'ell', 'emp', 'enb', 'eng', 'enx', 'epo', 'ese', 'ess', 'est', 'eus', 'evn', 'ewe', 'eza', 'fal', 'fao', 'far', 'fas', 'fij', 'fin', 'flr', 'fmu', 'fon', 'fra', 'frd', 'fry', 'ful', 'gag-script_cyrillic', 'gag-script_latin', 'gai', 'gam', 'gau', 'gbi', 'gbk', 'gbm', 'gbo', 'gde', 'geb', 'gej', 'gil', 'gjn', 'gkn', 'gld', 'gle', 'glg', 'glk', 'gmv', 'gna', 'gnd', 'gng', 'gof-script_latin', 'gog', 'gor', 'gqr', 'grc', 'gri', 'grn', 'grt', 'gso', 'gub', 'guc', 'gud', 'guh', 'guj', 'guk', 'gum', 'guo', 'guq', 'guu', 'gux', 'gvc', 'gvl', 'gwi', 'gwr', 'gym', 'gyr', 'had', 'hag', 'hak', 'hap', 'hat', 'hau', 'hay', 'heb', 'heh', 'hif', 'hig', 'hil', 'hin', 'hlb', 'hlt', 'hne', 'hnn', 'hns', 'hoc', 'hoy', 'hrv', 'hsb', 'hto', 'hub', 'hui', 'hun', 'hus-dialect_centralveracruz', 'hus-dialect_westernpotosino', 'huu', 'huv', 'hvn', 'hwc', 'hye', 'hyw', 'iba', 'ibo', 'icr', 'idd', 'ifa', 'ifb', 'ife', 'ifk', 'ifu', 'ify', 'ign', 'ikk', 'ilb', 'ilo', 'imo', 'ina', 'inb', 'ind', 'iou', 'ipi', 'iqw', 'iri', 'irk', 'isl', 'ita', 'itl', 'itv', 'ixl-dialect_sangasparchajul', 'ixl-dialect_sanjuancotzal', 'ixl-dialect_santamarianebaj', 'izr', 'izz', 'jac', 'jam', 'jav', 'jbu', 'jen', 'jic', 'jiv', 'jmc', 'jmd', 'jpn', 'jun', 'juy', 'jvn', 'kaa', 'kab', 'kac', 'kak', 'kam', 'kan', 'kao', 'kaq', 'kat', 'kay', 'kaz', 'kbo', 'kbp', 'kbq', 'kbr', 'kby', 'kca', 'kcg', 'kdc', 'kde', 'kdh', 'kdi', 'kdj', 'kdl', 'kdn', 'kdt', 'kea', 'kek', 'ken', 'keo', 'ker', 'key', 'kez', 'kfb', 'kff-script_telugu', 'kfw', 'kfx', 'khg', 'khm', 'khq', 'kia', 'kij', 'kik', 'kin', 'kir', 'kjb', 'kje', 'kjg', 'kjh', 'kki', 'kkj', 'kle', 'klu', 'klv', 'klw', 'kma', 'kmd', 'kml', 'kmr-script_arabic', 'kmr-script_cyrillic', 'kmr-script_latin', 'kmu', 'knb', 'kne', 'knf', 'knj', 'knk', 'kno', 'kog', 'kor', 'kpq', 'kps', 'kpv', 'kpy', 'kpz', 'kqe', 'kqp', 'kqr', 'kqy', 'krc', 'kri', 'krj', 'krl', 'krr', 'krs', 'kru', 'ksb', 'ksr', 'kss', 'ktb', 'ktj', 'kub', 'kue', 'kum', 'kus', 'kvn', 'kvw', 'kwd', 'kwf', 'kwi', 'kxc', 'kxf', 'kxm', 'kxv', 'kyb', 'kyc', 'kyf', 'kyg', 'kyo', 'kyq', 'kyu', 'kyz', 'kzf', 'lac', 'laj', 'lam', 'lao', 'las', 'lat', 'lav', 'law', 'lbj', 'lbw', 'lcp', 'lee', 'lef', 'lem', 'lew', 'lex', 'lgg', 'lgl', 'lhu', 'lia', 'lid', 'lif', 'lin', 'lip', 'lis', 'lit', 'lje', 'ljp', 'llg', 'lln', 'lme', 'lnd', 'lns', 'lob', 'lok', 'lom', 'lon', 'loq', 'lsi', 'lsm', 'ltz', 'luc', 'lug', 'luo', 'lwo', 'lww', 'lzz', 'maa-dialect_sanantonio', 'maa-dialect_sanjeronimo', 'mad', 'mag', 'mah', 'mai', 'maj', 'mak', 'mal', 'mam-dialect_central', 'mam-dialect_northern', 'mam-dialect_southern', 'mam-dialect_western', 'maq', 'mar', 'maw', 'maz', 'mbb', 'mbc', 'mbh', 'mbj', 'mbt', 'mbu', 'mbz', 'mca', 'mcb', 'mcd', 'mco', 'mcp', 'mcq', 'mcu', 'mda', 'mdf', 'mdv', 'mdy', 'med', 'mee', 'mej', 'men', 'meq', 'met', 'mev', 'mfe', 'mfh', 'mfi', 'mfk', 'mfq', 'mfy', 'mfz', 'mgd', 'mge', 'mgh', 'mgo', 'mhi', 'mhr', 'mhu', 'mhx', 'mhy', 'mib', 'mie', 'mif', 'mih', 'mil', 'mim', 'min', 'mio', 'mip', 'miq', 'mit', 'miy', 'miz', 'mjl', 'mjv', 'mkd', 'mkl', 'mkn', 'mlg', 'mlt', 'mmg', 'mnb', 'mnf', 'mnk', 'mnw', 'mnx', 'moa', 'mog', 'mon', 'mop', 'mor', 'mos', 'mox', 'moz', 'mpg', 'mpm', 'mpp', 'mpx', 'mqb', 'mqf', 'mqj', 'mqn', 'mri', 'mrw', 'msy', 'mtd', 'mtj', 'mto', 'muh', 'mup', 'mur', 'muv', 'muy', 'mvp', 'mwq', 'mwv', 'mxb', 'mxq', 'mxt', 'mxv', 'mya', 'myb', 'myk', 'myl', 'myv', 'myx', 'myy', 'mza', 'mzi', 'mzj', 'mzk', 'mzm', 'mzw', 'nab', 'nag', 'nan', 'nas', 'naw', 'nca', 'nch', 'ncj', 'ncl', 'ncu', 'ndj', 'ndp', 'ndv', 'ndy', 'ndz', 'neb', 'new', 'nfa', 'nfr', 'nga', 'ngl', 'ngp', 'ngu', 'nhe', 'nhi', 'nhu', 'nhw', 'nhx', 'nhy', 'nia', 'nij', 'nim', 'nin', 'nko', 'nlc', 'nld', 'nlg', 'nlk', 'nmz', 'nnb', 'nno', 'nnq', 'nnw', 'noa', 'nob', 'nod', 'nog', 'not', 'npi', 'npl', 'npy', 'nso', 'nst', 'nsu', 'ntm', 'ntr', 'nuj', 'nus', 'nuz', 'nwb', 'nxq', 'nya', 'nyf', 'nyn', 'nyo', 'nyy', 'nzi', 'obo', 'oci', 'ojb-script_latin', 'ojb-script_syllabics', 'oku', 'old', 'omw', 'onb', 'ood', 'orm', 'ory', 'oss', 'ote', 'otq', 'ozm', 'pab', 'pad', 'pag', 'pam', 'pan', 'pao', 'pap', 'pau', 'pbb', 'pbc', 'pbi', 'pce', 'pcm', 'peg', 'pez', 'pib', 'pil', 'pir', 'pis', 'pjt', 'pkb', 'pls', 'plw', 'pmf', 'pny', 'poh-dialect_eastern', 'poh-dialect_western', 'poi', 'pol', 'por', 'poy', 'ppk', 'pps', 'prf', 'prk', 'prt', 'pse', 'pss', 'ptu', 'pui', 'pus', 'pwg', 'pww', 'pxm', 'qub', 'quc-dialect_central', 'quc-dialect_east', 'quc-dialect_north', 'quf', 'quh', 'qul', 'quw', 'quy', 'quz', 'qvc', 'qve', 'qvh', 'qvm', 'qvn', 'qvo', 'qvs', 'qvw', 'qvz', 'qwh', 'qxh', 'qxl', 'qxn', 'qxo', 'qxr', 'rah', 'rai', 'rap', 'rav', 'raw', 'rej', 'rel', 'rgu', 'rhg', 'rif-script_arabic', 'rif-script_latin', 'ril', 'rim', 'rjs', 'rkt', 'rmc-script_cyrillic', 'rmc-script_latin', 'rmo', 'rmy-script_cyrillic', 'rmy-script_latin', 'rng', 'rnl', 'roh-dialect_sursilv', 'roh-dialect_vallader', 'rol', 'ron', 'rop', 'rro', 'rub', 'ruf', 'rug', 'run', 'rus', 'sab', 'sag', 'sah', 'saj', 'saq', 'sas', 'sat', 'sba', 'sbd', 'sbl', 'sbp', 'sch', 'sck', 'sda', 'sea', 'seh', 'ses', 'sey', 'sgb', 'sgj', 'sgw', 'shi', 'shk', 'shn', 'sho', 'shp', 'sid', 'sig', 'sil', 'sja', 'sjm', 'sld', 'slk', 'slu', 'slv', 'sml', 'smo', 'sna', 'snd', 'sne', 'snn', 'snp', 'snw', 'som', 'soy', 'spa', 'spp', 'spy', 'sqi', 'sri', 'srm', 'srn', 'srp-script_cyrillic', 'srp-script_latin', 'srx', 'stn', 'stp', 'suc', 'suk', 'sun', 'sur', 'sus', 'suv', 'suz', 'swe', 'swh', 'sxb', 'sxn', 'sya', 'syl', 'sza', 'tac', 'taj', 'tam', 'tao', 'tap', 'taq', 'tat', 'tav', 'tbc', 'tbg', 'tbk', 'tbl', 'tby', 'tbz', 'tca', 'tcc', 'tcs', 'tcz', 'tdj', 'ted', 'tee', 'tel', 'tem', 'teo', 'ter', 'tes', 'tew', 'tex', 'tfr', 'tgj', 'tgk', 'tgl', 'tgo', 'tgp', 'tha', 'thk', 'thl', 'tih', 'tik', 'tir', 'tkr', 'tlb', 'tlj', 'tly', 'tmc', 'tmf', 'tna', 'tng', 'tnk', 'tnn', 'tnp', 'tnr', 'tnt', 'tob', 'toc', 'toh', 'tom', 'tos', 'tpi', 'tpm', 'tpp', 'tpt', 'trc', 'tri', 'trn', 'trs', 'tso', 'tsz', 'ttc', 'tte', 'ttq-script_tifinagh', 'tue', 'tuf', 'tuk-script_arabic', 'tuk-script_latin', 'tuo', 'tur', 'tvw', 'twb', 'twe', 'twu', 'txa', 'txq', 'txu', 'tye', 'tzh-dialect_bachajon', 'tzh-dialect_tenejapa', 'tzj-dialect_eastern', 'tzj-dialect_western', 'tzo-dialect_chamula', 'tzo-dialect_chenalho', 'ubl', 'ubu', 'udm', 'udu', 'uig-script_arabic', 'uig-script_cyrillic', 'ukr', 'umb', 'unr', 'upv', 'ura', 'urb', 'urd-script_arabic', 'urd-script_devanagari', 'urd-script_latin', 'urk', 'urt', 'ury', 'usp', 'uzb-script_cyrillic', 'uzb-script_latin', 'vag', 'vid', 'vie', 'vif', 'vmw', 'vmy', 'vot', 'vun', 'vut', 'wal-script_ethiopic', 'wal-script_latin', 'wap', 'war', 'waw', 'way', 'wba', 'wlo', 'wlx', 'wmw', 'wob', 'wol', 'wsg', 'wwa', 'xal', 'xdy', 'xed', 'xer', 'xho', 'xmm', 'xnj', 'xnr', 'xog', 'xon', 'xrb', 'xsb', 'xsm', 'xsr', 'xsu', 'xta', 'xtd', 'xte', 'xtm', 'xtn', 'xua', 'xuo', 'yaa', 'yad', 'yal', 'yam', 'yao', 'yas', 'yat', 'yaz', 'yba', 'ybb', 'ycl', 'ycn', 'yea', 'yka', 'yli', 'yor', 'yre', 'yua', 'yue-script_traditional', 'yuz', 'yva', 'zaa', 'zab', 'zac', 'zad', 'zae', 'zai', 'zam', 'zao', 'zaq', 'zar', 'zas', 'zav', 'zaw', 'zca', 'zga', 'zim', 'ziw', 'zlm', 'zmz', 'zne', 'zos', 'zpc', 'zpg', 'zpi', 'zpl', 'zpm', 'zpo', 'zpt', 'zpu', 'zpz', 'ztq', 'zty', 'zul', 'zyb', 'zyp', 'zza'])

モデルの場合は load_adapter() 関数を、トークナイザーの場合は set_target_lang() を呼び出して、言語アダプターを切り替えます。前の手順で検出されたターゲット言語 ("detect_language_id") を入力として渡します。

asr_processor.tokenizer.set_target_lang(language_id) 
asr_model.load_adapter(language_id)
Ignored unknown kwarg option normalize 
Ignored unknown kwarg option normalize 
Ignored unknown kwarg option normalize 
Ignored unknown kwarg option normalize

推論には元のモデルを使用#

inputs = asr_processor(example["audio"]["array"], sampling_rate=16_000, return_tensors="pt") 

with torch.no_grad(): 
    outputs = asr_model(**inputs).logits 

ids = torch.argmax(outputs, dim=-1)[0] 
transcription = asr_processor.decode(ids) 
print(transcription)
grisé par ce parfum il fit des vers en l'honneur de l'humble fleur des bois et il les récita tout haut à ses pieds une violette l'entendit elle crut qu'il ne parlait que pour elle

OpenVINO IR モデルに変換して推論を実行#

ov.convert_model 関数を使用して OpenVINO IR モデル形式に直接変換します。ov.save_model 関数を使用して、変換結果をシリアル化します。今後の使用の利便性のために、これらの目的の関数を作成します。

asr_model_xml_path_template = "models/ov_asr_{}_model.xml" 

def get_asr_model(model_path_template, language_id, compiled=True): 
    input_values = torch.zeros([1, MAX_SEQ_LENGTH], dtype=torch.float) 
    model_path = Path(model_path_template.format(language_id)) 

    asr_processor.tokenizer.set_target_lang(language_id) 
    if not model_path.exists() and model_path_template == asr_model_xml_path_template: 
        asr_model.load_adapter(language_id)

        model_path.parent.mkdir(parents=True, exist_ok=True) 
        converted_model = ov.convert_model(asr_model, example_input={"input_values": input_values}) 
        ov.save_model(converted_model, model_path) 
        if not compiled: 
            return converted_model 

    if compiled: 
        return core.compile_model(model_path, device_name=device.value) 
    return core.read_model(model_path) 

compiled_asr_model = get_asr_model(asr_model_xml_path_template, language_id)
Ignored unknown kwarg option normalize 
Ignored unknown kwarg option normalize 
Ignored unknown kwarg option normalize 
Ignored unknown kwarg option normalize

推論を実行します。

def recognize_audio(compiled_model, src_audio): 
    inputs = asr_processor(src_audio, sampling_rate=16_000, return_tensors="pt") 
    outputs = compiled_model(inputs["input_values"])[0] 

    ids = torch.argmax(torch.from_numpy(outputs), dim=-1)[0] 
    transcription = asr_processor.decode(ids) 

    return transcription 

transcription = recognize_audio(compiled_asr_model, example["audio"]["array"]) 
print("Original text:", example["text"]) 
print("Transcription:", transcription)
Original text: grisé par ce parfum il fit des vers en l'honneur de l'humble fleur des bois et il les récita tout haut à ses pieds une violette l'entendit elle crut qu'il ne parlait que pour elle 
Transcription: grisé par ce parfum il fit des vers en l'honneur de l'humble fleur des bois et il les récita tout haut à ses pieds une violette l'entendit elle crut qu'il ne parlait que pour elle

量子化#

NNCF は、量子化レイヤーをモデルグラフに追加し、トレーニング・データセットのサブセットを使用してこれらの追加の量子化レイヤーのパラメーターを初期化することで、トレーニング後の量子化を可能にします。量子化操作は FP32/FP16 ではなく INT8 で実行されるため、モデル推論が高速化されます。

最適化プロセスには次の手順が含まれます:

  1. 量子化用のキャリブレーション・データセットを作成します。

  2. nncf.quantize() を実行して、量子化されたモデルを取得します。

  3. openvino.save_model() 関数を使用して量子化された INT8 モデルをシリアル化します。

注: 量子化は時間とメモリーを消費する操作です。以下の量子化コードの実行には時間がかかる場合があります。

compiled_quantized_lid_model = None 
quantized_asr_model_xml_path_template = None 

to_quantize = widgets.Checkbox( 
    value=False, 
    description="Quantization", 
    disabled=False, 
) 

to_quantize
Checkbox(value=True, description='Quantization')

to_quantize が選択されていない場合に量子化をスキップするスキップマジック拡張機能をロードします

# `skip_kernel_extension` モジュールを取得 
import requests 

r = requests.get( 

url="https://raw.githubusercontent.com/openvinotoolkit/openvino_notebooks/latest/utils/skip_kernel_extension.py", 
) 
open("skip_kernel_extension.py", "w").write(r.text) 

%load_ext skip_kernel_extension

キャリブレーション・データセットの準備#

モデルを量子化する言語を選択します:

%%skip not $to_quantize.value 

from IPython.display import display 

display(SAMPLE_LANG)

選択した言語の同じ MLS データセットの検証分割を読み込みます。

%%skip not $to_quantize.value 

mls_dataset = iter(load_dataset("facebook/multilingual_librispeech", SAMPLE_LANG.value, split="validation", streaming=True, trust_remote_code=True)) 
example = next(mls_dataset)

量子化用のキャリブレーション・データセットを作成します。

%%skip not $to_quantize.value 

CALIBRATION_DATASET_SIZE = 5 

calibration_data = [] 
for i in range(CALIBRATION_DATASET_SIZE): 
    data = asr_processor(next(mls_dataset)['audio']['array'], sampling_rate=16_000, return_tensors="np") 
    calibration_data.append(data["input_values"])

言語識別モデルの量子化#

LID モデルの量子化を実行します。

%%skip not $to_quantize.value 

import nncf 

quantized_lid_model_xml_path = Path(str(lid_model_xml_path).replace(".xml", "_quantized.xml")) 

if not quantized_lid_model_xml_path.exists(): 
    quantized_lid_model = nncf.quantize( 
        get_lid_model(lid_model_xml_path, compiled=False), 
        calibration_dataset=nncf.Dataset(calibration_data), 
        subset_size=len(calibration_data), 
        model_type=nncf.ModelType.TRANSFORMER 
    ) 
    ov.save_model(quantized_lid_model, quantized_lid_model_xml_path) 
    compiled_quantized_lid_model = core.compile_model(quantized_lid_model, device_name=device.value) 
else: 
    compiled_quantized_lid_model = get_lid_model(quantized_lid_model_xml_path)
INFO:nncf:NNCF initialized successfully. Supported frameworks detected: torch, openvino
Statistics collection: 100%|
██████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:06<00:00, 1.24s/it] 
Applying Smooth Quant: 100%|
██████████████████████████████████████████████████████████████████████████████████████████████| 291/291 [00:18<00:00, 15.34it/s]
INFO:nncf:144 ignored nodes was found by name in the NNCFGraph
Statistics collection: 100%|
██████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:18<00:00, 3.65s/it] 
Applying Fast Bias correction: 100%|
██████████████████████████████████████████████████████████████████████████████████████| 298/298 [05:09<00:00, 1.04s/it]

量子化モデルを使用して言語を検出します。

%%skip not $to_quantize.value 

language_id = detect_language(compiled_quantized_lid_model, example['audio']['array']) 
print("Detected language:", language_id)
Detected language: fra

音声認識モデルの量子化#

ASR モデルの量子化を実行します。

%%skip not $to_quantize.value 

quantized_asr_model_xml_path_template = asr_model_xml_path_template.replace(".xml", "_quantized.xml") 
quantized_asr_model_xml_path = Path(quantized_asr_model_xml_path_template.format(language_id)) 

if not quantized_asr_model_xml_path.exists(): 
    quantized_asr_model = nncf.quantize( 
        get_asr_model(asr_model_xml_path_template, language_id, compiled=False), 
        calibration_dataset=nncf.Dataset(calibration_data), 
        subset_size=len(calibration_data), 
        model_type=nncf.ModelType.TRANSFORMER 
    ) 
    ov.save_model(quantized_asr_model, quantized_asr_model_xml_path) 
    compiled_quantized_asr_model = core.compile_model(quantized_asr_model, device_name=device.value) 
else: 
    compiled_quantized_asr_model = get_asr_model(quantized_asr_model_xml_path_template, language_id)
Ignored unknown kwarg option normalize 
Ignored unknown kwarg option normalize 
Ignored unknown kwarg option normalize 
Ignored unknown kwarg option normalize
Statistics collection: 100%|
██████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:05<00:00, 1.17s/it] 
Applying Smooth Quant: 100%|
██████████████████████████████████████████████████████████████████████████████████████████████| 290/290 [00:17<00:00, 16.39it/s]
INFO:nncf:144 ignored nodes was found by name in the NNCFGraph
Statistics collection: 100%|
██████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:19<00:00, 3.93s/it] 
Applying Fast Bias correction: 100%|
██████████████████████████████████████████████████████████████████████████████████████| 393/393 [05:22<00:00, 1.22it/s]

量子化モデルを使用して転写を実行し、その結果を元のモデルによって生成された結果と比較します。

%%skip not $to_quantize.value 

compiled_asr_model = get_asr_model(asr_model_xml_path_template, language_id) 
transcription_original = recognize_audio(compiled_asr_model, example['audio']['array']) 
transcription_quantized = recognize_audio(compiled_quantized_asr_model, example['audio']['array']) 
print("Transcription by original model: ", transcription_original) 
print("Transcription by quantized model:", transcription_quantized)
Ignored unknown kwarg option normalize 
Ignored unknown kwarg option normalize 
Ignored unknown kwarg option normalize 
Ignored unknown kwarg option normalize 
Transcription by original model: le salon était de la plus haute magnificence dorée comme la galerie de diane aux tuileries avec des tableaux à l'huile au lombri il y avait des tâches claires dans ces tableaux julien apprit plus tard que les sujets avaient semblé peu décent à la maîtresse du logis qui avait fait corriger les tableaux 
Transcription by quantized model: le salon était de la plus haute magnificence doré comme la galerie de diane aux tuileries avec des tableaux à l'huile au lombri il y avait des tâches claires dans ces tableaux julien apprit plus tard que les sujets avaient semblé peu decent à la maîtresse du logis qui avait fait corriger les tableaux

モデルのサイズ、パフォーマンス、精度を比較#

まずモデルのサイズを比較します。

%%skip not $to_quantize.value 

def calculate_compression_rate(model_path_ov, model_path_ov_int8, model_type): 
    model_size_fp32 = model_path_ov.with_suffix(".bin").stat().st_size / 10 ** 6 
    model_size_int8 = model_path_ov_int8.with_suffix(".bin").stat().st_size / 10 ** 6 
    print(f"{model_type} model footprint comparison:") 
    print(f" * FP32 IR model size: {model_size_fp32:.2f} MB") 
    print(f" * INT8 IR model size: {model_size_int8:.2f} MB") 
    return model_size_fp32, model_size_int8 

lid_model_size_fp32, lid_model_size_int8 = \ 
    calculate_compression_rate(lid_model_xml_path, quantized_lid_model_xml_path, 'LID') 
asr_model_size_fp32, asr_model_size_int8 = \ 
    qalculate_compression_rate(Path(asr_model_xml_path_template.format(language_id)), quantized_asr_model_xml_path, 'ASR')
LID model footprint comparison:
    * FP32 IR model size: 1931.81 MB 
    * INT8 IR model size: 968.96 MB 
ASR model footprint comparison:
    * FP32 IR model size: 1930.10 MB 
    * INT8 IR model size: 968.29 MB

次に、MLS データセットのテスト分割で元のモデルと量子化されたモデルの精度値を比較します。ここでは Word Error Rate (WER) メトリックに依存し、精度を (1 - WER) として計算します。

また、言語識別モデルと音声認識モデルの両方の推論時間も測定します。

%%skip not $to_quantize.value 

import time 
from tqdm.notebook import tqdm 
import numpy as np 
from jiwer import wer 

TEST_DATASET_SIZE = 20 
test_dataset = load_dataset("facebook/multilingual_librispeech", SAMPLE_LANG.value, 
split="test", streaming=True, trust_remote_code=True) test_dataset = 
test_dataset.take(TEST_DATASET_SIZE) 

def calculate_transcription_time_and_accuracy(lid_model, asr_model): 
    ground_truths = [] 
    predictions = [] 
    identification_time = [] 
    transcription_time = [] 
    for data_item in tqdm(test_dataset, desc="Measuring performance and accuracy", total=TEST_DATASET_SIZE): 
        audio = data_item["audio"]["array"] 

        start_time = time.perf_counter() 
        detect_language(lid_model, audio) 
        end_time = time.perf_counter() 
        identification_time.append(end_time - start_time) 

        start_time = time.perf_counter() 
        transcription = recognize_audio(asr_model, audio) 
        end_time = time.perf_counter() 
        transcription_time.append(end_time - start_time) 

        ground_truths.append(data_item["text"]) 
        predictions.append(transcription) 

    word_accuracy = (1 - wer(ground_truths, predictions)) * 100 
    mean_identification_time = np.mean(identification_time) 
    mean_transcription_time = np.mean(transcription_time) 
    return mean_identification_time, mean_transcription_time, word_accuracy 

identification_time_fp32, transcription_time_fp32, accuracy_fp32 = \ 
    calculate_transcription_time_and_accuracy(compiled_lid_model, compiled_asr_model) 
identification_time_int8, transcription_time_int8, accuracy_int8 = \ 
    calculate_transcription_time_and_accuracy(compiled_quantized_lid_model, compiled_quantized_asr_model) 
print(f"LID model footprint reduction: {lid_model_size_fp32 / lid_model_size_int8:.3f}") 
print(f"ASR model footprint reduction: {asr_model_size_fp32 / asr_model_size_int8:.3f}") 
print(f"Language identification performance speedup: {identification_time_fp32 / identification_time_int8:.3f}") 
print(f"Language transcription performance speedup: {transcription_time_fp32 / transcription_time_int8:.3f}") 
print(f"Transcription word accuracy. FP32: {accuracy_fp32:.2f}%. INT8: {accuracy_int8:.2f}%. Accuracy drop :{accuracy_fp32 - accuracy_int8:.2f}%.")
Measuring performance and accuracy: 0%|          | 0/20 [00:00<?, ?it/s]
Measuring performance and accuracy: 0%|          | 0/20 [00:00<?, ?it/s]
LID model footprint reduction: 1.994 
ASR model footprint reduction: 1.993 
Language identification performance speedup: 1.425 
Language transcription performance speedup: 1.489 
Transcription word accuracy. FP32: 85.01%. INT8: 84.76%. Accuracy drop :0.25%.

Gradio によるインタラクティブなデモ#

このデモでは、独自の例を試すことができます。オーディオデータが 16000 kHz でサンプリングされていることを確認してください。

import gradio as gr 
import librosa 
import time 

title = "MMS with Gradio" 
description = ( 
    'Gradio Demo for MMS and OpenVINO™. Upload a source audio, then click the "Submit" button to detect a language ID and a transcription.   ' 
    "Make sure that the audio data is sampled to 16000 kHz. If this language has not been used before, it may take some time to prepare the ASR model."     "\n" 
    "> Note: In order to run quantized model to transcribe some language, first the quantized model for that specific language must be prepared." ) 

current_state = { 
    "fp32": {"model": None, "language": None}, 
    "int8": {"model": None, "language": None}, 
} 

def infer(src_audio_path, quantized): 
    src_audio, _ = librosa.load(src_audio_path) 
    lid_model = compiled_quantized_lid_model if quantized else compiled_lid_model 

    start_time = time.perf_counter() 
    detected_language_id = detect_language(lid_model, src_audio) 
    end_time = time.perf_counter() 
    identification_delta_time = f"{end_time - start_time:.2f}" 

    state = current_state["int8" if quantized else "fp32"] 
    if detected_language_id != state["language"]: 
        template_path = quantized_asr_model_xml_path_template if quantized else asr_model_xml_path_template 
        try: 
            gr.Info(f"Loading {'quantized' if quantized else ''} ASR model for '{detected_language_id}' language." "This will take some time.") 
            state["model"] = get_asr_model(template_path, detected_language_id) 
            state["language"] = detected_language_id 
        except RuntimeError as e: 
            if "Unable to read the model:" in str(e) and quantized: 
                raise gr.Error(f"There is no quantized ASR model for '{detected_language_id}' language." "Please run quantization for this language first.") 

    start_time = time.perf_counter() 
    transcription = recognize_audio(state["model"], src_audio) 
    end_time = time.perf_counter() 
    transcription_delta_time = f"{end_time - start_time:.2f}" 

    return ( 
        detected_language_id, 
        transcription, 
        identification_delta_time, 
        transcription_delta_time, 
    ) 

with gr.Blocks() as demo: 
    with gr.Row(): 
        gr.Markdown(f"# {title}") 
    with gr.Row(): 
        gr.Markdown(description) 

    run_button = {True: None, False: None} 
    detected_language = {True: None, False: None} 
    transcription = {True: None, False: None} 
    identification_time = {True: None, False: None} 
    transcription_time = {True: None, False: None} 
    for quantized in [False, True]: 
        if quantized and not to_quantize.value: 
            break 
        with gr.Row(): 
            with gr.Column(): 
                if not quantized: 
                    audio = gr.Audio(label="Source Audio", type="filepath") 
                run_button_name = "Run INT8" if quantized else "Run FP32" if to_quantize.value else "Run" 
                run_button[quantized] = gr.Button(value=run_button_name) 
            with gr.Column(): 
                detected_language[quantized] = gr.Textbox(label=f"Detected language ID{' (Quantized)' if quantized else ''}") 
                transcription[quantized] = gr.Textbox(label=f"Transcription{' (Quantized)' if quantized else ''}") 
                identification_time[quantized] = gr.Textbox(label=f"Identification time{' (Quantized)' if quantized else ''}") 
                transcription_time[quantized] = gr.Textbox(label=f"Transcription time{' (Quantized)' if quantized else ''}") 

    run_button[False].click( 
        infer, 
        inputs=[audio, gr.Number(0, visible=False)], 
        outputs=[ 
            detected_language[False], 
            transcription[False], 
            identification_time[False], 
            transcription_time[False], 
        ], 
    ) 
    if to_quantize.value: 
        run_button[True].click( 
            infer, 
            inputs=[audio, gr.Number(1, visible=False)], 
            outputs=[ 
                detected_language[True], 
                transcription[True], 
                identification_time[True], 
                transcription_time[True], 
            ], 
        ) 

try: 
    demo.queue().launch(debug=False) 
except Exception: 
    demo.queue().launch(share=True, debug=False) 
# リモートで起動する場合は、server_name と server_port を指定 
# demo.launch(server_name='your server name', server_port='server port in int') 
# 詳細はドキュメントをご覧ください: https://gradio.app/docs/
ローカル URL で実行中: http://127.0.0.1:7860 
パブリックリンクを作成するには、launch()share=True を設定します。
Ignored unknown kwarg option normalize 
Ignored unknown kwarg option normalize 
Ignored unknown kwarg option normalize 
Ignored unknown kwarg option normalize 
Ignored unknown kwarg option normalize 
Ignored unknown kwarg option normalize 
Ignored unknown kwarg option normalize 
Ignored unknown kwarg option normalize 
Ignored unknown kwarg option normalize 
Ignored unknown kwarg option normalize 
Ignored unknown kwarg option normalize 
Ignored unknown kwarg option normalize 
WARNING:nncf:NNCF provides best results with torch==2.0.1, while current torch version is 1.13.1+cu117. If you encounter issues, consider switching to torch==2.0.1
No CUDA runtime is found, using CUDA_HOME='/usr/local/cuda-11.7' 
/home/nsavel/venvs/ov_notebooks_tmp/lib/python3.8/site-packages/transformers/models/wav2vec2/modeling_wav2vec2.py:595: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect.We can't record the data flow of Python values, so this value will be treated as a constant in the future.This means that the trace might not generalize to other inputs! 
  if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): /home/nsavel/venvs/ov_notebooks_tmp/lib/python3.8/site-packages/transformers/models/wav2vec2/modeling_wav2vec2.py:634: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect.We can't record the data flow of Python values, so this value will be treated as a constant in the future.This means that the trace might not generalize to other inputs! 
  if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):