リモートテンソル#

ov::RemoteTensor クラスの機能:

  • デバイス固有のメモリーを操作するインターフェイスを提供します。

プラグインが独自のリモートテンソルのパブリック API を提供する場合、その API はヘッダーのみであり、プラグイン・ライブラリーに依存しません。

デバイス固有のリモート・テンソル・パブリック API#

デバイス固有のリモートテンソルを操作するパブリック・インターフェイスにはヘッダーのみの実装が必要であり、プラグイン・ライブラリーに依存しません。

class VectorTensor : public ov::RemoteTensor { 
public:     /** 
     * @brief Checks that type defined runtime parameters are presented in remote object 
     * @param tensor a tensor to check 
     */ 
    static void type_check(const Tensor& tensor) { 
        RemoteTensor::type_check( 
            tensor, 
            {{ov::device::full_name.name(), {"TEMPLATE"}}, {"vector_data_ptr", {}}, {"vector_data", {}}}); 
    } 

    /**
     * @brief Returns the underlying vector
     * @return const reference to vector if T is compatible with element type
     */ 
    template <class T> 
    const std::vector<T>& get_data() const { 
        auto params = get_params(); 
        OPENVINO_ASSERT(params.count("vector_data"), "Cannot get data. Tensor is incorrect!"); 
        try { 
            auto& vec = params.at("vector_data").as<const std::vector<T>>(); 
            return vec; } catch (const std::bad_cast&) { 
            OPENVINO_THROW("Cannot get data.Vector type is incorrect!"); 
        } 
    } 

    /**
     * @brief Returns the underlying vector
     * @return reference to vector if T is compatible with element type
     */ 
    template <class T> 
        std::vector<T>& get_data() { 
            auto params = get_params(); 
            OPENVINO_ASSERT(params.count("vector_data"), "Cannot get data. Tensor is incorrect!"); 
        try { 
            auto& vec = params.at("vector_data").as<std::vector<T>>(); 
           return vec; 
        } catch (const std::bad_cast&) { 
            OPENVINO_THROW("Cannot get data.Vector type is incorrect!"); 
        } 
    } 

    /**
     * @brief Returns the const pointer to the data
     *
     * @return const pointer to the tensor data
     */ 
    const void* get_data() const { 
        auto params = get_params(); 
        OPENVINO_ASSERT(params.count("vector_data"), "Cannot get data. Tensor is incorrect!"); 
        try { 
            auto* data = params.at("vector_data_ptr").as<const void*>(); 
            return data; 
        } catch (const std::bad_cast&) { 
            OPENVINO_THROW("Cannot get data.Tensor is incorrect!"); 
        } 
    } 

    /**
     * @brief Returns the pointer to the data
     *
     * @return pointer to the tensor data
     */ 
    void* get_data() { 
        auto params = get_params(); 
        OPENVINO_ASSERT(params.count("vector_data"), "Cannot get data. Tensor is incorrect!"); 
        try { 
            auto* data = params.at("vector_data_ptr").as<void*>(); 
            return data; 
        } catch (const std::bad_cast&) { 
            OPENVINO_THROW("Cannot get data.Tensor is incorrect!"); 
        } 
    } 
};

以下の実装にはいくつかのメソッドがあります:

type_check()#

静的メソッドは、いくつかの抽象リモートテンソルをこの特定のリモート・テンソル・タイプにキャストするのを理解するのに使用されます。

get_data()#

リモートデータへのアクセスを取得するヘルパーであるメソッドのセット (この例に固有、他の実装には別の API を持つことができます)。

デバイス固有の内部テンソル実装#

プラグインには、パブリック API と通信できるリモートテンソルの内部実装が必要です。この例には、stl ベクトルからメモリーをラップするリモートテンソルの実装が含まれます。

OpenVINO プラグイン API は、リモートテンソルの基本クラスとして使用するインターフェイス ov::IRemoteTensor を提供します。

実装例には 2 つのリモート・テンソル・クラスがあります:

  • テンプレート引数としてベクトルタイプを持ち、タイプ固有のテンソルを作成する内部タイプ依存実装。

  • 内部でタイプ依存テンソルを扱うタイプ独立実装。

これに基づいて、タイプに依存しないリモート・テンソル・クラスの実装は次のようになります:

class VectorImpl : public ov::IRemoteTensor { 
private: 
    std::shared_ptr<ov::IRemoteTensor> m_tensor; 

public:     VectorImpl(const std::shared_ptr<ov::IRemoteTensor>& tensor) : m_tensor(tensor) {} 

    template <class T> 
    operator std::vector<T>&() const { 
        auto impl = std::dynamic_pointer_cast<VectorTensorImpl<T>>(m_tensor); 
        OPENVINO_ASSERT(impl, "Cannot get vector. Type is incorrect!"); 
        return impl->get(); 
    } 

