CTCLoss

バージョン名: CTCLoss-4

カテゴリー: シーケンス処理

簡単な説明: CTCLoss は、CTC (コネクショニズム時間分類) 損失を計算します。

詳細説明:

CTCLoss 操作は、Connectionist Temporal Classics - Labeling Unsegmented Sequence Data with Recurrent Neural Networks: Graves et al., 2016 で紹介されています。

CTCLoss は、ロジット logits[i,:,:] の指定された入力シーケンスに対してターゲット labels[i,:] が発生する可能性 (または実際に) を推定します。要約すると、CTCLoss 操作は、ターゲット labels[i,:] にアライメントされたすべてのシーケンスを見つけ、logits[i,:,:] でアライメントされたシーケンスの対数確率を計算し、これらの対数確率の負の和を計算します。

ロジット logits の入力シーケンスは、異なる長さにできます。各シーケンスの長さ logits[i,:,:]logit_length[i] と等しくなります。ターゲットシーケンス labels[i,:] の長さは label_length[i] と等しくなります。ターゲットシーケンスの長さは、対応する入力シーケンス logits[i,:,:] の長さを超えてはなりません。それ以外の場合、操作の動作は未定義です。

CTCLoss 計算スキーム:

  1. ソフトマックスの公式を使用して、logits から i 番目の入力シーケンスのタイムステップ t における j 番目の文字の確率を計算します。

\[p_{i,t,j} = \frac{\exp(logits[i,t,j])}{\sum^{K}_{k=0}{\exp(logits[i,t,k])}}\]
  1. 指定された i 番目のターゲットに対して、labels[i,:] からすべての位置一致したパスを検索します。デコード後に両方のチェーンが等しい場合、パス S = (c1,c2,...,cT) はターゲット G=(g1,g2,...,gT) に位置合わせます。デコードでは、ターゲット G から長さ label_length[i] の部分文字列が抽出され、preprocess_collapse_repeated が true の場合は G 内の繰り返し文字がマージされ、unique が true のは文字の出現順序で一意の要素が検索されます。デコードでは、ctc_merge_repeated が true の場合に S 内の繰り返し文字がマージされ、blank_index で表される空白文字が削除されます。デフォルトでは、blank_indexC-1 に等しくなります。ここで、C はブランクを含むクラスの数です。例えば、デフォルトの ctc_merge_repeatedpreprocess_collapse_repeatedunique そして blank_index の場合、長さ label_length[i]=4 のターゲットシーケンス G=(0,3,2,2,2,2,2,4,3)(0,3,2,2)、長さ logit_length[i]=9 のパス S=(0,0,4,3,2,2,4,2,4)(0,3,2,2) になります。ここで C=5 です。0,4,3,3,2,4,2,2,2 など、G と一致する他のパスも存在します。ターゲット label[:,i] との位置合わせがチェックされるパスは、長さが logit_length[i] = L_i である必要があります。位置合わせされたパス (位置合わせ) の確率を次のように計算します。

\[p(S) = \prod_{t=1}^{L_i} p_{i,t,ct}\]
  1. 最後に、見つかったすべてのアライメントの合計確率の負の対数を計算します。

\[CTCLoss = - \ln \sum_{S} p(S)\]

注 1: この計算スキームは、最適な実装の手順を提供するものではなく、説明を分かりやすくするために役立ちます。

注 2: これは、整列されたパスの対数確率 \(\ln p(S)\) を入力ロジットの log-softmax の合計として計算することを推奨します。計算中のアンダーフローやオーバーフローを回避するのに役立ちます。整列されたパスの対数確率があれば、これらのパスの合計確率の対数は次のように計算できます。

\[\ln(a + b) = \ln(a) + \ln(1 + \exp(\ln(b) - \ln(a)))\]

属性:

  • preprocess_collapse_repeated

    • 説明: preprocess_collapse_repeated は、損失計算の前の前処理ステップのフラグであり、損失に渡される labels[i,:] 内の繰り返しラベルが単一のラベルにマージされます。

    • 値の範囲: true または false

    • タイプ: boolean

    • デフォルト値 : false

    • 必須: いいえ

  • ctc_merge_repeated

    • 説明: ctc_merge_repeated は、CTC 損失計算中の潜在的な位置合わせで繰り返される文字をマージするためのフラグです。

    • 値の範囲: true または false

    • タイプ: boolean

    • デフォルト値: true

    • 必須: いいえ

  • unique

    • 説明: unique は、潜在的なアラインメントと照合する前に、ターゲット labels[i,:] の一意の要素を検索するフラグです。処理された labels[i,:] 内の固有の要素は、元の labels[i,:] での出現順に並べ替えられます。例えば、長さ labels[i,:]=(0,1,1,0,1,3,3,2,2,3)label_length[i]=10 の処理されたシーケンスは、unique が true の場合 (0,1,3,2) になります。

    • 値の範囲: true または false

    • タイプ: boolean

    • デフォルト値 : false

    • 必須: いいえ

入力:

  • 1: logits - ロジットのシーケンスのバッチを含む入力テンソル。要素のタイプは T_F です。テンソルの形状は [N, T, C] です。ここで、N はバッチサイズ、T は最大シーケンス長、C はブランクを含むクラスの数です。必須。

  • 2: logit_length - タイプ T1 および形状 [N] の 1D 入力テンソル。テンソルは、T 以下の負ではない値で構成されなければなりません。ロジット logits[i,:,:] の入力シーケンスの長さ。必須。

  • 3: labels - T2 タイプの形状 [N, T] を持つ 2D テンソル。ターゲットシーケンス labels[i,:] の長さは、label_length[i] に等しく、blank_index を除く範囲 [0; C-1] が含まれている必要があります。必須。

  • 4: label_length - タイプ T1 および形状 [N] の 1D テンソル。テンソルは、すべての可能な i に対して T および label_length[i] <= logit_length[i] 以下の負ではない値で構成されなければなりません。必須。

  • 5: blank_index - T2 タイプのスカラー。空白のラベルに使用するクラスのインデックスを設定します。デフォルト値は C-1 です。オプション。

Output

  • 1: 形状 [N] の出力テンソル、アライメントの対数確率の負の合計。要素のタイプは T_F です。

タイプ:

  • T_F: サポートされている浮動小数点タイプ。

  • T1T2: int32 または int64

例:

<layer ... type="CTCLoss" ...>
    <input>
        <port id="0">
            <dim>8</dim>
            <dim>20</dim>
            <dim>128</dim>
        </port>
        <port id="1">
            <dim>8</dim>
        </port>
        <port id="2">
            <dim>8</dim>
            <dim>20</dim>
        </port>
        <port id="3">
            <dim>8</dim>
        </port>
        <port id="4">  <!-- blank_index value is: 120 -->
    </input>
    <output>
        <port id="0">
            <dim>8</dim>
        </port>
    </output>
</layer>