MatMul

バージョン名: MatMul-1

カテゴリー: 行列乗算

簡単な説明: 一般化された行列の乗算

詳細な説明:

MatMul 操作は 2 つのテンソルを受け取り、引数の形状に応じて通常の行列-行列乗算、行列-ベクトル乗算、またはベクトル-行列乗算を行います。入力テンソルは任意のランク >= 1 を持つことができます。各テンソルの右端の 2 つの軸は行列の行と列の次元として解釈され、左端のすべての軸 (存在する場合) は多次元バッチとして解釈されます: [BATCH_DIM_1, BATCH_DIM_2,..., BATCH_DIM_K, ROW_INDEX_DIM, COL_INDEX_DIM]。この操作は、バッチ次元の通常のブロードキャスト・セマンティクスをサポートします。これにより、行列のペアのバッチを一度に乗算できます。

行列乗算の前に、入力引数の暗黙的な形状の調整が行われます。これは次の手順で構成されます。

  1. オプションの transpose_a および transpose_b 属性で指定された転置を適用します。最も右の 2 つの次元のみが置き換えられ、他の次元はそのままです。1D テンソルの場合、転置属性は無視されます。

  2. 1 次元テンソルの削除 (アンスクイーズ) は、入力ごとに個別に適用されます。このステップで挿入された軸は、出力形状には含まれません。

    • 最初の入力のランクが 1 である場合、形状の左側の ROW_INDEX_DIM にサイズ 1 の軸を追加することによって、(transpose_a に関係なく) 常に 2D テンソル行ベクトルでは削除されます。例えば、[S][1, S] に変形されます。

    • 2 番目の入力のランクが 1 である場合、形状の右側の COL_INDEX_DIM にサイズ 1 の軸を追加することによって、(transpose_b に関係なく) 常に 2D テンソル列ベクトルでは削除されます。例えば、[S][S, 1] に変形されます。

  3. ステップ 1 と 2 の後で入力引数のランクが異なる場合、小さいランクのテンソルが必要な軸数だけ形状の左側から削除され、両方の形状が同じランクになります。

  4. ブロードキャストの通常のルールがバッチ次元にも適用されます。

ステップ 2 で挿入された一時軸は、乗算後の最終出力形状から削除されます。ベクトル-行列乗算の後、ROW_INDEX_DIM に挿入された一時軸が削除されます。行列-ベクトル乗算の後、COL_INDEX_DIM に挿入された一時的な軸が削除されます。2 つの 1D テンソル乗算 [S] x [S] の出力形状はスカラーに圧縮されます。

出力形状推論ロジックの例 (ここで ND は 1D より大きいことを意味します):

  • 1D x 1D: [X] x [X] -> [1, X] x [X, 1] -> [1, 1] => [] (スカラー)

  • 1D x ND: [X] x [B, ..., X, Y] -> [1, X] x [B, ..., X, Y] -> [B, ..., 1, Y] => [B, ..., Y]

  • ND x 1D: [B, ..., X, Y] x [Y] -> [B, ..., X, Y] x [Y, 1] -> [B, ..., X, 1] => [B, ..., X]

  • ND x ND: [B, ..., X, Y] x [B, ..., Y, Z] => [B, ..., X, Z]

2 つの属性 transpose_atranspose_b は、対応して 1 番目と 2 番目の入力テンソルの右端の 2 つの次元の埋め込み転置を指定します。これは、対応する入力テンソルの ROW_INDEX_DIM と COL_INDEX_DIM の交換を意味します。バッチ次元と 1D テンソルは、これらの属性の影響を受けません。

属性:

  • transpose_a

    • 説明: 最初の入力の次元 ROW_INDEX_DIM と COL_INDEX_DIM を転置します。false は転置しないことを、true は転置することを意味します。最初の入力が 1D テンソルの場合、これは無視されます。

    • 値の範囲: true または false

    • タイプ: ブール値

    • デフォルト値 : false

    • 必須: いいえ

  • transpose_b

    • 説明: 2 番目の入力の次元 ROW_INDEX_DIM と COL_INDEX_DIM を転置します。false は転置しないことを、true は転置することを意味します。2 番目の入力が 1D テンソルの場合は無視されます。

    • 値の範囲: true または false

    • タイプ: ブール値

    • デフォルト値 : false

    • 必須: いいえ

入力:

  • 1: 行列 A を持つ T 型のタイプ。ランク >= 1。必須。

  • 2: 行列 B を持つ T 型のタイプ。ランク >= 1。必須。

出力:

  • 1: 乗算の結果を含む T タイプのテンソル。

タイプ:

  • T: サポートされている浮動小数点または整数タイプ。

例:

ベクトル-行列乗算

<layer ... type="MatMul">
    <input>
        <port id="0">
            <dim>1024</dim>
        </port>
        <port id="1">
            <dim>1024</dim>
            <dim>1000</dim>
        </port>
    </input>
    <output>
        <port id="2">
            <dim>1000</dim>
        </port>
    </output>
</layer>

行列-ベクトル乗算

<layer ... type="MatMul">
    <input>
        <port id="0">
            <dim>1000</dim>
            <dim>1024</dim>
        </port>
        <port id="1">
            <dim>1024</dim>
        </port>
    </input>
    <output>
        <port id="2">
            <dim>1000</dim>
        </port>
    </output>
</layer>

行列-行列乗算 (バッチサイズ 1 の FullyConnected など)

<layer ... type="MatMul">
    <input>
        <port id="0">
            <dim>1</dim>
            <dim>1024</dim>
        </port>
        <port id="1">
            <dim>1024</dim>
            <dim>1000</dim>
        </port>
    </input>
    <output>
        <port id="2">
            <dim>1</dim>
            <dim>1000</dim>
        </port>
    </output>
</layer>

2 番目の行列の転置を埋め込んだベクトル行列乗算

<layer ... type="MatMul">
    <data transpose_b="true"/>
    <input>
        <port id="0">
            <dim>1024</dim>
        </port>
        <port id="1">
            <dim>1000</dim>
            <dim>1024</dim>
        </port>
    </input>
    <output>
        <port id="2">
            <dim>1000</dim>
        </port>
    </output>
</layer>

行列-行列乗算 (バッチサイズ 10 の FullyConnected など)

<layer ... type="MatMul">
    <input>
        <port id="0">
            <dim>10</dim>
            <dim>1024</dim>
        </port>
        <port id="1">
            <dim>1024</dim>
            <dim>1000</dim>
        </port>
    </input>
    <output>
        <port id="2">
            <dim>10</dim>
            <dim>1000</dim>
        </port>
    </output>
</layer>

ブロードキャストによる 5 つの行列のバッチと 1 つの行列の乗算

<layer ... type="MatMul">
    <input>
        <port id="0">
            <dim>5</dim>
            <dim>10</dim>
            <dim>1024</dim>
        </port>
        <port id="1">
            <dim>1024</dim>
            <dim>1000</dim>
        </port>
    </input>
    <output>
        <port id="2">
            <dim>5</dim>
            <dim>10</dim>
            <dim>1000</dim>
        </port>
    </output>
</layer>