FastSAM と OpenVINO によるオブジェクトのセグメント化

この Jupyter ノートブックはオンラインで起動でき、ブラウザーのウィンドウで対話型環境を開きます。ローカルにインストールすることもできます。次のオプションのいずれかを選択します。

Binder Google Colab GitHub

Fast Segment Anything Model (FastSAM) は、さまざまなユーザープロンプトに基づいて画像内の任意のオブジェクトをセグメント化できるリアルタイム CNN ベースのモデルです。 Segment Anything タスクは、画像内のオブジェクトを効率的に識別する方法を提供することで、視覚タスクを容易にするように設計されています。FastSAM は、競争力のあるパフォーマンスを維持しながら計算負荷を大幅に削減する、さまざまなビジョンタスクで実用的な選択肢となります。

FastSAM は、膨大な計算リソースを必要とするトランスフォーマー・モデルである Segment Anything Model (SAM) の制限を克服することを目的としたモデルです。FastSAM は、セグメント化タスクを全インスタンスのセグメント化とプロンプトガイドによる選択の 2 つの連続したステージに分けて処理します。

最初のステージでは、YOLOv8-seg を使用して、画像内のすべてのインスタンスのセグメント化マスクを生成します。第 2 ステージでは、FastSAM はプロンプトに対応する関心領域を出力します。

pipeline

パイプライン

目次

必要条件

要件をインストール

%pip install -q "ultralytics==8.0.200" onnx --extra-index-url https://download.pytorch.org/whl/cpu
%pip install -q "openvino-dev>=2023.1.0"
%pip install -q "nncf>=2.6.0"
%pip install -q "gradio>=4.13"
DEPRECATION: pytorch-lightning 1.6.5 has a non-standard dependency specifier torch>=1.8.*. pip 24.1 will enforce this behaviour change. A possible replacement is to upgrade to a newer version of pytorch-lightning or contact the author to suggest that they release a version with a conforming dependency specifiers. Discussion can be found at https://github.com/pypa/pip/issues/12063
Note: you may need to restart the kernel to use updated packages.
DEPRECATION: pytorch-lightning 1.6.5 has a non-standard dependency specifier torch>=1.8.*. pip 24.1 will enforce this behaviour change. A possible replacement is to upgrade to a newer version of pytorch-lightning or contact the author to suggest that they release a version with a conforming dependency specifiers. Discussion can be found at https://github.com/pypa/pip/issues/12063
Note: you may need to restart the kernel to use updated packages.
DEPRECATION: pytorch-lightning 1.6.5 has a non-standard dependency specifier torch>=1.8.*. pip 24.1 will enforce this behaviour change. A possible replacement is to upgrade to a newer version of pytorch-lightning or contact the author to suggest that they release a version with a conforming dependency specifiers. Discussion can be found at https://github.com/pypa/pip/issues/12063
Note: you may need to restart the kernel to use updated packages.
DEPRECATION: pytorch-lightning 1.6.5 has a non-standard dependency specifier torch>=1.8.*. pip 24.1 will enforce this behaviour change. A possible replacement is to upgrade to a newer version of pytorch-lightning or contact the author to suggest that they release a version with a conforming dependency specifiers. Discussion can be found at https://github.com/pypa/pip/issues/12063
Note: you may need to restart the kernel to use updated packages.

インポート

import ipywidgets as widgets
from pathlib import Path

import openvino as ov
import torch
from PIL import Image, ImageDraw
from ultralytics import FastSAM

import urllib.request
# Fetch skip_kernel_extension module
urllib.request.urlretrieve(
    url='https://raw.githubusercontent.com/openvinotoolkit/openvino_notebooks/main/notebooks/utils/skip_kernel_extension.py',
    filename='skip_kernel_extension.py'
)
# Fetch `notebook_utils` module
urllib.request.urlretrieve(
    url='https://raw.githubusercontent.com/openvinotoolkit/openvino_notebooks/main/notebooks/utils/notebook_utils.py',
    filename='notebook_utils.py'
)
from notebook_utils import download_file
%load_ext skip_kernel_extension

Ultralytics の FastSAM

CASIA-IVA-LabFast Segment Anything Model を使用するには、Ultralytics パッケージを使用します。Ultralytics パッケージは FastSAM クラスを公開し、モデルのインスタンス化と重みの読み込みを簡素化します。以下のコードは、FastSAM モデルを初期化し、セグメント化マップを生成する方法を示しています。

model_name = "FastSAM-x"
model = FastSAM(model_name)

