ScaledDotProductAttention¶
バージョン名: ScaledDotProductAttention-13
カテゴリー: シーケンス処理
簡単な説明: ScaledDotProductAttention は、トレーニング関連のパラメーターを省略して、部分的に torch.nn.functional.scaled_dot_product_attention を実装します。
詳細説明:
ScaledDotProductAttention は、OpenVINO opset および numpy
からの他の操作を使用して、次の疑似コードに従って機能を提供します。
def ScaledDotProductAttention(query, key, value, attn_mask=None, scale=None, *, causal):
L, S = Gather(ShapeOf(query), -2), Gather(ShapeOf(key), -2)
if scale is None:
scale = 1.0 / Sqrt(ConvertLike(Gather(ShapeOf(query), -1), query))
attn_bias = Broadcast(ConvertLike(0, query), [L, S])
if causal:
attn_bias = numpy.triu(Broadcast(ConvertLike(-inf, query), [L, S]), k=1)
elif attn_mask is not None:
if attn_mask.element_type == boolean:
attn_bias = Select(LogicalNot(attn_mask), ConvertLike(-inf, query), ConvertLike(0, query))
else:
attn_bias += attn_mask
attn_weight = MatMul(query, Transpose(key, [-2, -1])) * scale
attn_weight += attn_bias
attn_weight = Softmax(attn_weight, axis=-1)
return MatMul(attn_weight, value)
属性:
-
causal
説明: true の場合、擬似コードに従って因果的注意マスクを想定します。この場合、以下に説明する
attention_mask
入力は無視されます。値の範囲: ブール値
タイプ:
bool
必須: はい
入力:
1:
query
- タイプ T および形状[N, ..., L, E]
の 3 次元テンソル。必須。2:
key
- タイプ T および形状[N, ..., S, E]
の 3 次元テンソル。必須。3:
value
- タイプ T および形状[N, ..., S, Ev]
の 3 次元テンソル。必須。-
4:
attention_mask
- 2 つのオプションが使用可能です。attention_mask
は無視されます (causal
がTrue
に設定されている場合)。オプション。T タイプまたは
boolean
タイプおよび形状[N, ..., L, S]
の少なくとも 3 次元テンソル。値
0
の T タイプのスカラー。スカラー 0 値は、アテンション・マスクを適用する必要がないことを示します (提供された疑似コードで attention_mask=None を指定するのと同様)。
5: T タイプのスカラーテンソルを
scale
します。これは、上記の疑似コードでデフォルトで使用される 1/sqrt(query.shape[-1]) の代わりの代替スケール係数です。オプション。
出力:
1: - スケーリングされたドット積アテンションの結果、タイプ T および形状
[N, ..., L, Ev]
のテンソル。
タイプ:
T: サポートされている浮動小数点タイプ。
次元
N, ...
- 1 つ以上のバッチ次元。各バッチ次元は、入力テンソル (クエリ、キー、値) 全体で定数であるか、同じバッチサイズを持つことを示すか、同じ値にブロードキャスト可能である必要があります。S
- ソースシーケンスの長さL
- ターゲットシーケンスの長さE
- クエリーとキーの埋め込み次元Ev
- 値の埋め込み次元
query
、key
および value
の入力には、少なくとも 1 つのバッチ次元 N
が必要です。他のバッチ次元 ...
はオプションです。
例:
例 1: 1 つのバッチ次元、動的次元のサポート
<layer id="285" name="aten::scaled_dot_product_attention_0" type="ScaledDotProductAttention" version="opset13">
<data causal="false" />
<input>
<!-- Example with simple dimensions, with N = 1, L = -1, S = -1, E = 80, Ev = 80-->
<port id="0" precision="FP32"> <!-- query -->
<dim>1</dim> <!-- N -->
<dim>-1</dim> <!-- L -->
<dim>80</dim> <!-- E -->
</port>
<port id="1" precision="FP32"> <!-- key -->
<dim>1</dim> <!-- N -->
<dim>-1</dim> <!-- S -->
<dim>80</dim> <!-- E -->
</port>
<port id="2" precision="FP32"> <!-- value -->
<dim>1</dim> <!-- N -->
<dim>-1</dim> <!-- S -->
<dim>80</dim> <!-- Ev -->
</port>
<port id="3" precision="FP32"> <!-- attention_mask -->
<dim>1</dim> <!-- N -->
<dim>-1</dim> <!-- L -->
<dim>-1</dim> <!-- S -->
</port>
</input>
<output>
<port id="4" precision="FP32">
<dim>1</dim> <!-- N -->
<dim>-1</dim> <!-- L -->
<dim>80</dim> <!-- Ev -->
</port>
</output>
</layer>
例 2: 複数のバッチ次元の一致
<layer id="286" name="aten::scaled_dot_product_attention_0" type="ScaledDotProductAttention" version="opset13">
<data causal="false" />
<input>
<!-- Multiple batch dimensions: N1 = 1, N2 = 2, N3 = 3-->
<port id="0" precision="FP32"> <!-- query -->
<dim>1</dim> <!-- N1 -->
<dim>2</dim> <!-- N2 -->
<dim>3</dim> <!-- N3 -->
<dim>-1</dim> <!-- L -->
<dim>80</dim> <!-- E -->
</port>
<port id="1" precision="FP32"> <!-- key -->
<dim>1</dim> <!-- N1 -->
<dim>2</dim> <!-- N2 -->
<dim>3</dim> <!-- N3 -->
<dim>-1</dim> <!-- S -->
<dim>80</dim> <!-- E -->
</port>
<port id="2" precision="FP32"> <!-- value -->
<dim>1</dim> <!-- N1 -->
<dim>2</dim> <!-- N2 -->
<dim>3</dim> <!-- N3 -->
<dim>-1</dim> <!-- S -->
<dim>80</dim> <!-- Ev -->
</port>
<port id="3" precision="FP32"> <!-- attention_mask -->
<dim>1</dim> <!-- N1 -->
<dim>2</dim> <!-- N2 -->
<dim>3</dim> <!-- N3 -->
<dim>-1</dim> <!-- L -->
<dim>-1</dim> <!-- S -->
</port>
</input>
<output>
<port id="4" precision="FP32">
<dim>1</dim> <!-- N1 -->
<dim>2</dim> <!-- N2 -->
<dim>3</dim> <!-- N3 -->
<dim>-1</dim> <!-- L -->
<dim>80</dim> <!-- Ev -->
</port>
</output>
</layer>
例 3: バッチ次元のブロードキャストあり
<layer id="287" name="aten::scaled_dot_product_attention_0" type="ScaledDotProductAttention" version="opset13">
<data causal="false" />
<input>
<!-- Multiple batch dimensions, broadcastable to the following values: N1 = 4, N2 = 6, N3 = 10-->
<port id="0" precision="FP32"> <!-- query -->
<dim>1</dim> <!-- N1 (repeat 4 times) -->
<dim>6</dim> <!-- N2 (repeat 1 time)-->
<dim>5</dim> <!-- N3 (repeat 2 times)-->
<dim>-1</dim> <!-- L -->
<dim>80</dim> <!-- E -->
</port>
<port id="1" precision="FP32"> <!-- key -->
<dim>2</dim> (repeat 2 times)<!-- N1 -->
<dim>2</dim> (repeat 3 times)<!-- N2 -->
<dim>2</dim> (repeat 5 times)<!-- N3 -->
<dim>-1</dim> <!-- S -->
<dim>80</dim> <!-- E -->
</port>
<port id="2" precision="FP32"> <!-- value -->
<dim>4</dim> <!-- N1 (repeat 1 time)-->
<dim>3</dim> <!-- N2 (repeat 2 times)-->
<dim>10</dim> <!-- N3 (repeat 1 time)-->
<dim>-1</dim> <!-- S -->
<dim>80</dim> <!-- Ev -->
</port>
<port id="3" precision="FP32"> <!-- attention_mask -->
<dim>1</dim> <!-- N1 (repeat 4 times)-->
<dim>2</dim> <!-- N2 (repeat 3 times)-->
<dim>1</dim> <!-- N3 (repeat 10 times)-->
<dim>-1</dim> <!-- L -->
<dim>-1</dim> <!-- S -->
</port>
</input>
<output>
<!-- Output contains broadcasted dimensions N1 = 4, N2 = 6, N3 = 10-->
<port id="4" precision="FP32">
<dim>4</dim> <!-- N1 -->
<dim>6</dim> <!-- N2 -->
<dim>10</dim> <!-- N3 -->
<dim>-1</dim> <!-- L -->
<dim>80</dim> <!-- Ev -->
</port>
</output>
</layer>