OpenVINO マッチャーパス

ov::pass::MatcherPass はパターンベースの変換に使用されます。

MatcherPass 変換クラスのテンプレート

// transformations/template_pattern_transformation.hpp
/**
 * @ingroup ie_transformation_common_api
 * @brief Add transformation description.
 */
class ov::pass::DecomposeDivideMatcher : public ov::pass::MatcherPass {
public:
    OPENVINO_RTTI("DecomposeDivideMatcher", "0");
    DecomposeDivideMatcher();
};
// template_pattern_transformation.cpp
ov::pass::DecomposeDivideMatcher::DecomposeDivideMatcher() {
    MATCHER_SCOPE(DecomposeDivideMatcher);
    // Pattern example
    auto input0 = pattern::any_input();
    auto input1 = pattern::any_input();
    auto div = std::make_shared<ov::opset3::Divide>(input0, input1);

    ov::matcher_pass_callback callback = [](pattern::Matcher& m) {
        auto div = std::dynamic_pointer_cast<ov::opset3::Divide>(m.get_match_root());
        // We can not apply this transformation in case with integer input data type
        if (!div || div->input(0).get_element_type().is_integral()) {
            return false;
        }

        // Decompose Divide into Multiply with Power operations
        auto pow = std::make_shared<ov::opset3::Power>(
            div->input_value(1),
            opset3::Constant::create(div->get_input_element_type(1), Shape{1}, {-1}));

        auto mul = std::make_shared<ov::opset3::Multiply>(div->input_value(0), pow);

        // Save original name to last operation in replacement sub-graph
        mul->set_friendly_name(div->get_friendly_name());

        // Copy runtime info attributes to newly created operation
        ov::copy_runtime_info(div, {pow, mul});

        // Replace Divide operation with Multiply
        ov::replace_node(div, mul);

        // Return true as the root node was changed
        return true;
    };

    // Register pattern with Divide operation as a pattern root node
    auto m = std::make_shared<ov::pass::pattern::Matcher>(div, "ConvertDivide");
    // Register Matcher
    register_matcher(m, callback);
}

ov::pass::MatcherPass を使用するには、次のステップを完了する必要があります。

  1. パターンを作成

  2. コールバックを実装

  3. パターンとマッチャーを登録

  4. MatcherPass を実行

これらの各ステップを確認します。

パターンを作成

パターンは単一のルート ov::Model です。ただし、唯一の違いは、モデル・オブジェクトを作成する必要がなく、opset または特別なパターンの操作を作成して接続するだけで済みます。次に、最後に作成した操作を取得し、それをパターンのルートに配置する必要があります。このルートノードはパターンマッチングのルートノードとして使用されます。

コンシューマーを持たず、ルートに登録されていないパターン内のノードは、パターンマッチングでは使用されません。

// Pattern example
auto input = std::make_shared<ov::opset8::Parameter>(ov::element::i64, ov::Shape{1});
auto shapeof = std::make_shared<ov::opset8::ShapeOf>(input);

// Create Matcher with Parameter->ShapeOf pattern
auto m = std::make_shared<ov::pass::pattern::Matcher>(shapeof, "MyPatternBasedTransformation");

上の例の Parameter 操作には、タイプと形状が指定されています。これらの属性は、パラメーター操作クラスを作成する場合にのみ必要であり、パターンマッチングでは使用されません。

パターンの例の詳細は、パターンマッチングを参照してください。

コールバックを実装

コールバックは、すべてのパターンの入り口に適用されるアクションです。通常、コールバックは、検出されたサブグラフを持つマッチャー・オブジェクトを受け取るラムダ関数です。

ov::graph_rewrite_callback callback = [](ov::pass::pattern::Matcher& m) {
    // Get root node
    std::shared_ptr<ov::Node> root_node = m.get_match_root();

    // Get all nodes matched by pattern
    ov::NodeVector nodes = m.get_matched_nodes();

    // Transformation code
    return false;
};

上の例は、コールバック構造と、パターンによって検出されたノードにアクセスするマッチャーの使い方を示しています。ルートノードが置き換えられ、別のパターンを同じルートノードに適用できない場合、コールバックの戻り値は true です。それ以外は false です。

ルートノード下にあるノードを操作することはお勧めできません。これは、トポロジーの順序でルートノードに続くすべてのノードが有効であり、パターンマッチングで使用できることが期待されるため、GraphRewrite の実行に影響する可能性があります。

MatcherPass は、追加のパターンマッチングで使用できる、新しく作成されたノードをレポートする機能も提供します。MatcherPass が ov::pass::Manager または ov::pass::GraphRewrite に登録されている場合、それらのノードは追加のパターンマッチングに追加されます。つまり、ov::pass::GraphRewrite に登録されたマッチャーパスがこれらのノードに適用されます。

以下の例は、単一の MatcherPass が register_new_node メソッドを使用して一連の操作を融合する方法を示しています。

