RMBG v1.4 と OpenVINO による背景除去#

この Jupyter ノートブックは、ローカルへのインストール後にのみ起動できます。

GitHub

背景マッティングは、画像やビデオ内の前景オブジェクトを正確に推定するプロセスです。これは、画像やビデオの編集アプリケーション、特に視覚効果を作成する映画制作において重要な技術です。画像の分割の場合、ピクセルにラベルを付けて画像を前景と背景に分割します。画像セグメント化では、ピクセルが前景または背景のいずれかに属するバイナリー画像が生成されます。ただし、画像マッティングは画像セグメント化とは異なり、一部のピクセルは背景だけでなく前景にも属し、そのようなピクセルは部分ピクセルまたは混合ピクセルと呼ばれます。画像内の前景と背景を完全に分離するには、部分ピクセルまたは混合ピクセルのアルファ値を正確に推定する必要があります。

RMBG v1.4 は、さまざまなカテゴリーと画像タイプで前景と背景を効率よく分離するように設計された背景除去モデルです。このモデルは、一般的なストック画像、電子商取引、ゲーム、広告コンテンツなど、厳選されたデータセットでトレーニングされており、大規模なエンタープライズ・コンテンツ作成をサポートする商用ユースケースに適しています。精度、効率、汎用性は、現在、主要なソース利用可能モデルに匹敵します。

モデルの詳細については、モデルカードをご覧ください。

このチュートリアルでは、OpenVINO を使用してこのモデルを変換して実行する方法を説明します。

目次:

必要条件#

必要な依存関係をインストールします

%pip install -q torch torchvision pillow huggingface_hub "openvino>=2024.0.0" matplotlib "gradio>=4.15" "transformers>=4.39.1" tqdm --extra-index-url https://download.pytorch.org/whl/cpu
Note: you may need to restart the kernel to use updated packages.

HuggingFace ハブからモデルコードをダウンロード

from huggingface_hub import hf_hub_download 
from pathlib import Path 

repo_id = "briaai/RMBG-1.4" 

download_files = ["utilities.py", "example_input.jpg"] 

for file_for_downloading in download_files: 
    if not Path(file_for_downloading).exists(): 
        hf_hub_download(repo_id=repo_id, filename=file_for_downloading, local_dir=".")
utilities.py: 0%|          | 0.00/980 [00:00<?, ?B/s]
example_input.jpg: 0%|          | 0.00/327k [00:00<?, ?B/s]

PyTorch モデルのロード#

PyTorch を使用してモデルをロードするには、AutoModelForImageSegmentation.from_pretrained メソッドを使用します。モデルの重みは、モデルの初回使用時に自動的ダウンロードされます。しばらく時間がかかる場合があります。

from transformers import AutoModelForImageSegmentation 

