リモートテンソル

ov::RemoteTensor クラスの機能:

  • デバイス固有のメモリーを操作するインターフェイスを提供します。

プラグインが独自のリモートテンソルのパブリック API を提供する場合、その API はヘッダーのみであり、プラグイン・ライブラリーに依存しません。

デバイス固有のリモート・テンソル・パブリック API

デバイス固有のリモートテンソルを操作するパブリック・インターフェイスにはヘッダーのみの実装が必要であり、プラグイン・ライブラリーに依存しません。

class VectorTensor : public ov::RemoteTensor {
public:
    /**
     * @brief Checks that type defined runtime parameters are presented in remote object
     * @param tensor a tensor to check
     */
    static void type_check(const Tensor& tensor) {
        RemoteTensor::type_check(
            tensor,
            {{ov::device::full_name.name(), {"TEMPLATE"}}, {"vector_data_ptr", {}}, {"vector_data", {}}});
    }

    /**
     * @brief Returns the underlying vector
     * @return const reference to vector if T is compatible with element type
     */
    template <class T>
    const std::vector<T>& get_data() const {
        auto params = get_params();
        OPENVINO_ASSERT(params.count("vector_data"), "Cannot get data. Tensor is incorrect!");
        try {
            auto& vec = params.at("vector_data").as<const std::vector<T>>();
            return vec;
        } catch (const std::bad_cast&) {
            OPENVINO_THROW("Cannot get data. Vector type is incorrect!");
        }
    }

    /**
     * @brief Returns the underlying vector
     * @return reference to vector if T is compatible with element type
     */
    template <class T>
    std::vector<T>& get_data() {
        auto params = get_params();
        OPENVINO_ASSERT(params.count("vector_data"), "Cannot get data. Tensor is incorrect!");
        try {
            auto& vec = params.at("vector_data").as<std::vector<T>>();
            return vec;
        } catch (const std::bad_cast&) {
            OPENVINO_THROW("Cannot get data. Vector type is incorrect!");
        }
    }

    /**
     * @brief Returns the const pointer to the data
     *
     * @return const pointer to the tensor data
     */
    const void* get_data() const {
        auto params = get_params();
        OPENVINO_ASSERT(params.count("vector_data"), "Cannot get data. Tensor is incorrect!");
        try {
            auto* data = params.at("vector_data_ptr").as<const void*>();
            return data;
        } catch (const std::bad_cast&) {
            OPENVINO_THROW("Cannot get data. Tensor is incorrect!");
        }
    }

    /**
     * @brief Returns the pointer to the data
     *
     * @return pointer to the tensor data
     */
    void* get_data() {
        auto params = get_params();
        OPENVINO_ASSERT(params.count("vector_data"), "Cannot get data. Tensor is incorrect!");
        try {
            auto* data = params.at("vector_data_ptr").as<void*>();
            return data;
        } catch (const std::bad_cast&) {
            OPENVINO_THROW("Cannot get data. Tensor is incorrect!");
        }
    }
};

以下の実装にはいくつかのメソッドがあります。

type_check()

静的メソッドは、いくつかの抽象リモートテンソルをこの特定のリモート・テンソル・タイプにキャストするのを理解するのに使用されます。

get_data()

リモートデータへのアクセスを取得するヘルパーであるメソッドのセット (この例に固有、他の実装には別の API を持つことができます)。

デバイス固有の内部テンソル実装

プラグインには、パブリック API と通信できるリモートテンソルの内部実装が必要です。この例には、stl ベクトルからメモリーをラップするリモートテンソルの実装が含まれます。

OpenVINO プラグイン API は、リモートテンソルの基本クラスとして使用するインターフェイス ov::IRemoteTensor を提供します。

実装例には 2 つのリモート・テンソル・クラスがあります。

  • テンプレート引数としてベクトルタイプを持ち、タイプ固有のテンソルを作成する内部タイプ依存実装。

  • 内部でタイプ依存テンソルを扱うタイプ独立実装。

これに基づいて、タイプに依存しないリモート・テンソル・クラスの実装は次のようになります。

class VectorImpl : public ov::IRemoteTensor {
private:
    std::shared_ptr<ov::IRemoteTensor> m_tensor;

public:
    VectorImpl(const std::shared_ptr<ov::IRemoteTensor>& tensor) : m_tensor(tensor) {}

    template <class T>
    operator std::vector<T>&() const {
        auto impl = std::dynamic_pointer_cast<VectorTensorImpl<T>>(m_tensor);
        OPENVINO_ASSERT(impl, "Cannot get vector. Type is incorrect!");
        return impl->get();
    }

