OpenVINO™ によるアテンション・センター・モデル¶
この Jupyter ノートブックはオンラインで起動でき、ブラウザーのウィンドウで対話型環境を開きます。ローカルにインストールすることもできます。次のオプションのいずれかを選択します。
このノートブックは、OpenVINO でアテンション・センター・モデルを使用する方法を示します。このモデルは TensorFlow Lite 形式であり、現在は TFLite フロントエンドによって OpenVINO でサポートされています。
アイ・トラッキングは、視覚的な注意や意思決定などの関連する質問に答えるため、視覚神経科学や認知科学で一般的に使用されています。どこを見るべきかを予測する計算モデルは、さまざまなコンピューター・ビジョン・タスクに直接応用できます。アテンション・センター・モデルは、RGB イメージを入力として受け取り、2D ポイントを出力として返します。この 2D ポイントは、画像上の人間の注意の予測中心、つまり、人々が最初に注意を払う画像の最も顕著な部分です。これにより、視覚的に最も顕著な領域を見つけて、できるだけ早く処理できるようになります。例えば、最初に注目する部分のエンコードをサポートする最新世代の画像形式 (JPEG XL など) に使用できます。ユーザー体験の向上に役立ち、画像の読み込みが速くなります。
アテンション・センター・モデルのアーキテクチャーは次のとおりです。
> アテンション・センター・モデルはディープ・ニューラル・ネットワークであり、画像を入力として受け取り、事前にトレーニングされた分類ネットワーク (ResNet、MobileNet など) をバックボーンとして使用します。バックボーン・ネットワークから出力されるいくつかの中間レイヤーは、アテンション・センター予測モジュールの入力として使用されます。これらの異なる中間レイヤーには、異なる情報が含まれています。例えば、浅いレイヤーには強度/色/テクスチャーなどの低レベルの情報が含まれることが多く、一方、深いレイヤーには通常、形状/オブジェクトなどの高度でより意味論的な情報が含まれます。どれも注目度の予測に役立ちます。アテンション・センターの予測では、畳み込み、逆畳み込み、および/またはサイズ変更オペレーターを集約およびシグモイド関数とともに適用して、アテンション・センターの重み付けマップを生成します。そして、オペレーター (この場合はアインシュタイン加算オペレーター) を適用して、重み付けマップから (重力) 中心を計算できます。予測されたアテンション・センターと真のアテンション・センターの間の L2 ノルムは、トレーニング損失として計算できます。
ソース: Google AI ブログの投稿。
アテンション・センター・モデルは、SALICON データセットの顕著性アノテーションが付けられた COCO データセットの画像を使用してトレーニングされています。
目次¶
%pip install "openvino>=2023.2.0"
Requirement already satisfied: openvino>=2023.2.0 in /opt/home/k8sworker/ci-ai/cibuilds/ov-notebook/OVNotebookOps-609/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages (2023.3.0)
Requirement already satisfied: numpy>=1.16.6 in /opt/home/k8sworker/ci-ai/cibuilds/ov-notebook/OVNotebookOps-609/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages (from openvino>=2023.2.0) (1.23.5)
Requirement already satisfied: openvino-telemetry>=2023.2.1 in /opt/home/k8sworker/ci-ai/cibuilds/ov-notebook/OVNotebookOps-609/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages (from openvino>=2023.2.0) (2023.2.1)
DEPRECATION: pytorch-lightning 1.6.5 has a non-standard dependency specifier torch>=1.8.*. pip 24.1 will enforce this behaviour change. A possible replacement is to upgrade to a newer version of pytorch-lightning or contact the author to suggest that they release a version with a conforming dependency specifiers. Discussion can be found at https://github.com/pypa/pip/issues/12063
Note: you may need to restart the kernel to use updated packages.
インポート¶
import cv2
import numpy as np
import tensorflow as tf
from pathlib import Path
import matplotlib.pyplot as plt
import openvino as ov
2024-02-09 23:49:23.781601: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable TF_ENABLE_ONEDNN_OPTS=0. 2024-02-09 23:49:23.815218: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations. To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2024-02-09 23:49:24.361080: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
アテンション・センター・モデルをダウンロード¶
アテンション・センター・リポジトリーの一部としてモデルをダウンロードします。リポジトリーでは、フォルダー ./model
にモデルが含まれています。
if not Path('./attention-center').exists():
! git clone https://github.com/google/attention-center
Cloning into 'attention-center'...
remote: Enumerating objects: 168, done.[K
remote: Counting objects: 0% (1/168)[K
remote: Counting objects: 1% (2/168)[K
remote: Counting objects: 2% (4/168)[K
remote: Counting objects: 3% (6/168)[K
remote: Counting objects: 4% (7/168)[K
remote: Counting objects: 5% (9/168)[K
remote: Counting objects: 6% (11/168)[K
remote: Counting objects: 7% (12/168)[K
remote: Counting objects: 8% (14/168)[K
remote: Counting objects: 9% (16/168)[K
remote: Counting objects: 10% (17/168)[K
remote: Counting objects: 11% (19/168)[K
remote: Counting objects: 12% (21/168)[K
remote: Counting objects: 13% (22/168)[K
remote: Counting objects: 14% (24/168)[K
remote: Counting objects: 15% (26/168)[K
remote: Counting objects: 16% (27/168)[K
remote: Counting objects: 17% (29/168)[K
remote: Counting objects: 18% (31/168)[K
remote: Counting objects: 19% (32/168)[K
remote: Counting objects: 20% (34/168)[K
remote: Counting objects: 21% (36/168)[K
remote: Counting objects: 22% (37/168)[K
remote: Counting objects: 23% (39/168)[K
remote: Counting objects: 24% (41/168)[K
remote: Counting objects: 25% (42/168)[K
remote: Counting objects: 26% (44/168)[K
remote: Counting objects: 27% (46/168)[K
remote: Counting objects: 28% (48/168)[K
remote: Counting objects: 29% (49/168)[K
remote: Counting objects: 30% (51/168)[K
remote: Counting objects: 31% (53/168)[K
remote: Counting objects: 32% (54/168)[K
remote: Counting objects: 33% (56/168)[K
remote: Counting objects: 34% (58/168)[K
remote: Counting objects: 35% (59/168)[K
remote: Counting objects: 36% (61/168)[K
remote: Counting objects: 37% (63/168)[K
remote: Counting objects: 38% (64/168)[K
remote: Counting objects: 39% (66/168)[K
remote: Counting objects: 40% (68/168)[K
remote: Counting objects: 41% (69/168)[K
remote: Counting objects: 42% (71/168)[K
remote: Counting objects: 43% (73/168)[K
remote: Counting objects: 44% (74/168)[K
remote: Counting objects: 45% (76/168)[K
remote: Counting objects: 46% (78/168)[K
remote: Counting objects: 47% (79/168)[K
remote: Counting objects: 48% (81/168)[K
remote: Counting objects: 49% (83/168)[K
remote: Counting objects: 50% (84/168)[K
remote: Counting objects: 51% (86/168)[K
remote: Counting objects: 52% (88/168)[K
remote: Counting objects: 53% (90/168)[K
remote: Counting objects: 54% (91/168)[K
remote: Counting objects: 55% (93/168)[K
remote: Counting objects: 56% (95/168)[K
remote: Counting objects: 57% (96/168)[K
remote: Counting objects: 58% (98/168)[K
remote: Counting objects: 59% (100/168)[K
remote: Counting objects: 60% (101/168)[K
remote: Counting objects: 61% (103/168)[K
remote: Counting objects: 62% (105/168)[K
remote: Counting objects: 63% (106/168)[K
remote: Counting objects: 64% (108/168)[K
remote: Counting objects: 65% (110/168)[K
remote: Counting objects: 66% (111/168)[K
remote: Counting objects: 67% (113/168)[K
remote: Counting objects: 68% (115/168)[K
remote: Counting objects: 69% (116/168)[K
remote: Counting objects: 70% (118/168)[K
remote: Counting objects: 71% (120/168)[K
remote: Counting objects: 72% (121/168)[K
remote: Counting objects: 73% (123/168)[K
remote: Counting objects: 74% (125/168)[K
remote: Counting objects: 75% (126/168)[K
remote: Counting objects: 76% (128/168)[K
remote: Counting objects: 77% (130/168)[K
remote: Counting objects: 78% (132/168)[K
remote: Counting objects: 79% (133/168)[K
remote: Counting objects: 80% (135/168)[K
remote: Counting objects: 81% (137/168)[K
remote: Counting objects: 82% (138/168)[K
remote: Counting objects: 83% (140/168)[K
remote: Counting objects: 84% (142/168)[K
remote: Counting objects: 85% (143/168)[K
remote: Counting objects: 86% (145/168)[K
remote: Counting objects: 87% (147/168)[K
remote: Counting objects: 88% (148/168)[K
remote: Counting objects: 89% (150/168)[K
remote: Counting objects: 90% (152/168)[K
remote: Counting objects: 91% (153/168)[K
remote: Counting objects: 92% (155/168)[K
remote: Counting objects: 93% (157/168)[K
remote: Counting objects: 94% (158/168)[K
remote: Counting objects: 95% (160/168)[K
remote: Counting objects: 96% (162/168)[K
remote: Counting objects: 97% (163/168)[K
remote: Counting objects: 98% (165/168)[K
remote: Counting objects: 99% (167/168)[K
remote: Counting objects: 100% (168/168)[K
remote: Counting objects: 100% (168/168), done.[K
remote: Compressing objects: 0% (1/132)[K
remote: Compressing objects: 1% (2/132)[K
remote: Compressing objects: 2% (3/132)[K
remote: Compressing objects: 3% (4/132)[K
remote: Compressing objects: 4% (6/132)[K
remote: Compressing objects: 5% (7/132)[K
remote: Compressing objects: 6% (8/132)[K
remote: Compressing objects: 7% (10/132)[K
remote: Compressing objects: 8% (11/132)[K
remote: Compressing objects: 9% (12/132)[K
remote: Compressing objects: 10% (14/132)[K
remote: Compressing objects: 11% (15/132)[K
remote: Compressing objects: 12% (16/132)[K
remote: Compressing objects: 13% (18/132)[K
remote: Compressing objects: 14% (19/132)[K
remote: Compressing objects: 15% (20/132)[K
remote: Compressing objects: 16% (22/132)[K
remote: Compressing objects: 17% (23/132)[K
remote: Compressing objects: 18% (24/132)[K
remote: Compressing objects: 19% (26/132)[K
remote: Compressing objects: 20% (27/132)[K
remote: Compressing objects: 21% (28/132)[K
remote: Compressing objects: 22% (30/132)[K
remote: Compressing objects: 23% (31/132)[K
remote: Compressing objects: 24% (32/132)[K
remote: Compressing objects: 25% (33/132)[K
remote: Compressing objects: 26% (35/132)[K
remote: Compressing objects: 27% (36/132)[K
remote: Compressing objects: 28% (37/132)[K
remote: Compressing objects: 29% (39/132)[K
remote: Compressing objects: 30% (40/132)[K
remote: Compressing objects: 31% (41/132)[K
remote: Compressing objects: 32% (43/132)[K
remote: Compressing objects: 33% (44/132)[K
remote: Compressing objects: 34% (45/132)[K
remote: Compressing objects: 35% (47/132)[K
remote: Compressing objects: 36% (48/132)[K
remote: Compressing objects: 37% (49/132)[K
remote: Compressing objects: 38% (51/132)[K
remote: Compressing objects: 39% (52/132)[K
remote: Compressing objects: 40% (53/132)[K
remote: Compressing objects: 41% (55/132)[K
remote: Compressing objects: 42% (56/132)[K
remote: Compressing objects: 43% (57/132)[K
remote: Compressing objects: 44% (59/132)[K
remote: Compressing objects: 45% (60/132)[K
remote: Compressing objects: 46% (61/132)[K
remote: Compressing objects: 47% (63/132)[K
remote: Compressing objects: 48% (64/132)[K
remote: Compressing objects: 49% (65/132)[K
remote: Compressing objects: 50% (66/132)[K
remote: Compressing objects: 51% (68/132)[K
remote: Compressing objects: 52% (69/132)[K
remote: Compressing objects: 53% (70/132)[K
remote: Compressing objects: 54% (72/132)[K
remote: Compressing objects: 55% (73/132)[K
remote: Compressing objects: 56% (74/132)[K
remote: Compressing objects: 57% (76/132)[K
remote: Compressing objects: 58% (77/132)[K
remote: Compressing objects: 59% (78/132)[K
remote: Compressing objects: 60% (80/132)[K
remote: Compressing objects: 61% (81/132)[K
remote: Compressing objects: 62% (82/132)[K
remote: Compressing objects: 63% (84/132)[K
remote: Compressing objects: 64% (85/132)[K
remote: Compressing objects: 65% (86/132)[K
remote: Compressing objects: 66% (88/132)[K
remote: Compressing objects: 67% (89/132)[K
remote: Compressing objects: 68% (90/132)[K
remote: Compressing objects: 69% (92/132)[K
remote: Compressing objects: 70% (93/132)[K
remote: Compressing objects: 71% (94/132)[K
remote: Compressing objects: 72% (96/132)[K
remote: Compressing objects: 73% (97/132)[K
remote: Compressing objects: 74% (98/132)[K
remote: Compressing objects: 75% (99/132)[K
remote: Compressing objects: 76% (101/132)[K
remote: Compressing objects: 77% (102/132)[K
remote: Compressing objects: 78% (103/132)[K
remote: Compressing objects: 79% (105/132)[K
remote: Compressing objects: 80% (106/132)[K
remote: Compressing objects: 81% (107/132)[K
remote: Compressing objects: 82% (109/132)[K
remote: Compressing objects: 83% (110/132)[K
remote: Compressing objects: 84% (111/132)[K
remote: Compressing objects: 85% (113/132)[K
remote: Compressing objects: 86% (114/132)[K
remote: Compressing objects: 87% (115/132)[K
remote: Compressing objects: 88% (117/132)[K
remote: Compressing objects: 89% (118/132)[K
remote: Compressing objects: 90% (119/132)[K
remote: Compressing objects: 91% (121/132)[K
remote: Compressing objects: 92% (122/132)[K
remote: Compressing objects: 93% (123/132)[K
remote: Compressing objects: 94% (125/132)[K
remote: Compressing objects: 95% (126/132)[K
remote: Compressing objects: 96% (127/132)[K
remote: Compressing objects: 97% (129/132)[K
remote: Compressing objects: 98% (130/132)[K
remote: Compressing objects: 99% (131/132)[K
remote: Compressing objects: 100% (132/132)[K
remote: Compressing objects: 100% (132/132), done.[K
Receiving objects: 0% (1/168)
Receiving objects: 1% (2/168)
Receiving objects: 2% (4/168)
Receiving objects: 3% (6/168)
Receiving objects: 4% (7/168)
Receiving objects: 5% (9/168)
Receiving objects: 6% (11/168)
Receiving objects: 7% (12/168)
Receiving objects: 8% (14/168)
Receiving objects: 9% (16/168)
Receiving objects: 10% (17/168)
Receiving objects: 11% (19/168)
Receiving objects: 12% (21/168)
Receiving objects: 13% (22/168)
Receiving objects: 14% (24/168)
Receiving objects: 15% (26/168)
Receiving objects: 16% (27/168)
Receiving objects: 17% (29/168)
Receiving objects: 18% (31/168)
Receiving objects: 19% (32/168)
Receiving objects: 20% (34/168)
Receiving objects: 21% (36/168)
Receiving objects: 22% (37/168)
Receiving objects: 23% (39/168)
Receiving objects: 24% (41/168)
Receiving objects: 25% (42/168)
Receiving objects: 26% (44/168)
Receiving objects: 27% (46/168)
Receiving objects: 28% (48/168)
Receiving objects: 29% (49/168)
Receiving objects: 30% (51/168)
Receiving objects: 31% (53/168)
Receiving objects: 32% (54/168)
Receiving objects: 33% (56/168), 1.46 MiB | 2.90 MiB/s
Receiving objects: 34% (58/168), 1.46 MiB | 2.90 MiB/s
Receiving objects: 35% (59/168), 1.46 MiB | 2.90 MiB/s
Receiving objects: 35% (59/168), 3.15 MiB | 3.13 MiB/s
Receiving objects: 35% (60/168), 6.50 MiB | 3.23 MiB/s
Receiving objects: 36% (61/168), 6.50 MiB | 3.23 MiB/s
Receiving objects: 36% (62/168), 9.93 MiB | 3.27 MiB/s
Receiving objects: 37% (63/168), 9.93 MiB | 3.27 MiB/s
Receiving objects: 38% (64/168), 9.93 MiB | 3.27 MiB/s
Receiving objects: 39% (66/168), 11.62 MiB | 3.29 MiB/s
Receiving objects: 40% (68/168), 11.62 MiB | 3.29 MiB/s
Receiving objects: 40% (68/168), 13.29 MiB | 3.29 MiB/s
Receiving objects: 41% (69/168), 13.29 MiB | 3.29 MiB/s
Receiving objects: 42% (71/168), 13.29 MiB | 3.29 MiB/s
Receiving objects: 43% (73/168), 13.29 MiB | 3.29 MiB/s
Receiving objects: 44% (74/168), 13.29 MiB | 3.29 MiB/s
Receiving objects: 45% (76/168), 13.29 MiB | 3.29 MiB/s
Receiving objects: 46% (78/168), 13.29 MiB | 3.29 MiB/s
Receiving objects: 47% (79/168), 13.29 MiB | 3.29 MiB/s
Receiving objects: 48% (81/168), 13.29 MiB | 3.29 MiB/s
Receiving objects: 49% (83/168), 13.29 MiB | 3.29 MiB/s
Receiving objects: 50% (84/168), 13.29 MiB | 3.29 MiB/s
Receiving objects: 51% (86/168), 13.29 MiB | 3.29 MiB/s
Receiving objects: 52% (88/168), 13.29 MiB | 3.29 MiB/s
Receiving objects: 53% (90/168), 13.29 MiB | 3.29 MiB/s
Receiving objects: 54% (91/168), 13.29 MiB | 3.29 MiB/s
Receiving objects: 55% (93/168), 13.29 MiB | 3.29 MiB/s
Receiving objects: 56% (95/168), 13.29 MiB | 3.29 MiB/s
Receiving objects: 57% (96/168), 13.29 MiB | 3.29 MiB/s
Receiving objects: 58% (98/168), 13.29 MiB | 3.29 MiB/s
Receiving objects: 59% (100/168), 13.29 MiB | 3.29 MiB/s
Receiving objects: 60% (101/168), 13.29 MiB | 3.29 MiB/s
Receiving objects: 61% (103/168), 13.29 MiB | 3.29 MiB/s
Receiving objects: 61% (104/168), 16.65 MiB | 3.34 MiB/s
Receiving objects: 61% (104/168), 20.00 MiB | 3.34 MiB/s
Receiving objects: 61% (104/168), 23.32 MiB | 3.33 MiB/s
Receiving objects: 62% (105/168), 23.32 MiB | 3.33 MiB/s
Receiving objects: 63% (106/168), 24.84 MiB | 3.29 MiB/s
remote: Total 168 (delta 73), reused 114 (delta 28), pack-reused 0[K
Receiving objects: 64% (108/168), 24.84 MiB | 3.29 MiB/s
Receiving objects: 65% (110/168), 24.84 MiB | 3.29 MiB/s
Receiving objects: 66% (111/168), 24.84 MiB | 3.29 MiB/s
Receiving objects: 67% (113/168), 24.84 MiB | 3.29 MiB/s
Receiving objects: 68% (115/168), 24.84 MiB | 3.29 MiB/s
Receiving objects: 69% (116/168), 24.84 MiB | 3.29 MiB/s
Receiving objects: 70% (118/168), 24.84 MiB | 3.29 MiB/s
Receiving objects: 71% (120/168), 24.84 MiB | 3.29 MiB/s
Receiving objects: 72% (121/168), 24.84 MiB | 3.29 MiB/s
Receiving objects: 73% (123/168), 24.84 MiB | 3.29 MiB/s
Receiving objects: 74% (125/168), 24.84 MiB | 3.29 MiB/s
Receiving objects: 75% (126/168), 24.84 MiB | 3.29 MiB/s
Receiving objects: 76% (128/168), 24.84 MiB | 3.29 MiB/s
Receiving objects: 77% (130/168), 24.84 MiB | 3.29 MiB/s
Receiving objects: 78% (132/168), 24.84 MiB | 3.29 MiB/s
Receiving objects: 79% (133/168), 24.84 MiB | 3.29 MiB/s
Receiving objects: 80% (135/168), 24.84 MiB | 3.29 MiB/s
Receiving objects: 81% (137/168), 24.84 MiB | 3.29 MiB/s
Receiving objects: 82% (138/168), 24.84 MiB | 3.29 MiB/s
Receiving objects: 83% (140/168), 24.84 MiB | 3.29 MiB/s
Receiving objects: 84% (142/168), 24.84 MiB | 3.29 MiB/s
Receiving objects: 85% (143/168), 24.84 MiB | 3.29 MiB/s
Receiving objects: 86% (145/168), 24.84 MiB | 3.29 MiB/s
Receiving objects: 87% (147/168), 24.84 MiB | 3.29 MiB/s
Receiving objects: 88% (148/168), 24.84 MiB | 3.29 MiB/s
Receiving objects: 89% (150/168), 24.84 MiB | 3.29 MiB/s
Receiving objects: 90% (152/168), 24.84 MiB | 3.29 MiB/s
Receiving objects: 91% (153/168), 24.84 MiB | 3.29 MiB/s
Receiving objects: 92% (155/168), 24.84 MiB | 3.29 MiB/s
Receiving objects: 93% (157/168), 24.84 MiB | 3.29 MiB/s
Receiving objects: 94% (158/168), 24.84 MiB | 3.29 MiB/s
Receiving objects: 95% (160/168), 24.84 MiB | 3.29 MiB/s
Receiving objects: 96% (162/168), 24.84 MiB | 3.29 MiB/s
Receiving objects: 97% (163/168), 24.84 MiB | 3.29 MiB/s
Receiving objects: 98% (165/168), 24.84 MiB | 3.29 MiB/s
Receiving objects: 99% (167/168), 24.84 MiB | 3.29 MiB/s
Receiving objects: 100% (168/168), 24.84 MiB | 3.29 MiB/s
Receiving objects: 100% (168/168), 26.22 MiB | 3.29 MiB/s, done.
Resolving deltas: 0% (0/73)
Resolving deltas: 8% (6/73)
Resolving deltas: 12% (9/73)
Resolving deltas: 19% (14/73)
Resolving deltas: 28% (21/73)
Resolving deltas: 38% (28/73)
Resolving deltas: 39% (29/73)
Resolving deltas: 49% (36/73)
Resolving deltas: 69% (51/73)
Resolving deltas: 73% (54/73)
Resolving deltas: 78% (57/73)
Resolving deltas: 84% (62/73)
Resolving deltas: 90% (66/73)
Resolving deltas: 97% (71/73)
Resolving deltas: 100% (73/73)
Resolving deltas: 100% (73/73), done.
Tensorflow Lite モデルを OpenVINO IR 形式に変換¶
アテンション・センター・モデルは、TensorFlow Lite 形式で事前トレーニングされたモデルです。このノートブックでは、モデル変換 API を使用してモデルが OpenVINO IR 形式に変換されます。モデル変換の詳細については、このページを参照してください。モデルがすでに変換されている場合、このステップもスキップできます。
また、TFLite モデル形式は TFLite フロントエンドによって OpenVINO でサポートされているため、モデルを core.read_model()
に直接渡すことができます。例は 002-openvino-api にあります。
tflite_model_path = Path("./attention-center/model/center.tflite")
ir_model_path = Path("./model/ir_center_model.xml")
core = ov.Core()
if not ir_model_path.exists():
model = ov.convert_model(tflite_model_path, input=[('image:0', [1,480,640,3], ov.Type.f32)])
ov.save_model(model, ir_model_path)
print("IR model saved to {}".format(ir_model_path))
else:
print("Read IR model from {}".format(ir_model_path))
model = core.read_model(ir_model_path)
IR model saved to model/ir_center_model.xml
推論デバイスの選択¶
OpenVINO を使用して推論を実行するためにドロップダウン・リストからデバイスを選択します。
import ipywidgets as widgets
device = widgets.Dropdown(
options=core.available_devices + ["AUTO"],
value='AUTO',
description='Device:',
disabled=False,
)
device
Dropdown(description='Device:', index=1, options=('CPU', 'AUTO'), value='AUTO')
if "GPU" in device.value:
core.set_property(device_name=device.value, properties={'INFERENCE_PRECISION_HINT': ov.Type.f32})
compiled_model = core.compile_model(model=model, device_name=device.value)
アテンション・センター・モデルで使用する画像を準備¶
アテンション・センター・モデルは、形状 (480, 640) の RGB 画像を入力として受け取ります。
class Image():
def __init__(self, model_input_image_shape, image_path=None, image=None):
self.model_input_image_shape = model_input_image_shape
self.image = None
self.real_input_image_shape = None
if image_path is not None:
self.image = cv2.imread(str(image_path))
self.real_input_image_shape = self.image.shape
elif image is not None:
self.image = image
self.real_input_image_shape = self.image.shape
else:
raise Exception("Sorry, image can't be found, please, specify image_path or image")
def prepare_image_tensor(self):
rgb_image = cv2.cvtColor(self.image, cv2.COLOR_BGR2RGB)
resized_image = cv2.resize(rgb_image, (self.model_input_image_shape[1], self.model_input_image_shape[0]))
image_tensor = tf.constant(np.expand_dims(resized_image, axis=0),
dtype=tf.float32)
return image_tensor
def scalt_center_to_real_image_shape(self, predicted_center):
new_center_y = round(predicted_center[0] * self.real_input_image_shape[1] / self.model_input_image_shape[1])
new_center_x = round(predicted_center[1] * self.real_input_image_shape[0] / self.model_input_image_shape[0])
return (int(new_center_y), int(new_center_x))
def draw_attention_center_point(self, predicted_center):
image_with_circle = cv2.circle(self.image,
predicted_center,
radius=10,
color=(3, 3, 255),
thickness=-1)
return image_with_circle
def print_image(self, predicted_center=None):
image_to_print = self.image
if predicted_center is not None:
image_to_print = self.draw_attention_center_point(predicted_center)
plt.imshow(cv2.cvtColor(image_to_print, cv2.COLOR_BGR2RGB))
入力画像のロード¶
ファイルのロードボタンを使用して入力画像をアップロードします。
import ipywidgets as widgets
load_file_widget = widgets.FileUpload(
accept="image/*", multiple=False, description="Image file",
)
load_file_widget
FileUpload(value=(), accept='image/*', description='Image file')
import io
import PIL
from urllib.request import urlretrieve
img_path = Path("data/coco.jpg")
img_path.parent.mkdir(parents=True, exist_ok=True)
urlretrieve(
"https://storage.openvinotoolkit.org/repositories/openvino_notebooks/data/data/image/coco.jpg",
img_path,
)
# read uploaded image
image = PIL.Image.open(io.BytesIO(list(load_file_widget.value.values())[-1]['content'])) if load_file_widget.value else PIL.Image.open(img_path)
image.convert("RGB")
input_image = Image((480, 640), image=(np.ascontiguousarray(image)[:, :, ::-1]).astype(np.uint8))
image_tensor = input_image.prepare_image_tensor()
input_image.print_image()
2024-02-09 23:49:38.816368: E tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.cc:266] failed call to cuInit: CUDA_ERROR_COMPAT_NOT_SUPPORTED_ON_DEVICE: forward compatibility was attempted on non supported HW
2024-02-09 23:49:38.816405: I tensorflow/compiler/xla/stream_executor/cuda/cuda_diagnostics.cc:168] retrieving CUDA diagnostic information for host: iotg-dev-workstation-07
2024-02-09 23:49:38.816409: I tensorflow/compiler/xla/stream_executor/cuda/cuda_diagnostics.cc:175] hostname: iotg-dev-workstation-07
2024-02-09 23:49:38.816551: I tensorflow/compiler/xla/stream_executor/cuda/cuda_diagnostics.cc:199] libcuda reported version is: 470.223.2
2024-02-09 23:49:38.816566: I tensorflow/compiler/xla/stream_executor/cuda/cuda_diagnostics.cc:203] kernel reported version is: 470.182.3
2024-02-09 23:49:38.816569: E tensorflow/compiler/xla/stream_executor/cuda/cuda_diagnostics.cc:312] kernel version 470.182.3 does not match DSO version 470.223.2 -- cannot find working devices in this configuration
OpenVINO IR モデルで結果を取得¶
output_layer = compiled_model.output(0)
# make inference, get result in input image resolution
res = compiled_model([image_tensor])[output_layer]
# scale point to original image resulution
predicted_center = input_image.scalt_center_to_real_image_shape(res[0])
print(f'Prediction attention center point {predicted_center}')
input_image.print_image(predicted_center)
Prediction attention center point (292, 277)