ov::pass::ReluReluFusionMatcher::ReluReluFusionMatcher() {
    MATCHER_SCOPE(ReluReluFusionMatcher);
    auto m_relu1 = ov::pass::pattern::wrap_type<ov::opset3::Relu>(pattern::consumers_count(1));
    auto m_relu2 = ov::pass::pattern::wrap_type<ov::opset3::Relu>({m_relu1});

    ov::matcher_pass_callback callback = [=](pattern::Matcher& m) {
        // Map that helps to connect labels with matched outputs
        auto& node_to_output = m.get_pattern_value_map();

        // Create new Relu operation and add register it for additional execution
        auto new_relu =
            register_new_node<ov::opset3::Relu>(node_to_output.at(m_relu1).get_node_shared_ptr()->input_value(0));

        // Copy runtime info attributes to newly created operation
        ov::copy_runtime_info(m.get_matched_nodes(), new_relu);

        // Save last Relu name to new Relu operation
        new_relu->set_friendly_name(m.get_match_root()->get_friendly_name());

        // Replace Relu->Relu with Relu
        ov::replace_node(m.get_match_root(), new_relu);

        // Return true as the root node was changed
        return true;
    };

    // Register pattern with Relu operation as a pattern root node
    auto m = std::make_shared<ov::pass::pattern::Matcher>(m_relu2, "ReluReluFusion");
    // Register Matcher
    register_matcher(m, callback);
}

複数のノードを登録する場合は、トポロジー順に追加してください。これは時間のかかる操作であるため、ノードはトポロジー的にソートされません。

パターンとマッチャーを登録

最後のステップは、マッチャーとコールバックを MatcherPass パス内に登録することです。これには、register_matcher メソッドを呼び出します。

1 つの MatcherPass クラスに対して登録できるマッチャーは 1 つだけです。

// Register matcher and callback
register_matcher(m, callback);

MatcherPass を実行

MatcherPass には複数の実行方法があります。

  • 単一ノードで実行 - 別の変換内で MatcherPass を実行する場合に便利です。

    if (ov::pass::DecomposeDivideMatcher().apply(node)) {
        // successful execution (root node was replaced)
    }
  • GraphRewrite を使用して ov::Model で実行 - このアプローチでは、ov::Model 全体で MatcherPass を実行できます。さらに、複数の MatcherPass 変換を 1 つの GraphRewite に登録し、1 つのグラフ走査で実行することができます。

    // Two matcher passes will run simultaneously in a single graph traversal
    ov::pass::GraphRewrite pass;
    pass.add_matcher<ov::pass::DecomposeDivideMatcher>();
    pass.add_matcher<ov::pass::ReluReluFusionMatcher>();
    pass.run_on_model(f);
  • ov::pass::Manager を使用して ov::Model で実行 - このアプローチは、別の変換タイプとして ov::Model で実行する MatcherPass を登録するのに役立ちます。

    // Two matchers will run independently (two independent graph traversals)
    // pass::Manager automatically creates GraphRewrite container for each MatcherPass
    ov::pass::Manager manager;
    manager.register_pass<ov::pass::DecomposeDivideMatcher>();
    manager.register_pass<ov::pass::ReluReluFusionMatcher>();
    manager.run_passes(f);

パターンパッチング

通常の操作ではパターンを表現できない場合や、複雑すぎることがあります。例えば、畳み込み操作の特定の入力タイプを指定せずに畳み込み -> 加算サブグラフを検出したい場合、または一部の操作が異なるタイプを持つパターンを作成したい場合です。このために、OpenVINO™ は、GraphRewrite 変換のパターンを構築する追加のヘルパーを提供します。

2 つのメインヘルパーがあります。

  1. ov::pass::pattern::any_input - 入力のタイプが定義されていない場合に入力を表現するのに役立ちます。

  2. ov::pass::pattern::wrap_type <T> - ノード属性を指定せずにパターンのノードを表現するのに役立ちます。

どのように動作するか理解するため例を示します。

ノード属性はパターンマッチングには関与せず、操作の作成にのみ使用されます。パターンマッチングには操作タイプのみが必要です。

以下の例は、ov::passpattern::any_input の基本的な使用法を示しています。ここでは、最初の任意の入力と 2 番目の入力として定数を使用して乗算パターンを構築します。また、乗算は可換であるため、入力をどの順序で設定するか (any_input/Constant または Constant/any_input) は問題ではありません。両方のケースが一致するためです。

// Detect Multiply with arbitrary first input and second as Constant
// ov::pattern::op::Label - represent arbitrary input
auto input = ov::pass::pattern::any_input();
auto value = ov::opset8::Constant::create(ov::element::f32, ov::Shape{1}, {0.5});
auto mul = std::make_shared<ov::opset8::Multiply>(input, value);
auto m = std::make_shared<ov::pass::pattern::Matcher>(mul, "MultiplyMatcher");

この例では、操作に複数の入力がある場合にパターンを構築する方法を示します。

// Detect Concat operation with arbitrary number of inputs
auto concat = ov::pass::pattern::wrap_type<ov::opset8::Concat>();
auto m = std::make_shared<ov::pass::pattern::Matcher>(concat, "ConcatMatcher");

この例では、プレディケートを使用してパターンを構築する方法を示します。また、特定のノードでパターンを手動で照合する方法も示します。

// Detect Multiply->Add sequence where mul has exactly one consumer
auto mul = ov::pass::pattern::wrap_type<ov::opset8::Multiply>(ov::pass::pattern::consumers_count(1)/*сheck consumers count*/);
auto add = ov::pass::pattern::wrap_type<ov::opset8::Add>({mul, ov::pass::pattern::any_input()});
auto m = std::make_shared<ov::pass::pattern::Matcher>(add, "MultiplyAddMatcher");
// Matcher can be used to match pattern manually on given node
if (m->match(node->output(0))) {
    // Successfully matched
}

マッチャー・オブジェクトには一致したノードが保持されるため、手動マッチングの場合は注意してください。一致をクリアするには、m->clear_state() メソッドを使用します。

関連情報