net = AutoModelForImageSegmentation.from_pretrained("briaai/RMBG-1.4", trust_remote_code=True)
/opt/home/k8sworker/ci-ai/cibuilds/ov-notebook/OVNotebookOps-727/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: resume_download is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use force_download=True. 
warnings.warn(

PyTorch モデル推論を実行#

preprocess_image 関数は、モデル固有の形式で入力データを準備します。postprocess_image 関数は、モデル出力の後処理を担当します。後処理後、生成された背景マスクをアルファチャネルとして元の画像に挿入できます。

import torch 
from PIL import Image 
from utilities import preprocess_image, postprocess_image 
import numpy as np 
from matplotlib import pyplot as plt 

def visualize_result(orig_img: Image, mask: Image, result_img: Image):     """
    Helper for results visualization 

    parameters: 
        orig_img (Image): input image 
        mask (Image): background mask 
        result_img (Image) output image 
    returns: 
        plt.Figure: plot with 3 images for visualization 
    """ 
    titles = ["Original", "Background Mask", "Without background"] 
    im_w, im_h = orig_img.size 
    is_horizontal = im_h <= im_w 
    figsize = (20, 20) 
    num_images = 3 
    fig, axs = plt.subplots( 
        num_images if is_horizontal else 1, 
        1 if is_horizontal else num_images, 
        figsize=figsize, 
        sharex="all", 
        sharey="all", 
    ) 
    fig.patch.set_facecolor("white") 
    list_axes = list(axs.flat) 
    for a in list_axes: 
        a.set_xticklabels([]) 
        a.set_yticklabels([]) 
        a.get_xaxis().set_visible(False) 
        a.get_yaxis().set_visible(False) 
        a.grid(False) 
    list_axes[0].imshow(np.array(orig_img)) 
    list_axes[1].imshow(np.array(mask), cmap="gray") 
    list_axes[0].set_title(titles[0], fontsize=15) 
    list_axes[1].set_title(titles[1], fontsize=15) 
    list_axes[2].imshow(np.array(result_img)) 
    list_axes[2].set_title(titles[2], fontsize=15) 

    fig.subplots_adjust(wspace=0.01 if is_horizontal else 0.00, hspace=0.01 if is_horizontal else 0.1) 
    fig.tight_layout() 
    return fig 

im_path = "./example_input.jpg" 

# 入力の準備 
model_input_size = [1024, 1024] 
orig_im = np.array(Image.open(im_path)) 
orig_im_size = orig_im.shape[0:2] 
image = preprocess_image(orig_im, model_input_size) 

# 推論 
result = net(image) 

# 後処理 
result_image = postprocess_image(result[0][0], orig_im_size) 

# 結果を保存 
pil_im = Image.fromarray(result_image) 
no_bg_image = Image.new("RGBA", pil_im.size, (0, 0, 0, 0)) 
orig_image = Image.open(im_path) 
no_bg_image.paste(orig_image, mask=pil_im) 
no_bg_image.save("example_image_no_bg.png") visualize_result(orig_image, pil_im, no_bg_image);
../_images/rmbg-background-removal-with-output_8_0.png

モデルを OpenVINO 中間表現形式に変換#

OpenVINO は、OpenVINO 中間表現 (IR) への変換により PyTorch モデルをサポートします。これには、OpenVINO モデル・トランスフォーメーション API を使用する必要があります。ov.convert_model 関数は、元の PyTorch モデル・インスタンスとトレース用のサンプル入力を受け取り、OpenVINO フレームワークでこのモデルを表す ov.Model を返します。変換されたモデルは、ov.save_model 関数を使用してディスクに保存するか、core.complie_model を使用してデバイスに直接ロードできます。

import openvino as ov 

ov_model_path = Path("rmbg-1.4.xml") 

if not ov_model_path.exists(): 
    ov_model = ov.convert_model(net, example_input=image, input=[1, 3, *model_input_size]) 
    ov.save_model(ov_model, ov_model_path)
/opt/home/k8sworker/ci-ai/cibuilds/ov-notebook/OVNotebookOps-727/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages/transformers/modeling_utils.py:4371: FutureWarning: _is_quantized_training_enabled is going to be deprecated in transformers 4.39.0. Please use model.hf_quantizer.is_trainable instead warnings.warn(
['x']

OpenVINO モデル推論を実行#

変換が完了したら、変換されたモデルをコンパイルし、指定されたデバイスで OpenVINO を使用して実行できます。推論デバイスの選択には、以下のドロップダウン・リストを使用してください:

import ipywidgets as widgets 

core = ov.Core() 
device = widgets.Dropdown( 
    options=core.available_devices + ["AUTO"], 
    value="AUTO", 
    description="Device:", 
    disabled=False, 
) 

device
Dropdown(description='Device:', index=1, options=('CPU', 'AUTO'), value='AUTO')

以前 PyTorch モデルを起動したのと同じイメージでモデルを実行してみます。OpenVINO モデルの入力と出力は、元の前処理および後処理の手順と完全に互換性があるため、再利用できます。

ov_compiled_model = core.compile_model(ov_model_path, device.value) 

result = ov_compiled_model(image)[0] 

# 後処理 
result_image = postprocess_image(torch.from_numpy(result), orig_im_size) 

# 結果を保存 
pil_im = Image.fromarray(result_image) 
no_bg_image = Image.new("RGBA", pil_im.size, (0, 0, 0, 0)) 
orig_image = Image.open(im_path) 
no_bg_image.paste(orig_image, mask=pil_im) 
no_bg_image.save("example_image_no_bg.png") 

visualize_result(orig_image, pil_im, no_bg_image);
../_images/rmbg-background-removal-with-output_14_0.png

インタラクティブなデモ#

import gradio as gr 

title = "# RMBG background removal with OpenVINO" 

def get_background_mask(model, image): 
    return model(image)[0] 

with gr.Blocks() as demo: 
    gr.Markdown(title) 

    with gr.Row(): 
        input_image = gr.Image(label="Input Image", type="numpy") 
       background_image = gr.Image(label="Background removal Image") 
    submit = gr.Button("Submit") 

    def on_submit(image): 
        original_image = image.copy() 

        h, w = image.shape[:2] 
        image = preprocess_image(original_image, model_input_size) 

        mask = get_background_mask(ov_compiled_model, image) 
        result_image = postprocess_image(torch.from_numpy(mask), (h, w)) 
        pil_im = Image.fromarray(result_image) 
        orig_img = Image.fromarray(original_image) 
        no_bg_image = Image.new("RGBA", pil_im.size, (0, 0, 0, 0)) 
        no_bg_image.paste(orig_img, mask=pil_im) 

        return no_bg_image 

    submit.click(on_submit, inputs=[input_image], outputs=[background_image]) 
    examples = gr.Examples( 
        examples=["./example_input.jpg"], 
        inputs=[input_image], 
        outputs=[background_image], 
        fn=on_submit, 
        cache_examples=False, 
    ) 

if __name__ == "__main__": 
    try: 
        demo.launch(debug=False) 
    except Exception: 
        demo.launch(share=True, debug=False) 
# リモートで起動する場合は、server_name と server_port を指定します 
# demo.launch(server_name='your server name', server_port='server port in int') 
# 詳細については、ドキュメントをご覧ください: https://gradio.app/docs/
ローカル URL で実行中: http://127.0.0.1:7860 パブリックリンクを作成するには、launch()share=True を設定します。