GatherTree

バージョン名: GatherTree-1

カテゴリー: データ移動

簡単な説明: 各ステップの ID と親ビーム ID から完全なビームを生成します。

詳細説明:

GatherTree 操作は、親ビーム ID を表す入力テンソル parent_ids に基づいて、ビーム検索の各ステップごとの ID を表す特定の入力テンソル step_id のトークン ID を並べ替えます。特定のビームについて、最初にデコードされた end_token を含むタイムステップを過ぎると、すべての値が end_token で埋められます。

擬似コードのアルゴリズムは次のとおりです。

final_ids[ :, :, :] = end_token
for batch in range(BATCH_SIZE):
    for beam in range(BEAM_WIDTH):
        max_sequence_in_beam = min(MAX_TIME, max_seq_len[batch])

        parent = parent_ids[max_sequence_in_beam - 1, batch, beam]

        final_ids[max_sequence_in_beam - 1, batch, beam] = step_ids[max_sequence_in_beam - 1, batch, beam]

        for level in reversed(range(max_sequence_in_beam - 1)):
            final_ids[level, batch, beam] = step_ids[level, batch, parent]

            parent = parent_ids[level, batch, parent]

        # For a given beam, past the time step containing the first decoded end_token
        # all values are filled in with end_token.
        finished = False
        for time in range(max_sequence_in_beam):
            if(finished):
                final_ids[time, batch, beam] = end_token
            elif(final_ids[time, batch, beam] == end_token):
                finished = True

GatherTree 操作は、TensorFlow の GatherTree 操作と同等です。

属性: GatherTree 操作には属性がありません。

入力:

  • 1: step_ids - 各ステップごとのインデックス。タイプ T およびランク 3 のテンソル。レイアウトは [MAX_TIME, BATCH_SIZE, BEAM_WIDTH] です。必須。

  • 2: parent_ids - 親ビームのインデックス。タイプ T およびランク 3 のテンソル。レイアウトは [MAX_TIME, BATCH_SIZE, BEAM_WIDTH] です。必須。

  • 3: max_seq_len - バッチ内の各シーケンスの最大長。タイプ T およびランク 1 のテンソル。レイアウトは [BATCH_SIZE] です。必須。

  • 4: end_token - シーケンス内の終了マーカーの値。T タイプのスカラー。必須。

注: 入力には整数値のみを含める必要があります。

出力:

  • 1: final_ids - parent_ids 入力に基づいて並べ替えられたトークン ID。タイプ T およびランク 3 のテンソル。レイアウトは [MAX_TIME, BATCH_SIZE, BEAM_WIDTH] です。

タイプ:

  • T: サポートされている数値タイプ。

例:

<layer type="GatherTree" ...>
    <input>
        <port id="0">
            <dim>100</dim>
            <dim>1</dim>
            <dim>10</dim>
        </port>
        <port id="1">
            <dim>100</dim>
            <dim>1</dim>
            <dim>10</dim>
        </port>
        <port id="2">
            <dim>1</dim>
        </port>
        <port id="3">
        </port>
    </input>
    <output>
        <port id="0">
            <dim>100</dim>
            <dim>1</dim>
            <dim>10</dim>
        </port>
    </output>
</layer>