U^2-Net と OpenVINO™ を使用した画像の背景削除#

この Jupyter ノートブックはオンラインで起動でき、ブラウザーのウィンドウで対話型環境を開きます。ローカルにインストールすることもできます。次のオプションのいずれかを選択します:

BinderGoogle ColabGitHub

このノートブックでは、U2-Net と OpenVINO を使用して画像の背景を除去する方法を示します。

U2-Net のソースコードやテストデータの詳細については、GitHub ページと研究論文を参照してください: U^2-Net: 顕著なオブジェクトの検出のためネストされた U 構造をさらに深く理解する

PyTorch U2-Net モデルは OpenVINO IR 形式に変換されます。モデルのソースはこちらから入手できます。

目次:

準備#

要件をインストール#

import platform 

%pip install -q "openvino>=2023.1.0" 
%pip install -q --extra-index-url https://download.pytorch.org/whl/cpu "torch>=2.1" opencv-python 
%pip install -q "gdown<4.6.4" 

if platform.system() != "Windows":
     %pip install -q "matplotlib>=3.4" 
else:
     %pip install -q "matplotlib>=3.4,<3.7"
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.

PyTorch ライブラリーと U 2-Net をインポート#

import os 
import time 
from collections import namedtuple 
from pathlib import Path 

import cv2 
import matplotlib.pyplot as plt 
import numpy as np 
import openvino as ov 
import torch 
from IPython.display import HTML, FileLink, display
# ローカルモジュールをインポート 
import requests 

if not Path("./notebook_utils.py").exists():     # Fetch `notebook_utils` module 

    r = requests.get( 

url="https://raw.githubusercontent.com/openvinotoolkit/openvino_notebooks/latest/utils/notebook_utils.py", 
    ) 

    open("notebook_utils.py", "w").write(r.text) 

from notebook_utils import load_image, download_file 

if not Path("./model/u2net.py").exists(): 
    download_file( 

url="https://raw.githubusercontent.com/openvinotoolkit/openvino_notebooks/latest/notebooks/vision-background-removal/model/u2net.py", directory="model" 
    ) 
from model.u2net import U2NET, U2NETP

設定#

このチュートリアルでは、オリジナルの U2-Net 顕著物体検出モデルと、より小型の U2NETP バージョンをサポートしています。元のモデルでは、顕著なオブジェクトの検出と人間のセグメント化の 2 つの重みセットがサポートされています。

model_config = namedtuple("ModelConfig", ["name", "url", "model", "model_args"]) 

u2net_lite = model_config( 
    name="u2net_lite", 
    url="https://drive.google.com/uc?id=1W8E4FHIlTVstfRkYmNOjbr0VDXTZm0jD", 
    model=U2NETP, 
    model_args=(), 
) 
u2net = model_config( 
    name="u2net", 
    url="https://drive.google.com/uc?id=1ao1ovG1Qtx4b7EoskHXmi2E9rp5CHLcZ", 
    model=U2NET, 
    model_args=(3, 1), 
) 
u2net_human_seg = model_config( 
    name="u2net_human_seg", 
    url="https://drive.google.com/uc?id=1m_Kgs91b21gayc2XLW0ou8yugAIadWVP", 
    model=U2NET, 
    model_args=(3, 1), 
) 

# u2net_model を上記の 3 つの構成のいずれかに設定 
u2net_model = u2net_lite
# ダウンロードおよび変換されたモデルのファイル名
MODEL_DIR = "model" 
model_path = Path(MODEL_DIR) / u2net_model.name / Path(u2net_model.name).with_suffix(".pth")

U 2-Net モデルをロード#

U2-Net ヒューマンセグメント化モデルの重みは Google ドライブに保存されます。存在しない場合はダウンロードされます。次のセルは、モデルと事前トレーニングされた重みをロードします。

if not model_path.exists(): 
    import gdown 

    os.makedirs(name=model_path.parent, exist_ok=True) 
    print("Start downloading model weights file...") 
    with open(model_path, "wb") as model_file: 
        gdown.download(url=u2net_model.url, output=model_file) 
        print(f"Model weights have been downloaded to {model_path}")