    void* get_data() { 
        auto params = get_properties(); 
        OPENVINO_ASSERT(params.count("vector_data"), "Cannot get data. Tensor is incorrect!"); 
        try { 
            auto* data = params.at("vector_data_ptr").as<void*>(); 
            return data; 
        } catch (const std::bad_cast&) { 
            OPENVINO_THROW("Cannot get data.Tensor is incorrect!"); 
        } 
    } 

    void set_shape(ov::Shape shape) override { 
        m_tensor->set_shape(std::move(shape)); 
    } 

    const ov::element::Type& get_element_type() const override { 
        return m_tensor->get_element_type(); 
    } 

    const ov::Shape& get_shape() const override { 
        return m_tensor->get_shape(); 
    } 

    size_t get_size() const override { 
        return m_tensor->get_size(); 
    } 

    size_t get_byte_size() const override { 
        return m_tensor->get_byte_size(); 
    } 

    const ov::Strides& get_strides() const override { 
        return m_tensor->get_strides(); 
    } 

    const ov::AnyMap& get_properties() const override { 
        return m_tensor->get_properties(); 
    } 

    const std::string& get_device_name() const override { 
        return m_tensor->get_device_name(); 
    } 
};

この実装は、ラップされた stl tensor を取得するヘルパーを提供し、ov::IRemoteTensor クラスのメソッドをオーバーライドして、タイプ依存の実装を呼び出します。

タイプ依存のリモートテンソルには次の実装があります:

template <class T> 
class VectorTensorImpl : public ov::IRemoteTensor { 
    void update_strides() { 
        if (m_element_type.bitwidth() < 8) 
            return; 
        auto& shape = get_shape(); 
        m_strides.clear(); 
        if (!shape.empty()) { 
            m_strides.resize(shape.size()); 
            m_strides.back() = shape.back() == 0 ? 0 : m_element_type.size(); 
            std::copy(shape.rbegin(), shape.rend() - 1, m_strides.rbegin() + 1); 
            std::partial_sum(m_strides.rbegin(), m_strides.rend(), m_strides.rbegin(), std::multiplies<size_t>()); 
        } 
    } 
    ov::element::Type m_element_type; 
    ov::Shape m_shape; 
    ov::Strides m_strides; 
    std::vector<T> m_data; 
    std::string m_dev_name; 
    ov::AnyMap m_properties; 

public:     VectorTensorImpl(const ov::element::Type element_type, const ov::Shape& shape) : 
        m_element_type{element_type}, 
        m_shape{shape}, 
        m_data(ov::shape_size(shape)), 
        m_dev_name("TEMPLATE"), 
        m_properties{{ov::device::full_name.name(), m_dev_name}, {
                     "vector_data", m_data}, {
                     "vector_data_ptr", static_cast<void*>(m_data.data())}} { update_strides(); 
    } 

    const ov::element::Type& get_element_type() const override { 
        return m_element_type; 
    } 

    const ov::Shape& get_shape() const override { 
        return m_shape; 
    } 
    const ov::Strides& get_strides() const override { 
        OPENVINO_ASSERT(m_element_type.bitwidth() >= 8, 
                        "Could not get strides for types with bitwidths less then 8 bit. Tensor type: ", 
                        m_element_type); 
        return m_strides; 
    } 

    void set_shape(ov::Shape new_shape) override { 
        auto old_byte_size = get_byte_size(); 
        OPENVINO_ASSERT(shape_size(new_shape) * get_element_type().size() <= old_byte_size, 
                        "Could set new shape: ", 
                        new_shape); 
        m_shape = std::move(new_shape); 
        update_strides(); 
    } 

    const ov::AnyMap& get_properties() const override { 
        return m_properties; 
    } 

    const std::string& get_device_name() const override { 
        return m_dev_name; 
    } 
};

クラスフィールド#

クラスにはいくつかのフィールドがあります:

  • m_element_type - テンソルの要素タイプ。

  • m_shape - テンソルの形状。

  • m_strides - テンソルのストライド。

  • m_data - ラップされたベクトル。

  • m_dev_name - デバイス名。

  • m_properties - リモートテンソルのタイプを検出するリモートテンソル固有のプロパティー。

VectorTensorImpl()#

リモートテンソル実装のコンストラクター。データを含むベクトルを作成し、デバイス名とプロパティーを初期化し、形状、要素タイプとストライドを更新します。

get_element_type()#

このメソッドはテンソルの要素タイプを返します。

get_shape()#

このメソッドはテンソルの形状を返します。

get_strides()#

このメソッドはテンソルのストライドを返します。

set_shape()#

このメソッドはリモートテンソルの新しい形状を設定します。

get_properties()#

このメソッドは、テンソル固有のプロパティーを返します。

get_device_name()#

このメソッドは、テンソル固有のデバイス名を返します。