単一画像と OpenVINO から TripoSR フィードフォワード 3D を再構築#
この Jupyter ノートブックは、ローカルへのインストール後にのみ起動できます。
TripoSR は、Tripo AI と Stability AI が共同で開発した、単一の画像から高速フィードフォワード 3D 再構築を行う最先端のオープンソース・モデルです。
ソースコードは GitHub、デモは HuggingFace から入手できます。また、次の TripoSR の論文もご覧頂けます: Fast 3D Object Reconstruction from a Single Image (英語)。

ティーザービデオ#
目次:
必要条件#
%pip install -q "gradio>=4.19" "torch==2.2.2" rembg trimesh einops "omegaconf>=2.3.0" "transformers>=4.35.0" "openvino>=2024.0.0" --extra-index-url https://download.pytorch.org/whl/cpu
%pip install -q "git+https://github.com/tatsy/torchmcubes.git"
ERROR: pip's dependency resolver does not currently take into account all the packages that are installed.This behaviour is the source of the following dependency conflicts.
descript-audiotools 0.7.2 requires protobuf<3.20,>=3.9.2, but you have protobuf 3.20.3 which is incompatible.
mobileclip 0.1.0 requires torch==1.13.1, but you have torch 2.2.2+cpu which is incompatible. mobileclip 0.1.0 requires torchvision==0.14.1, but you have torchvision 0.18.1+cpu which is incompatible.
torchaudio 2.3.1+cpu requires torch==2.3.1, but you have torch 2.2.2+cpu which is incompatible.
torchvision 0.18.1+cpu requires torch==2.3.1, but you have torch 2.2.2+cpu which is incompatible.Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
import sys
from pathlib import Path
if not Path("TripoSR").exists():
!git clone https://huggingface.co/spaces/stabilityai/TripoSR
sys.path.append(model_repo)
Cloning into 'TripoSR'...
remote: Enumerating objects: 117, done.
remote: Counting objects: 100% (113/113), done.
remote: Compressing objects: 100% (111/111), done.
remote: Total 117 (delta 36), reused 0 (delta 0), pack-reused 4 (from 1)
Receiving objects: 100% (117/117), 569.16 KiB | 2.61 MiB/s, done.
Resolving deltas: 100% (36/36), done.
元のモデルを取得#
import os
from tsr.system import TSR
model = TSR.from_pretrained(
"stabilityai/TripoSR",
config_name="config.yaml",
weight_name="model.ckpt",
)
model.renderer.set_chunk_size(131072)
model.to("cpu")
TSR(
(image_tokenizer): DINOSingleImageTokenizer(
(model): ViTModel(
(embeddings): ViTEmbeddings(
(patch_embeddings): ViTPatchEmbeddings(
(projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
)
(dropout): Dropout(p=0.0, inplace=False)
)
(encoder): ViTEncoder(
(layer): ModuleList(
(0-11): 12 x ViTLayer(
(attention): ViTAttention(
(attention): ViTSelfAttention(
(query): Linear(in_features=768, out_features=768, bias=True)
(key): Linear(in_features=768, out_features=768, bias=True)
(value): Linear(in_features=768, out_features=768, bias=True)
(dropout): Dropout(p=0.0, inplace=False)
)
(output): ViTSelfOutput(
(dense): Linear(in_features=768, out_features=768, bias=True)
(dropout): Dropout(p=0.0, inplace=False)
)
)
(intermediate): ViTIntermediate(
(dense): Linear(in_features=768, out_features=3072, bias=True)
(intermediate_act_fn): GELUActivation()
)
(output): ViTOutput(
(dense): Linear(in_features=3072, out_features=768, bias=True)
(dropout): Dropout(p=0.0, inplace=False)
)
(layernorm_before): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(layernorm_after): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
)
)
)
(layernorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(pooler): ViTPooler(
(dense): Linear(in_features=768, out_features=768, bias=True)
(activation): Tanh()
)
)
)
(tokenizer): Triplane1DTokenizer()
(backbone): Transformer1D(
(norm): GroupNorm(32, 1024, eps=1e-06, affine=True)
(proj_in): Linear(in_features=1024, out_features=1024, bias=True)
(transformer_blocks): ModuleList(
(0-15): 16 x BasicTransformerBlock(
(norm1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(attn1): Attention(
(to_q): Linear(in_features=1024, out_features=1024, bias=False)
(to_k): Linear(in_features=1024, out_features=1024, bias=False)
(to_v): Linear(in_features=1024, out_features=1024, bias=False)
(to_out): ModuleList(
(0): Linear(in_features=1024, out_features=1024, bias=True)
(1): Dropout(p=0.0, inplace=False)
)
)
(norm2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(attn2): Attention(
(to_q): Linear(in_features=1024, out_features=1024, bias=False)
(to_k): Linear(in_features=768, out_features=1024, bias=False)
(to_v): Linear(in_features=768, out_features=1024, bias=False)
(to_out): ModuleList(
(0): Linear(in_features=1024, out_features=1024, bias=True)
(1): Dropout(p=0.0, inplace=False)
)
)
(norm3): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(ff): FeedForward(
(net): ModuleList(
(0): GEGLU(
(proj): Linear(in_features=1024, out_features=8192, bias=True)
)
(1): Dropout(p=0.0, inplace=False)
(2): Linear(in_features=4096, out_features=1024, bias=True)
)
)
)
)
(proj_out): Linear(in_features=1024, out_features=1024, bias=True)
)
(post_processor): TriplaneUpsampleNetwork(
(upsample): ConvTranspose2d(1024, 40, kernel_size=(2, 2), stride=(2, 2))
)
(decoder): NeRFMLP(
(layers): Sequential(
(0): Linear(in_features=120, out_features=64, bias=True)
(1): SiLU(inplace=True)
(2): Linear(in_features=64, out_features=64, bias=True)
(3): SiLU(inplace=True)
(4): Linear(in_features=64, out_features=64, bias=True)
(5): SiLU(inplace=True)
(6): Linear(in_features=64, out_features=64, bias=True)
(7): SiLU(inplace=True)
(8): Linear(in_features=64, out_features=64, bias=True)
(9): SiLU(inplace=True)
(10): Linear(in_features=64, out_features=64, bias=True)
(11): SiLU(inplace=True)
(12): Linear(in_features=64, out_features=64, bias=True)
(13): SiLU(inplace=True)
(14): Linear(in_features=64, out_features=64, bias=True)
(15): SiLU(inplace=True)
(16): Linear(in_features=64, out_features=64, bias=True)
(17): SiLU(inplace=True)
(18): Linear(in_features=64, out_features=4, bias=True)
)
)
(renderer): TriplaneNeRFRenderer()
)
モデルを OpenVINO IR に変換#
PyTorch モジュールの変換関数を定義します。ov.convert_model
関数を使用して OpenVINO 中間表現オブジェクトを取得し、ov.save_model
関数でそれを XML ファイルとして保存します。
import torch
import openvino as ov
def convert(model: torch.nn.Module, xml_path: str, example_input):
xml_path = Path(xml_path)
if not xml_path.exists():
xml_path.parent.mkdir(parents=True, exist_ok=True)
with torch.no_grad():
converted_model = ov.convert_model(model, example_input=example_input)
ov.save_model(converted_model, xml_path, compress_to_fp16=False)
# メモリーをクリーンアップ
torch._C._jit_clear_class_registry()
torch.jit._recursive.concrete_type_store = torch.jit._recursive.ConcreteTypeStore()
torch.jit._state._clear_class_state()
元のモデルには、複数のモデルのパイプラインです。それらには、image_tokenizer
、tokenizer
、backbone
、post_processor
があります。image_tokenizer
には、ViTPatchEmbeddings
、ViTEncoder
、および ViTPooler
で構成される ViTModel
が含まれています。Tokenizer
は Triplane1DTokenizer
、backbone
は Transformer1D
、post_processor
は TriplaneUpsampleNetwork
です。すべての内部モデルを 1 つずつ変換します。
VIT_PATCH_EMBEDDINGS_OV_PATH = Path("models/vit_patch_embeddings_ir.xml")
class PatchEmbedingWrapper(torch.nn.Module):
def __init__(self, patch_embeddings):
super().__init__()
self.patch_embeddings = patch_embeddings
def forward(self, pixel_values, interpolate_pos_encoding=True):
outputs = self.patch_embeddings(pixel_values=pixel_values, interpolate_pos_encoding=True)
return outputs
example_input = {
"pixel_values": torch.rand([1, 3, 512, 512], dtype=torch.float32),
}
convert(
PatchEmbedingWrapper(model.image_tokenizer.model.embeddings.patch_embeddings),
VIT_PATCH_EMBEDDINGS_OV_PATH,
example_input,
)
['pixel_values']
/opt/home/k8sworker/ci-ai/cibuilds/ov-notebook/OVNotebookOps-727/.workspace/scm/ov-notebook/.venv/lib/python3.8/site-packages/transformers/models/vit/modeling_vit.py:167: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect.We can't record the data flow of Python values, so this value will be treated as a constant in the future.This means that the trace might not generalize to other inputs!
if num_channels != self.num_channels:
VIT_ENCODER_OV_PATH = Path("models/vit_encoder_ir.xml")
class EncoderWrapper(torch.nn.Module):
def __init__(self, encoder):
super().__init__()
self.encoder = encoder
def forward(
self,
hidden_states=None,
head_mask=None,
output_attentions=False,
output_hidden_states=False,
return_dict=False,
):
outputs = self.encoder(
hidden_states=hidden_states,
)
return outputs.last_hidden_state
example_input = {
"hidden_states": torch.rand([1, 1025, 768], dtype=torch.float32),
}
convert(
EncoderWrapper(model.image_tokenizer.model.encoder),
VIT_ENCODER_OV_PATH,
example_input,
)
['hidden_states']
VIT_POOLER_OV_PATH = Path("models/vit_pooler_ir.xml")
convert(
model.image_tokenizer.model.pooler,
VIT_POOLER_OV_PATH,
torch.rand([1, 1025, 768], dtype=torch.float32),
)
['hidden_states']
TOKENIZER_OV_PATH = Path("models/tokenizer_ir.xml")
convert(model.tokenizer, TOKENIZER_OV_PATH, torch.tensor(1))
['batch_size']
example_input = {
"hidden_states": torch.rand([1, 1024, 3072], dtype=torch.float32),
"encoder_hidden_states": torch.rand([1, 1025, 768], dtype=torch.float32),
}
BACKBONE_OV_PATH = Path("models/backbone_ir.xml")
convert(model.backbone, BACKBONE_OV_PATH, example_input)
['hidden_states', 'encoder_hidden_states']
POST_PROCESSOR_OV_PATH = Path("models/post_processor_ir.xml")
convert(
model.post_processor,
POST_PROCESSOR_OV_PATH,
torch.rand([1, 3, 1024, 32, 32], dtype=torch.float32),
)
['triplanes']
モデルをコンパイルしてパイプラインを準備#
OpenVINO を使用して推論を実行するデバイスをドロップダウン・リストから選択します。
import ipywidgets as widgets
core = ov.Core()
device = widgets.Dropdown(
options=core.available_devices + ["AUTO"],
value="AUTO",
description="Device:",
disabled=False,
)
device
Dropdown(description='Device:', index=1, options=('CPU', 'AUTO'), value='AUTO')
compiled_vit_patch_embeddings = core.compile_model(VIT_PATCH_EMBEDDINGS_OV_PATH, device.value)
compiled_vit_model_encoder = core.compile_model(VIT_ENCODER_OV_PATH, device.value)
compiled_vit_model_pooler = core.compile_model(VIT_POOLER_OV_PATH, device.value)
compiled_tokenizer = core.compile_model(TOKENIZER_OV_PATH, device.value)
compiled_backbone = core.compile_model(BACKBONE_OV_PATH, device.value)
compiled_post_processor = core.compile_model(POST_PROCESSOR_OV_PATH, device.value)
元の TSR
クラスとの対話を可能にするため、コンパイルされたモデルの呼び出し可能なラッパークラスを作成します。すべてのラッパークラスは np.array
ではなく torch.Tensor
を返すことに注意してください。
from collections import namedtuple
class VitPatchEmdeddingsWrapper(torch.nn.Module):
def __init__(self, vit_patch_embeddings, model):
super().__init__()
self.vit_patch_embeddings = vit_patch_embeddings
self.projection = model.projection
def forward(self, pixel_values, interpolate_pos_encoding=False):
inputs = {
"pixel_values": pixel_values,
}
outs = self.vit_patch_embeddings(inputs)[0]
return torch.from_numpy(outs)
class VitModelEncoderWrapper(torch.nn.Module):
def __init__(self, vit_model_encoder):
super().__init__()
self.vit_model_encoder = vit_model_encoder
def forward(
self,
hidden_states,
head_mask,
output_attentions=False,
output_hidden_states=False,
return_dict=False,
):
inputs = {
"hidden_states": hidden_states.detach().numpy(),
}
outs = self.vit_model_encoder(inputs)
outputs = namedtuple("BaseModelOutput", ("last_hidden_state", "hidden_states", "attentions"))
return outputs(torch.from_numpy(outs[0]), None, None)
class VitModelPoolerWrapper(torch.nn.Module):
def __init__(self, vit_model_pooler):
super().__init__()
self.vit_model_pooler = vit_model_pooler
def forward(self, hidden_states):
outs = self.vit_model_pooler(hidden_states.detach().numpy())[0]
return torch.from_numpy(outs)
class TokenizerWrapper(torch.nn.Module):
def __init__(self, tokenizer, model):
super().__init__()
self.tokenizer = tokenizer self.detokenize = model.detokenize
def forward(self, batch_size):
outs = self.tokenizer(batch_size)[0]
return torch.from_numpy(outs)
class BackboneWrapper(torch.nn.Module):
def __init__(self, backbone):
super().__init__()
self.backbone = backbone
def forward(self, hidden_states, encoder_hidden_states):
inputs = {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states.detach().numpy(),
}
outs = self.backbone(inputs)[0]
return torch.from_numpy(outs)
class PostProcessorWrapper(torch.nn.Module):
def __init__(self, post_processor):
super().__init__()
self.post_processor = post_processor
def forward(self, triplanes):
outs = self.post_processor(triplanes)[0]
return torch.from_numpy(outs)
元のモデル内のすべてのモデルをラッパー・インスタンスに置き換えます:
model.image_tokenizer.model.embeddings.patch_embeddings = VitPatchEmdeddingsWrapper(
compiled_vit_patch_embeddings,
model.image_tokenizer.model.embeddings.patch_embeddings,
)
model.image_tokenizer.model.encoder = VitModelEncoderWrapper(compiled_vit_model_encoder)
model.image_tokenizer.model.pooler = VitModelPoolerWrapper(compiled_vit_model_pooler)
model.tokenizer = TokenizerWrapper(compiled_tokenizer, model.tokenizer)
model.backbone = BackboneWrapper(compiled_backbone)
model.post_processor = PostProcessorWrapper(compiled_post_processor)
インタラクティブな推論#
import tempfile
import gradio as gr
import numpy as np
import rembg
from PIL import Image
from tsr.utils import remove_background, resize_foreground, to_gradio_3d_orientation
rembg_session = rembg.new_session()
def check_input_image(input_image):
if input_image is None:
raise gr.Error("No image uploaded!")
def preprocess(input_image, do_remove_background, foreground_ratio):
def fill_background(image):
image = np.array(image).astype(np.float32) / 255.0
image = image[:, :, :3] * image[:, :, 3:4] + (1 - image[:, :, 3:4]) * 0.5
image = Image.fromarray((image * 255.0).astype(np.uint8))
return image
if do_remove_background:
image = input_image.convert("RGB")
image = remove_background(image, rembg_session)
image = resize_foreground(image, foreground_ratio)
image = fill_background(image)
else:
image = input_image
if image.mode == "RGBA":
image = fill_background(image)
return image
def generate(image):
scene_codes = model(image, "cpu") # このデバイスは画像処理装置用に提供されます
mesh = model.extract_mesh(scene_codes)[0]
mesh = to_gradio_3d_orientation(mesh)
mesh_path = tempfile.NamedTemporaryFile(suffix=".obj", delete=False)
mesh.export(mesh_path.name)
return mesh_path.name
with gr.Blocks() as demo:
with gr.Row(variant="panel"):
with gr.Column():
with gr.Row():
input_image = gr.Image(
label="Input Image",
image_mode="RGBA",
sources="upload",
type="pil",
elem_id="content_image",
)
processed_image = gr.Image(label="Processed Image", interactive=False)
with gr.Row():
with gr.Group():
do_remove_background = gr.Checkbox(label="Remove Background", value=True)
foreground_ratio = gr.Slider(
label="Foreground Ratio",
minimum=0.5,
maximum=1.0,
value=0.85,
step=0.05,
)
with gr.Row():
submit = gr.Button("Generate", elem_id="generate", variant="primary")
with gr.Column():
with gr.Tab("Model"):
output_model = gr.Model3D(
label="Output Model",
interactive=False,
)
with gr.Row(variant="panel"):
gr.Examples(
examples=[os.path.join("TripoSR/examples", img_name) for img_name in sorted(os.listdir("TripoSR/examples"))],
inputs=[input_image],
outputs=[processed_image, output_model],
label="Examples",
examples_per_page=20,
)
submit.click(fn=check_input_image, inputs=[input_image]).success(
fn=preprocess,
inputs=[input_image, do_remove_background, foreground_ratio],
outputs=[processed_image],
).success(
fn=generate,
inputs=[processed_image],
outputs=[output_model],
)
try:
demo.launch(debug=False, height=680)
except Exception:
demo.queue().launch(share=True, debug=False, height=680)
# リモートで起動する場合は、server_name と server_port を指定
# demo.launch(server_name='your server name', server_port='server port in int')
# 詳細については、ドキュメントをご覧ください: https://gradio.app/docs/
ローカル URL で実行中: http://127.0.0.1:7860 パブリックリンクを作成するには、launch() で share=True を設定します。