Start downloading model weights file...
Downloading... From: https://drive.google.com/uc?id=1W8E4FHIlTVstfRkYmNOjbr0VDXTZm0jD 
To: <_io.BufferedWriter name='model/u2net_lite/u2net_lite.pth'> 100%|██████████| 4.68M/4.68M [00:00<00:00, 34.0MB/s]
Model weights have been downloaded to model/u2net_lite/u2net_lite.pth
# モデルをロード 
net = u2net_model.model(*u2net_model.model_args) 
net.eval() 

# 重みをロード 
print(f"Loading model weights from: '{model_path}'") 
net.load_state_dict(state_dict=torch.load(model_path, map_location="cpu"))
Loading model weights from: 'model/u2net_lite/u2net_lite.pth'
<All keys matched successfully>

PyTorch U 2-Net モデルを OpenVINO IR に変換#

モデル変換 Python API を使用して、Pytorch モデルを OpenVINO IR 形式に変換します。次のコマンドの実行には時間がかかる場合があります。

model_ir = ov.convert_model(net, example_input=torch.zeros((1, 3, 512, 512)), input=([1, 3, 512, 512]))
/opt/home/k8sworker/ci-ai/cibuilds/ov-notebook/OVNotebookOps-727/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages/torch/nn/functional.py:3782: UserWarning: nn.functional.upsample is deprecated. Use nn.functional.interpolate instead.
  warnings.warn("nn.functional.upsample is deprecated.Use nn.functional.interpolate instead.")
['x']

入力画像のロードと前処理#

OpenCV は BGR 形式で画像を読み取りますが、OpenVINO IR モデルは RGB の画像を想定しています。したがって、画像を RGB に変換し、サイズを 512 x 512 に変更して、OpenVINO IR モデルが期待する形式に次元を置き換えます。

平均値を画像テンソルに追加し、標準偏差で入力をスケーリングします。これは、ネットワークを通じて伝播する前の入力データの正規化と呼ばれます。平均値と標準偏差の値は、U^2-Net リポジトリーデータローダー・ファイルにあり、0 ~ 255 のピクセル値を持つ画像をサポートするため 255 倍されます。

IMAGE_URI = "https://storage.openvinotoolkit.org/repositories/openvino_notebooks/data/data/image/coco_hollywood.jpg" 

input_mean = np.array([123.675, 116.28, 103.53]).reshape(1, 3, 1, 1) 
input_scale = np.array([58.395, 57.12, 57.375]).reshape(1, 3, 1, 1) 

image = cv2.cvtColor( 
    src=load_image(IMAGE_URI), 
    code=cv2.COLOR_BGR2RGB, 
) 

resized_image = cv2.resize(src=image, dsize=(512, 512)) 
# 画像の形状を、OpenVINO IR モデルのネットワークが期待する形状と 
# データタイプに変換: (1, 3, 512, 512). 
input_image = np.expand_dims(np.transpose(resized_image, (2, 0, 1)), 0) 

input_image = (input_image - input_mean) / input_scale

推論デバイスの選択#

OpenVINO を使用して推論を実行するためにドロップダウン・リストからデバイスを選択します

import ipywidgets as widgets 

core = ov.Core() 
device = widgets.Dropdown( 
    options=core.available_devices + ["AUTO"], 
    value="AUTO", 
    description="Device:", 
    disabled=False, 
) 

device
Dropdown(description='Device:', index=1, options=('CPU', 'AUTO'), value='AUTO')

OpenVINO IR モデルで推論を行う#

OpenVINO IR モデルを OpenVINO ランタイムにロードし、推論を実行します。

core = ov.Core() 
# ネットワークを OpenVINO ランタイムにロード 
compiled_model_ir = core.compile_model(model=model_ir, device_name=device.value) 
# 入力と出力レイヤーの名前を取得 
input_layer_ir = compiled_model_ir.input(0) 
output_layer_ir = compiled_model_ir.output(0) 