# Run inference on an image
image_uri = "https://storage.openvinotoolkit.org/repositories/openvino_notebooks/data/data/image/coco_bike.jpg"
image_uri = download_file(image_uri)
results = model(image_uri, device="cpu", retina_masks=True, imgsz=1024, conf=0.6, iou=0.9)
Downloading https://github.com/ultralytics/assets/releases/download/v0.0.0/FastSAM-x.pt to 'FastSAM-x.pt'...
0%|          | 0.00/138M [00:00<?, ?B/s]
0%|          | 144k/138M [00:00<01:38, 1.47MB/s]
0%|          | 496k/138M [00:00<00:53, 2.70MB/s]
1%|          | 880k/138M [00:00<00:44, 3.21MB/s]
1%|          | 1.23M/138M [00:00<00:41, 3.43MB/s]
1%|          | 1.61M/138M [00:00<00:40, 3.54MB/s]
1%|▏         | 1.97M/138M [00:00<00:39, 3.62MB/s]
2%|▏         | 2.34M/138M [00:00<00:38, 3.68MB/s]
2%|▏         | 2.72M/138M [00:00<00:38, 3.72MB/s]
2%|▏         | 3.09M/138M [00:00<00:38, 3.73MB/s]
3%|▎         | 3.47M/138M [00:01<00:37, 3.73MB/s]
3%|▎         | 3.84M/138M [00:01<00:37, 3.76MB/s]
3%|▎         | 4.22M/138M [00:01<00:37, 3.75MB/s]
3%|▎         | 4.59M/138M [00:01<00:37, 3.77MB/s]
4%|▎         | 4.97M/138M [00:01<00:37, 3.76MB/s]
4%|▍         | 5.34M/138M [00:01<00:36, 3.77MB/s]
4%|▍         | 5.72M/138M [00:01<00:36, 3.77MB/s]
4%|▍         | 6.09M/138M [00:01<00:36, 3.79MB/s]
5%|▍         | 6.45M/138M [00:01<00:36, 3.76MB/s]
5%|▍         | 6.83M/138M [00:01<00:36, 3.77MB/s]
5%|▌         | 7.20M/138M [00:02<00:36, 3.76MB/s]
5%|▌         | 7.58M/138M [00:02<00:36, 3.77MB/s]
6%|▌         | 7.95M/138M [00:02<00:36, 3.79MB/s]
6%|▌         | 8.31M/138M [00:02<00:36, 3.75MB/s]
6%|▋         | 8.69M/138M [00:02<00:36, 3.77MB/s]
7%|▋         | 9.06M/138M [00:02<00:36, 3.76MB/s]
7%|▋         | 9.44M/138M [00:02<00:35, 3.77MB/s]
7%|▋         | 9.81M/138M [00:02<00:35, 3.78MB/s]
7%|▋         | 10.2M/138M [00:02<00:35, 3.80MB/s]
8%|▊         | 10.5M/138M [00:02<00:35, 3.76MB/s]
8%|▊         | 10.9M/138M [00:03<00:35, 3.76MB/s]
8%|▊         | 11.3M/138M [00:03<00:35, 3.77MB/s]
8%|▊         | 11.7M/138M [00:03<00:35, 3.76MB/s]
9%|▊         | 12.0M/138M [00:03<00:35, 3.76MB/s]
9%|▉         | 12.4M/138M [00:03<00:35, 3.76MB/s]
9%|▉         | 12.8M/138M [00:03<00:35, 3.76MB/s]
10%|▉         | 13.2M/138M [00:03<00:34, 3.77MB/s]
10%|▉         | 13.5M/138M [00:03<00:34, 3.76MB/s]
10%|█         | 13.9M/138M [00:03<00:34, 3.78MB/s]
10%|█         | 14.3M/138M [00:04<00:34, 3.77MB/s]
11%|█         | 14.7M/138M [00:04<00:34, 3.76MB/s]
11%|█         | 15.0M/138M [00:04<00:34, 3.77MB/s]
11%|█         | 15.4M/138M [00:04<00:34, 3.77MB/s]
11%|█▏        | 15.8M/138M [00:04<00:34, 3.78MB/s]
12%|█▏        | 16.2M/138M [00:04<00:33, 3.77MB/s]
12%|█▏        | 16.5M/138M [00:04<00:33, 3.78MB/s]
12%|█▏        | 16.9M/138M [00:04<00:33, 3.80MB/s]
13%|█▎        | 17.3M/138M [00:04<00:33, 3.78MB/s]
13%|█▎        | 17.7M/138M [00:04<00:33, 3.77MB/s]
13%|█▎        | 18.0M/138M [00:05<00:33, 3.78MB/s]
13%|█▎        | 18.4M/138M [00:05<00:33, 3.77MB/s]
14%|█▎        | 18.8M/138M [00:05<00:33, 3.76MB/s]
14%|█▍        | 19.2M/138M [00:05<00:33, 3.77MB/s]
14%|█▍        | 19.5M/138M [00:05<00:33, 3.77MB/s]
14%|█▍        | 19.9M/138M [00:05<00:32, 3.78MB/s]
15%|█▍        | 20.3M/138M [00:05<00:32, 3.77MB/s]
15%|█▍        | 20.7M/138M [00:05<00:32, 3.78MB/s]
15%|█▌        | 21.0M/138M [00:05<00:32, 3.77MB/s]
15%|█▌        | 21.4M/138M [00:06<00:32, 3.78MB/s]
16%|█▌        | 21.8M/138M [00:06<00:32, 3.77MB/s]
16%|█▌        | 22.2M/138M [00:06<00:32, 3.76MB/s]
16%|█▋        | 22.5M/138M [00:06<00:32, 3.76MB/s]
17%|█▋        | 22.9M/138M [00:06<00:32, 3.75MB/s]
17%|█▋        | 23.3M/138M [00:06<00:31, 3.80MB/s]
17%|█▋        | 23.7M/138M [00:06<00:32, 3.75MB/s]
17%|█▋        | 24.0M/138M [00:06<00:31, 3.77MB/s]
18%|█▊        | 24.4M/138M [00:06<00:31, 3.76MB/s]
18%|█▊        | 24.8M/138M [00:06<00:31, 3.78MB/s]
18%|█▊        | 25.2M/138M [00:07<00:31, 3.77MB/s]
18%|█▊        | 25.5M/138M [00:07<00:31, 3.74MB/s]
19%|█▊        | 25.9M/138M [00:07<00:31, 3.76MB/s]
19%|█▉        | 26.3M/138M [00:07<00:31, 3.78MB/s]
19%|█▉        | 26.6M/138M [00:07<00:31, 3.77MB/s]
20%|█▉        | 27.0M/138M [00:07<00:31, 3.76MB/s]
20%|█▉        | 27.4M/138M [00:07<00:30, 3.76MB/s]
20%|██        | 27.8M/138M [00:07<00:30, 3.75MB/s]
20%|██        | 28.1M/138M [00:07<00:30, 3.75MB/s]
21%|██        | 28.5M/138M [00:07<00:30, 3.79MB/s]
21%|██        | 28.9M/138M [00:08<00:30, 3.75MB/s]
21%|██        | 29.3M/138M [00:08<00:30, 3.77MB/s]
21%|██▏       | 29.6M/138M [00:08<00:30, 3.72MB/s]
22%|██▏       | 30.0M/138M [00:08<00:30, 3.74MB/s]
22%|██▏       | 30.4M/138M [00:08<00:30, 3.75MB/s]
22%|██▏       | 30.8M/138M [00:08<00:30, 3.75MB/s]
23%|██▎       | 31.1M/138M [00:08<00:29, 3.75MB/s]
23%|██▎       | 31.5M/138M [00:08<00:29, 3.75MB/s]
23%|██▎       | 31.9M/138M [00:08<00:29, 3.75MB/s]
23%|██▎       | 32.3M/138M [00:09<00:29, 3.75MB/s]
24%|██▎       | 32.6M/138M [00:09<00:29, 3.75MB/s]
24%|██▍       | 33.0M/138M [00:09<00:29, 3.75MB/s]
24%|██▍       | 33.4M/138M [00:09<00:29, 3.77MB/s]
24%|██▍       | 33.8M/138M [00:09<00:29, 3.76MB/s]
25%|██▍       | 34.1M/138M [00:09<00:28, 3.77MB/s]
25%|██▍       | 34.5M/138M [00:09<00:28, 3.76MB/s]
25%|██▌       | 34.9M/138M [00:09<00:28, 3.78MB/s]
26%|██▌       | 35.3M/138M [00:09<00:28, 3.79MB/s]
26%|██▌       | 35.6M/138M [00:09<00:28, 3.75MB/s]
26%|██▌       | 36.0M/138M [00:10<00:28, 3.75MB/s]
26%|██▋       | 36.4M/138M [00:10<00:28, 3.77MB/s]
27%|██▋       | 36.7M/138M [00:10<00:28, 3.76MB/s]
27%|██▋       | 37.1M/138M [00:10<00:28, 3.77MB/s]
27%|██▋       | 37.5M/138M [00:10<00:27, 3.78MB/s]
27%|██▋       | 37.9M/138M [00:10<00:27, 3.77MB/s]
28%|██▊       | 38.2M/138M [00:10<00:27, 3.78MB/s]
28%|██▊       | 38.6M/138M [00:10<00:27, 3.77MB/s]
28%|██▊       | 39.0M/138M [00:10<00:27, 3.76MB/s]
28%|██▊       | 39.4M/138M [00:11<00:27, 3.77MB/s]
29%|██▊       | 39.7M/138M [00:11<00:27, 3.74MB/s]
29%|██▉       | 40.1M/138M [00:11<00:27, 3.74MB/s]
29%|██▉       | 40.5M/138M [00:11<00:27, 3.76MB/s]
30%|██▉       | 40.9M/138M [00:11<00:26, 3.79MB/s]
30%|██▉       | 41.2M/138M [00:11<00:27, 3.76MB/s]
30%|███       | 41.6M/138M [00:11<00:26, 3.77MB/s]
30%|███       | 42.0M/138M [00:11<00:26, 3.77MB/s]
31%|███       | 42.4M/138M [00:11<00:26, 3.76MB/s]
31%|███       | 42.7M/138M [00:11<00:26, 3.77MB/s]
31%|███       | 43.1M/138M [00:12<00:26, 3.76MB/s]
31%|███▏      | 43.5M/138M [00:12<00:26, 3.76MB/s]
32%|███▏      | 43.9M/138M [00:12<00:26, 3.77MB/s]
32%|███▏      | 44.2M/138M [00:12<00:26, 3.70MB/s]
32%|███▏      | 44.6M/138M [00:12<00:26, 3.73MB/s]
33%|███▎      | 45.0M/138M [00:12<00:26, 3.74MB/s]
33%|███▎      | 45.4M/138M [00:12<00:26, 3.72MB/s]
33%|███▎      | 45.7M/138M [00:12<00:25, 3.75MB/s]
33%|███▎      | 46.1M/138M [00:12<00:25, 3.75MB/s]
34%|███▎      | 46.5M/138M [00:13<00:25, 3.76MB/s]
34%|███▍      | 46.8M/138M [00:13<00:25, 3.75MB/s]
34%|███▍      | 47.2M/138M [00:13<00:25, 3.75MB/s]
34%|███▍      | 47.6M/138M [00:13<00:25, 3.76MB/s]
35%|███▍      | 48.0M/138M [00:13<00:25, 3.76MB/s]
35%|███▍      | 48.3M/138M [00:13<00:25, 3.77MB/s]
35%|███▌      | 48.7M/138M [00:13<00:24, 3.76MB/s]
36%|███▌      | 49.1M/138M [00:13<00:24, 3.78MB/s]
36%|███▌      | 49.5M/138M [00:13<00:24, 3.77MB/s]
36%|███▌      | 49.8M/138M [00:13<00:24, 3.76MB/s]
36%|███▋      | 50.2M/138M [00:14<00:24, 3.77MB/s]
37%|███▋      | 50.6M/138M [00:14<00:24, 3.77MB/s]
37%|███▋      | 51.0M/138M [00:14<00:24, 3.76MB/s]
37%|███▋      | 51.3M/138M [00:14<00:24, 3.79MB/s]
37%|███▋      | 51.7M/138M [00:14<00:24, 3.77MB/s]
38%|███▊      | 52.1M/138M [00:14<00:23, 3.76MB/s]
38%|███▊      | 52.5M/138M [00:14<00:23, 3.77MB/s]
38%|███▊      | 52.8M/138M [00:14<00:23, 3.76MB/s]
39%|███▊      | 53.2M/138M [00:14<00:23, 3.79MB/s]
39%|███▉      | 53.6M/138M [00:14<00:23, 3.75MB/s]
39%|███▉      | 54.0M/138M [00:15<00:23, 3.77MB/s]
39%|███▉      | 54.3M/138M [00:15<00:23, 3.76MB/s]
40%|███▉      | 54.7M/138M [00:15<00:23, 3.76MB/s]
40%|███▉      | 55.1M/138M [00:15<00:23, 3.75MB/s]
40%|████      | 55.5M/138M [00:15<00:23, 3.77MB/s]
40%|████      | 55.8M/138M [00:15<00:22, 3.76MB/s]
41%|████      | 56.2M/138M [00:15<00:22, 3.76MB/s]
41%|████      | 56.6M/138M [00:15<00:22, 3.79MB/s]
41%|████      | 57.0M/138M [00:15<00:22, 3.75MB/s]
41%|████▏     | 57.3M/138M [00:16<00:22, 3.77MB/s]
42%|████▏     | 57.7M/138M [00:16<00:22, 3.76MB/s]
42%|████▏     | 58.1M/138M [00:16<00:22, 3.76MB/s]
42%|████▏     | 58.5M/138M [00:16<00:22, 3.77MB/s]
43%|████▎     | 58.8M/138M [00:16<00:21, 3.80MB/s]
43%|████▎     | 59.2M/138M [00:16<00:21, 3.78MB/s]
43%|████▎     | 59.6M/138M [00:16<00:21, 3.77MB/s]
43%|████▎     | 60.0M/138M [00:16<00:21, 3.78MB/s]
44%|████▎     | 60.3M/138M [00:16<00:21, 3.77MB/s]
44%|████▍     | 60.7M/138M [00:16<00:21, 3.78MB/s]
44%|████▍     | 61.1M/138M [00:17<00:21, 3.77MB/s]
44%|████▍     | 61.5M/138M [00:17<00:21, 3.78MB/s]
45%|████▍     | 61.8M/138M [00:17<00:21, 3.79MB/s]
45%|████▌     | 62.2M/138M [00:17<00:21, 3.77MB/s]
45%|████▌     | 62.6M/138M [00:17<00:21, 3.76MB/s]
46%|████▌     | 63.0M/138M [00:17<00:20, 3.77MB/s]
46%|████▌     | 63.3M/138M [00:17<00:20, 3.77MB/s]
46%|████▌     | 63.7M/138M [00:17<00:20, 3.78MB/s]
46%|████▋     | 64.1M/138M [00:17<00:20, 3.77MB/s]
47%|████▋     | 64.4M/138M [00:18<00:20, 3.77MB/s]
47%|████▋     | 64.8M/138M [00:18<00:20, 3.76MB/s]
47%|████▋     | 65.2M/138M [00:18<00:20, 3.79MB/s]
47%|████▋     | 65.6M/138M [00:18<00:20, 3.78MB/s]
48%|████▊     | 65.9M/138M [00:18<00:20, 3.78MB/s]
48%|████▊     | 66.3M/138M [00:18<00:19, 3.77MB/s]
48%|████▊     | 66.7M/138M [00:18<00:19, 3.78MB/s]
49%|████▊     | 67.1M/138M [00:18<00:19, 3.80MB/s]
49%|████▉     | 67.4M/138M [00:18<00:19, 3.76MB/s]
49%|████▉     | 67.8M/138M [00:18<00:19, 3.76MB/s]
49%|████▉     | 68.2M/138M [00:19<00:19, 3.77MB/s]
50%|████▉     | 68.6M/138M [00:19<00:19, 3.76MB/s]
50%|████▉     | 68.9M/138M [00:19<00:19, 3.77MB/s]
50%|█████     | 69.3M/138M [00:19<00:19, 3.76MB/s]
50%|█████     | 69.7M/138M [00:19<00:19, 3.77MB/s]
51%|█████     | 70.1M/138M [00:19<00:18, 3.77MB/s]
51%|█████     | 70.4M/138M [00:19<00:18, 3.75MB/s]
51%|█████     | 70.8M/138M [00:19<00:18, 3.75MB/s]
51%|█████▏    | 71.2M/138M [00:19<00:18, 3.77MB/s]
52%|█████▏    | 71.6M/138M [00:19<00:18, 3.76MB/s]
52%|█████▏    | 71.9M/138M [00:20<00:18, 3.77MB/s]
52%|█████▏    | 72.3M/138M [00:20<00:18, 3.79MB/s]
53%|█████▎    | 72.7M/138M [00:20<00:18, 3.75MB/s]
53%|█████▎    | 73.0M/138M [00:20<00:18, 3.77MB/s]
53%|█████▎    | 73.4M/138M [00:20<00:18, 3.76MB/s]
53%|█████▎    | 73.8M/138M [00:20<00:17, 3.75MB/s]
54%|█████▎    | 74.2M/138M [00:20<00:17, 3.75MB/s]
54%|█████▍    | 74.5M/138M [00:20<00:17, 3.77MB/s]
54%|█████▍    | 74.9M/138M [00:20<00:17, 3.76MB/s]
54%|█████▍    | 75.3M/138M [00:21<00:17, 3.75MB/s]
55%|█████▍    | 75.7M/138M [00:21<00:17, 3.75MB/s]
55%|█████▌    | 76.0M/138M [00:21<00:17, 3.75MB/s]
55%|█████▌    | 76.4M/138M [00:21<00:17, 3.76MB/s]
56%|█████▌    | 76.8M/138M [00:21<00:17, 3.76MB/s]
56%|█████▌    | 77.2M/138M [00:21<00:16, 3.77MB/s]
56%|█████▌    | 77.5M/138M [00:21<00:16, 3.76MB/s]
56%|█████▋    | 77.9M/138M [00:21<00:16, 3.78MB/s]
57%|█████▋    | 78.3M/138M [00:21<00:16, 3.76MB/s]
57%|█████▋    | 78.7M/138M [00:21<00:16, 3.78MB/s]
57%|█████▋    | 79.0M/138M [00:22<00:16, 3.79MB/s]
57%|█████▋    | 79.4M/138M [00:22<00:16, 3.77MB/s]
58%|█████▊    | 79.8M/138M [00:22<00:16, 3.76MB/s]
58%|█████▊    | 80.2M/138M [00:22<00:16, 3.77MB/s]
58%|█████▊    | 80.5M/138M [00:22<00:16, 3.74MB/s]
59%|█████▊    | 80.9M/138M [00:22<00:15, 3.76MB/s]
59%|█████▉    | 81.3M/138M [00:22<00:15, 3.76MB/s]
59%|█████▉    | 81.7M/138M [00:22<00:15, 3.77MB/s]
59%|█████▉    | 82.0M/138M [00:22<00:15, 3.78MB/s]
60%|█████▉    | 82.4M/138M [00:23<00:15, 3.76MB/s]
60%|█████▉    | 82.8M/138M [00:23<00:15, 3.76MB/s]
60%|██████    | 83.1M/138M [00:23<00:15, 3.77MB/s]
60%|██████    | 83.5M/138M [00:23<00:15, 3.78MB/s]
61%|██████    | 83.9M/138M [00:23<00:15, 3.77MB/s]
61%|██████    | 84.3M/138M [00:23<00:15, 3.76MB/s]
61%|██████    | 84.6M/138M [00:23<00:14, 3.77MB/s]
61%|██████▏   | 85.0M/138M [00:23<00:14, 3.79MB/s]
62%|██████▏   | 85.4M/138M [00:23<00:14, 3.77MB/s]
62%|██████▏   | 85.8M/138M [00:23<00:14, 3.76MB/s]
62%|██████▏   | 86.1M/138M [00:24<00:14, 3.77MB/s]
63%|██████▎   | 86.5M/138M [00:24<00:14, 3.76MB/s]
63%|██████▎   | 86.9M/138M [00:24<00:14, 3.75MB/s]
63%|██████▎   | 87.2M/138M [00:24<00:14, 3.75MB/s]
63%|██████▎   | 87.6M/138M [00:24<00:14, 3.77MB/s]
64%|██████▎   | 88.0M/138M [00:24<00:13, 3.78MB/s]
64%|██████▍   | 88.4M/138M [00:24<00:13, 3.78MB/s]
64%|██████▍   | 88.7M/138M [00:24<00:13, 3.77MB/s]
64%|██████▍   | 89.1M/138M [00:24<00:13, 3.78MB/s]
65%|██████▍   | 89.5M/138M [00:24<00:13, 3.80MB/s]
65%|██████▌   | 89.9M/138M [00:25<00:13, 3.78MB/s]
65%|██████▌   | 90.2M/138M [00:25<00:13, 3.76MB/s]
66%|██████▌   | 90.6M/138M [00:25<00:13, 3.78MB/s]
66%|██████▌   | 91.0M/138M [00:25<00:13, 3.76MB/s]
66%|██████▌   | 91.3M/138M [00:25<00:13, 3.76MB/s]
66%|██████▋   | 91.7M/138M [00:25<00:12, 3.75MB/s]
67%|██████▋   | 92.1M/138M [00:25<00:12, 3.75MB/s]
67%|██████▋   | 92.5M/138M [00:25<00:12, 3.77MB/s]
67%|██████▋   | 92.8M/138M [00:25<00:12, 3.75MB/s]
67%|██████▋   | 93.2M/138M [00:26<00:12, 3.79MB/s]
68%|██████▊   | 93.6M/138M [00:26<00:12, 3.76MB/s]
68%|██████▊   | 94.0M/138M [00:26<00:12, 3.75MB/s]
68%|██████▊   | 94.3M/138M [00:26<00:12, 3.77MB/s]
69%|██████▊   | 94.7M/138M [00:26<00:12, 3.76MB/s]
69%|██████▉   | 95.1M/138M [00:26<00:12, 3.76MB/s]
69%|██████▉   | 95.5M/138M [00:26<00:11, 3.75MB/s]
69%|██████▉   | 95.8M/138M [00:26<00:11, 3.74MB/s]
70%|██████▉   | 96.2M/138M [00:26<00:11, 3.76MB/s]
70%|██████▉   | 96.6M/138M [00:26<00:11, 3.75MB/s]
70%|███████   | 97.0M/138M [00:27<00:11, 3.75MB/s]
70%|███████   | 97.3M/138M [00:27<00:11, 3.75MB/s]
71%|███████   | 97.7M/138M [00:27<00:11, 3.77MB/s]
71%|███████   | 98.1M/138M [00:27<00:11, 3.78MB/s]
71%|███████   | 98.5M/138M [00:27<00:11, 3.77MB/s]
71%|███████▏  | 98.8M/138M [00:27<00:11, 3.73MB/s]
72%|███████▏  | 99.2M/138M [00:27<00:10, 3.73MB/s]
72%|███████▏  | 99.6M/138M [00:27<00:10, 3.74MB/s]
72%|███████▏  | 99.9M/138M [00:27<00:10, 3.74MB/s]
73%|███████▎  | 100M/138M [00:28<00:10, 3.74MB/s]
73%|███████▎  | 101M/138M [00:28<00:10, 3.74MB/s]
73%|███████▎  | 101M/138M [00:28<00:10, 3.78MB/s]
73%|███████▎  | 101M/138M [00:28<00:10, 3.77MB/s]
74%|███████▎  | 102M/138M [00:28<00:10, 3.76MB/s]
74%|███████▍  | 102M/138M [00:28<00:10, 3.75MB/s]
74%|███████▍  | 103M/138M [00:28<00:09, 3.77MB/s]
74%|███████▍  | 103M/138M [00:28<00:09, 3.78MB/s]
75%|███████▍  | 103M/138M [00:28<00:09, 3.77MB/s]
75%|███████▌  | 104M/138M [00:28<00:09, 3.76MB/s]
75%|███████▌  | 104M/138M [00:29<00:09, 3.76MB/s]
76%|███████▌  | 104M/138M [00:29<00:09, 3.77MB/s]
76%|███████▌  | 105M/138M [00:29<00:09, 3.74MB/s]
76%|███████▌  | 105M/138M [00:29<00:09, 3.74MB/s]
76%|███████▋  | 106M/138M [00:29<00:09, 3.78MB/s]
77%|███████▋  | 106M/138M [00:29<00:08, 3.77MB/s]
77%|███████▋  | 106M/138M [00:29<00:08, 3.76MB/s]
77%|███████▋  | 107M/138M [00:29<00:08, 3.76MB/s]
77%|███████▋  | 107M/138M [00:29<00:08, 3.77MB/s]
78%|███████▊  | 107M/138M [00:29<00:08, 3.76MB/s]
78%|███████▊  | 108M/138M [00:30<00:08, 3.77MB/s]
78%|███████▊  | 108M/138M [00:30<00:08, 3.79MB/s]
79%|███████▊  | 109M/138M [00:30<00:08, 3.77MB/s]
79%|███████▉  | 109M/138M [00:30<00:08, 3.76MB/s]
79%|███████▉  | 109M/138M [00:30<00:08, 3.77MB/s]
79%|███████▉  | 110M/138M [00:30<00:07, 3.76MB/s]
80%|███████▉  | 110M/138M [00:30<00:07, 3.75MB/s]
80%|███████▉  | 110M/138M [00:30<00:07, 3.77MB/s]
80%|████████  | 111M/138M [00:30<00:07, 3.78MB/s]
80%|████████  | 111M/138M [00:31<00:07, 3.76MB/s]
81%|████████  | 112M/138M [00:31<00:07, 3.76MB/s]
81%|████████  | 112M/138M [00:31<00:07, 3.76MB/s]
81%|████████  | 112M/138M [00:31<00:07, 3.76MB/s]
81%|████████▏ | 113M/138M [00:31<00:07, 3.77MB/s]
82%|████████▏ | 113M/138M [00:31<00:07, 3.76MB/s]
82%|████████▏ | 113M/138M [00:31<00:06, 3.77MB/s]
82%|████████▏ | 114M/138M [00:31<00:06, 3.78MB/s]
83%|████████▎ | 114M/138M [00:31<00:06, 3.77MB/s]
83%|████████▎ | 115M/138M [00:31<00:06, 3.77MB/s]
83%|████████▎ | 115M/138M [00:32<00:06, 3.76MB/s]
83%|████████▎ | 115M/138M [00:32<00:06, 3.76MB/s]
84%|████████▎ | 116M/138M [00:32<00:06, 3.69MB/s]
84%|████████▍ | 116M/138M [00:32<00:06, 3.72MB/s]
84%|████████▍ | 116M/138M [00:32<00:06, 3.73MB/s]
84%|████████▍ | 117M/138M [00:32<00:05, 3.75MB/s]
85%|████████▍ | 117M/138M [00:32<00:05, 3.75MB/s]
85%|████████▌ | 118M/138M [00:32<00:05, 3.75MB/s]
85%|████████▌ | 118M/138M [00:32<00:05, 3.77MB/s]
86%|████████▌ | 118M/138M [00:33<00:05, 3.75MB/s]
86%|████████▌ | 119M/138M [00:33<00:05, 3.77MB/s]
86%|████████▌ | 119M/138M [00:33<00:05, 3.76MB/s]
86%|████████▋ | 119M/138M [00:33<00:05, 3.76MB/s]
87%|████████▋ | 120M/138M [00:33<00:05, 3.78MB/s]
87%|████████▋ | 120M/138M [00:33<00:05, 3.79MB/s]
87%|████████▋ | 120M/138M [00:33<00:04, 3.76MB/s]
87%|████████▋ | 121M/138M [00:33<00:04, 3.76MB/s]
88%|████████▊ | 121M/138M [00:33<00:04, 3.77MB/s]
88%|████████▊ | 122M/138M [00:33<00:04, 3.76MB/s]
88%|████████▊ | 122M/138M [00:34<00:04, 3.75MB/s]
89%|████████▊ | 122M/138M [00:34<00:04, 3.77MB/s]
89%|████████▉ | 123M/138M [00:34<00:04, 3.76MB/s]
89%|████████▉ | 123M/138M [00:34<00:04, 3.75MB/s]
89%|████████▉ | 123M/138M [00:34<00:04, 3.75MB/s]
90%|████████▉ | 124M/138M [00:34<00:04, 3.75MB/s]
90%|████████▉ | 124M/138M [00:34<00:03, 3.74MB/s]
90%|█████████ | 125M/138M [00:34<00:03, 3.74MB/s]
90%|█████████ | 125M/138M [00:34<00:03, 3.74MB/s]
91%|█████████ | 125M/138M [00:34<00:03, 3.76MB/s]
91%|█████████ | 126M/138M [00:35<00:03, 3.78MB/s]
91%|█████████ | 126M/138M [00:35<00:03, 3.77MB/s]
92%|█████████▏| 126M/138M [00:35<00:03, 3.78MB/s]
92%|█████████▏| 127M/138M [00:35<00:03, 3.79MB/s]
92%|█████████▏| 127M/138M [00:35<00:03, 3.76MB/s]
92%|█████████▏| 128M/138M [00:35<00:02, 3.79MB/s]
93%|█████████▎| 128M/138M [00:35<00:02, 3.77MB/s]
93%|█████████▎| 128M/138M [00:35<00:02, 3.78MB/s]
93%|█████████▎| 129M/138M [00:35<00:02, 3.77MB/s]
93%|█████████▎| 129M/138M [00:36<00:02, 3.76MB/s]
94%|█████████▎| 129M/138M [00:36<00:02, 3.77MB/s]
94%|█████████▍| 130M/138M [00:36<00:02, 3.78MB/s]
94%|█████████▍| 130M/138M [00:36<00:02, 3.77MB/s]
94%|█████████▍| 131M/138M [00:36<00:02, 3.78MB/s]
95%|█████████▍| 131M/138M [00:36<00:02, 3.79MB/s]
95%|█████████▌| 131M/138M [00:36<00:01, 3.78MB/s]
95%|█████████▌| 132M/138M [00:36<00:01, 3.77MB/s]
96%|█████████▌| 132M/138M [00:36<00:01, 3.77MB/s]
96%|█████████▌| 132M/138M [00:36<00:01, 3.76MB/s]
96%|█████████▌| 133M/138M [00:37<00:01, 3.75MB/s]
96%|█████████▋| 133M/138M [00:37<00:01, 3.75MB/s]
97%|█████████▋| 134M/138M [00:37<00:01, 3.79MB/s]
97%|█████████▋| 134M/138M [00:37<00:01, 3.77MB/s]
97%|█████████▋| 134M/138M [00:37<00:01, 3.74MB/s]
97%|█████████▋| 135M/138M [00:37<00:00, 3.74MB/s]
98%|█████████▊| 135M/138M [00:37<00:00, 3.72MB/s]
98%|█████████▊| 135M/138M [00:37<00:00, 3.75MB/s]
98%|█████████▊| 136M/138M [00:37<00:00, 3.75MB/s]
99%|█████████▊| 136M/138M [00:38<00:00, 3.75MB/s]
99%|█████████▉| 137M/138M [00:38<00:00, 3.74MB/s]
99%|█████████▉| 137M/138M [00:38<00:00, 3.72MB/s]
99%|█████████▉| 137M/138M [00:38<00:00, 3.72MB/s]
100%|█████████▉| 138M/138M [00:38<00:00, 3.73MB/s]
100%|█████████▉| 138M/138M [00:38<00:00, 3.74MB/s]
100%|██████████| 138M/138M [00:38<00:00, 3.76MB/s]
coco_bike.jpg:   0%|          | 0.00/182k [00:00<?, ?B/s]
image 1/1 /opt/home/k8sworker/ci-ai/cibuilds/ov-notebook/OVNotebookOps-609/.workspace/scm/ov-notebook/notebooks/261-fast-segment-anything/coco_bike.jpg: 768x1024 37 objects, 627.6ms
Speed: 3.3ms preprocess, 627.6ms inference, 27.0ms postprocess per image at shape (1, 3, 768, 1024)

