EfficientSAM と OpenVINO によるオブジェクトのセグメント化

この Jupyter ノートブックは、ローカルへのインストール後にのみ起動できます。

GitHub

セグメント・エニシング・モデル (SAM) は、数多くのビジョン・アプリケーション向けの強力なツールとして登場しました。ゼロショット転送と高い汎用性を実現する優れたパフォーマンスを実現する重要なコンポーネントは、広範囲にわたる高品質の SA-1B データセットでトレーニングされた超大規模トランスフォーマー・モデルです。SAM モデルは有益ではあるものの、計算コストが膨大であるため、その応用範囲は実世界のより広範なアプリケーションに限定されています。この制限に対処するため、複雑さを大幅に削減しながら適切なパフォーマンスを発揮する軽量 SAM モデルである EfficientSAM が提案されました。EfficientSAM のアイデアは、効果的な視覚表現学習のために SAM 画像エンコーダーから特徴を再構築することを学習するマスクされた画像の事前トレーニング (SAMI) を活用することに基づいています。

overview.png

overview.png

モデルの詳細については、論文モデルのウェブページ元のリポジトリーを参照してください

このチュートリアルでは、OpenVINO を使用して EfficientSAM を変換して実行する方法について説明します。また、NNCF を使用してモデルを量子化する方法も示します

目次

必要条件

%pip install -q "openvino>=2023.3.0" "nncf>=2.7" opencv-python matplotlib "gradio>=4.13" torch torchvision --extra-index-url https://download.pytorch.org/whl/cpu
DEPRECATION: pytorch-lightning 1.6.5 has a non-standard dependency specifier torch>=1.8.*. pip 24.1 will enforce this behaviour change. A possible replacement is to upgrade to a newer version of pytorch-lightning or contact the author to suggest that they release a version with a conforming dependency specifiers. Discussion can be found at https://github.com/pypa/pip/issues/12063
Note: you may need to restart the kernel to use updated packages.
from pathlib import Path

repo_dir = Path("EfficientSAM")

if not repo_dir.exists():
    !git clone https://github.com/yformer/EfficientSAM.git
%cd $repo_dir
Cloning into 'EfficientSAM'...
remote: Enumerating objects: 424, done.
remote: Counting objects:   0% (1/140)
remote: Counting objects:   1% (2/140)
remote: Counting objects:   2% (3/140)
remote: Counting objects:   3% (5/140)
remote: Counting objects:   4% (6/140)
remote: Counting objects:   5% (7/140)
remote: Counting objects:   6% (9/140)
remote: Counting objects:   7% (10/140)
remote: Counting objects:   8% (12/140)
remote: Counting objects:   9% (13/140)
remote: Counting objects:  10% (14/140)
remote: Counting objects:  11% (16/140)
remote: Counting objects:  12% (17/140)
remote: Counting objects:  13% (19/140)
remote: Counting objects:  14% (20/140)
remote: Counting objects:  15% (21/140)
remote: Counting objects:  16% (23/140)
remote: Counting objects:  17% (24/140)
remote: Counting objects:  18% (26/140)
remote: Counting objects:  19% (27/140)
remote: Counting objects:  20% (28/140)
remote: Counting objects:  21% (30/140)
remote: Counting objects:  22% (31/140)
remote: Counting objects:  23% (33/140)
remote: Counting objects:  24% (34/140)
remote: Counting objects:  25% (35/140)
remote: Counting objects:  26% (37/140)
remote: Counting objects:  27% (38/140)
remote: Counting objects:  28% (40/140)
remote: Counting objects:  29% (41/140)
remote: Counting objects:  30% (42/140)
remote: Counting objects:  31% (44/140)
remote: Counting objects:  32% (45/140)
remote: Counting objects:  33% (47/140)
remote: Counting objects:  34% (48/140)
remote: Counting objects:  35% (49/140)
remote: Counting objects:  36% (51/140)
remote: Counting objects:  37% (52/140)
remote: Counting objects:  38% (54/140)
remote: Counting objects:  39% (55/140)
remote: Counting objects:  40% (56/140)
remote: Counting objects:  41% (58/140)
remote: Counting objects:  42% (59/140)
remote: Counting objects:  43% (61/140)
remote: Counting objects:  44% (62/140)
remote: Counting objects:  45% (63/140)
remote: Counting objects:  46% (65/140)
remote: Counting objects:  47% (66/140)
remote: Counting objects:  48% (68/140)
remote: Counting objects:  49% (69/140)
remote: Counting objects:  50% (70/140)
remote: Counting objects:  51% (72/140)
remote: Counting objects:  52% (73/140)
remote: Counting objects:  53% (75/140)
remote: Counting objects:  54% (76/140)
remote: Counting objects:  55% (77/140)
remote: Counting objects:  56% (79/140)
remote: Counting objects:  57% (80/140)
remote: Counting objects:  58% (82/140)
remote: Counting objects:  59% (83/140)
remote: Counting objects:  60% (84/140)
remote: Counting objects:  61% (86/140)
remote: Counting objects:  62% (87/140)
remote: Counting objects:  63% (89/140)
remote: Counting objects:  64% (90/140)
remote: Counting objects:  65% (91/140)
remote: Counting objects:  66% (93/140)
remote: Counting objects:  67% (94/140)
remote: Counting objects:  68% (96/140)
remote: Counting objects:  69% (97/140)
remote: Counting objects:  70% (98/140)
remote: Counting objects:  71% (100/140)
remote: Counting objects:  72% (101/140)
remote: Counting objects:  73% (103/140)
remote: Counting objects:  74% (104/140)
remote: Counting objects:  75% (105/140)
remote: Counting objects:  76% (107/140)
remote: Counting objects:  77% (108/140)
remote: Counting objects:  78% (110/140)
remote: Counting objects:  79% (111/140)
remote: Counting objects:  80% (112/140)
remote: Counting objects:  81% (114/140)
remote: Counting objects:  82% (115/140)
remote: Counting objects:  83% (117/140)
remote: Counting objects:  84% (118/140)
remote: Counting objects:  85% (119/140)
remote: Counting objects:  86% (121/140)
remote: Counting objects:  87% (122/140)
remote: Counting objects:  88% (124/140)
remote: Counting objects:  89% (125/140)
remote: Counting objects:  90% (126/140)
remote: Counting objects:  91% (128/140)
remote: Counting objects:  92% (129/140)
remote: Counting objects:  93% (131/140)
remote: Counting objects:  94% (132/140)
remote: Counting objects:  95% (133/140)
remote: Counting objects:  96% (135/140)
remote: Counting objects:  97% (136/140)
remote: Counting objects:  98% (138/140)
remote: Counting objects:  99% (139/140)
remote: Counting objects: 100% (140/140)
remote: Counting objects: 100% (140/140), done.

remote: Compressing objects: 1% (1/85)