    void* get_data() {
        auto params = get_properties();
        OPENVINO_ASSERT(params.count("vector_data"), "Cannot get data. Tensor is incorrect!");
        try {
            auto* data = params.at("vector_data_ptr").as<void*>();
            return data;
        } catch (const std::bad_cast&) {
            OPENVINO_THROW("Cannot get data. Tensor is incorrect!");
        }
    }

    void set_shape(ov::Shape shape) override {
        m_tensor->set_shape(std::move(shape));
    }

    const ov::element::Type& get_element_type() const override {
        return m_tensor->get_element_type();
    }

    const ov::Shape& get_shape() const override {
        return m_tensor->get_shape();
    }

    size_t get_size() const override {
        return m_tensor->get_size();
    }

    size_t get_byte_size() const override {
        return m_tensor->get_byte_size();
    }

    const ov::Strides& get_strides() const override {
        return m_tensor->get_strides();
    }

    const ov::AnyMap& get_properties() const override {
        return m_tensor->get_properties();
    }

    const std::string& get_device_name() const override {
        return m_tensor->get_device_name();
    }
};

この実装は、ラップされた stl tensor を取得するヘルパーを提供し、ov::IRemoteTensor クラスのメソッドをオーバーライドして、タイプ依存の実装を呼び出します。

タイプ依存のリモートテンソルには次の実装があります。

template <class T>
class VectorTensorImpl : public ov::IRemoteTensor {
    void update_strides() {
        if (m_element_type.bitwidth() < 8)
            return;
        auto& shape = get_shape();
        m_strides.clear();
        if (!shape.empty()) {
            m_strides.resize(shape.size());
            m_strides.back() = shape.back() == 0 ? 0 : m_element_type.size();
            std::copy(shape.rbegin(), shape.rend() - 1, m_strides.rbegin() + 1);
            std::partial_sum(m_strides.rbegin(), m_strides.rend(), m_strides.rbegin(), std::multiplies<size_t>());
        }
    }
    ov::element::Type m_element_type;
    ov::Shape m_shape;
    ov::Strides m_strides;
    std::vector<T> m_data;
    std::string m_dev_name;
    ov::AnyMap m_properties;

public:
    VectorTensorImpl(const ov::element::Type element_type, const ov::Shape& shape)
        : m_element_type{element_type},
          m_shape{shape},
          m_data(ov::shape_size(shape)),
          m_dev_name("TEMPLATE"),
          m_properties{{ov::device::full_name.name(), m_dev_name},
                       {"vector_data", m_data},
                       {"vector_data_ptr", static_cast<void*>(m_data.data())}} {
        update_strides();
    }

    const ov::element::Type& get_element_type() const override {
        return m_element_type;
    }

    const ov::Shape& get_shape() const override {
        return m_shape;
    }
    const ov::Strides& get_strides() const override {
        OPENVINO_ASSERT(m_element_type.bitwidth() >= 8,
                        "Could not get strides for types with bitwidths less then 8 bit. Tensor type: ",
                        m_element_type);
        return m_strides;
    }

    void set_shape(ov::Shape new_shape) override {
        auto old_byte_size = get_byte_size();
        OPENVINO_ASSERT(shape_size(new_shape) * get_element_type().size() <= old_byte_size,
                        "Could set new shape: ",
                        new_shape);
        m_shape = std::move(new_shape);
        update_strides();
    }

    const ov::AnyMap& get_properties() const override {
        return m_properties;
    }

    const std::string& get_device_name() const override {
        return m_dev_name;
    }
};

クラスフィールド

クラスにはいくつかのフィールドがあります。

  • m_element_type - テンソルの要素タイプ。

  • m_shape - テンソルの形状。

  • m_strides - テンソルのストライド。

  • m_data - ラップされたベクトル。

  • m_dev_name - デバイス名

  • m_properties - リモートテンソルのタイプを検出するリモートテンソル固有のプロパティー。

VectorTensorImpl()

リモートテンソル実装のコンストラクター。データを含むベクトルを作成し、デバイス名とプロパティーを初期化し、形状、要素タイプとストライドを更新します。

get_element_type()

このメソッドはテンソルの要素タイプを返します。

get_shape()

このメソッドはテンソルの形状を返します。

get_strides()

このメソッドはテンソルのストライドを返します。

set_shape()

このメソッドはリモートテンソルの新しい形状を設定します。

get_properties()

このメソッドは、テンソル固有のプロパティーを返します。

get_device_name()

このメソッドは、テンソル固有のデバイス名を返します。