# 入力画像に対して推論を実行 
start_time = time.perf_counter() 
result = compiled_model_ir([input_image])[output_layer_ir] 
end_time = time.perf_counter() 
print(f"Inference finished. Inference time: {end_time-start_time:.3f} seconds, " f"FPS: {1/(end_time-start_time):.2f}.")
Inference finished. Inference time: 0.119 seconds, FPS: 8.43 です。

結果を可視化#

元の画像、セグメント化の結果、背景を除去した元の画像を表示します。

# ネットワークの結果を画像の形状に合わせてサイズを変更し、 
# 値を 0 (背景) と 1 (前景) に丸めます
# ネットワーク結果の形状は (1,1,512,512) になります。`np.squeeze` 関数はこれを (512, 512) に変換します。 
resized_result = np.rint(cv2.resize(src=np.squeeze(result), dsize=(image.shape[1], image.shape[0]))).astype(np.uint8) 

# 画像のコピーを作成し、すべての背景値を 255 (白) に設定 
bg_removed_result = image.copy() 
bg_removed_result[resized_result == 0] = 255 

fig, ax = plt.subplots(nrows=1, ncols=3, figsize=(20, 7)) 
ax[0].imshow(image) 
ax[1].imshow(resized_result, cmap="gray") 
ax[2].imshow(bg_removed_result) 
for a in ax: 
    a.axis("off")
../_images/vision-background-removal-with-output_22_0.png

背景画像を追加#

セグメント化の結果では、すべての前景ピクセルの値は 1、すべての背景ピクセルの値は 0 になります。背景画像を次のように置き換えます:

  • 新しい background_image をロードします。

  • 元の画像と同じ画像サイズに変更します。

  • background_image で、サイズ変更されたセグメント化の結果の値が 1 であるピクセル (元の画像の前景ピクセル) を 0 に設定します。

  • 前の手順で作成した bg_removed_result (元の画像のうち、前景ピクセルのみを含む部分) を、background_image に追加します。

BACKGROUND_FILE = "https://storage.openvinotoolkit.org/repositories/openvino_notebooks/data/data/image/wall.jpg" 
OUTPUT_DIR = "output" 

os.makedirs(name=OUTPUT_DIR, exist_ok=True) 

background_image = cv2.cvtColor(src=load_image(BACKGROUND_FILE), code=cv2.COLOR_BGR2RGB) 
background_image = cv2.resize(src=background_image, dsize=(image.shape[1], image.shape[0])) 

# 結果からの背景画像のすべての前景ピクセルを 0 に設定し、 
# 背景を削除した画像を追加します。 
background_image[resized_result == 1] = 0 
new_image = background_image + bg_removed_result 

# Save the generated image. 
new_image_path = Path(f"{OUTPUT_DIR}/{Path(IMAGE_URI).stem}-{Path(BACKGROUND_FILE).stem}.jpg") 
cv2.imwrite(filename=str(new_image_path), img=cv2.cvtColor(new_image, cv2.COLOR_RGB2BGR)) 

# 元の画像と新しい背景の画像を並べて表示 
fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(18, 7)) 
ax[0].imshow(image) 
ax[1].imshow(new_image) 
for a in ax: 
    a.axis("off") 
plt.show() 

# 画像をダウンロードするリンクを作成 
image_link = FileLink(new_image_path) 
image_link.html_link_str = "<a href='%s' download>%s</a>" 
display( 
    HTML( 
        f"The generated image <code>
        {new_image_path.name}</code> is saved in " f"the directory <code>{new_image_path.parent}</code>.You can also " 
        "download the image by clicking on this link: " 
        f"{image_link._repr_html_()}" 
    ) 
)
../_images/vision-background-removal-with-output_24_0.png 生成されたイメージ coco_hollywood-wall.jpg は、ディレクトリー output に保存されます。このリンクをクリックして画像をダウンロードすることもできます: output/coco_hollywood-wall.jpg

関連情報#