remote: Compressing objects:   2% (2/85)
remote: Compressing objects:   3% (3/85)
remote: Compressing objects:   4% (4/85)
remote: Compressing objects:   5% (5/85)
remote: Compressing objects:   7% (6/85)
remote: Compressing objects:   8% (7/85)
remote: Compressing objects:   9% (8/85)
remote: Compressing objects:  10% (9/85)
remote: Compressing objects:  11% (10/85)
remote: Compressing objects:  12% (11/85)
remote: Compressing objects:  14% (12/85)
remote: Compressing objects:  15% (13/85)
remote: Compressing objects:  16% (14/85)
remote: Compressing objects:  17% (15/85)
remote: Compressing objects:  18% (16/85)
remote: Compressing objects:  20% (17/85)
remote: Compressing objects:  21% (18/85)
remote: Compressing objects:  22% (19/85)
remote: Compressing objects:  23% (20/85)
remote: Compressing objects:  24% (21/85)
remote: Compressing objects:  25% (22/85)
remote: Compressing objects:  27% (23/85)
remote: Compressing objects:  28% (24/85)
remote: Compressing objects:  29% (25/85)
remote: Compressing objects:  30% (26/85)
remote: Compressing objects:  31% (27/85)
remote: Compressing objects:  32% (28/85)
remote: Compressing objects:  34% (29/85)
remote: Compressing objects:  35% (30/85)
remote: Compressing objects:  36% (31/85)
remote: Compressing objects:  37% (32/85)
remote: Compressing objects:  38% (33/85)
remote: Compressing objects:  40% (34/85)
remote: Compressing objects:  41% (35/85)
remote: Compressing objects:  42% (36/85)
remote: Compressing objects:  43% (37/85)
remote: Compressing objects:  44% (38/85)
remote: Compressing objects:  45% (39/85)
remote: Compressing objects:  47% (40/85)
remote: Compressing objects:  48% (41/85)
remote: Compressing objects:  49% (42/85)
remote: Compressing objects:  50% (43/85)
remote: Compressing objects:  51% (44/85)
remote: Compressing objects:  52% (45/85)
remote: Compressing objects:  54% (46/85)
remote: Compressing objects:  55% (47/85)
remote: Compressing objects:  56% (48/85)
remote: Compressing objects:  57% (49/85)
remote: Compressing objects:  58% (50/85)
remote: Compressing objects:  60% (51/85)
remote: Compressing objects:  61% (52/85)
remote: Compressing objects:  62% (53/85)
remote: Compressing objects:  63% (54/85)
remote: Compressing objects:  64% (55/85)
remote: Compressing objects:  65% (56/85)
remote: Compressing objects:  67% (57/85)
remote: Compressing objects:  68% (58/85)
remote: Compressing objects:  69% (59/85)
remote: Compressing objects:  70% (60/85)
remote: Compressing objects:  71% (61/85)
remote: Compressing objects:  72% (62/85)
remote: Compressing objects:  74% (63/85)
remote: Compressing objects:  75% (64/85)
remote: Compressing objects:  76% (65/85)
remote: Compressing objects:  77% (66/85)
remote: Compressing objects:  78% (67/85)
remote: Compressing objects:  80% (68/85)
remote: Compressing objects:  81% (69/85)
remote: Compressing objects:  82% (70/85)
remote: Compressing objects:  83% (71/85)
remote: Compressing objects:  84% (72/85)
remote: Compressing objects:  85% (73/85)
remote: Compressing objects:  87% (74/85)
remote: Compressing objects:  88% (75/85)
remote: Compressing objects:  89% (76/85)
remote: Compressing objects:  90% (77/85)
remote: Compressing objects:  91% (78/85)
remote: Compressing objects:  92% (79/85)
remote: Compressing objects:  94% (80/85)
remote: Compressing objects:  95% (81/85)
remote: Compressing objects:  96% (82/85)
remote: Compressing objects:  97% (83/85)
remote: Compressing objects:  98% (84/85)
remote: Compressing objects: 100% (85/85)
remote: Compressing objects: 100% (85/85), done.
Receiving objects:   0% (1/424)
Receiving objects:   1% (5/424)
Receiving objects:   2% (9/424)
Receiving objects:   3% (13/424)
Receiving objects:   4% (17/424)
Receiving objects:   5% (22/424)
Receiving objects:   6% (26/424)
Receiving objects:   6% (26/424), 3.20 MiB | 3.16 MiB/s
Receiving objects:   6% (26/424), 6.62 MiB | 3.26 MiB/s
Receiving objects:   6% (26/424), 10.04 MiB | 3.29 MiB/s
Receiving objects:   6% (26/424), 13.46 MiB | 3.31 MiB/s
Receiving objects:   6% (26/424), 16.88 MiB | 3.35 MiB/s
Receiving objects:   6% (26/424), 20.31 MiB | 3.35 MiB/s
Receiving objects:   6% (26/424), 23.73 MiB | 3.35 MiB/s
Receiving objects:   6% (26/424), 27.15 MiB | 3.35 MiB/s
Receiving objects:   6% (26/424), 30.57 MiB | 3.35 MiB/s
Receiving objects:   6% (26/424), 33.99 MiB | 3.35 MiB/s
Receiving objects:   6% (27/424), 35.70 MiB | 3.35 MiB/s
Receiving objects:   6% (29/424), 40.83 MiB | 3.35 MiB/s
Receiving objects:   6% (29/424), 44.21 MiB | 3.34 MiB/s
Receiving objects:   6% (29/424), 47.68 MiB | 3.35 MiB/s
Receiving objects:   6% (29/424), 51.09 MiB | 3.35 MiB/s
Receiving objects:   6% (29/424), 54.51 MiB | 3.35 MiB/s
Receiving objects:   6% (29/424), 57.93 MiB | 3.35 MiB/s
Receiving objects:   6% (29/424), 61.29 MiB | 3.34 MiB/s
Receiving objects:   6% (29/424), 64.71 MiB | 3.35 MiB/s
Receiving objects:   6% (29/424), 68.14 MiB | 3.35 MiB/s
Receiving objects:   6% (29/424), 71.57 MiB | 3.35 MiB/s
Receiving objects:   7% (30/424), 71.57 MiB | 3.35 MiB/s
Receiving objects:   8% (34/424), 71.57 MiB | 3.35 MiB/s
Receiving objects:   9% (39/424), 71.57 MiB | 3.35 MiB/s
Receiving objects:  10% (43/424), 71.57 MiB | 3.35 MiB/s
Receiving objects:  11% (47/424), 71.57 MiB | 3.35 MiB/s
Receiving objects:  11% (50/424), 74.99 MiB | 3.35 MiB/s
Receiving objects:  12% (51/424), 76.69 MiB | 3.36 MiB/s
Receiving objects:  13% (56/424), 76.69 MiB | 3.36 MiB/s
Receiving objects:  14% (60/424), 76.69 MiB | 3.36 MiB/s
Receiving objects:  15% (64/424), 76.69 MiB | 3.36 MiB/s
Receiving objects:  16% (68/424), 76.69 MiB | 3.36 MiB/s
Receiving objects:  17% (73/424), 76.69 MiB | 3.36 MiB/s
Receiving objects:  18% (77/424), 76.69 MiB | 3.36 MiB/s
Receiving objects:  19% (81/424), 76.69 MiB | 3.36 MiB/s
Receiving objects:  20% (85/424), 76.69 MiB | 3.36 MiB/s
Receiving objects:  21% (90/424), 76.69 MiB | 3.36 MiB/s
Receiving objects:  22% (94/424), 76.69 MiB | 3.36 MiB/s
Receiving objects:  23% (98/424), 76.69 MiB | 3.36 MiB/s
Receiving objects:  24% (102/424), 76.69 MiB | 3.36 MiB/s
Receiving objects:  25% (106/424), 76.69 MiB | 3.36 MiB/s
Receiving objects:  26% (111/424), 76.69 MiB | 3.36 MiB/s
Receiving objects:  27% (115/424), 76.69 MiB | 3.36 MiB/s
Receiving objects:  27% (115/424), 81.82 MiB | 3.35 MiB/s
Receiving objects:  27% (115/424), 83.54 MiB | 3.35 MiB/s
Receiving objects:  27% (115/424), 86.96 MiB | 3.35 MiB/s
Receiving objects:  27% (115/424), 90.43 MiB | 3.35 MiB/s
Receiving objects:  27% (115/424), 93.86 MiB | 3.35 MiB/s
Receiving objects:  27% (115/424), 97.28 MiB | 3.35 MiB/s
Receiving objects:  27% (115/424), 100.64 MiB | 3.34 MiB/s
Receiving objects:  27% (115/424), 104.12 MiB | 3.36 MiB/s
Receiving objects:  27% (115/424), 107.54 MiB | 3.35 MiB/s
Receiving objects:  27% (115/424), 110.96 MiB | 3.35 MiB/s
Receiving objects:  27% (115/424), 114.38 MiB | 3.35 MiB/s
Receiving objects:  27% (115/424), 117.80 MiB | 3.35 MiB/s
Receiving objects:  27% (115/424), 121.22 MiB | 3.35 MiB/s
Receiving objects:  27% (115/424), 124.64 MiB | 3.35 MiB/s
Receiving objects:  27% (115/424), 128.05 MiB | 3.35 MiB/s
Receiving objects:  27% (115/424), 131.48 MiB | 3.35 MiB/s
Receiving objects:  27% (115/424), 134.90 MiB | 3.35 MiB/s
Receiving objects:  27% (115/424), 138.32 MiB | 3.35 MiB/s
Receiving objects:  27% (115/424), 141.75 MiB | 3.35 MiB/s
Receiving objects:  27% (115/424), 145.17 MiB | 3.35 MiB/s
Receiving objects:  27% (115/424), 148.59 MiB | 3.35 MiB/s
Receiving objects:  27% (115/424), 152.01 MiB | 3.35 MiB/s
Receiving objects:  27% (115/424), 155.44 MiB | 3.35 MiB/s
Receiving objects:  27% (115/424), 158.86 MiB | 3.35 MiB/s
Receiving objects:  27% (115/424), 160.57 MiB | 3.35 MiB/s
Receiving objects:  27% (115/424), 164.00 MiB | 3.35 MiB/s
Receiving objects:  27% (115/424), 167.41 MiB | 3.35 MiB/s
Receiving objects:  27% (115/424), 170.83 MiB | 3.35 MiB/s
Receiving objects:  28% (119/424), 170.83 MiB | 3.35 MiB/s
Receiving objects:  29% (123/424), 170.83 MiB | 3.35 MiB/s
Receiving objects:  30% (128/424), 170.83 MiB | 3.35 MiB/s
Receiving objects:  31% (132/424), 170.83 MiB | 3.35 MiB/s
Receiving objects:  32% (136/424), 170.83 MiB | 3.35 MiB/s
Receiving objects:  33% (140/424), 170.83 MiB | 3.35 MiB/s
Receiving objects:  34% (145/424), 170.83 MiB | 3.35 MiB/s
Receiving objects:  35% (149/424), 170.83 MiB | 3.35 MiB/s
Receiving objects:  36% (153/424), 170.83 MiB | 3.35 MiB/s
Receiving objects:  37% (157/424), 170.83 MiB | 3.35 MiB/s
Receiving objects:  38% (162/424), 170.83 MiB | 3.35 MiB/s
Receiving objects:  39% (166/424), 170.83 MiB | 3.35 MiB/s
Receiving objects:  40% (170/424), 170.83 MiB | 3.35 MiB/s
Receiving objects:  41% (174/424), 170.83 MiB | 3.35 MiB/s
Receiving objects:  42% (179/424), 170.83 MiB | 3.35 MiB/s
Receiving objects:  43% (183/424), 170.83 MiB | 3.35 MiB/s
Receiving objects:  44% (187/424), 170.83 MiB | 3.35 MiB/s
Receiving objects:  45% (191/424), 170.83 MiB | 3.35 MiB/s
Receiving objects:  46% (196/424), 170.83 MiB | 3.35 MiB/s
Receiving objects:  47% (200/424), 170.83 MiB | 3.35 MiB/s
Receiving objects:  48% (204/424), 170.83 MiB | 3.35 MiB/s
Receiving objects:  49% (208/424), 170.83 MiB | 3.35 MiB/s
Receiving objects:  50% (212/424), 170.83 MiB | 3.35 MiB/s
Receiving objects:  51% (217/424), 172.54 MiB | 3.35 MiB/s
Receiving objects:  52% (221/424), 172.54 MiB | 3.35 MiB/s
Receiving objects:  53% (225/424), 172.54 MiB | 3.35 MiB/s
Receiving objects:  54% (229/424), 172.54 MiB | 3.35 MiB/s
Receiving objects:  55% (234/424), 172.54 MiB | 3.35 MiB/s
Receiving objects:  56% (238/424), 172.54 MiB | 3.35 MiB/s
Receiving objects:  56% (240/424), 174.25 MiB | 3.35 MiB/s
Receiving objects:  56% (240/424), 177.68 MiB | 3.35 MiB/s
Receiving objects:  56% (240/424), 181.10 MiB | 3.35 MiB/s
Receiving objects:  56% (240/424), 184.52 MiB | 3.35 MiB/s
Receiving objects:  56% (240/424), 187.94 MiB | 3.35 MiB/s
Receiving objects:  56% (240/424), 191.36 MiB | 3.35 MiB/s
Receiving objects:  56% (240/424), 194.79 MiB | 3.35 MiB/s
Receiving objects:  56% (240/424), 198.21 MiB | 3.35 MiB/s
Receiving objects:  56% (240/424), 201.63 MiB | 3.35 MiB/s
Receiving objects:  56% (240/424), 205.06 MiB | 3.35 MiB/s
Receiving objects:  56% (240/424), 208.47 MiB | 3.35 MiB/s
Receiving objects:  57% (242/424), 208.47 MiB | 3.35 MiB/s
Receiving objects:  58% (246/424), 208.47 MiB | 3.35 MiB/s
Receiving objects:  59% (251/424), 208.47 MiB | 3.35 MiB/s
Receiving objects:  60% (255/424), 208.47 MiB | 3.35 MiB/s
Receiving objects:  61% (259/424), 208.47 MiB | 3.35 MiB/s
Receiving objects:  61% (262/424), 211.89 MiB | 3.35 MiB/s
Receiving objects:  62% (263/424), 211.89 MiB | 3.35 MiB/s
Receiving objects:  62% (263/424), 215.32 MiB | 3.35 MiB/s
Receiving objects:  62% (264/424), 218.74 MiB | 3.35 MiB/s
Receiving objects:  62% (265/424), 222.15 MiB | 3.35 MiB/s
Receiving objects:  63% (268/424), 222.15 MiB | 3.35 MiB/s
Receiving objects:  64% (272/424), 222.15 MiB | 3.35 MiB/s
Receiving objects:  65% (276/424), 222.15 MiB | 3.35 MiB/s
Receiving objects:  66% (280/424), 222.15 MiB | 3.35 MiB/s
Receiving objects:  67% (285/424), 222.15 MiB | 3.35 MiB/s
Receiving objects:  68% (289/424), 222.15 MiB | 3.35 MiB/s
Receiving objects:  69% (293/424), 222.15 MiB | 3.35 MiB/s
Receiving objects:  70% (297/424), 222.15 MiB | 3.35 MiB/s
Receiving objects:  71% (302/424), 222.15 MiB | 3.35 MiB/s
Receiving objects:  72% (306/424), 222.15 MiB | 3.35 MiB/s
Receiving objects:  73% (310/424), 222.15 MiB | 3.35 MiB/s
Receiving objects:  74% (314/424), 222.15 MiB | 3.35 MiB/s
Receiving objects:  75% (318/424), 222.15 MiB | 3.35 MiB/s
Receiving objects:  76% (323/424), 222.15 MiB | 3.35 MiB/s
Receiving objects:  77% (327/424), 222.15 MiB | 3.35 MiB/s
Receiving objects:  78% (331/424), 222.15 MiB | 3.35 MiB/s
Receiving objects:  79% (335/424), 222.15 MiB | 3.35 MiB/s
Receiving objects:  80% (340/424), 222.15 MiB | 3.35 MiB/s
Receiving objects:  81% (344/424), 222.15 MiB | 3.35 MiB/s
Receiving objects:  82% (348/424), 222.15 MiB | 3.35 MiB/s
Receiving objects:  83% (352/424), 222.15 MiB | 3.35 MiB/s
Receiving objects:  84% (357/424), 222.15 MiB | 3.35 MiB/s
Receiving objects:  85% (361/424), 222.15 MiB | 3.35 MiB/s
Receiving objects:  86% (365/424), 222.15 MiB | 3.35 MiB/s
Receiving objects:  87% (369/424), 222.15 MiB | 3.35 MiB/s
Receiving objects:  87% (370/424), 225.57 MiB | 3.35 MiB/s
Receiving objects:  87% (370/424), 228.99 MiB | 3.35 MiB/s
Receiving objects:  87% (370/424), 232.41 MiB | 3.35 MiB/s
Receiving objects:  87% (370/424), 235.83 MiB | 3.35 MiB/s
Receiving objects:  87% (370/424), 239.25 MiB | 3.36 MiB/s
Receiving objects:  87% (370/424), 240.97 MiB | 3.35 MiB/s
Receiving objects:  87% (370/424), 244.39 MiB | 3.35 MiB/s
Receiving objects:  87% (370/424), 247.81 MiB | 3.35 MiB/s
Receiving objects:  87% (370/424), 251.23 MiB | 3.35 MiB/s
Receiving objects:  87% (370/424), 254.65 MiB | 3.35 MiB/s
Receiving objects:  87% (370/424), 258.08 MiB | 3.35 MiB/s
Receiving objects:  87% (371/424), 261.50 MiB | 3.35 MiB/s
Receiving objects:  87% (371/424), 264.93 MiB | 3.36 MiB/s
Receiving objects:  87% (371/424), 268.34 MiB | 3.35 MiB/s
Receiving objects:  87% (371/424), 271.76 MiB | 3.35 MiB/s
Receiving objects:  87% (371/424), 275.18 MiB | 3.35 MiB/s
Receiving objects:  87% (371/424), 278.60 MiB | 3.35 MiB/s
Receiving objects:  87% (371/424), 282.02 MiB | 3.35 MiB/s
Receiving objects:  87% (371/424), 285.44 MiB | 3.35 MiB/s
Receiving objects:  87% (371/424), 288.86 MiB | 3.35 MiB/s
Receiving objects:  87% (371/424), 292.29 MiB | 3.36 MiB/s
Receiving objects:  87% (372/424), 294.00 MiB | 3.35 MiB/s
Receiving objects:  88% (374/424), 295.71 MiB | 3.35 MiB/s
Receiving objects:  89% (378/424), 295.71 MiB | 3.35 MiB/s
Receiving objects:  90% (382/424), 295.71 MiB | 3.35 MiB/s
Receiving objects:  91% (386/424), 295.71 MiB | 3.35 MiB/s
Receiving objects:  92% (391/424), 295.71 MiB | 3.35 MiB/s
Receiving objects:  93% (395/424), 295.71 MiB | 3.35 MiB/s
Receiving objects:  94% (399/424), 295.71 MiB | 3.35 MiB/s
Receiving objects:  95% (403/424), 295.71 MiB | 3.35 MiB/s
Receiving objects:  95% (407/424), 299.14 MiB | 3.35 MiB/s
Receiving objects:  95% (407/424), 302.56 MiB | 3.35 MiB/s
Receiving objects:  95% (407/424), 305.99 MiB | 3.35 MiB/s
Receiving objects:  95% (407/424), 309.40 MiB | 3.35 MiB/s
Receiving objects:  95% (407/424), 312.82 MiB | 3.35 MiB/s
Receiving objects:  95% (407/424), 316.24 MiB | 3.35 MiB/s
Receiving objects:  95% (407/424), 317.95 MiB | 3.35 MiB/s
Receiving objects:  95% (407/424), 321.36 MiB | 3.35 MiB/s
Receiving objects:  95% (407/424), 324.79 MiB | 3.35 MiB/s
Receiving objects:  95% (407/424), 328.18 MiB | 3.34 MiB/s
Receiving objects:  95% (407/424), 331.63 MiB | 3.35 MiB/s
Receiving objects:  96% (408/424), 333.34 MiB | 3.35 MiB/s
Receiving objects:  97% (412/424), 333.34 MiB | 3.35 MiB/s
Receiving objects:  98% (416/424), 333.34 MiB | 3.35 MiB/s
Receiving objects:  99% (420/424), 333.34 MiB | 3.35 MiB/s
remote: Total 424 (delta 84), reused 99 (delta 55), pack-reused 284
Receiving objects: 100% (424/424), 333.34 MiB | 3.35 MiB/s
Receiving objects: 100% (424/424), 334.57 MiB | 3.35 MiB/s, done.