モデルは、画像上のすべてのオブジェクトのセグメント化マップを返します。以下の結果をご覧ください。

Image.fromarray(results[0].plot()[..., ::-1])
../_images/261-fast-segment-anything-with-output_9_0.png

モデルをOpenVINO中間表現 (IR) 形式に変換

Ultralytics モデル・エクスポート API を使用すると、PyTorch モデルを OpenVINO IR 形式に変換できます。内部的には、openvino.convert_model メソッドを使用してモデルの OpenVINO IR バージョンを取得します。このメソッドには、モデルトレース用のモデル・オブジェクトとサンプル入力が必要です。FastSAM モデル自体は YOLOv8 モデルに基づいています。

# instance segmentation model
ov_model_path = Path(f"{model_name}_openvino_model/{model_name}.xml")
if not ov_model_path.exists():
    ov_model = model.export(format="openvino", dynamic=True, half=False)
Ultralytics YOLOv8.0.200 🚀 Python-3.8.10 torch-2.2.0+cpu CPU (Intel Core(TM) i9-10920X 3.50GHz)
PyTorch: starting from 'FastSAM-x.pt' with input shape (1, 3, 1024, 1024) BCHW and output shape(s) ((1, 37, 21504), (1, 32, 256, 256)) (138.2 MB)
ONNX: starting export with onnx 1.15.0 opset 17...
ONNX: export success ✅ 3.8s, saved as 'FastSAM-x.onnx' (275.5 MB)
OpenVINO: starting export with openvino 2023.3.0-13775-ceeafaf64f3-releases/2023/3...
OpenVINO: export success ✅ 1.1s, saved as 'FastSAM-x_openvino_model/' (275.9 MB)
Export complete (7.7s)
Results saved to /opt/home/k8sworker/ci-ai/cibuilds/ov-notebook/OVNotebookOps-609/.workspace/scm/ov-notebook/notebooks/261-fast-segment-anything
Predict:         yolo predict task=segment model=FastSAM-x_openvino_model imgsz=1024
Validate:        yolo val task=segment model=FastSAM-x_openvino_model imgsz=1024 data=ultralytics/datasets/sa.yaml
Visualize:       https://netron.app

