MMS:OpenVINO™ で音声テクノロジーを 1000 以上の言語に拡張¶
この Jupyter ノートブックは、ローカルへのインストール後にのみ起動できます。
大規模多言語スピーチ (MMS) プロジェクトでは、1,100 を超える言語 (従来の 10 倍以上) をサポートする単一の多言語音声認識モデル、4,000 を超える言語 (従来の 40 倍) を識別できる言語識別モデル、1,400 を超える言語をサポートする事前トレーニング済みモデル、および 1,100 を超える言語のテキスト読み上げモデルを構築することで、音声テクノロジーを約 100 言語から 1,000 を超える言語に拡張します。
MMS モデルは、Scaling Speech Technology to 1,000+ Languages で提案されました。モデルとコードは元々ここでリリースされています。
MMS プロジェクトには、自動音声認識 (ASR)、言語識別 (LID)、音声合成 (TTS) など、さまざまなオープンソース・モデルがあります。これについて簡単な図を以下に示します。
このノートブックでは、ASR と LID を検討します。LID モデルを使用して言語を識別し、言語固有の ASR モデルを使用してそれを認識します。モデルの推論速度を向上させるため、追加のモデル量子化ステップが採用されています。ノートブックの最後には、Gradio ベースのインタラクティブ・デモがあります。
目次¶
必要条件¶
%pip install -q --upgrade pip
%pip install -q "transformers>=4.33.1" "openvino>=2023.1.0" "numpy>=1.21.0,<=1.24" "nncf>=2.6.0"
%pip install -q --extra-index-url https://download.pytorch.org/whl/cpu torch datasets accelerate soundfile librosa gradio jiwer
from pathlib import Path
import torch
import openvino as ov
サンプル音声を準備¶
オーディオファイルを読み取り、オーディオデータを処理します。オーディオデータが 16000 kHz でサンプリングされていることを確認してください。この例では、多言語 LibriSpeech (MLS) データセットのストリーミング可能なバージョンを使用します。7 つの言語の例をサポートしています: 'german', 'dutch', 'french', 'spanish', 'italian', 'portuguese', 'polish'
。いずれかを選択してください。
import ipywidgets as widgets
SAMPLE_LANG = widgets.Dropdown(
options=['german', 'dutch', 'french', 'spanish', 'italian', 'portuguese', 'polish'],
value='german',
description='Dataset language:',
disabled=False,
)
SAMPLE_LANG
Dropdown(description='Dataset language:', options=('german', 'dutch', 'french', 'spanish', 'italian', 'portugu…
データセット全体をダウンロードしない場合は、streaming=True
を指定します。
from datasets import load_dataset
mls_dataset = load_dataset("facebook/multilingual_librispeech", SAMPLE_LANG.value, split="test", streaming=True)
mls_dataset = iter(mls_dataset) # make it iterable
example = next(mls_dataset) # get one example
例には辞書構造があります。音声データとテキストの書き起こしが含まれています。
print(example) # look at structure
{'file': None, 'audio': {'path': '1054_1599_000000.flac', 'array': array([-0.00131226, -0.00152588, -0.00134277, ..., 0.00411987,
0.00308228, -0.00015259]), 'sampling_rate': 16000}, 'text': 'mein sechster sohn scheint wenigstens auf den ersten blick der tiefsinnigste von allen ein kopfhänger und doch ein schwätzer deshalb kommt man ihm nicht leicht bei ist er am unterliegen so verfällt er in unbesiegbare traurigkeit', 'speaker_id': 1054, 'chapter_id': 1599, 'id': '1054_1599_000000'}
import IPython.display as ipd
print(example['text'])
ipd.Audio(example['audio']['array'], rate=16_000)
mein sechster sohn scheint wenigstens auf den ersten blick der tiefsinnigste von allen ein kopfhänger und doch ein schwätzer deshalb kommt man ihm nicht leicht bei ist er am unterliegen so verfällt er in unbesiegbare traurigkeit
言語識別 (LID)¶
事前学習済みモデルとプロセッサーをダウンロード¶
認識できる言語の数に応じて、126、256、512、1024、2048、4017 などさまざまな LID モデルが用意されています。ここでは 126 を使用します。
from transformers import Wav2Vec2ForSequenceClassification, AutoFeatureExtractor
model_id = "facebook/mms-lid-126"
lid_processor = AutoFeatureExtractor.from_pretrained(model_id)
lid_model = Wav2Vec2ForSequenceClassification.from_pretrained(model_id)
元のモデルを使用して推論を実行¶
inputs = lid_processor(example['audio']['array'], sampling_rate=16_000, return_tensors="pt")
with torch.no_grad():
outputs = lid_model(**inputs).logits
lang_id = torch.argmax(outputs, dim=-1)[0].item()
detected_lang = lid_model.config.id2label[lang_id]
print(detected_lang)
deu
OpenVINO IR モデルに変換して推論を実行¶
OpenVINO を使用して推論を実行するデバイスをドロップダウン・リストから選択します。
core = ov.Core()
device = widgets.Dropdown(
options=core.available_devices + ["AUTO"],
value='AUTO',
description='Device:',
disabled=False,
)
device
Dropdown(description='Device:', index=1, options=('CPU', 'AUTO'), value='AUTO')
モデルを OpenVINO 形式に変換してコンパイルします。
MAX_SEQ_LENGTH = 30480
lid_model_xml_path = Path('models/ov_lid_model.xml')
def get_lid_model(model_path, compiled=True):
input_values = torch.zeros([1, MAX_SEQ_LENGTH], dtype=torch.float)
if not model_path.exists() and model_path == lid_model_xml_path:
lid_model_xml_path.parent.mkdir(parents=True, exist_ok=True)
converted_model = ov.convert_model(lid_model, example_input={'input_values': input_values})
ov.save_model(converted_model, lid_model_xml_path)
if not compiled:
return converted_model
if compiled:
return core.compile_model(model_path, device_name=device.value)
return core.read_model(model_path)
compiled_lid_model = get_lid_model(lid_model_xml_path)
/home/nsavel/venvs/ov_notebooks_tmp/lib/python3.8/site-packages/transformers/models/wav2vec2/modeling_wav2vec2.py:595: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
/home/nsavel/venvs/ov_notebooks_tmp/lib/python3.8/site-packages/transformers/models/wav2vec2/modeling_wav2vec2.py:634: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
これで推論を実行できるようになりました。
def detect_language(compiled_model, audio_data):
inputs = lid_processor(audio_data, sampling_rate=16_000, return_tensors="pt")
outputs = compiled_model(inputs['input_values'])[0]
lang_id = torch.argmax(torch.from_numpy(outputs), dim=-1)[0].item()
detected_lang = lid_model.config.id2label[lang_id]
return detected_lang
detect_language(compiled_lid_model, example['audio']['array'])
'deu'
別の言語を確認してみましょう。
SAMPLE_LANG = widgets.Dropdown(
options=['german', 'dutch', 'french', 'spanish', 'italian', 'portuguese', 'polish'],
value='french',
description='Dataset language:',
disabled=False,
)
SAMPLE_LANG
Dropdown(description='Dataset language:', index=2, options=('german', 'dutch', 'french', 'spanish', 'italian',…
mls_dataset = load_dataset("facebook/multilingual_librispeech", SAMPLE_LANG.value, split="test", streaming=True)
mls_dataset = iter(mls_dataset)
example = next(mls_dataset)
print(example['text'])
ipd.Audio(example['audio']['array'], rate=16_000)
grisé par ce parfum il fit des vers en l'honneur de l'humble fleur des bois et il les récita tout haut à ses pieds une violette l'entendit elle crut qu'il ne parlait que pour elle
language_id = detect_language(compiled_lid_model, example['audio']['array'])
print(language_id)
fra
自動音声認識 (ASR)¶
事前学習済みモデルとプロセッサーをダウンロード¶
事前学習済みモデルとプロセッサーをダウンロードします。デフォルトでは、MMS は英語のアダプター重みを読み込みます。別の言語のアダプターの重みをロードする場合、target_lang=<your-chosen-target-lang>
と ignore_mismatched_sizes=True
を必ず指定してください。指定された言語の語彙に応じて言語モデルヘッドのサイズを変更できるようにするには、ignore_mismatched_sizes=True
キーワードを渡す必要があります。同様に、プロセッサーにも同じターゲット言語をロードする必要があります。サポート言語は後から変更することもできます。
from transformers import Wav2Vec2ForCTC, AutoProcessor
model_id = "facebook/mms-1b-all"
asr_processor = AutoProcessor.from_pretrained(model_id)
asr_model = Wav2Vec2ForCTC.from_pretrained(model_id)
サポートされるすべての言語を確認できます。
asr_processor.tokenizer.vocab.keys()
dict_keys(['abi', 'abk', 'abp', 'aca', 'acd', 'ace', 'acf', 'ach', 'acn', 'acr', 'acu', 'ade', 'adh', 'adj', 'adx', 'aeu', 'afr', 'agd', 'agg', 'agn', 'agr', 'agu', 'agx', 'aha', 'ahk', 'aia', 'aka', 'akb', 'ake', 'akp', 'alj', 'alp', 'alt', 'alz', 'ame', 'amf', 'amh', 'ami', 'amk', 'ann', 'any', 'aoz', 'apb', 'apr', 'ara', 'arl', 'asa', 'asg', 'asm', 'ast', 'ata', 'atb', 'atg', 'ati', 'atq', 'ava', 'avn', 'avu', 'awa', 'awb', 'ayo', 'ayr', 'ayz', 'azb', 'azg', 'azj-script_cyrillic', 'azj-script_latin', 'azz', 'bak', 'bam', 'ban', 'bao', 'bas', 'bav', 'bba', 'bbb', 'bbc', 'bbo', 'bcc-script_arabic', 'bcc-script_latin', 'bcl', 'bcw', 'bdg', 'bdh', 'bdq', 'bdu', 'bdv', 'beh', 'bel', 'bem', 'ben', 'bep', 'bex', 'bfa', 'bfo', 'bfy', 'bfz', 'bgc', 'bgq', 'bgr', 'bgt', 'bgw', 'bha', 'bht', 'bhz', 'bib', 'bim', 'bis', 'biv', 'bjr', 'bjv', 'bjw', 'bjz', 'bkd', 'bkv', 'blh', 'blt', 'blx', 'blz', 'bmq', 'bmr', 'bmu', 'bmv', 'bng', 'bno', 'bnp', 'boa', 'bod', 'boj', 'bom', 'bor', 'bos', 'bov', 'box', 'bpr', 'bps', 'bqc', 'bqi', 'bqj', 'bqp', 'bre', 'bru', 'bsc', 'bsq', 'bss', 'btd', 'bts', 'btt', 'btx', 'bud', 'bul', 'bus', 'bvc', 'bvz', 'bwq', 'bwu', 'byr', 'bzh', 'bzi', 'bzj', 'caa', 'cab', 'cac-dialect_sanmateoixtatan', 'cac-dialect_sansebastiancoatan', 'cak-dialect_central', 'cak-dialect_santamariadejesus', 'cak-dialect_santodomingoxenacoj', 'cak-dialect_southcentral', 'cak-dialect_western', 'cak-dialect_yepocapa', 'cap', 'car', 'cas', 'cat', 'cax', 'cbc', 'cbi', 'cbr', 'cbs', 'cbt', 'cbu', 'cbv', 'cce', 'cco', 'cdj', 'ceb', 'ceg', 'cek', 'ces', 'cfm', 'cgc', 'che', 'chf', 'chv', 'chz', 'cjo', 'cjp', 'cjs', 'ckb', 'cko', 'ckt', 'cla', 'cle', 'cly', 'cme', 'cmn-script_simplified', 'cmo-script_khmer', 'cmo-script_latin', 'cmr', 'cnh', 'cni', 'cnl', 'cnt', 'coe', 'cof', 'cok', 'con', 'cot', 'cou', 'cpa', 'cpb', 'cpu', 'crh', 'crk-script_latin', 'crk-script_syllabics', 'crn', 'crq', 'crs', 'crt', 'csk', 'cso', 'ctd', 'ctg', 'cto', 'ctu', 'cuc', 'cui', 'cuk', 'cul', 'cwa', 'cwe', 'cwt', 'cya', 'cym', 'daa', 'dah', 'dan', 'dar', 'dbj', 'dbq', 'ddn', 'ded', 'des', 'deu', 'dga', 'dgi', 'dgk', 'dgo', 'dgr', 'dhi', 'did', 'dig', 'dik', 'dip', 'div', 'djk', 'dnj-dialect_blowowest', 'dnj-dialect_gweetaawueast', 'dnt', 'dnw', 'dop', 'dos', 'dsh', 'dso', 'dtp', 'dts', 'dug', 'dwr', 'dyi', 'dyo', 'dyu', 'dzo', 'eip', 'eka', 'ell', 'emp', 'enb', 'eng', 'enx', 'epo', 'ese', 'ess', 'est', 'eus', 'evn', 'ewe', 'eza', 'fal', 'fao', 'far', 'fas', 'fij', 'fin', 'flr', 'fmu', 'fon', 'fra', 'frd', 'fry', 'ful', 'gag-script_cyrillic', 'gag-script_latin', 'gai', 'gam', 'gau', 'gbi', 'gbk', 'gbm', 'gbo', 'gde', 'geb', 'gej', 'gil', 'gjn', 'gkn', 'gld', 'gle', 'glg', 'glk', 'gmv', 'gna', 'gnd', 'gng', 'gof-script_latin', 'gog', 'gor', 'gqr', 'grc', 'gri', 'grn', 'grt', 'gso', 'gub', 'guc', 'gud', 'guh', 'guj', 'guk', 'gum', 'guo', 'guq', 'guu', 'gux', 'gvc', 'gvl', 'gwi', 'gwr', 'gym', 'gyr', 'had', 'hag', 'hak', 'hap', 'hat', 'hau', 'hay', 'heb', 'heh', 'hif', 'hig', 'hil', 'hin', 'hlb', 'hlt', 'hne', 'hnn', 'hns', 'hoc', 'hoy', 'hrv', 'hsb', 'hto', 'hub', 'hui', 'hun', 'hus-dialect_centralveracruz', 'hus-dialect_westernpotosino', 'huu', 'huv', 'hvn', 'hwc', 'hye', 'hyw', 'iba', 'ibo', 'icr', 'idd', 'ifa', 'ifb', 'ife', 'ifk', 'ifu', 'ify', 'ign', 'ikk', 'ilb', 'ilo', 'imo', 'ina', 'inb', 'ind', 'iou', 'ipi', 'iqw', 'iri', 'irk', 'isl', 'ita', 'itl', 'itv', 'ixl-dialect_sangasparchajul', 'ixl-dialect_sanjuancotzal', 'ixl-dialect_santamarianebaj', 'izr', 'izz', 'jac', 'jam', 'jav', 'jbu', 'jen', 'jic', 'jiv', 'jmc', 'jmd', 'jpn', 'jun', 'juy', 'jvn', 'kaa', 'kab', 'kac', 'kak', 'kam', 'kan', 'kao', 'kaq', 'kat', 'kay', 'kaz', 'kbo', 'kbp', 'kbq', 'kbr', 'kby', 'kca', 'kcg', 'kdc', 'kde', 'kdh', 'kdi', 'kdj', 'kdl', 'kdn', 'kdt', 'kea', 'kek', 'ken', 'keo', 'ker', 'key', 'kez', 'kfb', 'kff-script_telugu', 'kfw', 'kfx', 'khg', 'khm', 'khq', 'kia', 'kij', 'kik', 'kin', 'kir', 'kjb', 'kje', 'kjg', 'kjh', 'kki', 'kkj', 'kle', 'klu', 'klv', 'klw', 'kma', 'kmd', 'kml', 'kmr-script_arabic', 'kmr-script_cyrillic', 'kmr-script_latin', 'kmu', 'knb', 'kne', 'knf', 'knj', 'knk', 'kno', 'kog', 'kor', 'kpq', 'kps', 'kpv', 'kpy', 'kpz', 'kqe', 'kqp', 'kqr', 'kqy', 'krc', 'kri', 'krj', 'krl', 'krr', 'krs', 'kru', 'ksb', 'ksr', 'kss', 'ktb', 'ktj', 'kub', 'kue', 'kum', 'kus', 'kvn', 'kvw', 'kwd', 'kwf', 'kwi', 'kxc', 'kxf', 'kxm', 'kxv', 'kyb', 'kyc', 'kyf', 'kyg', 'kyo', 'kyq', 'kyu', 'kyz', 'kzf', 'lac', 'laj', 'lam', 'lao', 'las', 'lat', 'lav', 'law', 'lbj', 'lbw', 'lcp', 'lee', 'lef', 'lem', 'lew', 'lex', 'lgg', 'lgl', 'lhu', 'lia', 'lid', 'lif', 'lin', 'lip', 'lis', 'lit', 'lje', 'ljp', 'llg', 'lln', 'lme', 'lnd', 'lns', 'lob', 'lok', 'lom', 'lon', 'loq', 'lsi', 'lsm', 'ltz', 'luc', 'lug', 'luo', 'lwo', 'lww', 'lzz', 'maa-dialect_sanantonio', 'maa-dialect_sanjeronimo', 'mad', 'mag', 'mah', 'mai', 'maj', 'mak', 'mal', 'mam-dialect_central', 'mam-dialect_northern', 'mam-dialect_southern', 'mam-dialect_western', 'maq', 'mar', 'maw', 'maz', 'mbb', 'mbc', 'mbh', 'mbj', 'mbt', 'mbu', 'mbz', 'mca', 'mcb', 'mcd', 'mco', 'mcp', 'mcq', 'mcu', 'mda', 'mdf', 'mdv', 'mdy', 'med', 'mee', 'mej', 'men', 'meq', 'met', 'mev', 'mfe', 'mfh', 'mfi', 'mfk', 'mfq', 'mfy', 'mfz', 'mgd', 'mge', 'mgh', 'mgo', 'mhi', 'mhr', 'mhu', 'mhx', 'mhy', 'mib', 'mie', 'mif', 'mih', 'mil', 'mim', 'min', 'mio', 'mip', 'miq', 'mit', 'miy', 'miz', 'mjl', 'mjv', 'mkd', 'mkl', 'mkn', 'mlg', 'mlt', 'mmg', 'mnb', 'mnf', 'mnk', 'mnw', 'mnx', 'moa', 'mog', 'mon', 'mop', 'mor', 'mos', 'mox', 'moz', 'mpg', 'mpm', 'mpp', 'mpx', 'mqb', 'mqf', 'mqj', 'mqn', 'mri', 'mrw', 'msy', 'mtd', 'mtj', 'mto', 'muh', 'mup', 'mur', 'muv', 'muy', 'mvp', 'mwq', 'mwv', 'mxb', 'mxq', 'mxt', 'mxv', 'mya', 'myb', 'myk', 'myl', 'myv', 'myx', 'myy', 'mza', 'mzi', 'mzj', 'mzk', 'mzm', 'mzw', 'nab', 'nag', 'nan', 'nas', 'naw', 'nca', 'nch', 'ncj', 'ncl', 'ncu', 'ndj', 'ndp', 'ndv', 'ndy', 'ndz', 'neb', 'new', 'nfa', 'nfr', 'nga', 'ngl', 'ngp', 'ngu', 'nhe', 'nhi', 'nhu', 'nhw', 'nhx', 'nhy', 'nia', 'nij', 'nim', 'nin', 'nko', 'nlc', 'nld', 'nlg', 'nlk', 'nmz', 'nnb', 'nno', 'nnq', 'nnw', 'noa', 'nob', 'nod', 'nog', 'not', 'npi', 'npl', 'npy', 'nso', 'nst', 'nsu', 'ntm', 'ntr', 'nuj', 'nus', 'nuz', 'nwb', 'nxq', 'nya', 'nyf', 'nyn', 'nyo', 'nyy', 'nzi', 'obo', 'oci', 'ojb-script_latin', 'ojb-script_syllabics', 'oku', 'old', 'omw', 'onb', 'ood', 'orm', 'ory', 'oss', 'ote', 'otq', 'ozm', 'pab', 'pad', 'pag', 'pam', 'pan', 'pao', 'pap', 'pau', 'pbb', 'pbc', 'pbi', 'pce', 'pcm', 'peg', 'pez', 'pib', 'pil', 'pir', 'pis', 'pjt', 'pkb', 'pls', 'plw', 'pmf', 'pny', 'poh-dialect_eastern', 'poh-dialect_western', 'poi', 'pol', 'por', 'poy', 'ppk', 'pps', 'prf', 'prk', 'prt', 'pse', 'pss', 'ptu', 'pui', 'pus', 'pwg', 'pww', 'pxm', 'qub', 'quc-dialect_central', 'quc-dialect_east', 'quc-dialect_north', 'quf', 'quh', 'qul', 'quw', 'quy', 'quz', 'qvc', 'qve', 'qvh', 'qvm', 'qvn', 'qvo', 'qvs', 'qvw', 'qvz', 'qwh', 'qxh', 'qxl', 'qxn', 'qxo', 'qxr', 'rah', 'rai', 'rap', 'rav', 'raw', 'rej', 'rel', 'rgu', 'rhg', 'rif-script_arabic', 'rif-script_latin', 'ril', 'rim', 'rjs', 'rkt', 'rmc-script_cyrillic', 'rmc-script_latin', 'rmo', 'rmy-script_cyrillic', 'rmy-script_latin', 'rng', 'rnl', 'roh-dialect_sursilv', 'roh-dialect_vallader', 'rol', 'ron', 'rop', 'rro', 'rub', 'ruf', 'rug', 'run', 'rus', 'sab', 'sag', 'sah', 'saj', 'saq', 'sas', 'sat', 'sba', 'sbd', 'sbl', 'sbp', 'sch', 'sck', 'sda', 'sea', 'seh', 'ses', 'sey', 'sgb', 'sgj', 'sgw', 'shi', 'shk', 'shn', 'sho', 'shp', 'sid', 'sig', 'sil', 'sja', 'sjm', 'sld', 'slk', 'slu', 'slv', 'sml', 'smo', 'sna', 'snd', 'sne', 'snn', 'snp', 'snw', 'som', 'soy', 'spa', 'spp', 'spy', 'sqi', 'sri', 'srm', 'srn', 'srp-script_cyrillic', 'srp-script_latin', 'srx', 'stn', 'stp', 'suc', 'suk', 'sun', 'sur', 'sus', 'suv', 'suz', 'swe', 'swh', 'sxb', 'sxn', 'sya', 'syl', 'sza', 'tac', 'taj', 'tam', 'tao', 'tap', 'taq', 'tat', 'tav', 'tbc', 'tbg', 'tbk', 'tbl', 'tby', 'tbz', 'tca', 'tcc', 'tcs', 'tcz', 'tdj', 'ted', 'tee', 'tel', 'tem', 'teo', 'ter', 'tes', 'tew', 'tex', 'tfr', 'tgj', 'tgk', 'tgl', 'tgo', 'tgp', 'tha', 'thk', 'thl', 'tih', 'tik', 'tir', 'tkr', 'tlb', 'tlj', 'tly', 'tmc', 'tmf', 'tna', 'tng', 'tnk', 'tnn', 'tnp', 'tnr', 'tnt', 'tob', 'toc', 'toh', 'tom', 'tos', 'tpi', 'tpm', 'tpp', 'tpt', 'trc', 'tri', 'trn', 'trs', 'tso', 'tsz', 'ttc', 'tte', 'ttq-script_tifinagh', 'tue', 'tuf', 'tuk-script_arabic', 'tuk-script_latin', 'tuo', 'tur', 'tvw', 'twb', 'twe', 'twu', 'txa', 'txq', 'txu', 'tye', 'tzh-dialect_bachajon', 'tzh-dialect_tenejapa', 'tzj-dialect_eastern', 'tzj-dialect_western', 'tzo-dialect_chamula', 'tzo-dialect_chenalho', 'ubl', 'ubu', 'udm', 'udu', 'uig-script_arabic', 'uig-script_cyrillic', 'ukr', 'umb', 'unr', 'upv', 'ura', 'urb', 'urd-script_arabic', 'urd-script_devanagari', 'urd-script_latin', 'urk', 'urt', 'ury', 'usp', 'uzb-script_cyrillic', 'uzb-script_latin', 'vag', 'vid', 'vie', 'vif', 'vmw', 'vmy', 'vot', 'vun', 'vut', 'wal-script_ethiopic', 'wal-script_latin', 'wap', 'war', 'waw', 'way', 'wba', 'wlo', 'wlx', 'wmw', 'wob', 'wol', 'wsg', 'wwa', 'xal', 'xdy', 'xed', 'xer', 'xho', 'xmm', 'xnj', 'xnr', 'xog', 'xon', 'xrb', 'xsb', 'xsm', 'xsr', 'xsu', 'xta', 'xtd', 'xte', 'xtm', 'xtn', 'xua', 'xuo', 'yaa', 'yad', 'yal', 'yam', 'yao', 'yas', 'yat', 'yaz', 'yba', 'ybb', 'ycl', 'ycn', 'yea', 'yka', 'yli', 'yor', 'yre', 'yua', 'yue-script_traditional', 'yuz', 'yva', 'zaa', 'zab', 'zac', 'zad', 'zae', 'zai', 'zam', 'zao', 'zaq', 'zar', 'zas', 'zav', 'zaw', 'zca', 'zga', 'zim', 'ziw', 'zlm', 'zmz', 'zne', 'zos', 'zpc', 'zpg', 'zpi', 'zpl', 'zpm', 'zpo', 'zpt', 'zpu', 'zpz', 'ztq', 'zty', 'zul', 'zyb', 'zyp', 'zza'])
モデルの場合は load_adapter()
関数を、トークナイザーの場合は set_target_lang()
を呼び出して、言語アダプターを切り替えます。前の手順で検出されたターゲット言語 ("detect_language_id"
) を入力として渡します。
asr_processor.tokenizer.set_target_lang(language_id)
asr_model.load_adapter(language_id)
Ignored unknown kwarg option normalize
Ignored unknown kwarg option normalize
Ignored unknown kwarg option normalize
Ignored unknown kwarg option normalize
推論には元のモデルを使用¶
inputs = asr_processor(example['audio']['array'], sampling_rate=16_000, return_tensors="pt")
with torch.no_grad():
outputs = asr_model(**inputs).logits
ids = torch.argmax(outputs, dim=-1)[0]
transcription = asr_processor.decode(ids)
print(transcription)
grisé par ce parfum il fit des vers en l'honneur de l'humble fleur des bois et il les récita tout haut à ses pieds une violette l'entendit elle crut qu'il ne parlait que pour elle
OpenVINO IR モデルに変換して推論を実行¶
ov.convert_model
関数を使用して OpenVINO IR モデル形式に直接変換します。ov.save_model
関数を使用して、変換結果をシリアル化します。今後の使用の利便性のために、これらの目的の関数を作成します。
asr_model_xml_path_template = 'models/ov_asr_{}_model.xml'
def get_asr_model(model_path_template, language_id, compiled=True):
input_values = torch.zeros([1, MAX_SEQ_LENGTH], dtype=torch.float)
model_path = Path(model_path_template.format(language_id))
asr_processor.tokenizer.set_target_lang(language_id)
if not model_path.exists() and model_path_template == asr_model_xml_path_template:
asr_model.load_adapter(language_id)
model_path.parent.mkdir(parents=True, exist_ok=True)
converted_model = ov.convert_model(asr_model, example_input={'input_values': input_values})
ov.save_model(converted_model, model_path)
if not compiled:
return converted_model
if compiled:
return core.compile_model(model_path, device_name=device.value)
return core.read_model(model_path)
compiled_asr_model = get_asr_model(asr_model_xml_path_template, language_id)
Ignored unknown kwarg option normalize
Ignored unknown kwarg option normalize
Ignored unknown kwarg option normalize
Ignored unknown kwarg option normalize
推論を実行します。
def recognize_audio(compiled_model, src_audio):
inputs = asr_processor(src_audio, sampling_rate=16_000, return_tensors="pt")
outputs = compiled_model(inputs['input_values'])[0]
ids = torch.argmax(torch.from_numpy(outputs), dim=-1)[0]
transcription = asr_processor.decode(ids)
return transcription
transcription = recognize_audio(compiled_asr_model, example['audio']['array'])
print("Original text:", example['text'])
print("Transcription:", transcription)
Original text: grisé par ce parfum il fit des vers en l'honneur de l'humble fleur des bois et il les récita tout haut à ses pieds une violette l'entendit elle crut qu'il ne parlait que pour elle
Transcription: grisé par ce parfum il fit des vers en l'honneur de l'humble fleur des bois et il les récita tout haut à ses pieds une violette l'entendit elle crut qu'il ne parlait que pour elle
量子化¶
NNCF は、量子化レイヤーをモデルグラフに追加し、トレーニング・データセットのサブセットを使用してこれらの追加の量子化レイヤーのパラメーターを初期化することで、トレーニング後の量子化を可能にします。量子化操作は FP32
/FP16
ではなく INT8
で実行されるため、モデル推論が高速化されます。
最適化プロセスには次の手順が含まれます。
量子化用のキャリブレーション・データセットを作成します。
nncf.quantize()
を実行して、量子化されたモデルを取得します。openvino.save_model()
関数を使用して量子化されたINT8
モデルをシリアル化します。
注: 量子化は時間とメモリーを消費する操作です。以下の量子化コードの実行には時間がかかる場合があります。
compiled_quantized_lid_model = None
quantized_asr_model_xml_path_template = None
to_quantize = widgets.Checkbox(
value=False,
description='Quantization',
disabled=False,
)
to_quantize
Checkbox(value=True, description='Quantization')
to_quantize が選択されていない場合に量子化をスキップするスキップマジック拡張機能をロードします。
import sys
sys.path.append("../utils")
%load_ext skip_kernel_extension
キャリブレーション・データセットの準備¶
モデルを量子化する言語を選択します。
%%skip not $to_quantize.value
from IPython.display import display
display(SAMPLE_LANG)
選択した言語の同じ MLS データセットの検証分割を読み込みます。
%%skip not $to_quantize.value
mls_dataset = iter(load_dataset("facebook/multilingual_librispeech", SAMPLE_LANG.value, split="validation", streaming=True))
example = next(mls_dataset)
量子化用のキャリブレーション・データセットを作成します。
%%skip not $to_quantize.value
CALIBRATION_DATASET_SIZE = 5
calibration_data = []
for i in range(CALIBRATION_DATASET_SIZE):
data = asr_processor(next(mls_dataset)['audio']['array'], sampling_rate=16_000, return_tensors="np")
calibration_data.append(data["input_values"])
言語識別モデルの量子化¶
LID モデルの量子化を実行します。
%%skip not $to_quantize.value
import nncf
quantized_lid_model_xml_path = Path(str(lid_model_xml_path).replace(".xml", "_quantized.xml"))
if not quantized_lid_model_xml_path.exists():
quantized_lid_model = nncf.quantize(
get_lid_model(lid_model_xml_path, compiled=False),
calibration_dataset=nncf.Dataset(calibration_data),
preset=nncf.QuantizationPreset.MIXED,
subset_size=len(calibration_data),
model_type=nncf.ModelType.TRANSFORMER
)
ov.save_model(quantized_lid_model, quantized_lid_model_xml_path)
compiled_quantized_lid_model = core.compile_model(quantized_lid_model, device_name=device.value)
else:
compiled_quantized_lid_model = get_lid_model(quantized_lid_model_xml_path)
INFO:nncf:NNCF initialized successfully. Supported frameworks detected: torch, openvino
Statistics collection: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:06<00:00, 1.24s/it]
Applying Smooth Quant: 100%|██████████████████████████████████████████████████████████████████████████████████████████████| 291/291 [00:18<00:00, 15.34it/s]
INFO:nncf:144 ignored nodes was found by name in the NNCFGraph
Statistics collection: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:18<00:00, 3.65s/it]
Applying Fast Bias correction: 100%|██████████████████████████████████████████████████████████████████████████████████████| 298/298 [05:09<00:00, 1.04s/it]
量子化モデルを使用して言語を検出します。
%%skip not $to_quantize.value
language_id = detect_language(compiled_quantized_lid_model, example['audio']['array'])
print("Detected language:", language_id)
Detected language: fra
音声認識モデルの量子化¶
ASR モデルの量子化を実行します。
%%skip not $to_quantize.value
quantized_asr_model_xml_path_template = asr_model_xml_path_template.replace(".xml", "_quantized.xml")
quantized_asr_model_xml_path = Path(quantized_asr_model_xml_path_template.format(language_id))
if not quantized_asr_model_xml_path.exists():
quantized_asr_model = nncf.quantize(
get_asr_model(asr_model_xml_path_template, language_id, compiled=False),
calibration_dataset=nncf.Dataset(calibration_data),
preset=nncf.QuantizationPreset.MIXED,
subset_size=len(calibration_data),
model_type=nncf.ModelType.TRANSFORMER
)
ov.save_model(quantized_asr_model, quantized_asr_model_xml_path)
compiled_quantized_asr_model = core.compile_model(quantized_asr_model, device_name=device.value)
else:
compiled_quantized_asr_model = get_asr_model(quantized_asr_model_xml_path_template, language_id)
Ignored unknown kwarg option normalize
Ignored unknown kwarg option normalize
Ignored unknown kwarg option normalize
Ignored unknown kwarg option normalize
Statistics collection: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:05<00:00, 1.17s/it]
Applying Smooth Quant: 100%|██████████████████████████████████████████████████████████████████████████████████████████████| 290/290 [00:17<00:00, 16.39it/s]
INFO:nncf:144 ignored nodes was found by name in the NNCFGraph
Statistics collection: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:19<00:00, 3.93s/it]
Applying Fast Bias correction: 100%|██████████████████████████████████████████████████████████████████████████████████████| 393/393 [05:22<00:00, 1.22it/s]
量子化モデルを使用して転写を実行し、その結果を元のモデルによって生成された結果と比較します。
%%skip not $to_quantize.value
compiled_asr_model = get_asr_model(asr_model_xml_path_template, language_id)
transcription_original = recognize_audio(compiled_asr_model, example['audio']['array'])
transcription_quantized = recognize_audio(compiled_quantized_asr_model, example['audio']['array'])
print("Transcription by original model: ", transcription_original)
print("Transcription by quantized model:", transcription_quantized)
Ignored unknown kwarg option normalize
Ignored unknown kwarg option normalize
Ignored unknown kwarg option normalize
Ignored unknown kwarg option normalize
Transcription by original model: le salon était de la plus haute magnificence dorée comme la galerie de diane aux tuileries avec des tableaux à l'huile au lombri il y avait des tâches claires dans ces tableaux julien apprit plus tard que les sujets avaient semblé peu décent à la maîtresse du logis qui avait fait corriger les tableaux
Transcription by quantized model: le salon était de la plus haute magnificence doré comme la galerie de diane aux tuileries avec des tableaux à l'huile au lombri il y avait des tâches claires dans ces tableaux julien apprit plus tard que les sujets avaient semblé peu decent à la maîtresse du logis qui avait fait corriger les tableaux
モデルのサイズ、パフォーマンス、精度を比較¶
まずモデルのサイズを比較します。
%%skip not $to_quantize.value
def calculate_compression_rate(model_path_ov, model_path_ov_int8, model_type):
model_size_fp32 = model_path_ov.with_suffix(".bin").stat().st_size / 10 ** 6
model_size_int8 = model_path_ov_int8.with_suffix(".bin").stat().st_size / 10 ** 6
print(f"{model_type} model footprint comparison:")
print(f" * FP32 IR model size: {model_size_fp32:.2f} MB")
print(f" * INT8 IR model size: {model_size_int8:.2f} MB")
return model_size_fp32, model_size_int8
lid_model_size_fp32, lid_model_size_int8 = \
calculate_compression_rate(lid_model_xml_path, quantized_lid_model_xml_path, 'LID')
asr_model_size_fp32, asr_model_size_int8 = \
calculate_compression_rate(Path(asr_model_xml_path_template.format(language_id)), quantized_asr_model_xml_path, 'ASR')
LID model footprint comparison:
* FP32 IR model size: 1931.81 MB
* INT8 IR model size: 968.96 MB
ASR model footprint comparison:
* FP32 IR model size: 1930.10 MB
* INT8 IR model size: 968.29 MB
次に、MLS データセットのテスト分割で元のモデルと量子化されたモデルの精度値を比較します。ここでは Word Error Rate (WER) メトリックに依存し、精度を (1 - WER)
として計算します。
また、言語識別モデルと音声認識モデルの両方の推論時間も測定します。
%%skip not $to_quantize.value
import time
from tqdm.notebook import tqdm
import numpy as np
from jiwer import wer
TEST_DATASET_SIZE = 20
test_dataset = load_dataset("facebook/multilingual_librispeech", SAMPLE_LANG.value, split="test", streaming=True)
test_dataset = test_dataset.take(TEST_DATASET_SIZE)
def calculate_transcription_time_and_accuracy(lid_model, asr_model):
ground_truths = []
predictions = []
identification_time = []
transcription_time = []
for data_item in tqdm(test_dataset, desc="Measuring performance and accuracy", total=TEST_DATASET_SIZE):
audio = data_item["audio"]["array"]
start_time = time.perf_counter()
detect_language(lid_model, audio)
end_time = time.perf_counter()
identification_time.append(end_time - start_time)
start_time = time.perf_counter()
transcription = recognize_audio(asr_model, audio)
end_time = time.perf_counter()
transcription_time.append(end_time - start_time)
ground_truths.append(data_item["text"])
predictions.append(transcription)
word_accuracy = (1 - wer(ground_truths, predictions)) * 100
mean_identification_time = np.mean(identification_time)
mean_transcription_time = np.mean(transcription_time)
return mean_identification_time, mean_transcription_time, word_accuracy
identification_time_fp32, transcription_time_fp32, accuracy_fp32 = \
calculate_transcription_time_and_accuracy(compiled_lid_model, compiled_asr_model)
identification_time_int8, transcription_time_int8, accuracy_int8 = \
calculate_transcription_time_and_accuracy(compiled_quantized_lid_model, compiled_quantized_asr_model)
print(f"LID model footprint reduction: {lid_model_size_fp32 / lid_model_size_int8:.3f}")
print(f"ASR model footprint reduction: {asr_model_size_fp32 / asr_model_size_int8:.3f}")
print(f"Language identification performance speedup: {identification_time_fp32 / identification_time_int8:.3f}")
print(f"Language transcription performance speedup: {transcription_time_fp32 / transcription_time_int8:.3f}")
print(f"Transcription word accuracy. FP32: {accuracy_fp32:.2f}%. INT8: {accuracy_int8:.2f}%. Accuracy drop :{accuracy_fp32 - accuracy_int8:.2f}%.")
Measuring performance and accuracy: 0%| | 0/20 [00:00<?, ?it/s]
Measuring performance and accuracy: 0%| | 0/20 [00:00<?, ?it/s]
LID model footprint reduction: 1.994
ASR model footprint reduction: 1.993
Language identification performance speedup: 1.425
Language transcription performance speedup: 1.489
Transcription word accuracy. FP32: 85.01%. INT8: 84.76%. Accuracy drop :0.25%.
Gradio によるインタラクティブなデモ¶
このデモでは、独自の例を試すことができます。オーディオデータが 16000 kHz でサンプリングされていることを確認してください。
import gradio as gr
import librosa
import time
title = 'MMS with Gradio'
description = 'Gradio Demo for MMS and OpenVINO™. Upload a source audio, then click the "Submit" button to detect a language ID and a transcription. ' \
'Make sure that the audio data is sampled to 16000 kHz. If this language has not been used before, it may take some time to prepare the ASR model.' \
'\n' \
'> Note: In order to run quantized model to transcribe some language, first the quantized model for that specific language must be prepared.'
current_state = {
"fp32": {"model": None, "language": None},
"int8": {"model": None, "language": None}
}
def infer(src_audio_path, quantized):
src_audio, _ = librosa.load(src_audio_path)
lid_model = compiled_quantized_lid_model if quantized else compiled_lid_model
start_time = time.perf_counter()
detected_language_id = detect_language(lid_model, src_audio)
end_time = time.perf_counter()
identification_delta_time = f"{end_time - start_time:.2f}"
state = current_state["int8" if quantized else "fp32"]
if detected_language_id != state["language"]:
template_path = quantized_asr_model_xml_path_template if quantized else asr_model_xml_path_template
try:
gr.Info(f"Loading {'quantized' if quantized else ''} ASR model for '{detected_language_id}' language. "
"This will take some time.")
state["model"] = get_asr_model(template_path, detected_language_id)
state["language"] = detected_language_id
except RuntimeError as e:
if "Unable to read the model:" in str(e) and quantized:
raise gr.Error(f"There is no quantized ASR model for '{detected_language_id}' language. "
"Please run quantization for this language first.")
start_time = time.perf_counter()
transcription = recognize_audio(state["model"], src_audio)
end_time = time.perf_counter()
transcription_delta_time = f"{end_time - start_time:.2f}"
return detected_language_id, transcription, identification_delta_time, transcription_delta_time
with gr.Blocks() as demo:
with gr.Row():
gr.Markdown(f"# {title}")
with gr.Row():
gr.Markdown(description)
run_button = {True: None, False: None}
detected_language = {True: None, False: None}
transcription = {True: None, False: None}
identification_time = {True: None, False: None}
transcription_time = {True: None, False: None}
for quantized in [False, True]:
if quantized and not to_quantize.value:
break
with gr.Row():
with gr.Column():
if not quantized:
audio = gr.Audio(label="Source Audio", type='filepath')
run_button_name = "Run INT8" if quantized else "Run FP32" if to_quantize.value else "Run"
run_button[quantized] = gr.Button(value=run_button_name)
with gr.Column():
detected_language[quantized] = gr.Textbox(label=f"Detected language ID{' (Quantized)' if quantized else ''}")
transcription[quantized] = gr.Textbox(label=f"Transcription{' (Quantized)' if quantized else ''}")
identification_time[quantized] = gr.Textbox(label=f"Identification time{' (Quantized)' if quantized else ''}")
transcription_time[quantized] = gr.Textbox(label=f"Transcription time{' (Quantized)' if quantized else ''}")
run_button[False].click(infer,
inputs=[audio, gr.Number(0, visible=False)],
outputs=[detected_language[False], transcription[False], identification_time[False], transcription_time[False]])
if to_quantize.value:
run_button[True].click(infer,
inputs=[audio, gr.Number(1, visible=False)],
outputs=[detected_language[True], transcription[True], identification_time[True], transcription_time[True]])
try:
try:
demo.queue().launch(debug=False)
except Exception:
demo.queue().launch(share=True, debug=False)
# if you are launching remotely, specify server_name and server_port
# demo.launch(server_name='your server name', server_port='server port in int')
# Read more in the docs: https://gradio.app/docs/
Running on local URL: http://127.0.0.1:7860 To create a public link, set share=True in launch().
Ignored unknown kwarg option normalize
Ignored unknown kwarg option normalize
Ignored unknown kwarg option normalize
Ignored unknown kwarg option normalize
Ignored unknown kwarg option normalize
Ignored unknown kwarg option normalize
Ignored unknown kwarg option normalize
Ignored unknown kwarg option normalize
Ignored unknown kwarg option normalize
Ignored unknown kwarg option normalize
Ignored unknown kwarg option normalize
Ignored unknown kwarg option normalize
WARNING:nncf:NNCF provides best results with torch==2.0.1, while current torch version is 1.13.1+cu117. If you encounter issues, consider switching to torch==2.0.1
No CUDA runtime is found, using CUDA_HOME='/usr/local/cuda-11.7'
/home/nsavel/venvs/ov_notebooks_tmp/lib/python3.8/site-packages/transformers/models/wav2vec2/modeling_wav2vec2.py:595: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
/home/nsavel/venvs/ov_notebooks_tmp/lib/python3.8/site-packages/transformers/models/wav2vec2/modeling_wav2vec2.py:634: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):