Resolving deltas: 0% (0/226)

Resolving deltas:   4% (11/226)
Resolving deltas:   7% (17/226)
Resolving deltas:   9% (22/226)
Resolving deltas:  15% (35/226)
Resolving deltas:  17% (40/226)
Resolving deltas:  19% (44/226)
Resolving deltas:  23% (52/226)
Resolving deltas:  26% (59/226)
Resolving deltas:  28% (65/226)
Resolving deltas:  35% (81/226)
Resolving deltas:  36% (83/226)
Resolving deltas:  39% (89/226)
Resolving deltas:  42% (95/226)
Resolving deltas:  46% (104/226)
Resolving deltas:  50% (114/226)
Resolving deltas:  51% (116/226)
Resolving deltas:  55% (125/226)
Resolving deltas:  58% (133/226)
Resolving deltas:  59% (135/226)
Resolving deltas:  60% (136/226)
Resolving deltas:  61% (138/226)
Resolving deltas:  69% (158/226)
Resolving deltas:  83% (188/226)
Resolving deltas:  92% (208/226)
Resolving deltas:  94% (213/226)
Resolving deltas:  95% (215/226)
Resolving deltas:  96% (217/226)
Resolving deltas:  97% (220/226)
Resolving deltas:  98% (222/226)
Resolving deltas:  99% (224/226)
Resolving deltas: 100% (226/226)