変換したモデルを元のパイプラインに埋め込み

OpenVINO™ ランタイム Python API は、モデルを OpenVINO IR 形式でコンパイルするために使用されます。Core クラスは、OpenVINO ランタイム API へのアクセスを提供します。Core クラスのインスタンスである core オブジェクトは API を表し、モデルをコンパイルするために使用されます。

core = ov.Core()

推論デバイスの選択

OpenVINO を使用してモデル推論に使用されるデバイスをドロップダウン・リストから選択します。

DEVICE = widgets.Dropdown(
    options=core.available_devices + ["AUTO"],
    value="AUTO",
    description="Device:",
    disabled=False,
)

DEVICE
Dropdown(description='Device:', index=1, options=('CPU', 'AUTO'), value='AUTO')

ここでは、元の推論パイプラインに埋め込む OpenVINO モデルのラッパークラスを作成します。OV モデルを適応させる際に考慮すべき事項をいくつか示します。

  • 元のパイプラインから渡されたパラメーターがコンパイルされた OV モデルに適切に転送されることを確認します。OV モデルでは入力引数の一部のみが使用され、一部は無視される場合があり、引数を別のデータタイプに変換したり、タプルや辞書などの一部のデータ構造をアンラップしたりする必要がある場合があります。
  • ラッパークラスが期待どおりの形式でパイプラインに結果を返すことを保証します。以下の例では、OV モデルの出力を torch テンソルのタプルにパックする方法が分かります。
  • モデルを呼び出すため元のパイプラインで使用されるモデルメソッドに注意してください。これは forward メソッドではない可能性があります。この例では、モデルは予測子オブジェクトの一部であり、オブジェクトとして呼び出されるため、マジック __call__ メソッドを再定義する必要があります。
