ScaledDotProductAttention

バージョン名: ScaledDotProductAttention-13

カテゴリー: シーケンス処理

簡単な説明: ScaledDotProductAttention は、トレーニング関連のパラメーターを省略して、部分的に torch.nn.functional.scaled_dot_product_attention を実装します。

詳細説明:

ScaledDotProductAttention は、OpenVINO opset および numpy からの他の操作を使用して、次の疑似コードに従って機能を提供します。

def ScaledDotProductAttention(query, key, value, attn_mask=None, scale=None, *, causal):
    L, S = Gather(ShapeOf(query), -2), Gather(ShapeOf(key), -2)
    if scale is None:
        scale = 1.0 / Sqrt(ConvertLike(Gather(ShapeOf(query), -1), query))
    attn_bias = Broadcast(ConvertLike(0, query), [L, S])
    if causal:
        attn_bias = numpy.triu(Broadcast(ConvertLike(-inf, query), [L, S]), k=1)
    elif attn_mask is not None:
        if attn_mask.element_type == boolean:
            attn_bias = Select(LogicalNot(attn_mask), ConvertLike(-inf, query), ConvertLike(0, query))
        else:
            attn_bias += attn_mask
    attn_weight = MatMul(query, Transpose(key, [-2, -1])) * scale
    attn_weight += attn_bias
    attn_weight = Softmax(attn_weight, axis=-1)
    return MatMul(attn_weight, value)

属性:

  • causal

    • 説明: true の場合、擬似コードに従って因果的注意マスクを想定します。この場合、以下に説明する attention_mask 入力は無視されます。

    • 値の範囲: ブール値

    • タイプ: bool

    • 必須: はい

入力:

  • 1: query - タイプ T および形状 [N, ..., L, E] の 3 次元テンソル。必須。

  • 2: key - タイプ T および形状 [N, ..., S, E] の 3 次元テンソル。必須。

  • 3: value - タイプ T および形状 [N, ..., S, Ev] の 3 次元テンソル。必須。

  • 4: attention_mask - 2 つのオプションが使用可能です。attention_mask は無視されます (causalTrue に設定されている場合)。オプション。

    • T タイプまたは boolean タイプおよび形状 [N, ..., L, S] の少なくとも 3 次元テンソル。

    • 0T タイプのスカラー。スカラー 0 値は、アテンション・マスクを適用する必要がないことを示します (提供された疑似コードで attention_mask=None を指定するのと同様)。

  • 5: T タイプのスカラーテンソルを scale します。これは、上記の疑似コードでデフォルトで使用される 1/sqrt(query.shape[-1]) の代わりの代替スケール係数です。オプション。

出力:

  • 1: - スケーリングされたドット積アテンションの結果、タイプ T および形状 [N, ..., L, Ev] のテンソル。

タイプ:

  • T: サポートされている浮動小数点タイプ。

次元

  • N, ... - 1 つ以上のバッチ次元。各バッチ次元は、入力テンソル (クエリ、キー、値) 全体で定数であるか、同じバッチサイズを持つことを示すか、同じ値にブロードキャスト可能である必要があります。

  • S - ソースシーケンスの長さ

  • L - ターゲットシーケンスの長さ

  • E - クエリーとキーの埋め込み次元

  • Ev - 値の埋め込み次元

querykey および value の入力には、少なくとも 1 つのバッチ次元 N が必要です。他のバッチ次元 ... はオプションです。

例:

例 1: 1 つのバッチ次元、動的次元のサポート

 <layer id="285" name="aten::scaled_dot_product_attention_0" type="ScaledDotProductAttention" version="opset13">
                     <data causal="false" />
                     <input>
                             <!-- Example with simple dimensions, with N = 1, L = -1, S = -1, E = 80, Ev = 80-->
                             <port id="0" precision="FP32"> <!-- query -->
                                     <dim>1</dim> <!-- N -->
                                     <dim>-1</dim> <!-- L -->
                                     <dim>80</dim> <!-- E -->
                             </port>
                             <port id="1" precision="FP32"> <!-- key -->
                                     <dim>1</dim> <!-- N -->
                                     <dim>-1</dim> <!-- S -->
                                     <dim>80</dim> <!-- E -->
                             </port>
                             <port id="2" precision="FP32"> <!-- value -->
                                     <dim>1</dim> <!-- N -->
                                     <dim>-1</dim> <!-- S -->
                                     <dim>80</dim> <!-- Ev -->
                             </port>
                             <port id="3" precision="FP32"> <!-- attention_mask -->
                                     <dim>1</dim> <!-- N -->
                                     <dim>-1</dim> <!-- L -->
                                     <dim>-1</dim> <!-- S -->
                             </port>
                     </input>
                     <output>
                             <port id="4" precision="FP32">
                                     <dim>1</dim> <!-- N -->
                                     <dim>-1</dim> <!-- L -->
                                     <dim>80</dim> <!-- Ev -->
                             </port>
                     </output>
             </layer>