Resolving deltas: 100% (226/226), done.

/opt/home/k8sworker/ci-ai/cibuilds/ov-notebook/OVNotebookOps-609/.workspace/scm/ov-notebook/notebooks/274-efficient-sam/EfficientSAM

PyTorch モデルのロード

リポジトリーにはいくつかのモデルが用意されています。

  • efficient-sam-vitt - 画像エンコーダーとして Vision Transformer Tiny (VIT-T) を搭載した EfficientSAM。EfficientSAM ファミリーの最小かつ最速のモデル。

  • efficient-sam-vits - 画像エンコーダーとして Vision Transformer Small (VIT-S) を搭載した EfficientSAM。Efficiency-sam-vitt よりも重いですが、より正確なモデルです。

EfficientSAM は、モデルとの対話のための統一されたインターフェースを提供します。つまり、ノートブックで提供されるモデルの変換と実行のすべての手順は、すべてのモデルで同じになります。以下では、例としてそのうちの 1 つを選択できます。

from efficient_sam.build_efficient_sam import build_efficient_sam_vitt, build_efficient_sam_vits
import zipfile

MODELS_LIST = {"efficient-sam-vitt": build_efficient_sam_vitt, "efficient-sam-vits": build_efficient_sam_vits}

# Since EfficientSAM-S checkpoint file is >100MB, we store the zip file.
with zipfile.ZipFile("weights/efficient_sam_vits.pt.zip", 'r') as zip_ref:
    zip_ref.extractall("weights")

サポートされるモデルから 1 つ選択してください。

import ipywidgets as widgets

model_ids = list(MODELS_LIST)

model_id = widgets.Dropdown(
    options=model_ids,
    value=model_ids[0],
    description="Model:",
    disabled=False,
)