class OVWrapper:
    def __init__(self, ov_model, device="CPU", stride=32) -> None:
        self.model = core.compile_model(ov_model, device_name=device)

        self.stride = stride
        self.pt = True
        self.fp16 = False
        self.names = {0: "object"}

    def __call__(self, im, **_):
        result = self.model(im)
        return torch.from_numpy(result[0]), torch.from_numpy(result[1])

ラッパー・オブジェクトを初期化し、FastSAM パイプラインにロードします。

wrapped_model = OVWrapper(ov_model_path, device=DEVICE.value, stride=model.predictor.model.stride)
model.predictor.model = wrapped_model

ov_results = model(image_uri, device=DEVICE.value, retina_masks=True, imgsz=640, conf=0.6, iou=0.9)
image 1/1 /opt/home/k8sworker/ci-ai/cibuilds/ov-notebook/OVNotebookOps-609/.workspace/scm/ov-notebook/notebooks/261-fast-segment-anything/coco_bike.jpg: 480x640 33 objects, 310.8ms
Speed: 3.3ms preprocess, 310.8ms inference, 26.3ms postprocess per image at shape (1, 3, 480, 640)

次のセルで変換されたモデル出力を確認できます。元のモデルと同じになります。

Image.fromarray(ov_results[0].plot()[..., ::-1])
../_images/261-fast-segment-anything-with-output_21_0.png

