Skip to content

Prediction is not aligned with depth or mask #162

@alex-bene

Description

@alex-bene

Hello, thanks a lot for this great work!
I am trying to use SAM3D-Objects to predict object mesh and pose in images; however, the SAM3D prediction does not seem to really be aligned with either the depth pointmap or the binary mask.

Is this the expected behaviour from SAM3D?

Example 1

image used: https://guitar.com/wp-content/uploads/2020/03/person-playing-classical-guitar-at-home@1400x1050.jpg

Input

Image

Output

Image

Example 2

image used: https://media.istockphoto.com/id/1131330399/photo/confectioner-woman-making-delicious-cream-for-cupcakes.jpg?s=612x612&w=0&k=20&c=vUVlu8HuyZtKQiP6hFKO-CFUeO2a47ALrh9j-ya-bzk=

Input

Image

Output

Image

Example 2

image used: https://media.istockphoto.com/id/1359675618/photo/young-man-sitting-on-a-gray-sofa-caresses-the-head-of-a-brown-tabby-cat.jpg?s=612x612&w=0&k=20&c=wWQuof2KFLNKcrDxKzc0WayKNZS7qVo2y4R2Ubeqo0A=

Input

Image

Output

Image

Reproducability

Here's the actual code to reproduce the results.

import os

import numpy as np
import torch
from inference import (
    Inference,
    display_image,
)
from pytorch3d.transforms import Transform3d, quaternion_to_matrix
PATH = os.getcwd()
TAG = "hf"
config_path = f"{PATH}/../checkpoints/{TAG}/pipeline.yaml"
inference = Inference(config_path, compile=False)

from PIL import Image

image_path = .....
mask_path = .....
image = Image.open(image_path).convert("RGB")
image = image.resize((512, 512 * image.height // image.width))
mask = Image.open(mask_path).convert("1").resize(image.size, Image.NEAREST)
image = np.array(image).astype(np.uint8)
mask = np.array(mask) > 0.5

display_image(image, [mask])

from moge.model.v1 import MoGeModel

device = "cuda:0"
depth_model = MoGeModel.from_pretrained("Ruicheng/moge-vitl").to(device).eval()
depth_output = depth_model.infer(
    torch.tensor(image / 255.0).permute(2, 0, 1).unsqueeze(0).float().to(device),
    force_projection=False,
)

sam3d_output = inference(
    image,
    mask,
    seed=42,
    pointmap=depth_output["points"][0] * torch.tensor([-1, -1, 1], device=device),
)

R_yup_to_zup = torch.tensor([[-1, 0, 0], [0, 0, 1], [0, 1, 0]], dtype=torch.float32)
R_flip_z = torch.tensor([[1, 0, 0], [0, 1, 0], [0, 0, -1]], dtype=torch.float32)
R_pytorch3d_to_cam = torch.tensor(
    [[-1, 0, 0], [0, -1, 0], [0, 0, 1]], dtype=torch.float32
)


def transform_mesh_vertices(vertices, rotation, translation, scale):

    if isinstance(vertices, np.ndarray):
        vertices = torch.tensor(vertices, dtype=torch.float32)

    vertices = vertices.unsqueeze(0)  #  batch dimension [1, N, 3]
    vertices = vertices @ R_flip_z.to(vertices.device)
    vertices = vertices @ R_yup_to_zup.to(vertices.device)
    R_mat = quaternion_to_matrix(rotation.to(vertices.device))
    tfm = Transform3d(dtype=vertices.dtype, device=vertices.device)
    tfm = (
        tfm.scale(scale)
        .rotate(R_mat)
        .translate(translation[0], translation[1], translation[2])
    )
    vertices_world = tfm.transform_points(vertices)
    vertices_world = vertices_world @ R_pytorch3d_to_cam.to(vertices_world.device)

    return vertices_world[0]  # remove batch dimension


mesh = sam3d_output["glb"].copy()
vertices = mesh.vertices
vertices_tensor = torch.tensor(vertices)

S = sam3d_output["scale"][0].cpu().float()
T = sam3d_output["translation"][0].cpu().float()
R = sam3d_output["rotation"].squeeze().cpu().float()

vertices_transformed = transform_mesh_vertices(vertices, R, T, S)
mesh.vertices = vertices_transformed.cpu().numpy().astype(np.float32)

from plotly import graph_objects as go
import numpy as np

fig = go.Figure(
    layout=go.Layout(
        scene_camera={"up": {"x": 0, "y": 1, "z": 0}},
        margin={"l": 10, "r": 10, "b": 10, "t": 10},
        height=600,
        scene={
            "xaxis_title": "X Axis",
            "yaxis_title": "Y Axis",
            "zaxis_title": "Z Axis",
            "aspectmode": "data",
            "aspectratio": {"x": 1, "y": 1, "z": 1},
        },
    )
)

# Mesh trace
mesh_vertices = np.asarray(mesh.vertices)
mesh_faces = np.asarray(mesh.faces)
fig.add_trace(
    go.Mesh3d(
        x=mesh_vertices[:, 0],
        y=mesh_vertices[:, 1],
        z=mesh_vertices[:, 2],
        i=mesh_faces[:, 0],
        j=mesh_faces[:, 1],
        k=mesh_faces[:, 2],
        color="lightgray",
        opacity=0.8,
        name="mesh",
    )
)

# Point cloud colored by image
points = depth_output["points"][0]
points = points.detach().cpu().numpy()
points = points.reshape(-1, 3)

colors = image
colors = colors.reshape(-1, 3)
colors = colors.astype(np.uint8)
color_str = [f"rgb({r},{g},{b})" for r, g, b in colors]

fig.add_trace(
    go.Scatter3d(
        x=points[:, 0],
        y=points[:, 1],
        z=points[:, 2],
        mode="markers",
        marker={"size": 1, "color": color_str},
        name="pointcloud",
    )
)

fig.show()

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions