GatherTree#

バージョン名: GatherTree-1

カテゴリー: データ移動

簡単な説明: 各ステップの ID と親ビーム ID から完全なビームを生成します。

詳細な説明

GatherTree 操作は、親ビーム ID を表す入力テンソル parent_ids に基づいて、ビーム検索の各ステップごとの ID を表す特定の入力テンソル step_id のトークン ID を並べ替えます。特定のビームについて、最初にデコードされた end_token を含むタイムステップを過ぎると、すべての値が end_token で埋められます。

擬似コードのアルゴリズムは次のとおりです:

final_ids[ :, :, :] = end_token 
for batch in range(BATCH_SIZE): 
    for beam in range(BEAM_WIDTH): 
        max_sequence_in_beam = min(MAX_TIME, max_seq_len[batch]) 
        parent = parent_ids[max_sequence_in_beam - 1, batch, beam] 
        final_ids[max_sequence_in_beam - 1, batch, beam] = step_ids[max_sequence_in_beam - 1, batch, beam] 

        for level in reversed(range(max_sequence_in_beam - 1)): 
            final_ids[level, batch, beam] = step_ids[level, batch, parent] 
            parent = parent_ids[level, batch, parent] 

        # 特定のビームの場合、最初にデコードされた end_token を含む時間ステップを経過すると、 
        # すべての値は end_token で埋められます。 
        finished = False 
        for time in range(max_sequence_in_beam): 
            if(finished): 
                final_ids[time, batch, beam] = end_token 
            elif(final_ids[time, batch, beam] == end_token): 
                finished = True

GatherTree 操作は、TensorFlow の GatherTree 操作と同等です。

属性: GatherTree 操作には属性がありません。

入力

  • 1: step_ids - 各ステップごとのインデックス。タイプ T およびランク 3 のテンソル。レイアウトは [MAX_TIME, BATCH_SIZE, BEAM_WIDTH] です。必須。

  • 2: parent_ids - 親ビームのインデックス。タイプ T およびランク 3 のテンソル。レイアウトは [MAX_TIME, BATCH_SIZE, BEAM_WIDTH] です。必須。

  • 3: max_seq_len - バッチ内の各シーケンスの最大長。タイプ T およびランク 1 のテンソル。レイアウトは [BATCH_SIZE] です。必須。

  • 4: end_token - シーケンス内の終了マーカーの値。T タイプのスカラー。必須。

  • : 入力には整数値のみを含める必要があります。

出力

  • 1: final_ids - parent_ids 入力に基づいて並べ替えられたトークン ID。タイプ T およびランク 3 のテンソル。レイアウトは [MAX_TIME, BATCH_SIZE, BEAM_WIDTH] です。

タイプ

  • T: サポートされている数値タイプ。

<layer type="GatherTree" ...> 
    <input> 
        <port id="0"> 
            <dim>100</dim> 
            <dim>1</dim> 
            <dim>10</dim> 
        </port> 
        <port id="1"> 
            <dim>100</dim> 
            <dim>1</dim> 
            <dim>10</dim> 
        </port> 
        <port id="2"> 
            <dim>1</dim> 
        </port> 
        <port id="3"> 
        </port> 
    </input> 
    <output> 
        <port id="0"> 
            <dim>100</dim> 
            <dim>1</dim> 
            <dim>10</dim> 
        </port> 
    </output> 
</layer>