NNCF トレーニング後量子化 API を使用してモデルを最適化

NNCF は、精度の低下を最小限に抑えながら、OpenVINO でニューラル・ネットワーク推論を最適化する一連の高度なアルゴリズムを提供します。FastSAM を最適化するため、ポストトレーニング・モード (微調整パイプラインなし) で 8 ビット量子化を使用します。

最適化プロセスには次の手順が含まれます。

  1. 量子化用のデータセットを作成します。

  2. nncf.quantize を実行して、量子化されたモデルを取得します。

  3. openvino.save_model() 関数を使用して INT8 モデルを保存します。

do_quantize = widgets.Checkbox(
    value=True,
    description='Quantization',
    disabled=False,
)

do_quantize
Checkbox(value=True, description='Quantization')

nncf.quantize 関数は、モデルの量子化のインターフェイスを提供します。OpenVINO モデルのインスタンスと量子化データセットが必要です。オプションで、量子化プロセスの追加パラメーター (量子化のサンプル数、プリセット、無視される範囲など) を提供できます。FastSAM をサポートする YOLOv8 モデルには、アクティベーションの非対称量子化を必要とする非 ReLU アクティベーション関数が含まれています。さらに良い結果を得るため、混合量子化プリセットを使用します。これは、重みの対称量子化とアクティベーションの非対称量子化を提供します。より正確な結果を得るには、ignored_scope パラメーターを使用して、後処理サブグラフの操作を浮動小数点精度に保つ必要があります。

量子化アルゴリズムは、NNCF リポジトリーの YOLOv8 量子化の例に基づいています。詳細については、そちらを参照してください。さらに、OV ノートブック・リポジトリーで他の量子化チュートリアルを確認することもできます。

注: モデルのトレーニング後の量子化は時間のかかるプロセスです。ハードウェアによっては数分かかる場合があります。

%%skip not $do_quantize.value

import pickle
from contextlib import contextmanager
from zipfile import ZipFile

import cv2
from tqdm.autonotebook import tqdm

import nncf


COLLECT_CALIBRATION_DATA = False
calibration_data = []

@contextmanager
def calibration_data_collection():
    global COLLECT_CALIBRATION_DATA
    try:
        COLLECT_CALIBRATION_DATA = True
        yield
    finally:
        COLLECT_CALIBRATION_DATA = False


class NNCFWrapper:
    def __init__(self, ov_model, stride=32) -> None:
        self.model = core.read_model(ov_model)
        self.compiled_model = core.compile_model(self.model, device_name="CPU")

        self.stride = stride
        self.pt = True
        self.fp16 = False
        self.names = {0: "object"}

    def __call__(self, im, **_):
        if COLLECT_CALIBRATION_DATA:
            calibration_data.append(im)

        result = self.compiled_model(im)
        return torch.from_numpy(result[0]), torch.from_numpy(result[1])

# Fetch data from the web and descibe a dataloader
DATA_URL = "https://ultralytics.com/assets/coco128.zip"
OUT_DIR = Path('.')

download_file(DATA_URL, directory=OUT_DIR, show_progress=True)

if not (OUT_DIR / "coco128/images/train2017").exists():
    with ZipFile('coco128.zip', "r") as zip_ref:
        zip_ref.extractall(OUT_DIR)

class COCOLoader(torch.utils.data.Dataset):
    def __init__(self, images_path):
        self.images = list(Path(images_path).iterdir())

    def __getitem__(self, index):
        if isinstance(index, slice):
            return [self.read_image(image_path) for image_path in self.images[index]]
        return self.read_image(self.images[index])

    def read_image(self, image_path):
        image = cv2.imread(str(image_path))
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        return image

    def __len__(self):
        return len(self.images)


def collect_calibration_data_for_decoder(model, calibration_dataset_size: int,
                                         calibration_cache_path: Path):
    global calibration_data


    if not calibration_cache_path.exists():
        coco_dataset = COCOLoader(OUT_DIR / 'coco128/images/train2017')
        with calibration_data_collection():
            for image in tqdm(coco_dataset[:calibration_dataset_size], desc="Collecting calibration data"):
                model(image, retina_masks=True, imgsz=640, conf=0.6, iou=0.9, verbose=False)
        calibration_cache_path.parent.mkdir(parents=True, exist_ok=True)
        with open(calibration_cache_path, "wb") as f:
            pickle.dump(calibration_data, f)
    else:
        with open(calibration_cache_path, "rb") as f:
            calibration_data = pickle.load(f)

    return calibration_data