model_id
Dropdown(description='Model:', options=('efficient-sam-vitt', 'efficient-sam-vits'), value='efficient-sam-vitt…

PyTorch モデルのビルド

pt_model = MODELS_LIST[model_id.value]()

pt_model.eval();

PyTorch モデル推論を実行

PyTorchモデルを選択してロードすると、その結果を確認できます

入力データを準備

まず、モデルの入力データを準備する必要があります。モデルには 3 つの入力があります。

  • 画像テンソル - 正規化された入力画像を持つテンソル。
  • 入力ポイント - ユーザーが指定したポイントを持つテンソル。これは、画像上の特定のポイント (画面上でのユーザーのクリックによって提供されるものなど) である場合もあれば、左上の角度ポイントと右下の角度ポイントの形式の境界ボックス座標である場合もあります。
  • 入力ラベル - 指定された各ポイントのポイントタイプの定義を含むテンソル。
    1. 通常のポイント
    2. 境界ボックスの左上のポイント
    3. 境界ボックスの右下のポイント
from PIL import Image

image_path = "figs/examples/dogs.jpg"

image = Image.open(image_path)
image
../_images/274-efficient-sam-with-output_11_0.png

入力および出力処理のヘルパーを定義

以下のコードは、モデル入力と後処理推論結果を準備するためのヘルパーを定義します。入力形式は、上記のモデルで受け入れられます。モデルは、画像上の各ピクセルのマスクロジットと、各領域の交差と結合のスコア、指定されたポイントにどれだけ近いかを予測します。結果の視覚化のためのヘルパー関数もいくつか用意しました。

import torch
import matplotlib.pyplot as plt
import numpy as np


def prepare_input(input_image, points, labels, torch_tensor=True):
    img_tensor = np.ascontiguousarray(input_image)[None, ...].astype(np.float32) / 255
    img_tensor = np.transpose(img_tensor, (0, 3, 1, 2))
    pts_sampled = np.reshape(np.ascontiguousarray(points), [1, 1, -1, 2])
    pts_labels = np.reshape(np.ascontiguousarray(labels), [1, 1, -1])
    if torch_tensor:
        img_tensor = torch.from_numpy(img_tensor)
        pts_sampled = torch.from_numpy(pts_sampled)
        pts_labels = torch.from_numpy(pts_labels)
    return img_tensor, pts_sampled, pts_labels


def postprocess_results(predicted_iou, predicted_logits):
    sorted_ids = np.argsort(-predicted_iou, axis=-1)
    predicted_iou = np.take_along_axis(predicted_iou, sorted_ids, axis=2)
    predicted_logits = np.take_along_axis(
        predicted_logits, sorted_ids[..., None, None], axis=2
    )

    return predicted_logits[0, 0, 0, :, :] >= 0


def show_points(coords, labels, ax, marker_size=375):
    pos_points = coords[labels == 1]
    neg_points = coords[labels == 0]
    ax.scatter(
        pos_points[:, 0],
        pos_points[:, 1],
        color="green",
        marker="*",
        s=marker_size,
        edgecolor="white",
        linewidth=1.25,
    )
    ax.scatter(
        neg_points[:, 0],
        neg_points[:, 1],
        color="red",
        marker="*",
        s=marker_size,
        edgecolor="white",
        linewidth=1.25,
    )


def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(
        plt.Rectangle((x0, y0), w, h, edgecolor="yellow", facecolor=(0, 0, 0, 0), lw=5)
    )


def show_anns(mask, ax):
    ax.set_autoscale_on(False)
    img = np.ones((mask.shape[0], mask.shape[1], 4))
    img[:, :, 3] = 0
    # for ann in mask:
    #     m = ann
    color_mask = np.concatenate([np.random.random(3), [0.5]])
    img[mask] = color_mask
    ax.imshow(img)

完全なモデル推論の例を以下に示します

input_points = [[580, 350], [650, 350]]
input_labels = [1, 1]

example_input = prepare_input(image, input_points, input_labels)

predicted_logits, predicted_iou = pt_model(*example_input)

predicted_mask = postprocess_results(predicted_iou.detach().numpy(), predicted_logits.detach().numpy())
image = Image.open(image_path)

plt.figure(figsize=(20, 20))
plt.axis("off")
plt.imshow(image)
show_points(np.array(input_points), np.array(input_labels), plt.gca())
plt.figure(figsize=(20, 20))
plt.axis("off")
plt.imshow(image)
show_anns(predicted_mask, plt.gca())
plt.title(f"PyTorch {model_id.value}", fontsize=18)
plt.show()
../_images/274-efficient-sam-with-output_16_0.png ../_images/274-efficient-sam-with-output_16_1.png

モデルを OpenVINO IR 形式に変換

OpenVINO は、OpenVINO モデル変換 API を使用して中間表現 (IR) 形式に変換することにより、PyTorch モデルをサポートします。openvino.convert_model 関数は、PyTorch モデルのインスタンスとサンプル入力 (正しいモデル操作のトレースと形状の推論に役立ちます) を受け入れ、OpenVINO フレームワークでモデルを表す openvino.Model オブジェクトを返します。この openvino.Model は、ov.Core.compile_model を使用してデバイスに読み込む準備ができており、openvino.save_model を使用してディスクに保存することもできます。

import openvino as ov

core = ov.Core()

ov_model_path = Path(f'{model_id.value}.xml')

if not ov_model_path.exists():
    ov_model = ov.convert_model(pt_model, example_input=example_input)
    ov.save_model(ov_model, ov_model_path)
else:
    ov_model = core.read_model(ov_model_path)
/opt/home/k8sworker/ci-ai/cibuilds/ov-notebook/OVNotebookOps-609/.workspace/scm/ov-notebook/notebooks/274-efficient-sam/EfficientSAM/efficient_sam/efficient_sam.py:220: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if (
/opt/home/k8sworker/ci-ai/cibuilds/ov-notebook/OVNotebookOps-609/.workspace/scm/ov-notebook/notebooks/274-efficient-sam/EfficientSAM/efficient_sam/efficient_sam_encoder.py:241: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  assert (
/opt/home/k8sworker/ci-ai/cibuilds/ov-notebook/OVNotebookOps-609/.workspace/scm/ov-notebook/notebooks/274-efficient-sam/EfficientSAM/efficient_sam/efficient_sam_encoder.py:163: TracerWarning: Converting a tensor to a Python float might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  size = int(math.sqrt(xy_num))
/opt/home/k8sworker/ci-ai/cibuilds/ov-notebook/OVNotebookOps-609/.workspace/scm/ov-notebook/notebooks/274-efficient-sam/EfficientSAM/efficient_sam/efficient_sam_encoder.py:164: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  assert size * size == xy_num
/opt/home/k8sworker/ci-ai/cibuilds/ov-notebook/OVNotebookOps-609/.workspace/scm/ov-notebook/notebooks/274-efficient-sam/EfficientSAM/efficient_sam/efficient_sam_encoder.py:166: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if size != h or size != w:
/opt/home/k8sworker/ci-ai/cibuilds/ov-notebook/OVNotebookOps-609/.workspace/scm/ov-notebook/notebooks/274-efficient-sam/EfficientSAM/efficient_sam/efficient_sam_encoder.py:251: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  assert x.shape[2] == num_patches
/opt/home/k8sworker/ci-ai/cibuilds/ov-notebook/OVNotebookOps-609/.workspace/scm/ov-notebook/notebooks/274-efficient-sam/EfficientSAM/efficient_sam/efficient_sam.py:85: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if num_pts > self.decoder_max_num_input_points:
/opt/home/k8sworker/ci-ai/cibuilds/ov-notebook/OVNotebookOps-609/.workspace/scm/ov-notebook/notebooks/274-efficient-sam/EfficientSAM/efficient_sam/efficient_sam.py:92: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  elif num_pts < self.decoder_max_num_input_points:
/opt/home/k8sworker/ci-ai/cibuilds/ov-notebook/OVNotebookOps-609/.workspace/scm/ov-notebook/notebooks/274-efficient-sam/EfficientSAM/efficient_sam/efficient_sam.py:126: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if output_w > 0 and output_h > 0:

OpenVINO モデル推論を実行

ドロップダウン・リストから推論デバイスを選択

device = widgets.Dropdown(
    options=core.available_devices + ["AUTO"],
    value="AUTO",
    description="Device:",
    disabled=False,
)

device
Dropdown(description='Device:', index=1, options=('CPU', 'AUTO'), value='AUTO')

OpenVINO モデルをコンパイル

compiled_model = core.compile_model(ov_model, device.value)

推論と結果の視覚化

OpenVINO モデルの予測を見てみましょう。

example_input = prepare_input(image, input_points, input_labels, torch_tensor=False)
result = compiled_model(example_input)

predicted_logits, predicted_iou = result[0], result[1]

predicted_mask = postprocess_results(predicted_iou, predicted_logits)

plt.figure(figsize=(20, 20))
plt.axis("off")
plt.imshow(image)
show_points(np.array(input_points), np.array(input_labels), plt.gca())
plt.figure(figsize=(20, 20))
plt.axis("off")
plt.imshow(image)
show_anns(predicted_mask, plt.gca())
plt.title(f"OpenVINO {model_id.value}", fontsize=18)
plt.show()
../_images/274-efficient-sam-with-output_24_0.png ../_images/274-efficient-sam-with-output_24_1.png

量子化

NNCF は、モデルグラフに量子化レイヤーを追加し、トレーニング・データセットのサブセットを使用してこれらの追加の量子化レイヤーのパラメーターを初期化することで、トレーニング後の量子化を可能にします。このフレームワークは、元のトレーニング・コードへの変更が最小限になるように設計されています。

最適化プロセスには次の手順が含まれます。

  1. 量子化用のキャリブレーション・データセットを作成します。

  2. nncf.quantize を実行して、量子化されたエンコーダーおよびデコーダーモデルを取得します。

  3. openvino.save_model 関数を使用して INT8 モデルをシリアル化します。

注: 量子化は時間とメモリーを消費する操作です。以下の量子化コードの実行には時間がかかる場合があります。

EfficientSAM 量子化を実行するかどうかを以下から選択してください。

to_quantize = widgets.Checkbox(
    value=True,
    description='Quantization',
    disabled=False,
)

to_quantize
Checkbox(value=True, description='Quantization')
import urllib.request

urllib.request.urlretrieve(
    url='https://raw.githubusercontent.com/openvinotoolkit/openvino_notebooks/main/notebooks/utils/skip_kernel_extension.py',
    filename='skip_kernel_extension.py'
)

%load_ext skip_kernel_extension

キャリブレーション・データセットの準備

最初のステップは、量子化のキャリブレーション・データセットを準備することです。量子化には coco128 データセットを使用します。通常、このデータセットはオブジェクト検出タスクを解決するために使用され、そのアノテーションは画像のボックス座標を提供します。この場合、ボックス座標はオブジェクト・セグメント化の入力ポイントとして機能し、以下のコードはデータセットをダウンロードし、EfficientSAM モデルの入力を準備する DataLoader を作成します。

%%skip not $to_quantize.value

from zipfile import ZipFile

urllib.request.urlretrieve(
    url='https://raw.githubusercontent.com/openvinotoolkit/openvino_notebooks/main/notebooks/utils/notebook_utils.py',
    filename='notebook_utils.py'
)

from notebook_utils import download_file

DATA_URL = "https://ultralytics.com/assets/coco128.zip"
OUT_DIR = Path('.')

download_file(DATA_URL, directory=OUT_DIR, show_progress=True)

if not (OUT_DIR / "coco128/images/train2017").exists():
    with ZipFile('coco128.zip' , "r") as zip_ref:
        zip_ref.extractall(OUT_DIR)
coco128.zip:   0%|          | 0.00/6.66M [00:00<?, ?B/s]
%%skip not $to_quantize.value

import torch.utils.data as data

class COCOLoader(data.Dataset):
    def __init__(self, images_path):
        self.images = list(Path(images_path).iterdir())
        self.labels_dir = images_path.parents[1] / 'labels' / images_path.name

    def get_points(self, image_path, image_width, image_height):
        file_name = image_path.name.replace('.jpg', '.txt')
        label_file =  self.labels_dir / file_name
        if not label_file.exists():
            x1, x2 = np.random.randint(low=0, high=image_width, size=(2, ))
            y1, y2 = np.random.randint(low=0, high=image_height, size=(2, ))
        else:
            with label_file.open("r") as f:
                box_line = f.readline()
            _, x1, y1, x2, y2 = box_line.split()
            x1 = int(float(x1) * image_width)
            y1 = int(float(y1) * image_height)
            x2 = int(float(x2) * image_width)
            y2 = int(float(y2) * image_height)
        return [[x1, y1], [x2, y2]]

    def __getitem__(self, index):
        image_path = self.images[index]
        image = Image.open(image_path)
        image = image.convert('RGB')
        w, h = image.size
        points = self.get_points(image_path, w, h)
        labels = [1, 1] if index % 2 == 0 else [2, 3]
        batched_images, batched_points, batched_point_labels = prepare_input(image, points, labels, torch_tensor=False)
        return {'batched_images': np.ascontiguousarray(batched_images)[0], 'batched_points': np.ascontiguousarray(batched_points)[0], 'batched_point_labels': np.ascontiguousarray(batched_point_labels)[0]}

    def __len__(self):
        return len(self.images)
%%skip not $to_quantize.value

coco_dataset = COCOLoader(OUT_DIR / 'coco128/images/train2017')
calibration_loader = torch.utils.data.DataLoader(coco_dataset)

モデル量子化の実行

nncf.quantize 関数は、モデルの量子化のインターフェイスを提供します。OpenVINO モデルのインスタンスと量子化データセットが必要です。オプションで、量子化プロセスの追加パラメーター (量子化のサンプル数、プリセット、無視される範囲など) を提供できます。EfficientSAM には、活性化の非対称量子化を必要とする非 ReLU 活性化関数が含まれています。さらに良い結果を得るため、mixed 量子化プリセットを使用します。モデル・エンコーダー部分はビジョン・トランスフォーマー・アーキテクチャーに基づいており、このアーキテクチャー・タイプに特別な最適化を有効にするには、model_typetransformer を指定する必要があります。

%%skip not $to_quantize.value

import nncf

calibration_dataset = nncf.Dataset(calibration_loader)

model = core.read_model(ov_model_path)
quantized_model = nncf.quantize(model,
                                calibration_dataset,
                                model_type=nncf.parameters.ModelType.TRANSFORMER,
                                preset=nncf.common.quantization.structs.QuantizationPreset.MIXED, subset_size=128)
print("model quantization finished")
INFO:nncf:NNCF initialized successfully. Supported frameworks detected: torch, tensorflow, onnx, openvino
2024-02-10 00:36:45.274123: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable TF_ENABLE_ONEDNN_OPTS=0.
2024-02-10 00:36:45.306722: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2024-02-10 00:36:45.919421: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
Output()
INFO:nncf:57 ignored nodes were found by name in the NNCFGraph
INFO:nncf:88 ignored nodes were found by name in the NNCFGraph
Output()
/opt/home/k8sworker/ci-ai/cibuilds/ov-notebook/OVNotebookOps-609/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages/nncf/experimental/tensor/tensor.py:84: RuntimeWarning: invalid value encountered in multiply
  return Tensor(self.data * unwrap_tensor_data(other))
Output()
model quantization finished

量子化モデル推論の検証

%%skip not $to_quantize.value

compiled_model = core.compile_model(quantized_model, device.value)

result = compiled_model(example_input)

predicted_logits, predicted_iou = result[0], result[1]

predicted_mask = postprocess_results(predicted_iou, predicted_logits)

plt.figure(figsize=(20, 20))
plt.axis("off")
plt.imshow(image)
show_points(np.array(input_points), np.array(input_labels), plt.gca())
plt.figure(figsize=(20, 20))
plt.axis("off")
plt.imshow(image)
show_anns(predicted_mask, plt.gca())
plt.title(f"OpenVINO INT8 {model_id.value}", fontsize=18)
plt.show()
../_images/274-efficient-sam-with-output_35_0.png ../_images/274-efficient-sam-with-output_35_1.png

量子化モデルをディスクに保存

%%skip not $to_quantize.value

quantized_model_path = Path(f"{model_id.value}_int8.xml")
ov.save_model(quantized_model, quantized_model_path)

量子化モデルのサイズを比較

%%skip not $to_quantize.value

fp16_weights = ov_model_path.with_suffix('.bin')
quantized_weights = quantized_model_path.with_suffix('.bin')

print(f"Size of FP16 model is {fp16_weights.stat().st_size / 1024 / 1024:.2f} MB")
print(f"Size of INT8 quantized model is {quantized_weights.stat().st_size / 1024 / 1024:.2f} MB")
print(f"Compression rate for INT8 model: {fp16_weights.stat().st_size / quantized_weights.stat().st_size:.3f}")
Size of FP16 model is 21.50 MB
Size of INT8 quantized model is 10.96 MB
Compression rate for INT8 model: 1.962

FP16 モデルと INT8 モデルの推論時間を比較

FP16 モデルと INT8 モデルの推論パフォーマンスを測定するには、bencmark_app を使用します。

注: 最も正確なパフォーマンス推定を行うには、他のアプリケーションを閉じた後、ターミナル/コマンドプロンプトで benchmark_app を実行することを推奨します。

!benchmark_app -m $ov_model_path -d $device.value -data_shape "batched_images[1,3,512,512],batched_points[1,1,2,2],batched_point_labels[1,1,2]" -t 15
[Step 1/11] Parsing and validating input arguments
[ INFO ] Parsing input parameters
[Step 2/11] Loading OpenVINO Runtime
[ INFO ] OpenVINO:
[ INFO ] Build ................................. 2023.3.0-13775-ceeafaf64f3-releases/2023/3
[ INFO ]
[ INFO ] Device info:
[ INFO ] AUTO
[ INFO ] Build ................................. 2023.3.0-13775-ceeafaf64f3-releases/2023/3
[ INFO ]
[ INFO ]
[Step 3/11] Setting device configuration
[ WARNING ] Performance hint was not explicitly specified in command line. Device(AUTO) performance hint will be set to PerformanceMode.THROUGHPUT.
[Step 4/11] Reading model files
[ INFO ] Loading model files
[ INFO ] Read model took 42.80 ms
[ INFO ] Original model I/O parameters:
[ INFO ] Model inputs:
[ INFO ]     batched_images (node: batched_images) : f32 / [...] / [?,?,?,?]
[ INFO ]     batched_points (node: batched_points) : i64 / [...] / [?,?,?,?]
[ INFO ]     batched_point_labels (node: batched_point_labels) : i64 / [...] / [?,?,?]
[ INFO ] Model outputs:
[ INFO ]     133 (node: aten::reshape/Reshape_3) : f32 / [...] / [?,?,?,?,?]
[ INFO ]     135 (node: aten::reshape/Reshape_2) : f32 / [...] / [?,?,?]
[Step 5/11] Resizing model to match image sizes and given batch
[ INFO ] Model batch size: 1
[Step 6/11] Configuring input of the model
[ INFO ] Model inputs:
[ INFO ]     batched_images (node: batched_images) : f32 / [...] / [?,?,?,?]
[ INFO ]     batched_points (node: batched_points) : i64 / [...] / [?,?,?,?]
[ INFO ]     batched_point_labels (node: batched_point_labels) : i64 / [...] / [?,?,?]
[ INFO ] Model outputs:
[ INFO ]     133 (node: aten::reshape/Reshape_3) : f32 / [...] / [?,?,?,?,?]
[ INFO ]     135 (node: aten::reshape/Reshape_2) : f32 / [...] / [?,?,?]
[Step 7/11] Loading the model to the device
[ INFO ] Compile model took 1174.09 ms
[Step 8/11] Querying optimal runtime parameters
[ INFO ] Model:
[ INFO ]   NETWORK_NAME: Model0
[ INFO ]   EXECUTION_DEVICES: ['CPU']
[ INFO ]   PERFORMANCE_HINT: PerformanceMode.THROUGHPUT
[ INFO ]   OPTIMAL_NUMBER_OF_INFER_REQUESTS: 6
[ INFO ]   MULTI_DEVICE_PRIORITIES: CPU
[ INFO ]   CPU:
[ INFO ]     AFFINITY: Affinity.CORE
[ INFO ]     CPU_DENORMALS_OPTIMIZATION: False
[ INFO ]     CPU_SPARSE_WEIGHTS_DECOMPRESSION_RATE: 1.0
[ INFO ]     ENABLE_CPU_PINNING: True
[ INFO ]     ENABLE_HYPER_THREADING: True
[ INFO ]     EXECUTION_DEVICES: ['CPU']
[ INFO ]     EXECUTION_MODE_HINT: ExecutionMode.PERFORMANCE
[ INFO ]     INFERENCE_NUM_THREADS: 24
[ INFO ]     INFERENCE_PRECISION_HINT: <Type: 'float32'>
[ INFO ]     NETWORK_NAME: Model0
[ INFO ]     NUM_STREAMS: 6
[ INFO ]     OPTIMAL_NUMBER_OF_INFER_REQUESTS: 6
[ INFO ]     PERFORMANCE_HINT: THROUGHPUT
[ INFO ]     PERFORMANCE_HINT_NUM_REQUESTS: 0
[ INFO ]     PERF_COUNT: NO
[ INFO ]     SCHEDULING_CORE_TYPE: SchedulingCoreType.ANY_CORE
[ INFO ]   MODEL_PRIORITY: Priority.MEDIUM
[ INFO ]   LOADED_FROM_CACHE: False
[Step 9/11] Creating infer requests and preparing input tensors
[ WARNING ] No input files were given for input 'batched_images'!. This input will be filled with random values!
[ WARNING ] No input files were given for input 'batched_points'!. This input will be filled with random values!
[ WARNING ] No input files were given for input 'batched_point_labels'!. This input will be filled with random values!
[ INFO ] Fill input 'batched_images' with random values
[ INFO ] Fill input 'batched_points' with random values
[ INFO ] Fill input 'batched_point_labels' with random values
[Step 10/11] Measuring performance (Start inference asynchronously, 6 inference requests, limits: 15000 ms duration)
[ INFO ] Benchmarking in full mode (inputs filling are included in measurement loop).
[ INFO ] First inference took 644.38 ms
[Step 11/11] Dumping statistics report
[ INFO ] Execution Devices:['CPU']
[ INFO ] Count:            50 iterations
[ INFO ] Duration:         16093.09 ms
[ INFO ] Latency:
[ INFO ]    Median:        1872.68 ms
[ INFO ]    Average:       1848.02 ms
[ INFO ]    Min:           1047.31 ms
[ INFO ]    Max:           1952.79 ms
[ INFO ] Throughput:   3.11 FPS
if to_quantize.value:
                                            !benchmark_app -m $quantized_model_path -d $device.value -data_shape "batched_images[1,3,512,512],batched_points[1,1,2,2],batched_point_labels[1,1,2]" -t 15
[Step 1/11] Parsing and validating input arguments
[ INFO ] Parsing input parameters
[Step 2/11] Loading OpenVINO Runtime
[ INFO ] OpenVINO:
[ INFO ] Build ................................. 2023.3.0-13775-ceeafaf64f3-releases/2023/3
[ INFO ]
[ INFO ] Device info:
[ INFO ] AUTO
[ INFO ] Build ................................. 2023.3.0-13775-ceeafaf64f3-releases/2023/3
[ INFO ]
[ INFO ]
[Step 3/11] Setting device configuration
[ WARNING ] Performance hint was not explicitly specified in command line. Device(AUTO) performance hint will be set to PerformanceMode.THROUGHPUT.
[Step 4/11] Reading model files
[ INFO ] Loading model files
[ INFO ] Read model took 65.77 ms
[ INFO ] Original model I/O parameters:
[ INFO ] Model inputs:
[ INFO ]     batched_images (node: batched_images) : f32 / [...] / [?,?,?,?]
[ INFO ]     batched_points (node: batched_points) : i64 / [...] / [?,?,?,?]
[ INFO ]     batched_point_labels (node: batched_point_labels) : i64 / [...] / [?,?,?]
[ INFO ] Model outputs:
[ INFO ]     133 (node: aten::reshape/Reshape_3) : f32 / [...] / [?,?,?,?,?]
[ INFO ]     135 (node: aten::reshape/Reshape_2) : f32 / [...] / [?,?,?]
[Step 5/11] Resizing model to match image sizes and given batch
[ INFO ] Model batch size: 1
[Step 6/11] Configuring input of the model
[ INFO ] Model inputs:
[ INFO ]     batched_images (node: batched_images) : f32 / [...] / [?,?,?,?]
[ INFO ]     batched_points (node: batched_points) : i64 / [...] / [?,?,?,?]
[ INFO ]     batched_point_labels (node: batched_point_labels) : i64 / [...] / [?,?,?]
[ INFO ] Model outputs:
[ INFO ]     133 (node: aten::reshape/Reshape_3) : f32 / [...] / [?,?,?,?,?]
[ INFO ]     135 (node: aten::reshape/Reshape_2) : f32 / [...] / [?,?,?]
[Step 7/11] Loading the model to the device
[ INFO ] Compile model took 1539.94 ms
[Step 8/11] Querying optimal runtime parameters
[ INFO ] Model:
[ INFO ]   NETWORK_NAME: Model0
[ INFO ]   EXECUTION_DEVICES: ['CPU']
[ INFO ]   PERFORMANCE_HINT: PerformanceMode.THROUGHPUT
[ INFO ]   OPTIMAL_NUMBER_OF_INFER_REQUESTS: 6
[ INFO ]   MULTI_DEVICE_PRIORITIES: CPU
[ INFO ]   CPU:
[ INFO ]     AFFINITY: Affinity.CORE
[ INFO ]     CPU_DENORMALS_OPTIMIZATION: False
[ INFO ]     CPU_SPARSE_WEIGHTS_DECOMPRESSION_RATE: 1.0
[ INFO ]     ENABLE_CPU_PINNING: True
[ INFO ]     ENABLE_HYPER_THREADING: True
[ INFO ]     EXECUTION_DEVICES: ['CPU']
[ INFO ]     EXECUTION_MODE_HINT: ExecutionMode.PERFORMANCE
[ INFO ]     INFERENCE_NUM_THREADS: 24
[ INFO ]     INFERENCE_PRECISION_HINT: <Type: 'float32'>
[ INFO ]     NETWORK_NAME: Model0
[ INFO ]     NUM_STREAMS: 6
[ INFO ]     OPTIMAL_NUMBER_OF_INFER_REQUESTS: 6
[ INFO ]     PERFORMANCE_HINT: THROUGHPUT
[ INFO ]     PERFORMANCE_HINT_NUM_REQUESTS: 0
[ INFO ]     PERF_COUNT: NO
[ INFO ]     SCHEDULING_CORE_TYPE: SchedulingCoreType.ANY_CORE
[ INFO ]   MODEL_PRIORITY: Priority.MEDIUM
[ INFO ]   LOADED_FROM_CACHE: False
[Step 9/11] Creating infer requests and preparing input tensors
[ WARNING ] No input files were given for input 'batched_images'!. This input will be filled with random values!
[ WARNING ] No input files were given for input 'batched_points'!. This input will be filled with random values!
[ WARNING ] No input files were given for input 'batched_point_labels'!. This input will be filled with random values!
[ INFO ] Fill input 'batched_images' with random values
[ INFO ] Fill input 'batched_points' with random values
[ INFO ] Fill input 'batched_point_labels' with random values
[Step 10/11] Measuring performance (Start inference asynchronously, 6 inference requests, limits: 15000 ms duration)
[ INFO ] Benchmarking in full mode (inputs filling are included in measurement loop).
[ INFO ] First inference took 578.80 ms
[Step 11/11] Dumping statistics report
[ INFO ] Execution Devices:['CPU']
[ INFO ] Count:            55 iterations
[ INFO ] Duration:         16094.50 ms
[ INFO ] Latency:
[ INFO ]    Median:        1735.52 ms
[ INFO ]    Average:       1715.26 ms
[ INFO ]    Min:           507.00 ms
[ INFO ]    Max:           1794.23 ms
[ INFO ] Throughput:   3.42 FPS

インタラクティブなセグメント化のデモ

import copy
import gradio as gr
import numpy as np
from PIL import ImageDraw, Image
import cv2
import matplotlib.pyplot as plt

example_images = [
    "https://github.com/openvinotoolkit/openvino_notebooks/assets/29454499/b8083dd5-1ce7-43bf-8b09-a2ebc280c86e",
    "https://github.com/openvinotoolkit/openvino_notebooks/assets/29454499/9a90595d-70e7-469b-bdaf-469ef4f56fa2",
    "https://github.com/openvinotoolkit/openvino_notebooks/assets/29454499/b626c123-9fa2-4aa6-9929-30565991bf0c",
]

examples_dir = Path("examples")
examples_dir.mkdir(exist_ok=True)

for img_id, image_url in enumerate(example_images):
    urllib.request.urlretrieve(image_url, examples_dir / f"example_{img_id}.jpg")


def sigmoid(x):
    return 1 / (1 + np.exp(-x))


def clear():
    return None, None, [], []


def format_results(masks, scores, logits, filter=0):
    annotations = []
    n = len(scores)
    for i in range(n):
        annotation = {}

        mask = masks[i]
        tmp = np.where(mask != 0)
        if np.sum(mask) < filter:
            continue
        annotation["id"] = i
        annotation["segmentation"] = mask
        annotation["bbox"] = [np.min(tmp[0]), np.min(tmp[1]), np.max(tmp[1]), np.max(tmp[0])]
        annotation["score"] = scores[i]
        annotation["area"] = annotation["segmentation"].sum()
        annotations.append(annotation)
    return annotations


def point_prompt(masks, points, point_label, target_height, target_width):  # numpy
    h = masks[0]["segmentation"].shape[0]
    w = masks[0]["segmentation"].shape[1]
    if h != target_height or w != target_width:
        points = [
            [int(point[0] * w / target_width), int(point[1] * h / target_height)]
            for point in points
        ]
    onemask = np.zeros((h, w))
    for i, annotation in enumerate(masks):
        if isinstance(annotation, dict):
            mask = annotation["segmentation"]
        else:
            mask = annotation
        for i, point in enumerate(points):
            if point[1] < mask.shape[0] and point[0] < mask.shape[1]:
                if mask[point[1], point[0]] == 1 and point_label[i] == 1:
                    onemask += mask
                if mask[point[1], point[0]] == 1 and point_label[i] == 0:
                    onemask -= mask
    onemask = onemask >= 1
    return onemask, 0


def show_mask(
    annotation,
    ax,
    random_color=False,
    bbox=None,
    retinamask=True,
    target_height=960,
    target_width=960,
):
    mask_sum = annotation.shape[0]
    height = annotation.shape[1]
    weight = annotation.shape[2]
    # annotation is sorted by area
    areas = np.sum(annotation, axis=(1, 2))
    sorted_indices = np.argsort(areas)[::1]
    annotation = annotation[sorted_indices]

    index = (annotation != 0).argmax(axis=0)
    if random_color:
        color = np.random.random((mask_sum, 1, 1, 3))
    else:
        color = np.ones((mask_sum, 1, 1, 3)) * np.array([30 / 255, 144 / 255, 255 / 255])
    transparency = np.ones((mask_sum, 1, 1, 1)) * 0.6
    visual = np.concatenate([color, transparency], axis=-1)
    mask_image = np.expand_dims(annotation, -1) * visual

    mask = np.zeros((height, weight, 4))

    h_indices, w_indices = np.meshgrid(
        np.arange(height), np.arange(weight), indexing="ij"
    )
    indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))

    mask[h_indices, w_indices, :] = mask_image[indices]
    if bbox is not None:
        x1, y1, x2, y2 = bbox
        ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor="b", linewidth=1))

    if not retinamask:
        mask = cv2.resize(mask, (target_width, target_height), interpolation=cv2.INTER_NEAREST)

    return mask


