RMBG v1.4 と OpenVINO による背景除去#
この Jupyter ノートブックは、ローカルへのインストール後にのみ起動できます。
背景マッティングは、画像やビデオ内の前景オブジェクトを正確に推定するプロセスです。これは、画像やビデオの編集アプリケーション、特に視覚効果を作成する映画制作において重要な技術です。画像の分割の場合、ピクセルにラベルを付けて画像を前景と背景に分割します。画像セグメント化では、ピクセルが前景または背景のいずれかに属するバイナリー画像が生成されます。ただし、画像マッティングは画像セグメント化とは異なり、一部のピクセルは背景だけでなく前景にも属し、そのようなピクセルは部分ピクセルまたは混合ピクセルと呼ばれます。画像内の前景と背景を完全に分離するには、部分ピクセルまたは混合ピクセルのアルファ値を正確に推定する必要があります。
RMBG v1.4 は、さまざまなカテゴリーと画像タイプで前景と背景を効率よく分離するように設計された背景除去モデルです。このモデルは、一般的なストック画像、電子商取引、ゲーム、広告コンテンツなど、厳選されたデータセットでトレーニングされており、大規模なエンタープライズ・コンテンツ作成をサポートする商用ユースケースに適しています。精度、効率、汎用性は、現在、主要なソース利用可能モデルに匹敵します。
モデルの詳細については、モデルカードをご覧ください。
このチュートリアルでは、OpenVINO を使用してこのモデルを変換して実行する方法を説明します。
目次:
必要条件#
必要な依存関係をインストールします
%pip install -q torch torchvision pillow huggingface_hub "openvino>=2024.0.0" matplotlib "gradio>=4.15" "transformers>=4.39.1" tqdm --extra-index-url https://download.pytorch.org/whl/cpu
Note: you may need to restart the kernel to use updated packages.
HuggingFace ハブからモデルコードをダウンロード
from huggingface_hub import hf_hub_download
from pathlib import Path
repo_id = "briaai/RMBG-1.4"
download_files = ["utilities.py", "example_input.jpg"]
for file_for_downloading in download_files:
if not Path(file_for_downloading).exists():
hf_hub_download(repo_id=repo_id, filename=file_for_downloading, local_dir=".")
utilities.py: 0%| | 0.00/980 [00:00<?, ?B/s]
example_input.jpg: 0%| | 0.00/327k [00:00<?, ?B/s]
PyTorch モデルのロード#
PyTorch を使用してモデルをロードするには、AutoModelForImageSegmentation.from_pretrained
メソッドを使用します。モデルの重みは、モデルの初回使用時に自動的ダウンロードされます。しばらく時間がかかる場合があります。
from transformers import AutoModelForImageSegmentation
net = AutoModelForImageSegmentation.from_pretrained("briaai/RMBG-1.4", trust_remote_code=True)
/opt/home/k8sworker/ci-ai/cibuilds/ov-notebook/OVNotebookOps-727/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: resume_download is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use force_download=True. warnings.warn(
PyTorch モデル推論を実行#
preprocess_image
関数は、モデル固有の形式で入力データを準備します。postprocess_image
関数は、モデル出力の後処理を担当します。後処理後、生成された背景マスクをアルファチャネルとして元の画像に挿入できます。
import torch
from PIL import Image
from utilities import preprocess_image, postprocess_image
import numpy as np
from matplotlib import pyplot as plt
def visualize_result(orig_img: Image, mask: Image, result_img: Image): """
Helper for results visualization
parameters:
orig_img (Image): input image
mask (Image): background mask
result_img (Image) output image
returns:
plt.Figure: plot with 3 images for visualization
"""
titles = ["Original", "Background Mask", "Without background"]
im_w, im_h = orig_img.size
is_horizontal = im_h <= im_w
figsize = (20, 20)
num_images = 3
fig, axs = plt.subplots(
num_images if is_horizontal else 1,
1 if is_horizontal else num_images,
figsize=figsize,
sharex="all",
sharey="all",
)
fig.patch.set_facecolor("white")
list_axes = list(axs.flat)
for a in list_axes:
a.set_xticklabels([])
a.set_yticklabels([])
a.get_xaxis().set_visible(False)
a.get_yaxis().set_visible(False)
a.grid(False)
list_axes[0].imshow(np.array(orig_img))
list_axes[1].imshow(np.array(mask), cmap="gray")
list_axes[0].set_title(titles[0], fontsize=15)
list_axes[1].set_title(titles[1], fontsize=15)
list_axes[2].imshow(np.array(result_img))
list_axes[2].set_title(titles[2], fontsize=15)
fig.subplots_adjust(wspace=0.01 if is_horizontal else 0.00, hspace=0.01 if is_horizontal else 0.1)
fig.tight_layout()
return fig
im_path = "./example_input.jpg"
# 入力の準備
model_input_size = [1024, 1024]
orig_im = np.array(Image.open(im_path))
orig_im_size = orig_im.shape[0:2]
image = preprocess_image(orig_im, model_input_size)
# 推論
result = net(image)
# 後処理
result_image = postprocess_image(result[0][0], orig_im_size)
# 結果を保存
pil_im = Image.fromarray(result_image)
no_bg_image = Image.new("RGBA", pil_im.size, (0, 0, 0, 0))
orig_image = Image.open(im_path)
no_bg_image.paste(orig_image, mask=pil_im)
no_bg_image.save("example_image_no_bg.png") visualize_result(orig_image, pil_im, no_bg_image);

モデルを OpenVINO 中間表現形式に変換#
OpenVINO は、OpenVINO 中間表現 (IR) への変換により PyTorch モデルをサポートします。これには、OpenVINO モデル・トランスフォーメーション API を使用する必要があります。ov.convert_model
関数は、元の PyTorch モデル・インスタンスとトレース用のサンプル入力を受け取り、OpenVINO フレームワークでこのモデルを表す ov.Model
を返します。変換されたモデルは、ov.save_model
関数を使用してディスクに保存するか、core.complie_model
を使用してデバイスに直接ロードできます。
import openvino as ov
ov_model_path = Path("rmbg-1.4.xml")
if not ov_model_path.exists():
ov_model = ov.convert_model(net, example_input=image, input=[1, 3, *model_input_size])
ov.save_model(ov_model, ov_model_path)
/opt/home/k8sworker/ci-ai/cibuilds/ov-notebook/OVNotebookOps-727/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages/transformers/modeling_utils.py:4371: FutureWarning: _is_quantized_training_enabled is going to be deprecated in transformers 4.39.0. Please use model.hf_quantizer.is_trainable instead warnings.warn(
['x']
OpenVINO モデル推論を実行#
変換が完了したら、変換されたモデルをコンパイルし、指定されたデバイスで OpenVINO を使用して実行できます。推論デバイスの選択には、以下のドロップダウン・リストを使用してください:
import ipywidgets as widgets
core = ov.Core()
device = widgets.Dropdown(
options=core.available_devices + ["AUTO"],
value="AUTO",
description="Device:",
disabled=False,
)
device
Dropdown(description='Device:', index=1, options=('CPU', 'AUTO'), value='AUTO')
以前 PyTorch モデルを起動したのと同じイメージでモデルを実行してみます。OpenVINO モデルの入力と出力は、元の前処理および後処理の手順と完全に互換性があるため、再利用できます。
ov_compiled_model = core.compile_model(ov_model_path, device.value)
result = ov_compiled_model(image)[0]
# 後処理
result_image = postprocess_image(torch.from_numpy(result), orig_im_size)
# 結果を保存
pil_im = Image.fromarray(result_image)
no_bg_image = Image.new("RGBA", pil_im.size, (0, 0, 0, 0))
orig_image = Image.open(im_path)
no_bg_image.paste(orig_image, mask=pil_im)
no_bg_image.save("example_image_no_bg.png")
visualize_result(orig_image, pil_im, no_bg_image);

インタラクティブなデモ#
import gradio as gr
title = "# RMBG background removal with OpenVINO"
def get_background_mask(model, image):
return model(image)[0]
with gr.Blocks() as demo:
gr.Markdown(title)
with gr.Row():
input_image = gr.Image(label="Input Image", type="numpy")
background_image = gr.Image(label="Background removal Image")
submit = gr.Button("Submit")
def on_submit(image):
original_image = image.copy()
h, w = image.shape[:2]
image = preprocess_image(original_image, model_input_size)
mask = get_background_mask(ov_compiled_model, image)
result_image = postprocess_image(torch.from_numpy(mask), (h, w))
pil_im = Image.fromarray(result_image)
orig_img = Image.fromarray(original_image)
no_bg_image = Image.new("RGBA", pil_im.size, (0, 0, 0, 0))
no_bg_image.paste(orig_img, mask=pil_im)
return no_bg_image
submit.click(on_submit, inputs=[input_image], outputs=[background_image])
examples = gr.Examples(
examples=["./example_input.jpg"],
inputs=[input_image],
outputs=[background_image],
fn=on_submit,
cache_examples=False,
)
if __name__ == "__main__":
try:
demo.launch(debug=False)
except Exception:
demo.launch(share=True, debug=False)
# リモートで起動する場合は、server_name と server_port を指定します
# demo.launch(server_name='your server name', server_port='server port in int')
# 詳細については、ドキュメントをご覧ください: https://gradio.app/docs/
ローカル URL で実行中: http://127.0.0.1:7860 パブリックリンクを作成するには、launch() で share=True を設定します。