def quantize(model, save_model_path: Path, calibration_cache_path: Path,
             calibration_dataset_size: int, preset: nncf.QuantizationPreset):
    calibration_data = collect_calibration_data_for_decoder(
        model, calibration_dataset_size, calibration_cache_path)
    quantized_ov_decoder = nncf.quantize(
        model.predictor.model.model,
        calibration_dataset=nncf.Dataset(calibration_data),
        preset=preset,
        subset_size=len(calibration_data),
        fast_bias_correction=True,
        ignored_scope=nncf.IgnoredScope(
            types=["Multiply", "Subtract", "Sigmoid"],  # ignore operations
            names=[
                "/model.22/dfl/conv/Conv",  # in the post-processing subgraph
                "/model.22/Add",
                "/model.22/Add_1",
                "/model.22/Add_2",
                "/model.22/Add_3",
                "/model.22/Add_4",
                "/model.22/Add_5",
                "/model.22/Add_6",
                "/model.22/Add_7",
                "/model.22/Add_8",
                "/model.22/Add_9",
                "/model.22/Add_10",
            ],
        )
    )
    ov.save_model(quantized_ov_decoder, save_model_path)

wrapped_model = NNCFWrapper(ov_model_path, stride=model.predictor.model.stride)
model.predictor.model = wrapped_model

calibration_dataset_size = 128
quantized_model_path = Path(f"{model_name}_quantized") / "FastSAM-x.xml"
calibration_cache_path = Path(f"calibration_data/coco{calibration_dataset_size}.pkl")
if not quantized_model_path.exists():
    quantize(model, quantized_model_path, calibration_cache_path,
             calibration_dataset_size=calibration_dataset_size,
             preset=nncf.QuantizationPreset.MIXED)
INFO:nncf:NNCF initialized successfully. Supported frameworks detected: torch, tensorflow, onnx, openvino
coco128.zip:   0%|          | 0.00/6.66M [00:00<?, ?B/s]
Collecting calibration data:   0%|          | 0/128 [00:00<?, ?it/s]
INFO:nncf:12 ignored nodes were found by name in the NNCFGraph
INFO:nncf:13 ignored nodes were found by types in the NNCFGraph
INFO:nncf:Not adding activation input quantizer for operation: 204 /model.22/Sigmoid
INFO:nncf:Not adding activation input quantizer for operation: 246 /model.22/dfl/conv/Conv
INFO:nncf:Not adding activation input quantizer for operation: 275 /model.22/Sub
INFO:nncf:Not adding activation input quantizer for operation: 276 /model.22/Add_10
INFO:nncf:Not adding activation input quantizer for operation: 299 /model.22/Sub_1
INFO:nncf:Not adding activation input quantizer for operation: 341 /model.22/Mul_5
Output()
/opt/home/k8sworker/ci-ai/cibuilds/ov-notebook/OVNotebookOps-609/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages/nncf/experimental/tensor/tensor.py:84: RuntimeWarning: invalid value encountered in multiply
                                        return Tensor(self.data * unwrap_tensor_data(other))
Output()

元のモデルと量子化モデルのパフォーマンスを比較

最後に、OV モデルと量子化モデルの両方をキャリブレーション・データセット上で反復して、パフォーマンスを測定します。

%%skip not $do_quantize.value

import datetime

coco_dataset = COCOLoader(OUT_DIR / 'coco128/images/train2017')
calibration_dataset_size = 128

wrapped_model = OVWrapper(ov_model_path, device=DEVICE.value, stride=model.predictor.model.stride)
model.predictor.model = wrapped_model

start_time = datetime.datetime.now()
for image in tqdm(coco_dataset, desc="Measuring inference time"):
    model(image, retina_masks=True, imgsz=640, conf=0.6, iou=0.9, verbose=False)
duration_base = (datetime.datetime.now() - start_time).seconds
print("Segmented in", duration_base, "seconds.")
print("Resulting in", round(calibration_dataset_size / duration_base, 2), "fps")
Measuring inference time:   0%|          | 0/128 [00:00<?, ?it/s]
Segmented in 21 seconds.
Resulting in 6.1 fps
%%skip not $do_quantize.value

quantized_wrapped_model = OVWrapper(quantized_model_path, device=DEVICE.value, stride=model.predictor.model.stride)
model.predictor.model = quantized_wrapped_model

start_time = datetime.datetime.now()
for image in tqdm(coco_dataset, desc="Measuring inference time"):
    model(image, retina_masks=True, imgsz=640, conf=0.6, iou=0.9, verbose=False)
duration_quantized = (datetime.datetime.now() - start_time).seconds
print("Segmented in", duration_quantized, "seconds")
print("Resulting in", round(calibration_dataset_size / duration_quantized, 2), "fps")
print("That is", round(duration_base / duration_quantized, 2), "times faster!")
Measuring inference time:   0%|          | 0/128 [00:00<?, ?it/s]
Segmented in 11 seconds
Resulting in 11.64 fps
That is 1.91 times faster!

変換されたパイプラインを試す

以下のデモアプリは Gradio パッケージを使用して作成されています。

このアプリでは、モデルの出力をインタラクティブに変更できます。ピクセル・セレクター・タイプ・スイッチを使用すると、入力画像に前景/背景ポイントまたは境界ボックスを配置できます。

import cv2
import numpy as np
import matplotlib.pyplot as plt

def fast_process(
    annotations,
    image,
    scale,
    better_quality=False,
    mask_random_color=True,
    bbox=None,
    use_retina=True,
    with_contours=True,
):

    original_h = image.height
    original_w = image.width

    if better_quality:
        for i, mask in enumerate(annotations):
            mask = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((3, 3), np.uint8))
            annotations[i] = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_OPEN, np.ones((8, 8), np.uint8))

    inner_mask = fast_show_mask(
        annotations,
        plt.gca(),
        random_color=mask_random_color,
        bbox=bbox,
        retinamask=use_retina,
        target_height=original_h,
        target_width=original_w,
    )

    if with_contours:
        contour_all = []
        temp = np.zeros((original_h, original_w, 1))
        for i, mask in enumerate(annotations):
            annotation = mask.astype(np.uint8)
            if not use_retina:
                annotation = cv2.resize(
                    annotation,
                    (original_w, original_h),
                    interpolation=cv2.INTER_NEAREST,
                )
            contours, _ = cv2.findContours(annotation, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
            for contour in contours:
                contour_all.append(contour)
        cv2.drawContours(temp, contour_all, -1, (255, 255, 255), 2 // scale)
        color = np.array([0 / 255, 0 / 255, 255 / 255, 0.9])
        contour_mask = temp / 255 * color.reshape(1, 1, -1)

    image = image.convert("RGBA")
    overlay_inner = Image.fromarray((inner_mask * 255).astype(np.uint8), "RGBA")
    image.paste(overlay_inner, (0, 0), overlay_inner)

    if with_contours:
        overlay_contour = Image.fromarray((contour_mask * 255).astype(np.uint8), "RGBA")
        image.paste(overlay_contour, (0, 0), overlay_contour)

    return image


# CPU post process
def fast_show_mask(
    annotation,
    ax,
    random_color=False,
    bbox=None,
    retinamask=True,
    target_height=960,
    target_width=960,
):
    mask_sum = annotation.shape[0]
    height = annotation.shape[1]
    weight = annotation.shape[2]
    #
    areas = np.sum(annotation, axis=(1, 2))
    sorted_indices = np.argsort(areas)[::1]
    annotation = annotation[sorted_indices]

    index = (annotation != 0).argmax(axis=0)
    if random_color:
        color = np.random.random((mask_sum, 1, 1, 3))
    else:
        color = np.ones((mask_sum, 1, 1, 3)) * np.array([30 / 255, 144 / 255, 255 / 255])
    transparency = np.ones((mask_sum, 1, 1, 1)) * 0.6
    visual = np.concatenate([color, transparency], axis=-1)
    mask_image = np.expand_dims(annotation, -1) * visual

    mask = np.zeros((height, weight, 4))

    h_indices, w_indices = np.meshgrid(np.arange(height), np.arange(weight), indexing="ij")
    indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))

    mask[h_indices, w_indices, :] = mask_image[indices]
    if bbox is not None:
        x1, y1, x2, y2 = bbox
        ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor="b", linewidth=1))

    if not retinamask:
        mask = cv2.resize(mask, (target_width, target_height), interpolation=cv2.INTER_NEAREST)

    return mask