例 2: 複数のバッチ次元の一致

 <layer id="286" name="aten::scaled_dot_product_attention_0" type="ScaledDotProductAttention" version="opset13">
                     <data causal="false" />
                     <input>
                             <!-- Multiple batch dimensions: N1 = 1, N2 = 2, N3 = 3-->
                             <port id="0" precision="FP32"> <!-- query -->
                                     <dim>1</dim> <!-- N1 -->
                                     <dim>2</dim> <!-- N2 -->
                                     <dim>3</dim> <!-- N3 -->
                                     <dim>-1</dim> <!-- L -->
                                     <dim>80</dim> <!-- E -->
                             </port>
                             <port id="1" precision="FP32"> <!-- key -->
                                     <dim>1</dim> <!-- N1 -->
                                     <dim>2</dim> <!-- N2 -->
                                     <dim>3</dim> <!-- N3 -->
                                     <dim>-1</dim> <!-- S -->
                                     <dim>80</dim> <!-- E -->
                             </port>
                             <port id="2" precision="FP32"> <!-- value -->
                                     <dim>1</dim> <!-- N1 -->
                                     <dim>2</dim> <!-- N2 -->
                                     <dim>3</dim> <!-- N3 -->
                                     <dim>-1</dim> <!-- S -->
                                     <dim>80</dim> <!-- Ev -->
                             </port>
                             <port id="3" precision="FP32"> <!-- attention_mask -->
                                     <dim>1</dim> <!-- N1 -->
                                     <dim>2</dim> <!-- N2 -->
                                     <dim>3</dim> <!-- N3 -->
                                     <dim>-1</dim> <!-- L -->
                                     <dim>-1</dim> <!-- S -->
                             </port>
                     </input>
                     <output>
                             <port id="4" precision="FP32">
                                     <dim>1</dim> <!-- N1 -->
                                     <dim>2</dim> <!-- N2 -->
                                     <dim>3</dim> <!-- N3 -->
                                     <dim>-1</dim> <!-- L -->
                                     <dim>80</dim> <!-- Ev -->
                             </port>
                     </output>
             </layer>

例 3: バッチ次元のブロードキャストあり

 <layer id="287" name="aten::scaled_dot_product_attention_0" type="ScaledDotProductAttention" version="opset13">
                     <data causal="false" />
                     <input>
                             <!-- Multiple batch dimensions, broadcastable to the following values: N1 = 4, N2 = 6, N3 = 10-->
                             <port id="0" precision="FP32"> <!-- query -->
                                     <dim>1</dim> <!-- N1 (repeat 4 times) -->
                                     <dim>6</dim> <!-- N2 (repeat 1 time)-->
                                     <dim>5</dim> <!-- N3 (repeat 2 times)-->
                                     <dim>-1</dim> <!-- L -->
                                     <dim>80</dim> <!-- E -->
                             </port>
                             <port id="1" precision="FP32"> <!-- key -->
                                     <dim>2</dim> (repeat 2 times)<!-- N1 -->
                                     <dim>2</dim> (repeat 3 times)<!-- N2 -->
                                     <dim>2</dim> (repeat 5 times)<!-- N3 -->
                                     <dim>-1</dim> <!-- S -->
                                     <dim>80</dim> <!-- E -->
                             </port>
                             <port id="2" precision="FP32"> <!-- value -->
                                     <dim>4</dim> <!-- N1 (repeat 1 time)-->
                                     <dim>3</dim> <!-- N2 (repeat 2 times)-->
                                     <dim>10</dim> <!-- N3 (repeat 1 time)-->
                                     <dim>-1</dim> <!-- S -->
                                     <dim>80</dim> <!-- Ev -->
                             </port>
                             <port id="3" precision="FP32"> <!-- attention_mask -->
                                     <dim>1</dim> <!-- N1 (repeat 4 times)-->
                                     <dim>2</dim> <!-- N2 (repeat 3 times)-->
                                     <dim>1</dim> <!-- N3 (repeat 10 times)-->
                                     <dim>-1</dim> <!-- L -->
                                     <dim>-1</dim> <!-- S -->
                             </port>
                     </input>
                     <output>
                             <!-- Output contains broadcasted dimensions N1 = 4, N2 = 6, N3 = 10-->
                             <port id="4" precision="FP32">
                                     <dim>4</dim> <!-- N1 -->
                                     <dim>6</dim> <!-- N2 -->
                                     <dim>10</dim> <!-- N3 -->
                                     <dim>-1</dim> <!-- L -->
                                     <dim>80</dim> <!-- Ev -->
                             </port>
                     </output>
             </layer>