EmbeddingBagOffsetsSum#

バージョン名: EmbeddingBagOffsetsSum-3

カテゴリー: Sparse

簡単な説明: 中間の埋め込みをインスタンス化せずに、埋め込みの “バッグ” の合計を計算します。

詳細な説明:

EmbeddingBagOffsets 操作は、インデックスとオフセット入力が 1D テンソルである torch.nn.EmbeddingBag の実装です。

この演算子は、indices 内の各インデックスに対して、emb_table 埋め込みテーブルから値を収集します。次に、同じバッグ範囲内のインデックスの値 (offset 入力に基づく) が、reduction 属性に従って削減されます。

offsets の値は、各 “バッグ” のインデックス・テンソルの開始インデックスを定義します。値 [0, 3, 4, 4, 6]offsets は、バッグごとのインデックス [indices[0:3], indices[3:4], empty_bag, indices[4:6], indices[6:]] スライスに対応する [3, 1, 0, 2, num_indices-6] 要素を含む 5 つの “バッグ” を定義します。

EmbeddingBagOffsetsSum は、次の NumPy コードの一部と同等です:

def embedding_bag_offsets( 
    emb_table: np.ndarray, 
    indices: np.ndarray, 
    offsets: np.ndarray, 
    default_index: Optional[int] = None, 
    per_sample_weights: Optional[np.ndarray] = None, 
): 
    if per_sample_weights is None: 
        per_sample_weights = np.ones_like(indices) 
    embeddings = [] 
    for emb_idx, emb_weight in zip(indices, per_sample_weights): 
        embeddings.append(emb_table[emb_idx] * emb_weight) 
    previous_offset = offsets[0] 
    bags = [] 
    offsets = np.append(offsets, len(indices)) 
    for bag_offset in offsets[1:]: 
        bag_size = bag_offset - previous_offset 
        if bag_size != 0: 
            embedding_bag = embeddings[previous_offset:bag_offset] 
            reduced_bag = np.add.reduce(embedding_bag) 
            bags.append(reduced_bag) 
        else:             # Empty bag case 
            if default_index is not None and default_index != -1: 
                bags.append(emb_table[default_index]) 
            else: 
                bags.append(np.zeros(emb_table.shape[1:])) 
            previous_offset = bag_offset 
return np.stack(bags, axis=0)

属性: EmbeddingBagOffsetsSum 操作には属性がありません。

入力:

  • 1: 形状 [num_emb, emb_dim1, emb_dim2, ...] およびタイプ T のモジュールの埋め込みルックアップ・テーブルを含む emb_table テンソル。必須。

  • 2: 形状 [num_indices] およびタイプ T_IND のインデックス・テンソル。必須。

  • 3: インデックス内の各 “バッグ” の開始インデックス位置を含む、形状 [batch] およびタイプ T_INDoffsets テンソル。オフセットの最大値は indices の長さより大きくすることはできません。必須。

  • 4: 空の “バッグ” を埋める埋め込みテーブル内のデフォルトのインデックスを含む、T_IND タイプの default_index スカラー。-1 に設定される、または提供されない場合、空の “バッグ” にはゼロが埋められます。負の値を使用した逆インデックスはサポートされていません。オプション。

  • 5: インデックスと同じ形状でタイプ Tper_sample_weights テンソル。このテンソルの各値は、各インデックスの埋め込みテーブルからプールされた各値と乗算されます。オプション。デフォルトは 1 のテンソルです。オプション。

出力:

  • 1: 形状 [batch, emb_dim1, emb_dim2, ...] および各バッグの埋め込みを含むタイプ T のテンソル。

タイプ

  • T: 任意の数値タイプ。

  • T_IND: int32 または int64

例 1: per_sample_weights が指定され、default_index は 0 に設定されると、指定されたインデックスの emb_table から収集された値で空のバッグが埋められます。

<layer ... type="EmbeddingBagOffsetsSum" ... > 
    <input> 
        <port id="0"> <!-- emb_table 値: [[-0.2, -0.6], [-0.1, -0.4], [-1.9, -1.8], [-1., 1.5], [ 0.8, -0.7]] --> 
            <dim>5</dim> 
            <dim>2</dim> 
        </port> 
        <port id="1"> <!-- インデックス値: [0, 2, 3, 4] --> 
            <dim>4</dim> 
        </port> 
        <port id="2"> <!-- オフセット値: [0, 2, 2] - 3 つの "バッグ" には [2,0,4-2] の要素が含まれており、2 番目の "バッグ" は空です --> 
            <dim>3</dim> 
        </port> 
        <port id="3"/> <!-- default_index 値: 0 --> 
        <port id="4"/> <!-- per_sample_weights 値: [0.5, 0.5, 0.5, 0.5] --> 
            <dim>4</dim> 
        </port> 
    </input> 
    <output> 
        <port id="5"> <!-- output 値: [[-1.05, -1.2], [-0.2, -0.6], [-0.1, 0.4]] --> 
            <dim>3</dim> 
            <dim>2</dim> 
        </port> 
    </output> 
</layer>

例 2: per_sample_weights が指定され、default_index は -1 に設定されると、バッグは 0 で埋められます。

<layer ... type="EmbeddingBagOffsets" ... > 
    <input> 
        <port id="0"> <!-- emb_table 値: [[-0.2, -0.6], [-0.1, -0.4], [-1.9, -1.8], [-1., 1.5], [ 0.8, -0.7]] --> 
            <dim>5</dim> 
            <dim>2</dim> 
        </port> 
        <port id="1"> <!-- インデックス値: [0, 2, 3, 4] --> 
            <dim>4</dim> 
        </port> 
        <port id="2"> <!-- オフセット値: [0, 2, 2] - 3 つの "バッグ" には [2,0,4-2] の要素が含まれており、2 番目の "バッグ" は空です --> 
            <dim>3</dim> 
        </port> 
        <port id="3"/> <!-- default_index 値: -1 - 空のバッグを 0 で埋める --> 
        <port id="4"/> <!-- per_sample_weights 値: [0.5, 0.5, 0.5, 0.5] --> 
            <dim>4</dim> 
        </port> 
    </input> 
    <output> 
        <port id="5"> <!-- output 値: [[-1.05, -1.2], [0., 0], [-0.1, 0.4]] --> 
            <dim>3</dim> 
            <dim>2</dim> 
        </port> 
    </output> 
</layer>