import gradio as gr

examples = [[image_uri], ["https://storage.openvinotoolkit.org/repositories/openvino_notebooks/data/data/image/empty_road_mapillary.jpg"],
            ["https://storage.openvinotoolkit.org/repositories/openvino_notebooks/data/data/image/wall.jpg"]]

object_points = []
background_points = []
bbox_points = []
last_image = examples[0][0]

これは、ユーザー入力に基づいて画像をセグメント化するのに呼び出されるメインのコールバック関数です。

def segment(
        image,
        model_type,
        input_size=1024,
        iou_threshold=0.75,
        conf_threshold=0.4,
        better_quality=True,
        with_contours=True,
        use_retina=True,
        mask_random_color=True,
):
    if do_quantize.value and model_type == 'Quantized model':
        model.predictor.model = quantized_wrapped_model
    else:
        model.predictor.model = wrapped_model

    input_size = int(input_size)
    w, h = image.size
    scale = input_size / max(w, h)
    new_w = int(w * scale)
    new_h = int(h * scale)
    image = image.resize((new_w, new_h))

    results = model(image,
                    device=DEVICE.value,
                    retina_masks=use_retina,
                    iou=iou_threshold,
                    conf=conf_threshold,
                    imgsz=input_size,)

    masks = results[0].masks.data
    # Calculate annotations
    if not (object_points or bbox_points):
        annotations = masks.cpu().numpy()
    else:
        annotations = []

    if object_points:
        all_points = object_points + background_points
        labels = [1] * len(object_points) + [0] * len(background_points)
        scaled_points = [[int(x * scale) for x in point] for point in all_points]
        h, w = masks[0].shape[:2]
        assert max(h, w) == input_size
        onemask = np.zeros((h, w))
        for mask in sorted(masks, key=lambda x: x.sum(), reverse=True):
            mask_np = (mask == 1.0).cpu().numpy()
            for point, label in zip(scaled_points, labels):
                if mask_np[point[1], point[0]] == 1 and label == 1:
                    onemask[mask_np] = 1
                if mask_np[point[1], point[0]] == 1 and label == 0:
                    onemask[mask_np] = 0
        annotations.append(onemask >= 1)
    if len(bbox_points) >= 2:
        scaled_bbox_points = []
        for i, point in enumerate(bbox_points):
            x, y = int(point[0] * scale), int(point[1] * scale)
            x = max(min(x, new_w), 0)
            y = max(min(y, new_h), 0)
            scaled_bbox_points.append((x, y))

        for i in range(0, len(scaled_bbox_points) - 1, 2):
            x0, y0, x1, y1 = *scaled_bbox_points[i], *scaled_bbox_points[i + 1]

            intersection_area = torch.sum(masks[:, y0:y1, x0:x1], dim=(1, 2))
            masks_area = torch.sum(masks, dim=(1, 2))
            bbox_area = (y1 - y0) * (x1 - x0)

            union = bbox_area + masks_area - intersection_area
            iou = intersection_area / union
            max_iou_index = torch.argmax(iou)

            annotations.append(masks[max_iou_index].cpu().numpy())

    return fast_process(
        annotations=np.array(annotations),
        image=image,
        scale=(1024 // input_size),
        better_quality=better_quality,
        mask_random_color=mask_random_color,
        bbox=None,
        use_retina=use_retina,
        with_contours=with_contours
    )
def select_point(img: Image.Image, point_type: str, evt: gr.SelectData) -> Image.Image:
    """Gradio select callback."""
    img = img.convert("RGBA")
    x, y = evt.index[0], evt.index[1]
    point_radius = np.round(max(img.size) / 100)
    if point_type == "Object point":
        object_points.append((x, y))
        color = (30, 255, 30, 200)
    elif point_type == "Background point":
        background_points.append((x, y))
        color = (255, 30, 30, 200)
    elif point_type == "Bounding Box":
        bbox_points.append((x, y))
        color = (10, 10, 255, 255)
        if len(bbox_points) % 2 == 0:
            # Draw a rectangle if number of points is even
            new_img = Image.new("RGBA", img.size, (255, 255, 255, 0))
            _draw = ImageDraw.Draw(new_img)
            x0, y0, x1, y1 = *bbox_points[-2], *bbox_points[-1]
            x0, x1 = sorted([x0, x1])
            y0, y1 = sorted([y0, y1])
            # Save sorted order
            bbox_points[-2] = (x0, y0)
            bbox_points[-1] = (x1, y1)
            _draw.rectangle((x0, y0, x1, y1), fill=(*color[:-1], 90))
            img = Image.alpha_composite(img, new_img)
    # Draw a point
    ImageDraw.Draw(img).ellipse(
        [(x - point_radius, y - point_radius), (x + point_radius, y + point_radius)],
        fill=color
    )
    return img

def clear_points() -> (Image.Image, None):
    """Gradio clear points callback."""
    global object_points, background_points, bbox_points
    # global object_points; global background_points; global bbox_points
    object_points = []
    background_points = []
    bbox_points = []
    return last_image, None

def save_last_picked_image(img: Image.Image) -> None:
    """Gradio callback saves the last used image."""
    global last_image
    last_image = img
    # If we change the input image
    # we should clear all the previous points
    clear_points()
    # Removes the segmentation map output
    return None

with gr.Blocks(title="Fast SAM") as demo:
    with gr.Row(variant="panel"):
        original_img = gr.Image(label="Input", value=examples[0][0], type="pil")
        segmented_img = gr.Image(label="Segmentation Map", type="pil")
    with gr.Row():
        point_type = gr.Radio(
            ["Object point", "Background point", "Bounding Box"],
            value="Object point", label="Pixel selector type"
        )
        model_type = gr.Radio(
            ["FP32 model", "Quantized model"] if do_quantize.value else ["FP32 model"],
            value="FP32 model", label="Select model variant"
        )
    with gr.Row(variant="panel"):
        segment_button = gr.Button("Segment", variant="primary")
        clear_button = gr.Button("Clear points", variant="secondary")
    gr.Examples(examples, inputs=original_img,
                fn=save_last_picked_image, run_on_click=True, outputs=segmented_img
    )

    # Callbacks
    original_img.select(select_point,
                        inputs=[original_img, point_type],
                        outputs=original_img)
    original_img.upload(save_last_picked_image, inputs=original_img, outputs=segmented_img)
    clear_button.click(clear_points, outputs=[original_img, segmented_img])
    segment_button.click(segment, inputs=[original_img, model_type], outputs=segmented_img)

try:
    demo.queue().launch(debug=False)
except Exception:
    demo.queue().launch(share=True, debug=False)

# If you are launching remotely, specify server_name and server_port
# EXAMPLE: `demo.launch(server_name="your server name", server_port="server port in int")`
# To learn more please refer to the Gradio docs: https://gradio.app/docs/
Running on local URL:  http://127.0.0.1:7860

To create a public link, set share=True in launch().