def process(
    annotations,
    image,
    scale,
    better_quality=False,
    mask_random_color=True,
    bbox=None,
    points=None,
    use_retina=True,
    withContours=True,
):
    if isinstance(annotations[0], dict):
        annotations = [annotation["segmentation"] for annotation in annotations]

    original_h = image.height
    original_w = image.width
    if better_quality:
        if isinstance(annotations[0], torch.Tensor):
            annotations = np.array(annotations)
        for i, mask in enumerate(annotations):
            mask = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((3, 3), np.uint8))
            annotations[i] = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_OPEN, np.ones((8, 8), np.uint8))
    annotations = np.array(annotations)
    inner_mask = show_mask(
        annotations,
        plt.gca(),
        random_color=mask_random_color,
        bbox=bbox,
        retinamask=use_retina,
        target_height=original_h,
        target_width=original_w,
    )

    if isinstance(annotations, torch.Tensor):
        annotations = annotations.cpu().numpy()

    if withContours:
        contour_all = []
        temp = np.zeros((original_h, original_w, 1))
        for i, mask in enumerate(annotations):
            if isinstance(mask, dict):
                mask = mask["segmentation"]
            annotation = mask.astype(np.uint8)
            if not use_retina:
                annotation = cv2.resize(
                    annotation,
                    (original_w, original_h),
                    interpolation=cv2.INTER_NEAREST,
                )
            contours, _ = cv2.findContours(
                annotation, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE
            )
            for contour in contours:
                contour_all.append(contour)
        cv2.drawContours(temp, contour_all, -1, (255, 255, 255), 2 // scale)
        color = np.array([0 / 255, 0 / 255, 255 / 255, 0.9])
        contour_mask = temp / 255 * color.reshape(1, 1, -1)

    image = image.convert("RGBA")
    overlay_inner = Image.fromarray((inner_mask * 255).astype(np.uint8), "RGBA")
    image.paste(overlay_inner, (0, 0), overlay_inner)

    if withContours:
        overlay_contour = Image.fromarray((contour_mask * 255).astype(np.uint8), "RGBA")
        image.paste(overlay_contour, (0, 0), overlay_contour)

    return image



# Description
title = "<center><strong><font size='8'>Efficient Segment Anything with OpenVINO and EfficientSAM <font></strong></center>"


description_p = """# Interactive Instance Segmentation
                - Point-prompt instruction
                <ol>
                <li> Click on the left image (point input), visualizing the point on the right image </li>
                <li> Click the button of Segment with Point Prompt </li>
                </ol>
                - Box-prompt instruction
                <ol>
                <li> Click on the left image (one point input), visualizing the point on the right image </li>
                <li> Click on the left image (another point input), visualizing the point and the box on the right image</li>
                <li> Click the button of Segment with Box Prompt </li>
                </ol>
              """

# examples
examples = [[img] for img in examples_dir.glob("*.jpg")]

default_example = examples[0]

css = "h1 { text-align: center } .about { text-align: justify; padding-left: 10%; padding-right: 10%; }"


def segment_with_boxs(
    image,
    seg_image,
    global_points,
    global_point_label,
    input_size=1024,
    better_quality=False,
    withContours=True,
    use_retina=True,
    mask_random_color=True,
):
    if global_points is None or len(global_points) < 2 or global_points[0] is None:
        return image, global_points, global_point_label

    input_size = int(input_size)
    w, h = image.size
    scale = input_size / max(w, h)
    new_w = int(w * scale)
    new_h = int(h * scale)
    image = image.resize((new_w, new_h))

    scaled_points = np.array([[int(x * scale) for x in point] for point in global_points])
    scaled_points = scaled_points[:2]
    scaled_point_label = np.array(global_point_label)[:2]

    if scaled_points.size == 0 and scaled_point_label.size == 0:
        return image, global_points, global_point_label

    nd_image = np.array(image)
    img_tensor = nd_image.astype(np.float32) / 255
    img_tensor = np.transpose(img_tensor, (2, 0, 1))

    pts_sampled = np.reshape(scaled_points, [1, 1, -1, 2])
    pts_sampled = pts_sampled[:, :, :2, :]
    pts_labels = np.reshape(np.array([2, 3]), [1, 1, 2])

    results = compiled_model([img_tensor[None, ...], pts_sampled, pts_labels])
    predicted_logits = results[0]
    predicted_iou = results[1]
    all_masks = sigmoid(predicted_logits[0, 0, :, :, :]) >= 0.5
    predicted_iou = predicted_iou[0, 0, ...]


    max_predicted_iou = -1
    selected_mask_using_predicted_iou = None
    selected_predicted_iou = None

    for m in range(all_masks.shape[0]):
        curr_predicted_iou = predicted_iou[m]
        if curr_predicted_iou > max_predicted_iou or selected_mask_using_predicted_iou is None:
            max_predicted_iou = curr_predicted_iou
            selected_mask_using_predicted_iou = all_masks[m:m + 1]
            selected_predicted_iou = predicted_iou[m:m + 1]

    results = format_results(selected_mask_using_predicted_iou, selected_predicted_iou, predicted_logits, 0)

    annotations = results[0]["segmentation"]
    annotations = np.array([annotations])
    fig = process(
        annotations=annotations,
        image=image,
        scale=(1024 // input_size),
        better_quality=better_quality,
        mask_random_color=mask_random_color,
        use_retina=use_retina,
        bbox=scaled_points.reshape([4]),
        withContours=withContours,
    )

    global_points = []
    global_point_label = []
    return fig, global_points, global_point_label


def segment_with_points(
    image,
    global_points,
    global_point_label,
    input_size=1024,
    better_quality=False,
    withContours=True,
    use_retina=True,
    mask_random_color=True,
):
    input_size = int(input_size)
    w, h = image.size
    scale = input_size / max(w, h)
    new_w = int(w * scale)
    new_h = int(h * scale)
    image = image.resize((new_w, new_h))

    if global_points is None or len(global_points) < 1 or global_points[0] is None:
        return image, global_points, global_point_label
    scaled_points = np.array([[int(x * scale) for x in point] for point in global_points])
    scaled_point_label = np.array(global_point_label)

    if scaled_points.size == 0 and scaled_point_label.size == 0:
        return image, global_points, global_point_label

    nd_image = np.array(image)
    img_tensor = (nd_image).astype(np.float32) / 255
    img_tensor = np.transpose(img_tensor, (2, 0, 1))

    pts_sampled = np.reshape(scaled_points, [1, 1, -1, 2])
    pts_labels = np.reshape(np.array(global_point_label), [1, 1, -1])

    results = compiled_model([img_tensor[None, ...], pts_sampled, pts_labels])
    predicted_logits = results[0]
    predicted_iou = results[1]
    all_masks = sigmoid(predicted_logits[0, 0, :, :, :]) >= 0.5
    predicted_iou = predicted_iou[0, 0, ...]

    results = format_results(all_masks, predicted_iou, predicted_logits, 0)
    annotations, _ = point_prompt(results, scaled_points, scaled_point_label, new_h, new_w)
    annotations = np.array([annotations])

    fig = process(
        annotations=annotations,
        image=image,
        scale=(1024 // input_size),
        better_quality=better_quality,
        mask_random_color=mask_random_color,
        points=scaled_points,
        bbox=None,
        use_retina=use_retina,
        withContours=withContours,
    )

    global_points = []
    global_point_label = []
    # return fig, None
    return fig, global_points, global_point_label


def get_points_with_draw(image, cond_image, global_points, global_point_label, evt: gr.SelectData):
    print(global_points)
    if len(global_points) == 0:
        image = copy.deepcopy(cond_image)
    x, y = evt.index[0], evt.index[1]
    label = "Add Mask"
    point_radius, point_color = 15, (255, 255, 0) if label == "Add Mask" else (255, 0, 255)
    global_points.append([x, y])
    global_point_label.append(1 if label == "Add Mask" else 0)

    if image is not None:
        draw = ImageDraw.Draw(image)

        draw.ellipse([(x - point_radius, y - point_radius), (x + point_radius, y + point_radius)], fill=point_color)

    return image, global_points, global_point_label


def get_points_with_draw_(image, cond_image, global_points, global_point_label, evt: gr.SelectData):
    if len(global_points) == 0:
        image = copy.deepcopy(cond_image)
    if len(global_points) > 2:
        return image, global_points, global_point_label
    x, y = evt.index[0], evt.index[1]
    label = "Add Mask"
    point_radius, point_color = 15, (255, 255, 0) if label == "Add Mask" else (255, 0, 255)
    global_points.append([x, y])
    global_point_label.append(1 if label == "Add Mask" else 0)

    if image is not None:
        draw = ImageDraw.Draw(image)
        draw.ellipse([(x - point_radius, y - point_radius), (x + point_radius, y + point_radius)], fill=point_color)

    if len(global_points) == 2:
        x1, y1 = global_points[0]
        x2, y2 = global_points[1]
        if x1 < x2 and y1 < y2:
            draw.rectangle([x1, y1, x2, y2], outline="red", width=5)
        elif x1 < x2 and y1 >= y2:
            draw.rectangle([x1, y2, x2, y1], outline="red", width=5)
            global_points[0][0] = x1
            global_points[0][1] = y2
            global_points[1][0] = x2
            global_points[1][1] = y1
        elif x1 >= x2 and y1 < y2:
            draw.rectangle([x2, y1, x1, y2], outline="red", width=5)
            global_points[0][0] = x2
            global_points[0][1] = y1
            global_points[1][0] = x1
            global_points[1][1] = y2
        elif x1 >= x2 and y1 >= y2:
            draw.rectangle([x2, y2, x1, y1], outline="red", width=5)
            global_points[0][0] = x2
            global_points[0][1] = y2
            global_points[1][0] = x1
            global_points[1][1] = y1

    return image, global_points, global_point_label


cond_img_p = gr.Image(label="Input with Point", value=default_example[0], type="pil")
cond_img_b = gr.Image(label="Input with Box", value=default_example[0], type="pil")

segm_img_p = gr.Image(label="Segmented Image with Point-Prompt", interactive=False, type="pil")
segm_img_b = gr.Image(label="Segmented Image with Box-Prompt", interactive=False, type="pil")


with gr.Blocks(css=css, title="Efficient SAM") as demo:
    global_points = gr.State([])
    global_point_label = gr.State([])
    with gr.Row():
        with gr.Column(scale=1):
            # Title
            gr.Markdown(title)

    with gr.Tab("Point mode"):
        # Images
        with gr.Row(variant="panel"):
            with gr.Column(scale=1):
                cond_img_p.render()

            with gr.Column(scale=1):
                segm_img_p.render()

        # Submit & Clear
        # ###
        with gr.Row():
            with gr.Column():

                with gr.Column():
                    segment_btn_p = gr.Button(
                        "Segment with Point Prompt", variant="primary"
                    )
                    clear_btn_p = gr.Button("Clear", variant="secondary")

                gr.Markdown("Try some of the examples below ⬇️")
                gr.Examples(
                    examples=examples,
                    inputs=[cond_img_p],
                    examples_per_page=4,
                )

            with gr.Column():
                # Description
                gr.Markdown(description_p)

    with gr.Tab("Box mode"):
        # Images
        with gr.Row(variant="panel"):
            with gr.Column(scale=1):
                cond_img_b.render()

            with gr.Column(scale=1):
                segm_img_b.render()

        # Submit & Clear
        with gr.Row():
            with gr.Column():

                with gr.Column():
                    segment_btn_b = gr.Button(
                        "Segment with Box Prompt", variant="primary"
                    )
                    clear_btn_b = gr.Button("Clear", variant="secondary")

                gr.Markdown("Try some of the examples below ⬇️")
                gr.Examples(
                    examples=examples,
                    inputs=[cond_img_b],

                    examples_per_page=4,
                )

            with gr.Column():
                # Description
                gr.Markdown(description_p)

    cond_img_p.select(get_points_with_draw, inputs=[segm_img_p, cond_img_p, global_points, global_point_label], outputs=[segm_img_p, global_points, global_point_label])

    cond_img_b.select(get_points_with_draw_, [segm_img_b, cond_img_b, global_points, global_point_label], [segm_img_b, global_points, global_point_label])

    segment_btn_p.click(
        segment_with_points, inputs=[cond_img_p, global_points, global_point_label], outputs=[segm_img_p, global_points, global_point_label]
    )

    segment_btn_b.click(
        segment_with_boxs, inputs=[cond_img_b, segm_img_b, global_points, global_point_label], outputs=[segm_img_b, global_points, global_point_label]
    )

    clear_btn_p.click(clear, outputs=[cond_img_p, segm_img_p, global_points, global_point_label])
    clear_btn_b.click(clear, outputs=[cond_img_b, segm_img_b, global_points, global_point_label])

demo.queue()
try:
    demo.launch(debug=False)
except Exception:
    demo.launch(share=True, debug=False)
# if you are launching remotely, specify server_name and server_port
# demo.launch(server_name='your server name', server_port='server port in int')
# Read more in the docs: https://gradio.app/docs/
Running on local URL:  http://127.0.0.1:7860

To create a public link, set share=True in launch().