GatherTree¶
バージョン名: GatherTree-1
カテゴリー: データ移動
簡単な説明: 各ステップの ID と親ビーム ID から完全なビームを生成します。
詳細説明:
GatherTree 操作は、親ビーム ID を表す入力テンソル parent_ids
に基づいて、ビーム検索の各ステップごとの ID を表す特定の入力テンソル step_id
のトークン ID を並べ替えます。特定のビームについて、最初にデコードされた end_token
を含むタイムステップを過ぎると、すべての値が end_token
で埋められます。
擬似コードのアルゴリズムは次のとおりです。
final_ids[ :, :, :] = end_token
for batch in range(BATCH_SIZE):
for beam in range(BEAM_WIDTH):
max_sequence_in_beam = min(MAX_TIME, max_seq_len[batch])
parent = parent_ids[max_sequence_in_beam - 1, batch, beam]
final_ids[max_sequence_in_beam - 1, batch, beam] = step_ids[max_sequence_in_beam - 1, batch, beam]
for level in reversed(range(max_sequence_in_beam - 1)):
final_ids[level, batch, beam] = step_ids[level, batch, parent]
parent = parent_ids[level, batch, parent]
# For a given beam, past the time step containing the first decoded end_token
# all values are filled in with end_token.
finished = False
for time in range(max_sequence_in_beam):
if(finished):
final_ids[time, batch, beam] = end_token
elif(final_ids[time, batch, beam] == end_token):
finished = True
GatherTree 操作は、TensorFlow の GatherTree 操作と同等です。
属性: GatherTree 操作には属性がありません。
入力:
-
1:
step_ids
- 各ステップごとのインデックス。タイプ T およびランク 3 のテンソル。レイアウトは[MAX_TIME, BATCH_SIZE, BEAM_WIDTH]
です。必須。 -
2:
parent_ids
- 親ビームのインデックス。タイプ T およびランク 3 のテンソル。レイアウトは[MAX_TIME, BATCH_SIZE, BEAM_WIDTH]
です。必須。 -
3:
max_seq_len
- バッチ内の各シーケンスの最大長。タイプ T およびランク 1 のテンソル。レイアウトは[BATCH_SIZE]
です。必須。 -
4:
end_token
- シーケンス内の終了マーカーの値。T タイプのスカラー。必須。
注: 入力には整数値のみを含める必要があります。
出力:
-
1:
final_ids
-parent_ids
入力に基づいて並べ替えられたトークン ID。タイプ T およびランク 3 のテンソル。レイアウトは[MAX_TIME, BATCH_SIZE, BEAM_WIDTH]
です。
タイプ:
T: サポートされている数値タイプ。
例:
<layer type="GatherTree" ...>
<input>
<port id="0">
<dim>100</dim>
<dim>1</dim>
<dim>10</dim>
</port>
<port id="1">
<dim>100</dim>
<dim>1</dim>
<dim>10</dim>
</port>
<port id="2">
<dim>1</dim>
</port>
<port id="3">
</port>
</input>
<output>
<port id="0">
<dim>100</dim>
<dim>1</dim>
<dim>10</dim>
</port>
</output>
</layer>