リモートテンソル¶
ov::RemoteTensor クラスの機能:
デバイス固有のメモリーを操作するインターフェイスを提供します。
注
プラグインが独自のリモートテンソルのパブリック API を提供する場合、その API はヘッダーのみであり、プラグイン・ライブラリーに依存しません。
デバイス固有のリモート・テンソル・パブリック API¶
デバイス固有のリモートテンソルを操作するパブリック・インターフェイスにはヘッダーのみの実装が必要であり、プラグイン・ライブラリーに依存しません。
class VectorTensor : public ov::RemoteTensor {
public:
/**
* @brief Checks that type defined runtime parameters are presented in remote object
* @param tensor a tensor to check
*/
static void type_check(const Tensor& tensor) {
RemoteTensor::type_check(
tensor,
{{ov::device::full_name.name(), {"TEMPLATE"}}, {"vector_data_ptr", {}}, {"vector_data", {}}});
}
/**
* @brief Returns the underlying vector
* @return const reference to vector if T is compatible with element type
*/
template <class T>
const std::vector<T>& get_data() const {
auto params = get_params();
OPENVINO_ASSERT(params.count("vector_data"), "Cannot get data. Tensor is incorrect!");
try {
auto& vec = params.at("vector_data").as<const std::vector<T>>();
return vec;
} catch (const std::bad_cast&) {
OPENVINO_THROW("Cannot get data. Vector type is incorrect!");
}
}
/**
* @brief Returns the underlying vector
* @return reference to vector if T is compatible with element type
*/
template <class T>
std::vector<T>& get_data() {
auto params = get_params();
OPENVINO_ASSERT(params.count("vector_data"), "Cannot get data. Tensor is incorrect!");
try {
auto& vec = params.at("vector_data").as<std::vector<T>>();
return vec;
} catch (const std::bad_cast&) {
OPENVINO_THROW("Cannot get data. Vector type is incorrect!");
}
}
/**
* @brief Returns the const pointer to the data
*
* @return const pointer to the tensor data
*/
const void* get_data() const {
auto params = get_params();
OPENVINO_ASSERT(params.count("vector_data"), "Cannot get data. Tensor is incorrect!");
try {
auto* data = params.at("vector_data_ptr").as<const void*>();
return data;
} catch (const std::bad_cast&) {
OPENVINO_THROW("Cannot get data. Tensor is incorrect!");
}
}
/**
* @brief Returns the pointer to the data
*
* @return pointer to the tensor data
*/
void* get_data() {
auto params = get_params();
OPENVINO_ASSERT(params.count("vector_data"), "Cannot get data. Tensor is incorrect!");
try {
auto* data = params.at("vector_data_ptr").as<void*>();
return data;
} catch (const std::bad_cast&) {
OPENVINO_THROW("Cannot get data. Tensor is incorrect!");
}
}
};
以下の実装にはいくつかのメソッドがあります。
type_check()¶
静的メソッドは、いくつかの抽象リモートテンソルをこの特定のリモート・テンソル・タイプにキャストするのを理解するのに使用されます。
get_data()¶
リモートデータへのアクセスを取得するヘルパーであるメソッドのセット (この例に固有、他の実装には別の API を持つことができます)。
デバイス固有の内部テンソル実装¶
プラグインには、パブリック API と通信できるリモートテンソルの内部実装が必要です。この例には、stl ベクトルからメモリーをラップするリモートテンソルの実装が含まれます。
OpenVINO プラグイン API は、リモートテンソルの基本クラスとして使用するインターフェイス ov::IRemoteTensor を提供します。
実装例には 2 つのリモート・テンソル・クラスがあります。
テンプレート引数としてベクトルタイプを持ち、タイプ固有のテンソルを作成する内部タイプ依存実装。
内部でタイプ依存テンソルを扱うタイプ独立実装。
これに基づいて、タイプに依存しないリモート・テンソル・クラスの実装は次のようになります。
class VectorImpl : public ov::IRemoteTensor {
private:
std::shared_ptr<ov::IRemoteTensor> m_tensor;
public:
VectorImpl(const std::shared_ptr<ov::IRemoteTensor>& tensor) : m_tensor(tensor) {}
template <class T>
operator std::vector<T>&() const {
auto impl = std::dynamic_pointer_cast<VectorTensorImpl<T>>(m_tensor);
OPENVINO_ASSERT(impl, "Cannot get vector. Type is incorrect!");
return impl->get();
}
void* get_data() {
auto params = get_properties();
OPENVINO_ASSERT(params.count("vector_data"), "Cannot get data. Tensor is incorrect!");
try {
auto* data = params.at("vector_data_ptr").as<void*>();
return data;
} catch (const std::bad_cast&) {
OPENVINO_THROW("Cannot get data. Tensor is incorrect!");
}
}
void set_shape(ov::Shape shape) override {
m_tensor->set_shape(std::move(shape));
}
const ov::element::Type& get_element_type() const override {
return m_tensor->get_element_type();
}
const ov::Shape& get_shape() const override {
return m_tensor->get_shape();
}
size_t get_size() const override {
return m_tensor->get_size();
}
size_t get_byte_size() const override {
return m_tensor->get_byte_size();
}
const ov::Strides& get_strides() const override {
return m_tensor->get_strides();
}
const ov::AnyMap& get_properties() const override {
return m_tensor->get_properties();
}
const std::string& get_device_name() const override {
return m_tensor->get_device_name();
}
};
この実装は、ラップされた stl tensor を取得するヘルパーを提供し、ov::IRemoteTensor クラスのメソッドをオーバーライドして、タイプ依存の実装を呼び出します。
タイプ依存のリモートテンソルには次の実装があります。
template <class T>
class VectorTensorImpl : public ov::IRemoteTensor {
void update_strides() {
if (m_element_type.bitwidth() < 8)
return;
auto& shape = get_shape();
m_strides.clear();
if (!shape.empty()) {
m_strides.resize(shape.size());
m_strides.back() = shape.back() == 0 ? 0 : m_element_type.size();
std::copy(shape.rbegin(), shape.rend() - 1, m_strides.rbegin() + 1);
std::partial_sum(m_strides.rbegin(), m_strides.rend(), m_strides.rbegin(), std::multiplies<size_t>());
}
}
ov::element::Type m_element_type;
ov::Shape m_shape;
ov::Strides m_strides;
std::vector<T> m_data;
std::string m_dev_name;
ov::AnyMap m_properties;
public:
VectorTensorImpl(const ov::element::Type element_type, const ov::Shape& shape)
: m_element_type{element_type},
m_shape{shape},
m_data(ov::shape_size(shape)),
m_dev_name("TEMPLATE"),
m_properties{{ov::device::full_name.name(), m_dev_name},
{"vector_data", m_data},
{"vector_data_ptr", static_cast<void*>(m_data.data())}} {
update_strides();
}
const ov::element::Type& get_element_type() const override {
return m_element_type;
}
const ov::Shape& get_shape() const override {
return m_shape;
}
const ov::Strides& get_strides() const override {
OPENVINO_ASSERT(m_element_type.bitwidth() >= 8,
"Could not get strides for types with bitwidths less then 8 bit. Tensor type: ",
m_element_type);
return m_strides;
}
void set_shape(ov::Shape new_shape) override {
auto old_byte_size = get_byte_size();
OPENVINO_ASSERT(shape_size(new_shape) * get_element_type().size() <= old_byte_size,
"Could set new shape: ",
new_shape);
m_shape = std::move(new_shape);
update_strides();
}
const ov::AnyMap& get_properties() const override {
return m_properties;
}
const std::string& get_device_name() const override {
return m_dev_name;
}
};
クラスフィールド¶
クラスにはいくつかのフィールドがあります。
m_element_type
- テンソルの要素タイプ。m_shape
- テンソルの形状。m_strides
- テンソルのストライド。m_data
- ラップされたベクトル。m_dev_name
- デバイス名m_properties
- リモートテンソルのタイプを検出するリモートテンソル固有のプロパティー。
VectorTensorImpl()¶
リモートテンソル実装のコンストラクター。データを含むベクトルを作成し、デバイス名とプロパティーを初期化し、形状、要素タイプとストライドを更新します。
get_element_type()¶
このメソッドはテンソルの要素タイプを返します。
get_shape()¶
このメソッドはテンソルの形状を返します。
get_strides()¶
このメソッドはテンソルのストライドを返します。
set_shape()¶
このメソッドはリモートテンソルの新しい形状を設定します。
get_properties()¶
このメソッドは、テンソル固有のプロパティーを返します。
get_device_name()¶
このメソッドは、テンソル固有のデバイス名を返します。