Spaces:
Running
on
Zero
Running
on
Zero
xinjie.wang
commited on
Commit
·
c948747
1
Parent(s):
7f124e2
update
Browse files- app.py +1 -6
- common.py +50 -250
- embodied_gen/data/asset_converter.py +17 -0
- embodied_gen/models/sam3d.py +32 -27
- embodied_gen/scripts/gen_scene3d.py +17 -0
- embodied_gen/scripts/gen_texture.py +17 -0
- embodied_gen/scripts/imageto3d.py +32 -52
- embodied_gen/scripts/render_gs.py +0 -1
- embodied_gen/scripts/textto3d.py +3 -3
- embodied_gen/utils/gpt_clients.py +10 -3
- embodied_gen/utils/inference.py +59 -0
- embodied_gen/utils/monkey_patches.py +62 -12
- embodied_gen/utils/process_media.py +1 -1
- embodied_gen/utils/tags.py +1 -1
- embodied_gen/utils/trender.py +51 -0
- embodied_gen/validators/aesthetic_predictor.py +5 -1
- embodied_gen/validators/quality_checkers.py +27 -6
- embodied_gen/validators/urdf_convertor.py +5 -2
app.py
CHANGED
|
@@ -17,8 +17,6 @@
|
|
| 17 |
|
| 18 |
import os
|
| 19 |
|
| 20 |
-
os.environ["GRADIO_APP"] = "textto3d"
|
| 21 |
-
|
| 22 |
# GRADIO_APP == "textto3d_sam3d", sam3d object model, by default.
|
| 23 |
# GRADIO_APP == "textto3d", TRELLIS model.
|
| 24 |
os.environ["GRADIO_APP"] = "textto3d_sam3d"
|
|
@@ -35,19 +33,16 @@ from common import (
|
|
| 35 |
get_cached_image,
|
| 36 |
get_seed,
|
| 37 |
get_selected_image,
|
|
|
|
| 38 |
start_session,
|
| 39 |
text2image_fn,
|
| 40 |
)
|
| 41 |
|
| 42 |
app_name = os.getenv("GRADIO_APP")
|
| 43 |
if app_name == "textto3d_sam3d":
|
| 44 |
-
from common import image_to_3d_sam3d as image_to_3d
|
| 45 |
-
|
| 46 |
enable_pre_resize = False
|
| 47 |
sample_step = 25
|
| 48 |
elif app_name == "textto3d":
|
| 49 |
-
from common import image_to_3d
|
| 50 |
-
|
| 51 |
enable_pre_resize = True
|
| 52 |
sample_step = 12
|
| 53 |
|
|
|
|
| 17 |
|
| 18 |
import os
|
| 19 |
|
|
|
|
|
|
|
| 20 |
# GRADIO_APP == "textto3d_sam3d", sam3d object model, by default.
|
| 21 |
# GRADIO_APP == "textto3d", TRELLIS model.
|
| 22 |
os.environ["GRADIO_APP"] = "textto3d_sam3d"
|
|
|
|
| 33 |
get_cached_image,
|
| 34 |
get_seed,
|
| 35 |
get_selected_image,
|
| 36 |
+
image_to_3d,
|
| 37 |
start_session,
|
| 38 |
text2image_fn,
|
| 39 |
)
|
| 40 |
|
| 41 |
app_name = os.getenv("GRADIO_APP")
|
| 42 |
if app_name == "textto3d_sam3d":
|
|
|
|
|
|
|
| 43 |
enable_pre_resize = False
|
| 44 |
sample_step = 25
|
| 45 |
elif app_name == "textto3d":
|
|
|
|
|
|
|
| 46 |
enable_pre_resize = True
|
| 47 |
sample_step = 12
|
| 48 |
|
common.py
CHANGED
|
@@ -14,6 +14,10 @@
|
|
| 14 |
# implied. See the License for the specific language governing
|
| 15 |
# permissions and limitations under the License.
|
| 16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
import gc
|
| 18 |
import logging
|
| 19 |
import os
|
|
@@ -27,9 +31,7 @@ import gradio as gr
|
|
| 27 |
import numpy as np
|
| 28 |
import spaces
|
| 29 |
import torch
|
| 30 |
-
import torch.nn.functional as F
|
| 31 |
import trimesh
|
| 32 |
-
from easydict import EasyDict as edict
|
| 33 |
from PIL import Image
|
| 34 |
from embodied_gen.data.backproject_v2 import entrypoint as backproject_api
|
| 35 |
from embodied_gen.data.backproject_v3 import entrypoint as backproject_api_v3
|
|
@@ -37,6 +39,7 @@ from embodied_gen.data.differentiable_render import entrypoint as render_api
|
|
| 37 |
from embodied_gen.data.utils import trellis_preprocess, zip_files
|
| 38 |
from embodied_gen.models.delight_model import DelightingModel
|
| 39 |
from embodied_gen.models.gs_model import GaussianOperator
|
|
|
|
| 40 |
from embodied_gen.models.segment_model import (
|
| 41 |
BMGG14Remover,
|
| 42 |
RembgRemover,
|
|
@@ -57,7 +60,7 @@ from embodied_gen.utils.process_media import (
|
|
| 57 |
merge_images_video,
|
| 58 |
)
|
| 59 |
from embodied_gen.utils.tags import VERSION
|
| 60 |
-
from embodied_gen.utils.trender import render_video
|
| 61 |
from embodied_gen.validators.quality_checkers import (
|
| 62 |
BaseChecker,
|
| 63 |
ImageAestheticChecker,
|
|
@@ -70,15 +73,6 @@ current_file_path = os.path.abspath(__file__)
|
|
| 70 |
current_dir = os.path.dirname(current_file_path)
|
| 71 |
sys.path.append(os.path.join(current_dir, ".."))
|
| 72 |
from thirdparty.TRELLIS.trellis.pipelines import TrellisImageTo3DPipeline
|
| 73 |
-
from thirdparty.TRELLIS.trellis.representations import (
|
| 74 |
-
Gaussian,
|
| 75 |
-
MeshExtractResult,
|
| 76 |
-
)
|
| 77 |
-
from thirdparty.TRELLIS.trellis.representations.gaussian.general_utils import (
|
| 78 |
-
build_scaling_rotation,
|
| 79 |
-
inverse_sigmoid,
|
| 80 |
-
strip_symmetric,
|
| 81 |
-
)
|
| 82 |
from thirdparty.TRELLIS.trellis.utils import postprocessing_utils
|
| 83 |
|
| 84 |
logging.basicConfig(
|
|
@@ -86,79 +80,24 @@ logging.basicConfig(
|
|
| 86 |
)
|
| 87 |
logger = logging.getLogger(__name__)
|
| 88 |
|
| 89 |
-
|
| 90 |
-
os.environ["TORCH_EXTENSIONS_DIR"] = os.path.expanduser(
|
| 91 |
-
"~/.cache/torch_extensions"
|
| 92 |
-
)
|
| 93 |
os.environ["GRADIO_ANALYTICS_ENABLED"] = "false"
|
| 94 |
-
os.environ["SPCONV_ALGO"] = "native"
|
| 95 |
|
| 96 |
MAX_SEED = 100000
|
| 97 |
|
| 98 |
-
|
| 99 |
-
def patched_setup_functions(self):
|
| 100 |
-
def inverse_softplus(x):
|
| 101 |
-
return x + torch.log(-torch.expm1(-x))
|
| 102 |
-
|
| 103 |
-
def build_covariance_from_scaling_rotation(
|
| 104 |
-
scaling, scaling_modifier, rotation
|
| 105 |
-
):
|
| 106 |
-
L = build_scaling_rotation(scaling_modifier * scaling, rotation)
|
| 107 |
-
actual_covariance = L @ L.transpose(1, 2)
|
| 108 |
-
symm = strip_symmetric(actual_covariance)
|
| 109 |
-
return symm
|
| 110 |
-
|
| 111 |
-
if self.scaling_activation_type == "exp":
|
| 112 |
-
self.scaling_activation = torch.exp
|
| 113 |
-
self.inverse_scaling_activation = torch.log
|
| 114 |
-
elif self.scaling_activation_type == "softplus":
|
| 115 |
-
self.scaling_activation = F.softplus
|
| 116 |
-
self.inverse_scaling_activation = inverse_softplus
|
| 117 |
-
|
| 118 |
-
self.covariance_activation = build_covariance_from_scaling_rotation
|
| 119 |
-
self.opacity_activation = torch.sigmoid
|
| 120 |
-
self.inverse_opacity_activation = inverse_sigmoid
|
| 121 |
-
self.rotation_activation = F.normalize
|
| 122 |
-
|
| 123 |
-
self.scale_bias = self.inverse_scaling_activation(
|
| 124 |
-
torch.tensor(self.scaling_bias)
|
| 125 |
-
).to(self.device)
|
| 126 |
-
self.rots_bias = torch.zeros((4)).to(self.device)
|
| 127 |
-
self.rots_bias[0] = 1
|
| 128 |
-
self.opacity_bias = self.inverse_opacity_activation(
|
| 129 |
-
torch.tensor(self.opacity_bias)
|
| 130 |
-
).to(self.device)
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
Gaussian.setup_functions = patched_setup_functions
|
| 134 |
-
|
| 135 |
-
|
| 136 |
# DELIGHT = DelightingModel()
|
| 137 |
# IMAGESR_MODEL = ImageRealESRGAN(outscale=4)
|
| 138 |
# IMAGESR_MODEL = ImageStableSR()
|
| 139 |
-
if os.getenv("GRADIO_APP")
|
| 140 |
-
RBG_REMOVER = RembgRemover()
|
| 141 |
-
RBG14_REMOVER = BMGG14Remover()
|
| 142 |
-
SAM_PREDICTOR = SAMPredictor(model_type="vit_h", device="cpu")
|
| 143 |
-
PIPELINE = TrellisImageTo3DPipeline.from_pretrained(
|
| 144 |
-
"microsoft/TRELLIS-image-large"
|
| 145 |
-
)
|
| 146 |
-
# PIPELINE.cuda()
|
| 147 |
-
SEG_CHECKER = ImageSegChecker(GPT_CLIENT)
|
| 148 |
-
GEO_CHECKER = MeshGeoChecker(GPT_CLIENT)
|
| 149 |
-
AESTHETIC_CHECKER = ImageAestheticChecker()
|
| 150 |
-
CHECKERS = [GEO_CHECKER, SEG_CHECKER, AESTHETIC_CHECKER]
|
| 151 |
-
TMP_DIR = os.path.join(
|
| 152 |
-
os.path.dirname(os.path.abspath(__file__)), "sessions/imageto3d"
|
| 153 |
-
)
|
| 154 |
-
os.makedirs(TMP_DIR, exist_ok=True)
|
| 155 |
-
elif os.getenv("GRADIO_APP") == "imageto3d_sam3d":
|
| 156 |
-
from embodied_gen.models.sam3d import Sam3dInference
|
| 157 |
-
|
| 158 |
RBG_REMOVER = RembgRemover()
|
| 159 |
RBG14_REMOVER = BMGG14Remover()
|
| 160 |
SAM_PREDICTOR = SAMPredictor(model_type="vit_h", device="cpu")
|
| 161 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 162 |
SEG_CHECKER = ImageSegChecker(GPT_CLIENT)
|
| 163 |
GEO_CHECKER = MeshGeoChecker(GPT_CLIENT)
|
| 164 |
AESTHETIC_CHECKER = ImageAestheticChecker()
|
|
@@ -167,30 +106,16 @@ elif os.getenv("GRADIO_APP") == "imageto3d_sam3d":
|
|
| 167 |
os.path.dirname(os.path.abspath(__file__)), "sessions/imageto3d"
|
| 168 |
)
|
| 169 |
os.makedirs(TMP_DIR, exist_ok=True)
|
| 170 |
-
elif os.getenv("GRADIO_APP")
|
| 171 |
RBG_REMOVER = RembgRemover()
|
| 172 |
RBG14_REMOVER = BMGG14Remover()
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
SEG_CHECKER = ImageSegChecker(GPT_CLIENT)
|
| 181 |
-
GEO_CHECKER = MeshGeoChecker(GPT_CLIENT)
|
| 182 |
-
AESTHETIC_CHECKER = ImageAestheticChecker()
|
| 183 |
-
CHECKERS = [GEO_CHECKER, SEG_CHECKER, AESTHETIC_CHECKER]
|
| 184 |
-
TMP_DIR = os.path.join(
|
| 185 |
-
os.path.dirname(os.path.abspath(__file__)), "sessions/textto3d"
|
| 186 |
-
)
|
| 187 |
-
os.makedirs(TMP_DIR, exist_ok=True)
|
| 188 |
-
elif os.getenv("GRADIO_APP") == "textto3d_sam3d":
|
| 189 |
-
from embodied_gen.models.sam3d import Sam3dInference
|
| 190 |
-
|
| 191 |
-
RBG_REMOVER = RembgRemover()
|
| 192 |
-
RBG14_REMOVER = BMGG14Remover()
|
| 193 |
-
PIPELINE = Sam3dInference()
|
| 194 |
text_model_dir = "weights/Kolors"
|
| 195 |
PIPELINE_IMG_IP = build_text2img_ip_pipeline(text_model_dir, ref_scale=0.3)
|
| 196 |
PIPELINE_IMG = build_text2img_pipeline(text_model_dir)
|
|
@@ -302,50 +227,6 @@ def get_cached_image(image_path: str) -> Image.Image:
|
|
| 302 |
return Image.open(image_path).resize((512, 512))
|
| 303 |
|
| 304 |
|
| 305 |
-
@spaces.GPU
|
| 306 |
-
def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
|
| 307 |
-
return {
|
| 308 |
-
"gaussian": {
|
| 309 |
-
**gs.init_params,
|
| 310 |
-
"_xyz": gs._xyz.cpu().numpy(),
|
| 311 |
-
"_features_dc": gs._features_dc.cpu().numpy(),
|
| 312 |
-
"_scaling": gs._scaling.cpu().numpy(),
|
| 313 |
-
"_rotation": gs._rotation.cpu().numpy(),
|
| 314 |
-
"_opacity": gs._opacity.cpu().numpy(),
|
| 315 |
-
},
|
| 316 |
-
"mesh": {
|
| 317 |
-
"vertices": mesh.vertices.cpu().numpy(),
|
| 318 |
-
"faces": mesh.faces.cpu().numpy(),
|
| 319 |
-
},
|
| 320 |
-
}
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
def unpack_state(state: dict, device: str = "cpu") -> tuple[Gaussian, dict]:
|
| 324 |
-
gs = Gaussian(
|
| 325 |
-
aabb=state["gaussian"]["aabb"],
|
| 326 |
-
sh_degree=state["gaussian"]["sh_degree"],
|
| 327 |
-
mininum_kernel_size=state["gaussian"]["mininum_kernel_size"],
|
| 328 |
-
scaling_bias=state["gaussian"]["scaling_bias"],
|
| 329 |
-
opacity_bias=state["gaussian"]["opacity_bias"],
|
| 330 |
-
scaling_activation=state["gaussian"]["scaling_activation"],
|
| 331 |
-
device=device,
|
| 332 |
-
)
|
| 333 |
-
gs._xyz = torch.tensor(state["gaussian"]["_xyz"], device=device)
|
| 334 |
-
gs._features_dc = torch.tensor(
|
| 335 |
-
state["gaussian"]["_features_dc"], device=device
|
| 336 |
-
)
|
| 337 |
-
gs._scaling = torch.tensor(state["gaussian"]["_scaling"], device=device)
|
| 338 |
-
gs._rotation = torch.tensor(state["gaussian"]["_rotation"], device=device)
|
| 339 |
-
gs._opacity = torch.tensor(state["gaussian"]["_opacity"], device=device)
|
| 340 |
-
|
| 341 |
-
mesh = edict(
|
| 342 |
-
vertices=torch.tensor(state["mesh"]["vertices"], device=device),
|
| 343 |
-
faces=torch.tensor(state["mesh"]["faces"], device=device),
|
| 344 |
-
)
|
| 345 |
-
|
| 346 |
-
return gs, mesh
|
| 347 |
-
|
| 348 |
-
|
| 349 |
def get_seed(randomize_seed: bool, seed: int, max_seed: int = MAX_SEED) -> int:
|
| 350 |
return np.random.randint(0, max_seed) if randomize_seed else seed
|
| 351 |
|
|
@@ -399,87 +280,48 @@ def image_to_3d(
|
|
| 399 |
if is_sam_image:
|
| 400 |
seg_image = filter_image_small_connected_components(sam_image)
|
| 401 |
seg_image = Image.fromarray(seg_image, mode="RGBA")
|
| 402 |
-
seg_image = trellis_preprocess(seg_image)
|
| 403 |
else:
|
| 404 |
seg_image = image
|
| 405 |
|
| 406 |
if isinstance(seg_image, np.ndarray):
|
| 407 |
seg_image = Image.fromarray(seg_image)
|
| 408 |
|
| 409 |
-
|
| 410 |
-
|
| 411 |
-
|
| 412 |
-
|
| 413 |
-
|
| 414 |
-
|
| 415 |
-
|
| 416 |
-
|
| 417 |
-
|
| 418 |
-
|
| 419 |
-
|
| 420 |
-
|
| 421 |
-
|
| 422 |
-
|
| 423 |
-
|
| 424 |
-
|
| 425 |
-
|
| 426 |
-
|
| 427 |
-
|
| 428 |
-
|
| 429 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 430 |
|
| 431 |
gs_model = outputs["gaussian"][0]
|
| 432 |
mesh_model = outputs["mesh"][0]
|
| 433 |
color_images = render_video(gs_model, r=1.85)["color"]
|
| 434 |
normal_images = render_video(mesh_model, r=1.85)["normal"]
|
| 435 |
|
| 436 |
-
video_path = os.path.join(output_root, "gs_mesh.mp4")
|
| 437 |
-
merge_images_video(color_images, normal_images, video_path)
|
| 438 |
-
state = pack_state(gs_model, mesh_model)
|
| 439 |
-
|
| 440 |
-
gc.collect()
|
| 441 |
-
torch.cuda.empty_cache()
|
| 442 |
-
|
| 443 |
-
return state, video_path
|
| 444 |
-
|
| 445 |
-
|
| 446 |
-
@spaces.GPU
|
| 447 |
-
def image_to_3d_sam3d(
|
| 448 |
-
image: Image.Image,
|
| 449 |
-
seed: int,
|
| 450 |
-
ss_sampling_steps: int,
|
| 451 |
-
slat_sampling_steps: int,
|
| 452 |
-
raw_image_cache: Image.Image,
|
| 453 |
-
ss_guidance_strength: float = None,
|
| 454 |
-
slat_guidance_strength: float = None,
|
| 455 |
-
sam_image: Image.Image = None,
|
| 456 |
-
is_sam_image: bool = False,
|
| 457 |
-
req: gr.Request = None,
|
| 458 |
-
) -> tuple[dict, str]:
|
| 459 |
-
if is_sam_image:
|
| 460 |
-
seg_image = filter_image_small_connected_components(sam_image)
|
| 461 |
-
seg_image = Image.fromarray(seg_image, mode="RGBA")
|
| 462 |
-
else:
|
| 463 |
-
seg_image = image
|
| 464 |
-
|
| 465 |
-
if isinstance(seg_image, np.ndarray):
|
| 466 |
-
seg_image = Image.fromarray(seg_image)
|
| 467 |
-
|
| 468 |
output_root = os.path.join(TMP_DIR, str(req.session_hash))
|
| 469 |
os.makedirs(output_root, exist_ok=True)
|
| 470 |
seg_image.save(f"{output_root}/seg_image.png")
|
| 471 |
raw_image_cache.save(f"{output_root}/raw_image.png")
|
| 472 |
-
outputs = PIPELINE.run(
|
| 473 |
-
seg_image,
|
| 474 |
-
seed=seed,
|
| 475 |
-
stage1_inference_steps=ss_sampling_steps,
|
| 476 |
-
stage2_inference_steps=slat_sampling_steps,
|
| 477 |
-
)
|
| 478 |
-
|
| 479 |
-
gs_model = outputs["gaussian"][0]
|
| 480 |
-
mesh_model = outputs["mesh"][0]
|
| 481 |
-
color_images = render_video(gs_model, r=1.85)["color"]
|
| 482 |
-
normal_images = render_video(mesh_model, r=1.85)["normal"]
|
| 483 |
|
| 484 |
video_path = os.path.join(output_root, "gs_mesh.mp4")
|
| 485 |
merge_images_video(color_images, normal_images, video_path)
|
|
@@ -491,56 +333,13 @@ def image_to_3d_sam3d(
|
|
| 491 |
return state, video_path
|
| 492 |
|
| 493 |
|
| 494 |
-
@spaces.GPU
|
| 495 |
-
def extract_3d_representations(
|
| 496 |
-
state: dict, enable_delight: bool, texture_size: int, req: gr.Request
|
| 497 |
-
):
|
| 498 |
-
output_root = TMP_DIR
|
| 499 |
-
output_root = os.path.join(output_root, str(req.session_hash))
|
| 500 |
-
gs_model, mesh_model = unpack_state(state, device="cuda")
|
| 501 |
-
|
| 502 |
-
mesh = postprocessing_utils.to_glb(
|
| 503 |
-
gs_model,
|
| 504 |
-
mesh_model,
|
| 505 |
-
simplify=0.9,
|
| 506 |
-
texture_size=1024,
|
| 507 |
-
verbose=True,
|
| 508 |
-
)
|
| 509 |
-
filename = "sample"
|
| 510 |
-
gs_path = os.path.join(output_root, f"{filename}_gs.ply")
|
| 511 |
-
gs_model.save_ply(gs_path)
|
| 512 |
-
|
| 513 |
-
# Rotate mesh and GS by 90 degrees around Z-axis.
|
| 514 |
-
rot_matrix = [[0, 0, -1], [0, 1, 0], [1, 0, 0]]
|
| 515 |
-
# Addtional rotation for GS to align mesh.
|
| 516 |
-
gs_rot = np.array([[1, 0, 0], [0, -1, 0], [0, 0, -1]]) @ np.array(
|
| 517 |
-
rot_matrix
|
| 518 |
-
)
|
| 519 |
-
pose = GaussianOperator.trans_to_quatpose(gs_rot)
|
| 520 |
-
aligned_gs_path = gs_path.replace(".ply", "_aligned.ply")
|
| 521 |
-
GaussianOperator.resave_ply(
|
| 522 |
-
in_ply=gs_path,
|
| 523 |
-
out_ply=aligned_gs_path,
|
| 524 |
-
instance_pose=pose,
|
| 525 |
-
)
|
| 526 |
-
|
| 527 |
-
mesh.vertices = mesh.vertices @ np.array(rot_matrix)
|
| 528 |
-
mesh_obj_path = os.path.join(output_root, f"{filename}.obj")
|
| 529 |
-
mesh.export(mesh_obj_path)
|
| 530 |
-
mesh_glb_path = os.path.join(output_root, f"{filename}.glb")
|
| 531 |
-
mesh.export(mesh_glb_path)
|
| 532 |
-
|
| 533 |
-
torch.cuda.empty_cache()
|
| 534 |
-
|
| 535 |
-
return mesh_glb_path, gs_path, mesh_obj_path, aligned_gs_path
|
| 536 |
-
|
| 537 |
-
|
| 538 |
def extract_3d_representations_v2(
|
| 539 |
state: dict,
|
| 540 |
enable_delight: bool,
|
| 541 |
texture_size: int,
|
| 542 |
req: gr.Request,
|
| 543 |
):
|
|
|
|
| 544 |
output_root = TMP_DIR
|
| 545 |
user_dir = os.path.join(output_root, str(req.session_hash))
|
| 546 |
gs_model, mesh_model = unpack_state(state, device="cpu")
|
|
@@ -607,6 +406,7 @@ def extract_3d_representations_v3(
|
|
| 607 |
texture_size: int,
|
| 608 |
req: gr.Request,
|
| 609 |
):
|
|
|
|
| 610 |
output_root = TMP_DIR
|
| 611 |
user_dir = os.path.join(output_root, str(req.session_hash))
|
| 612 |
gs_model, mesh_model = unpack_state(state, device="cpu")
|
|
|
|
| 14 |
# implied. See the License for the specific language governing
|
| 15 |
# permissions and limitations under the License.
|
| 16 |
|
| 17 |
+
from embodied_gen.utils.monkey_patches import monkey_path_trellis
|
| 18 |
+
|
| 19 |
+
monkey_path_trellis()
|
| 20 |
+
|
| 21 |
import gc
|
| 22 |
import logging
|
| 23 |
import os
|
|
|
|
| 31 |
import numpy as np
|
| 32 |
import spaces
|
| 33 |
import torch
|
|
|
|
| 34 |
import trimesh
|
|
|
|
| 35 |
from PIL import Image
|
| 36 |
from embodied_gen.data.backproject_v2 import entrypoint as backproject_api
|
| 37 |
from embodied_gen.data.backproject_v3 import entrypoint as backproject_api_v3
|
|
|
|
| 39 |
from embodied_gen.data.utils import trellis_preprocess, zip_files
|
| 40 |
from embodied_gen.models.delight_model import DelightingModel
|
| 41 |
from embodied_gen.models.gs_model import GaussianOperator
|
| 42 |
+
from embodied_gen.models.sam3d import Sam3dInference
|
| 43 |
from embodied_gen.models.segment_model import (
|
| 44 |
BMGG14Remover,
|
| 45 |
RembgRemover,
|
|
|
|
| 60 |
merge_images_video,
|
| 61 |
)
|
| 62 |
from embodied_gen.utils.tags import VERSION
|
| 63 |
+
from embodied_gen.utils.trender import pack_state, render_video, unpack_state
|
| 64 |
from embodied_gen.validators.quality_checkers import (
|
| 65 |
BaseChecker,
|
| 66 |
ImageAestheticChecker,
|
|
|
|
| 73 |
current_dir = os.path.dirname(current_file_path)
|
| 74 |
sys.path.append(os.path.join(current_dir, ".."))
|
| 75 |
from thirdparty.TRELLIS.trellis.pipelines import TrellisImageTo3DPipeline
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
from thirdparty.TRELLIS.trellis.utils import postprocessing_utils
|
| 77 |
|
| 78 |
logging.basicConfig(
|
|
|
|
| 80 |
)
|
| 81 |
logger = logging.getLogger(__name__)
|
| 82 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
os.environ["GRADIO_ANALYTICS_ENABLED"] = "false"
|
|
|
|
| 84 |
|
| 85 |
MAX_SEED = 100000
|
| 86 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
# DELIGHT = DelightingModel()
|
| 88 |
# IMAGESR_MODEL = ImageRealESRGAN(outscale=4)
|
| 89 |
# IMAGESR_MODEL = ImageStableSR()
|
| 90 |
+
if os.getenv("GRADIO_APP").startswith("imageto3d"):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
RBG_REMOVER = RembgRemover()
|
| 92 |
RBG14_REMOVER = BMGG14Remover()
|
| 93 |
SAM_PREDICTOR = SAMPredictor(model_type="vit_h", device="cpu")
|
| 94 |
+
if "sam3d" in os.getenv("GRADIO_APP"):
|
| 95 |
+
PIPELINE = Sam3dInference()
|
| 96 |
+
else:
|
| 97 |
+
PIPELINE = TrellisImageTo3DPipeline.from_pretrained(
|
| 98 |
+
"microsoft/TRELLIS-image-large"
|
| 99 |
+
)
|
| 100 |
+
# PIPELINE.cuda()
|
| 101 |
SEG_CHECKER = ImageSegChecker(GPT_CLIENT)
|
| 102 |
GEO_CHECKER = MeshGeoChecker(GPT_CLIENT)
|
| 103 |
AESTHETIC_CHECKER = ImageAestheticChecker()
|
|
|
|
| 106 |
os.path.dirname(os.path.abspath(__file__)), "sessions/imageto3d"
|
| 107 |
)
|
| 108 |
os.makedirs(TMP_DIR, exist_ok=True)
|
| 109 |
+
elif os.getenv("GRADIO_APP").startswith("textto3d"):
|
| 110 |
RBG_REMOVER = RembgRemover()
|
| 111 |
RBG14_REMOVER = BMGG14Remover()
|
| 112 |
+
if "sam3d" in os.getenv("GRADIO_APP"):
|
| 113 |
+
PIPELINE = Sam3dInference()
|
| 114 |
+
else:
|
| 115 |
+
PIPELINE = TrellisImageTo3DPipeline.from_pretrained(
|
| 116 |
+
"microsoft/TRELLIS-image-large"
|
| 117 |
+
)
|
| 118 |
+
# PIPELINE.cuda()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
text_model_dir = "weights/Kolors"
|
| 120 |
PIPELINE_IMG_IP = build_text2img_ip_pipeline(text_model_dir, ref_scale=0.3)
|
| 121 |
PIPELINE_IMG = build_text2img_pipeline(text_model_dir)
|
|
|
|
| 227 |
return Image.open(image_path).resize((512, 512))
|
| 228 |
|
| 229 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 230 |
def get_seed(randomize_seed: bool, seed: int, max_seed: int = MAX_SEED) -> int:
|
| 231 |
return np.random.randint(0, max_seed) if randomize_seed else seed
|
| 232 |
|
|
|
|
| 280 |
if is_sam_image:
|
| 281 |
seg_image = filter_image_small_connected_components(sam_image)
|
| 282 |
seg_image = Image.fromarray(seg_image, mode="RGBA")
|
|
|
|
| 283 |
else:
|
| 284 |
seg_image = image
|
| 285 |
|
| 286 |
if isinstance(seg_image, np.ndarray):
|
| 287 |
seg_image = Image.fromarray(seg_image)
|
| 288 |
|
| 289 |
+
if isinstance(PIPELINE, Sam3dInference):
|
| 290 |
+
outputs = PIPELINE.run(
|
| 291 |
+
seg_image,
|
| 292 |
+
seed=seed,
|
| 293 |
+
stage1_inference_steps=ss_sampling_steps,
|
| 294 |
+
stage2_inference_steps=slat_sampling_steps,
|
| 295 |
+
)
|
| 296 |
+
else:
|
| 297 |
+
PIPELINE.cuda()
|
| 298 |
+
seg_image = trellis_preprocess(seg_image)
|
| 299 |
+
outputs = PIPELINE.run(
|
| 300 |
+
seg_image,
|
| 301 |
+
seed=seed,
|
| 302 |
+
formats=["gaussian", "mesh"],
|
| 303 |
+
preprocess_image=False,
|
| 304 |
+
sparse_structure_sampler_params={
|
| 305 |
+
"steps": ss_sampling_steps,
|
| 306 |
+
"cfg_strength": ss_guidance_strength,
|
| 307 |
+
},
|
| 308 |
+
slat_sampler_params={
|
| 309 |
+
"steps": slat_sampling_steps,
|
| 310 |
+
"cfg_strength": slat_guidance_strength,
|
| 311 |
+
},
|
| 312 |
+
)
|
| 313 |
+
# Set back to cpu for memory saving.
|
| 314 |
+
PIPELINE.cpu()
|
| 315 |
|
| 316 |
gs_model = outputs["gaussian"][0]
|
| 317 |
mesh_model = outputs["mesh"][0]
|
| 318 |
color_images = render_video(gs_model, r=1.85)["color"]
|
| 319 |
normal_images = render_video(mesh_model, r=1.85)["normal"]
|
| 320 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 321 |
output_root = os.path.join(TMP_DIR, str(req.session_hash))
|
| 322 |
os.makedirs(output_root, exist_ok=True)
|
| 323 |
seg_image.save(f"{output_root}/seg_image.png")
|
| 324 |
raw_image_cache.save(f"{output_root}/raw_image.png")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 325 |
|
| 326 |
video_path = os.path.join(output_root, "gs_mesh.mp4")
|
| 327 |
merge_images_video(color_images, normal_images, video_path)
|
|
|
|
| 333 |
return state, video_path
|
| 334 |
|
| 335 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 336 |
def extract_3d_representations_v2(
|
| 337 |
state: dict,
|
| 338 |
enable_delight: bool,
|
| 339 |
texture_size: int,
|
| 340 |
req: gr.Request,
|
| 341 |
):
|
| 342 |
+
"""Back-Projection Version of Texture Super-Resolution."""
|
| 343 |
output_root = TMP_DIR
|
| 344 |
user_dir = os.path.join(output_root, str(req.session_hash))
|
| 345 |
gs_model, mesh_model = unpack_state(state, device="cpu")
|
|
|
|
| 406 |
texture_size: int,
|
| 407 |
req: gr.Request,
|
| 408 |
):
|
| 409 |
+
"""Back-Projection Version with Optimization-Based."""
|
| 410 |
output_root = TMP_DIR
|
| 411 |
user_dir = os.path.join(output_root, str(req.session_hash))
|
| 412 |
gs_model, mesh_model = unpack_state(state, device="cpu")
|
embodied_gen/data/asset_converter.py
CHANGED
|
@@ -1,3 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
import logging
|
|
|
|
| 1 |
+
# Project EmbodiedGen
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 14 |
+
# implied. See the License for the specific language governing
|
| 15 |
+
# permissions and limitations under the License.
|
| 16 |
+
|
| 17 |
+
|
| 18 |
from __future__ import annotations
|
| 19 |
|
| 20 |
import logging
|
embodied_gen/models/sam3d.py
CHANGED
|
@@ -19,7 +19,6 @@ from embodied_gen.utils.monkey_patches import monkey_patch_sam3d
|
|
| 19 |
monkey_patch_sam3d()
|
| 20 |
import os
|
| 21 |
import sys
|
| 22 |
-
from typing import Optional, Union
|
| 23 |
|
| 24 |
import numpy as np
|
| 25 |
from hydra.utils import instantiate
|
|
@@ -31,29 +30,38 @@ from PIL import Image
|
|
| 31 |
current_file_path = os.path.abspath(__file__)
|
| 32 |
current_dir = os.path.dirname(current_file_path)
|
| 33 |
sys.path.append(os.path.join(current_dir, "../.."))
|
|
|
|
| 34 |
from thirdparty.sam3d.sam3d_objects.pipeline.inference_pipeline_pointmap import (
|
| 35 |
InferencePipelinePointMap,
|
| 36 |
)
|
| 37 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
__all__ = ["Sam3dInference"]
|
| 39 |
|
| 40 |
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
image = np.array(image)
|
| 44 |
-
image = image.astype(np.uint8)
|
| 45 |
-
return image
|
| 46 |
|
|
|
|
|
|
|
|
|
|
| 47 |
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
if mask.ndim == 3:
|
| 52 |
-
mask = mask[..., -1]
|
| 53 |
-
return mask
|
| 54 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
|
| 56 |
-
class Sam3dInference:
|
| 57 |
def __init__(
|
| 58 |
self, local_dir: str = "weights/sam-3d-objects", compile: bool = False
|
| 59 |
) -> None:
|
|
@@ -65,7 +73,7 @@ class Sam3dInference:
|
|
| 65 |
config.rendering_engine = "nvdiffrast"
|
| 66 |
config.compile_model = compile
|
| 67 |
config.workspace_dir = os.path.dirname(config_file)
|
| 68 |
-
# Generate 4 gs in each pixel.
|
| 69 |
config["slat_decoder_gs_config_path"] = config.pop(
|
| 70 |
"slat_decoder_gs_4_config_path", "slat_decoder_gs_4.yaml"
|
| 71 |
)
|
|
@@ -118,25 +126,22 @@ class Sam3dInference:
|
|
| 118 |
if __name__ == "__main__":
|
| 119 |
pipeline = Sam3dInference()
|
| 120 |
|
| 121 |
-
|
| 122 |
-
image = load_image(
|
| 123 |
-
"/home/users/xinjie.wang/xinjie/sam-3d-objects/notebook/images/shutterstock_stylish_kidsroom_1640806567/image.png"
|
| 124 |
-
)
|
| 125 |
-
mask = load_mask(
|
| 126 |
-
"/home/users/xinjie.wang/xinjie/sam-3d-objects/notebook/images/shutterstock_stylish_kidsroom_1640806567/13.png"
|
| 127 |
-
)
|
| 128 |
|
| 129 |
import torch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
|
| 131 |
if torch.cuda.is_available():
|
| 132 |
torch.cuda.reset_peak_memory_stats()
|
| 133 |
torch.cuda.empty_cache()
|
| 134 |
|
| 135 |
-
from time import time
|
| 136 |
-
|
| 137 |
start = time()
|
| 138 |
-
|
| 139 |
-
output = pipeline.run(image, mask, seed=42)
|
| 140 |
print(f"Running cost: {round(time()-start, 1)}")
|
| 141 |
|
| 142 |
if torch.cuda.is_available():
|
|
@@ -145,5 +150,5 @@ if __name__ == "__main__":
|
|
| 145 |
|
| 146 |
print(f"End: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")
|
| 147 |
|
| 148 |
-
output["gs"].save_ply(
|
| 149 |
-
print("
|
|
|
|
| 19 |
monkey_patch_sam3d()
|
| 20 |
import os
|
| 21 |
import sys
|
|
|
|
| 22 |
|
| 23 |
import numpy as np
|
| 24 |
from hydra.utils import instantiate
|
|
|
|
| 30 |
current_file_path = os.path.abspath(__file__)
|
| 31 |
current_dir = os.path.dirname(current_file_path)
|
| 32 |
sys.path.append(os.path.join(current_dir, "../.."))
|
| 33 |
+
from loguru import logger
|
| 34 |
from thirdparty.sam3d.sam3d_objects.pipeline.inference_pipeline_pointmap import (
|
| 35 |
InferencePipelinePointMap,
|
| 36 |
)
|
| 37 |
|
| 38 |
+
logger.remove()
|
| 39 |
+
logger.add(lambda _: None, level="ERROR")
|
| 40 |
+
|
| 41 |
+
|
| 42 |
__all__ = ["Sam3dInference"]
|
| 43 |
|
| 44 |
|
| 45 |
+
class Sam3dInference:
|
| 46 |
+
"""Wrapper for the SAM-3D-Objects inference pipeline.
|
|
|
|
|
|
|
|
|
|
| 47 |
|
| 48 |
+
This class handles loading the SAM-3D-Objects model, configuring it for inference,
|
| 49 |
+
and running the pipeline on input images (optionally with masks and pointmaps).
|
| 50 |
+
It supports distillation options and inference step customization.
|
| 51 |
|
| 52 |
+
Args:
|
| 53 |
+
local_dir (str): Directory to store or load model weights and configs.
|
| 54 |
+
compile (bool): Whether to compile the model for faster inference.
|
|
|
|
|
|
|
|
|
|
| 55 |
|
| 56 |
+
Methods:
|
| 57 |
+
merge_mask_to_rgba(image, mask):
|
| 58 |
+
Merges a binary mask into the alpha channel of an RGB image.
|
| 59 |
+
|
| 60 |
+
run(image, mask=None, seed=None, pointmap=None, use_stage1_distillation=False,
|
| 61 |
+
use_stage2_distillation=False, stage1_inference_steps=25, stage2_inference_steps=25):
|
| 62 |
+
Runs the inference pipeline and returns the output dictionary.
|
| 63 |
+
"""
|
| 64 |
|
|
|
|
| 65 |
def __init__(
|
| 66 |
self, local_dir: str = "weights/sam-3d-objects", compile: bool = False
|
| 67 |
) -> None:
|
|
|
|
| 73 |
config.rendering_engine = "nvdiffrast"
|
| 74 |
config.compile_model = compile
|
| 75 |
config.workspace_dir = os.path.dirname(config_file)
|
| 76 |
+
# Generate 4 instead of 32 gs in each pixel for efficient storage.
|
| 77 |
config["slat_decoder_gs_config_path"] = config.pop(
|
| 78 |
"slat_decoder_gs_4_config_path", "slat_decoder_gs_4.yaml"
|
| 79 |
)
|
|
|
|
| 126 |
if __name__ == "__main__":
|
| 127 |
pipeline = Sam3dInference()
|
| 128 |
|
| 129 |
+
from time import time
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
|
| 131 |
import torch
|
| 132 |
+
from embodied_gen.models.segment_model import RembgRemover
|
| 133 |
+
|
| 134 |
+
input_image = "apps/assets/example_image/sample_00.jpg"
|
| 135 |
+
output_gs = "outputs/splat.ply"
|
| 136 |
+
remover = RembgRemover()
|
| 137 |
+
clean_image = remover(input_image)
|
| 138 |
|
| 139 |
if torch.cuda.is_available():
|
| 140 |
torch.cuda.reset_peak_memory_stats()
|
| 141 |
torch.cuda.empty_cache()
|
| 142 |
|
|
|
|
|
|
|
| 143 |
start = time()
|
| 144 |
+
output = pipeline.run(clean_image, seed=42)
|
|
|
|
| 145 |
print(f"Running cost: {round(time()-start, 1)}")
|
| 146 |
|
| 147 |
if torch.cuda.is_available():
|
|
|
|
| 150 |
|
| 151 |
print(f"End: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")
|
| 152 |
|
| 153 |
+
output["gs"].save_ply(output_gs)
|
| 154 |
+
print(f"Saved to {output_gs}")
|
embodied_gen/scripts/gen_scene3d.py
CHANGED
|
@@ -1,3 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import logging
|
| 2 |
import os
|
| 3 |
import random
|
|
|
|
| 1 |
+
# Project EmbodiedGen
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 14 |
+
# implied. See the License for the specific language governing
|
| 15 |
+
# permissions and limitations under the License.
|
| 16 |
+
|
| 17 |
+
|
| 18 |
import logging
|
| 19 |
import os
|
| 20 |
import random
|
embodied_gen/scripts/gen_texture.py
CHANGED
|
@@ -1,3 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import os
|
| 2 |
import shutil
|
| 3 |
from dataclasses import dataclass
|
|
|
|
| 1 |
+
# Project EmbodiedGen
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 14 |
+
# implied. See the License for the specific language governing
|
| 15 |
+
# permissions and limitations under the License.
|
| 16 |
+
|
| 17 |
+
|
| 18 |
import os
|
| 19 |
import shutil
|
| 20 |
from dataclasses import dataclass
|
embodied_gen/scripts/imageto3d.py
CHANGED
|
@@ -14,30 +14,30 @@
|
|
| 14 |
# implied. See the License for the specific language governing
|
| 15 |
# permissions and limitations under the License.
|
| 16 |
|
| 17 |
-
|
| 18 |
import argparse
|
| 19 |
import os
|
| 20 |
import random
|
| 21 |
-
import sys
|
| 22 |
from glob import glob
|
| 23 |
from shutil import copy, copytree, rmtree
|
| 24 |
|
| 25 |
import numpy as np
|
| 26 |
-
import torch
|
| 27 |
import trimesh
|
| 28 |
from PIL import Image
|
| 29 |
from embodied_gen.data.backproject_v3 import entrypoint as backproject_api
|
| 30 |
-
from embodied_gen.data.utils import delete_dir
|
| 31 |
|
|
|
|
| 32 |
# from embodied_gen.models.delight_model import DelightingModel
|
| 33 |
from embodied_gen.models.gs_model import GaussianOperator
|
| 34 |
from embodied_gen.models.segment_model import RembgRemover
|
| 35 |
-
|
| 36 |
-
# from embodied_gen.models.sr_model import ImageRealESRGAN
|
| 37 |
from embodied_gen.scripts.render_gs import entrypoint as render_gs_api
|
| 38 |
from embodied_gen.utils.gpt_clients import GPT_CLIENT
|
|
|
|
| 39 |
from embodied_gen.utils.log import logger
|
| 40 |
-
from embodied_gen.utils.process_media import
|
|
|
|
|
|
|
|
|
|
| 41 |
from embodied_gen.utils.tags import VERSION
|
| 42 |
from embodied_gen.utils.trender import render_video
|
| 43 |
from embodied_gen.validators.quality_checkers import (
|
|
@@ -48,26 +48,24 @@ from embodied_gen.validators.quality_checkers import (
|
|
| 48 |
)
|
| 49 |
from embodied_gen.validators.urdf_convertor import URDFGenerator
|
| 50 |
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
|
|
|
| 55 |
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
)
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
|
|
|
|
|
|
| 62 |
|
| 63 |
-
logger.info("Loading Image3D Models...")
|
| 64 |
# DELIGHT = DelightingModel()
|
| 65 |
# IMAGESR_MODEL = ImageRealESRGAN(outscale=4)
|
| 66 |
RBG_REMOVER = RembgRemover()
|
| 67 |
-
PIPELINE = TrellisImageTo3DPipeline.from_pretrained(
|
| 68 |
-
"microsoft/TRELLIS-image-large"
|
| 69 |
-
)
|
| 70 |
-
# PIPELINE.cuda()
|
| 71 |
SEG_CHECKER = ImageSegChecker(GPT_CLIENT)
|
| 72 |
GEO_CHECKER = MeshGeoChecker(GPT_CLIENT)
|
| 73 |
AESTHETIC_CHECKER = ImageAestheticChecker()
|
|
@@ -151,7 +149,6 @@ def entrypoint(**kwargs):
|
|
| 151 |
# Segmentation: Get segmented image using Rembg.
|
| 152 |
seg_path = f"{output_root}/{filename}_cond.png"
|
| 153 |
seg_image = RBG_REMOVER(image) if image.mode != "RGBA" else image
|
| 154 |
-
seg_image = trellis_preprocess(seg_image)
|
| 155 |
seg_image.save(seg_path)
|
| 156 |
|
| 157 |
seed = args.seed
|
|
@@ -162,27 +159,8 @@ def entrypoint(**kwargs):
|
|
| 162 |
logger.info(
|
| 163 |
f"Try: {try_idx + 1}/{args.n_retry}, Seed: {seed}, Prompt: {seg_path}"
|
| 164 |
)
|
| 165 |
-
# Run the pipeline
|
| 166 |
try:
|
| 167 |
-
PIPELINE
|
| 168 |
-
outputs = PIPELINE.run(
|
| 169 |
-
seg_image,
|
| 170 |
-
preprocess_image=False,
|
| 171 |
-
seed=(
|
| 172 |
-
random.randint(0, 100000) if seed is None else seed
|
| 173 |
-
),
|
| 174 |
-
# Optional parameters
|
| 175 |
-
# sparse_structure_sampler_params={
|
| 176 |
-
# "steps": 12,
|
| 177 |
-
# "cfg_strength": 7.5,
|
| 178 |
-
# },
|
| 179 |
-
# slat_sampler_params={
|
| 180 |
-
# "steps": 12,
|
| 181 |
-
# "cfg_strength": 3,
|
| 182 |
-
# },
|
| 183 |
-
)
|
| 184 |
-
PIPELINE.cpu()
|
| 185 |
-
torch.cuda.empty_cache()
|
| 186 |
except Exception as e:
|
| 187 |
logger.error(
|
| 188 |
f"[Pipeline Failed] process {image_path}: {e}, skip."
|
|
@@ -215,14 +193,13 @@ def entrypoint(**kwargs):
|
|
| 215 |
render_gs_api(
|
| 216 |
input_gs=aligned_gs_path,
|
| 217 |
output_path=color_path,
|
| 218 |
-
elevation=[
|
| 219 |
-
num_images=
|
| 220 |
)
|
| 221 |
-
|
| 222 |
color_img = Image.open(color_path)
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
logger.warning(
|
| 227 |
f"{GEO_CHECKER.__class__.__name__}: {geo_result} for {seg_path}"
|
| 228 |
)
|
|
@@ -232,8 +209,8 @@ def entrypoint(**kwargs):
|
|
| 232 |
seed = random.randint(0, 100000) if seed is not None else None
|
| 233 |
|
| 234 |
# Render the video for generated 3D asset.
|
| 235 |
-
color_images = render_video(gs_model)["color"]
|
| 236 |
-
normal_images = render_video(mesh_model)["normal"]
|
| 237 |
video_path = os.path.join(output_root, "gs_mesh.mp4")
|
| 238 |
merge_images_video(color_images, normal_images, video_path)
|
| 239 |
|
|
@@ -312,7 +289,7 @@ def entrypoint(**kwargs):
|
|
| 312 |
image_paths = glob(f"{image_dir}/*.png")
|
| 313 |
images_list = []
|
| 314 |
for checker in CHECKERS:
|
| 315 |
-
images = image_paths
|
| 316 |
if isinstance(checker, ImageSegChecker):
|
| 317 |
images = [
|
| 318 |
f"{output_root}/{filename}_raw.png",
|
|
@@ -334,9 +311,12 @@ def entrypoint(**kwargs):
|
|
| 334 |
f"{result_dir}/{urdf_convertor.output_mesh_dir}",
|
| 335 |
)
|
| 336 |
copy(video_path, f"{result_dir}/video.mp4")
|
|
|
|
| 337 |
if not args.keep_intermediate:
|
| 338 |
delete_dir(output_root, keep_subs=["result"])
|
| 339 |
|
|
|
|
|
|
|
| 340 |
except Exception as e:
|
| 341 |
logger.error(f"Failed to process {image_path}: {e}, skip.")
|
| 342 |
continue
|
|
|
|
| 14 |
# implied. See the License for the specific language governing
|
| 15 |
# permissions and limitations under the License.
|
| 16 |
|
|
|
|
| 17 |
import argparse
|
| 18 |
import os
|
| 19 |
import random
|
|
|
|
| 20 |
from glob import glob
|
| 21 |
from shutil import copy, copytree, rmtree
|
| 22 |
|
| 23 |
import numpy as np
|
|
|
|
| 24 |
import trimesh
|
| 25 |
from PIL import Image
|
| 26 |
from embodied_gen.data.backproject_v3 import entrypoint as backproject_api
|
| 27 |
+
from embodied_gen.data.utils import delete_dir
|
| 28 |
|
| 29 |
+
# from embodied_gen.models.sr_model import ImageRealESRGAN
|
| 30 |
# from embodied_gen.models.delight_model import DelightingModel
|
| 31 |
from embodied_gen.models.gs_model import GaussianOperator
|
| 32 |
from embodied_gen.models.segment_model import RembgRemover
|
|
|
|
|
|
|
| 33 |
from embodied_gen.scripts.render_gs import entrypoint as render_gs_api
|
| 34 |
from embodied_gen.utils.gpt_clients import GPT_CLIENT
|
| 35 |
+
from embodied_gen.utils.inference import image3d_model_infer
|
| 36 |
from embodied_gen.utils.log import logger
|
| 37 |
+
from embodied_gen.utils.process_media import (
|
| 38 |
+
combine_images_to_grid,
|
| 39 |
+
merge_images_video,
|
| 40 |
+
)
|
| 41 |
from embodied_gen.utils.tags import VERSION
|
| 42 |
from embodied_gen.utils.trender import render_video
|
| 43 |
from embodied_gen.validators.quality_checkers import (
|
|
|
|
| 48 |
)
|
| 49 |
from embodied_gen.validators.urdf_convertor import URDFGenerator
|
| 50 |
|
| 51 |
+
# random.seed(0)
|
| 52 |
+
IMAGE3D_MODEL = "SAM3D" # TRELLIS or SAM3D
|
| 53 |
+
logger.info(f"Loading {IMAGE3D_MODEL} as Image3D Models...")
|
| 54 |
+
if IMAGE3D_MODEL == "TRELLIS":
|
| 55 |
+
from thirdparty.TRELLIS.trellis.pipelines import TrellisImageTo3DPipeline
|
| 56 |
|
| 57 |
+
PIPELINE = TrellisImageTo3DPipeline.from_pretrained(
|
| 58 |
+
"microsoft/TRELLIS-image-large"
|
| 59 |
+
)
|
| 60 |
+
# PIPELINE.cuda()
|
| 61 |
+
elif IMAGE3D_MODEL == "SAM3D":
|
| 62 |
+
from embodied_gen.models.sam3d import Sam3dInference
|
| 63 |
+
|
| 64 |
+
PIPELINE = Sam3dInference()
|
| 65 |
|
|
|
|
| 66 |
# DELIGHT = DelightingModel()
|
| 67 |
# IMAGESR_MODEL = ImageRealESRGAN(outscale=4)
|
| 68 |
RBG_REMOVER = RembgRemover()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
SEG_CHECKER = ImageSegChecker(GPT_CLIENT)
|
| 70 |
GEO_CHECKER = MeshGeoChecker(GPT_CLIENT)
|
| 71 |
AESTHETIC_CHECKER = ImageAestheticChecker()
|
|
|
|
| 149 |
# Segmentation: Get segmented image using Rembg.
|
| 150 |
seg_path = f"{output_root}/{filename}_cond.png"
|
| 151 |
seg_image = RBG_REMOVER(image) if image.mode != "RGBA" else image
|
|
|
|
| 152 |
seg_image.save(seg_path)
|
| 153 |
|
| 154 |
seed = args.seed
|
|
|
|
| 159 |
logger.info(
|
| 160 |
f"Try: {try_idx + 1}/{args.n_retry}, Seed: {seed}, Prompt: {seg_path}"
|
| 161 |
)
|
|
|
|
| 162 |
try:
|
| 163 |
+
outputs = image3d_model_infer(PIPELINE, seg_image, seed)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 164 |
except Exception as e:
|
| 165 |
logger.error(
|
| 166 |
f"[Pipeline Failed] process {image_path}: {e}, skip."
|
|
|
|
| 193 |
render_gs_api(
|
| 194 |
input_gs=aligned_gs_path,
|
| 195 |
output_path=color_path,
|
| 196 |
+
elevation=[30, -30],
|
| 197 |
+
num_images=4,
|
| 198 |
)
|
|
|
|
| 199 |
color_img = Image.open(color_path)
|
| 200 |
+
geo_flag, geo_result = GEO_CHECKER(
|
| 201 |
+
[color_img], text=asset_node
|
| 202 |
+
)
|
| 203 |
logger.warning(
|
| 204 |
f"{GEO_CHECKER.__class__.__name__}: {geo_result} for {seg_path}"
|
| 205 |
)
|
|
|
|
| 209 |
seed = random.randint(0, 100000) if seed is not None else None
|
| 210 |
|
| 211 |
# Render the video for generated 3D asset.
|
| 212 |
+
color_images = render_video(gs_model, r=1.85)["color"]
|
| 213 |
+
normal_images = render_video(mesh_model, r=1.85)["normal"]
|
| 214 |
video_path = os.path.join(output_root, "gs_mesh.mp4")
|
| 215 |
merge_images_video(color_images, normal_images, video_path)
|
| 216 |
|
|
|
|
| 289 |
image_paths = glob(f"{image_dir}/*.png")
|
| 290 |
images_list = []
|
| 291 |
for checker in CHECKERS:
|
| 292 |
+
images = combine_images_to_grid(image_paths)
|
| 293 |
if isinstance(checker, ImageSegChecker):
|
| 294 |
images = [
|
| 295 |
f"{output_root}/{filename}_raw.png",
|
|
|
|
| 311 |
f"{result_dir}/{urdf_convertor.output_mesh_dir}",
|
| 312 |
)
|
| 313 |
copy(video_path, f"{result_dir}/video.mp4")
|
| 314 |
+
|
| 315 |
if not args.keep_intermediate:
|
| 316 |
delete_dir(output_root, keep_subs=["result"])
|
| 317 |
|
| 318 |
+
logger.info(f"Saved results for {image_path} in {result_dir}")
|
| 319 |
+
|
| 320 |
except Exception as e:
|
| 321 |
logger.error(f"Failed to process {image_path}: {e}, skip.")
|
| 322 |
continue
|
embodied_gen/scripts/render_gs.py
CHANGED
|
@@ -27,7 +27,6 @@ from tqdm import tqdm
|
|
| 27 |
from embodied_gen.data.utils import (
|
| 28 |
CameraSetting,
|
| 29 |
init_kal_camera,
|
| 30 |
-
normalize_vertices_array,
|
| 31 |
)
|
| 32 |
from embodied_gen.models.gs_model import load_gs_model
|
| 33 |
from embodied_gen.utils.process_media import combine_images_to_grid
|
|
|
|
| 27 |
from embodied_gen.data.utils import (
|
| 28 |
CameraSetting,
|
| 29 |
init_kal_camera,
|
|
|
|
| 30 |
)
|
| 31 |
from embodied_gen.models.gs_model import load_gs_model
|
| 32 |
from embodied_gen.utils.process_media import combine_images_to_grid
|
embodied_gen/scripts/textto3d.py
CHANGED
|
@@ -30,6 +30,7 @@ from embodied_gen.utils.gpt_clients import GPT_CLIENT
|
|
| 30 |
from embodied_gen.utils.log import logger
|
| 31 |
from embodied_gen.utils.process_media import (
|
| 32 |
check_object_edge_truncated,
|
|
|
|
| 33 |
render_asset3d,
|
| 34 |
)
|
| 35 |
from embodied_gen.validators.quality_checkers import (
|
|
@@ -51,7 +52,6 @@ BG_REMOVER = RembgRemover()
|
|
| 51 |
|
| 52 |
|
| 53 |
__all__ = [
|
| 54 |
-
"text_to_image",
|
| 55 |
"text_to_3d",
|
| 56 |
]
|
| 57 |
|
|
@@ -176,12 +176,12 @@ def text_to_3d(**kwargs) -> dict:
|
|
| 176 |
image_path = render_asset3d(
|
| 177 |
mesh_path,
|
| 178 |
output_root=f"{node_save_dir}/result",
|
| 179 |
-
num_images=
|
| 180 |
elevation=(30, -30),
|
| 181 |
output_subdir="renders",
|
| 182 |
no_index_file=True,
|
| 183 |
)
|
| 184 |
-
|
| 185 |
check_text = asset_type if asset_type is not None else prompt
|
| 186 |
qa_flag, qa_result = TXTGEN_CHECKER(check_text, image_path)
|
| 187 |
logger.warning(
|
|
|
|
| 30 |
from embodied_gen.utils.log import logger
|
| 31 |
from embodied_gen.utils.process_media import (
|
| 32 |
check_object_edge_truncated,
|
| 33 |
+
combine_images_to_grid,
|
| 34 |
render_asset3d,
|
| 35 |
)
|
| 36 |
from embodied_gen.validators.quality_checkers import (
|
|
|
|
| 52 |
|
| 53 |
|
| 54 |
__all__ = [
|
|
|
|
| 55 |
"text_to_3d",
|
| 56 |
]
|
| 57 |
|
|
|
|
| 176 |
image_path = render_asset3d(
|
| 177 |
mesh_path,
|
| 178 |
output_root=f"{node_save_dir}/result",
|
| 179 |
+
num_images=4,
|
| 180 |
elevation=(30, -30),
|
| 181 |
output_subdir="renders",
|
| 182 |
no_index_file=True,
|
| 183 |
)
|
| 184 |
+
image_path = combine_images_to_grid(image_path)
|
| 185 |
check_text = asset_type if asset_type is not None else prompt
|
| 186 |
qa_flag, qa_result = TXTGEN_CHECKER(check_text, image_path)
|
| 187 |
logger.warning(
|
embodied_gen/utils/gpt_clients.py
CHANGED
|
@@ -21,13 +21,14 @@ import os
|
|
| 21 |
from io import BytesIO
|
| 22 |
from typing import Optional
|
| 23 |
|
|
|
|
| 24 |
import yaml
|
| 25 |
from openai import AzureOpenAI, OpenAI # pip install openai
|
| 26 |
from PIL import Image
|
| 27 |
from tenacity import (
|
| 28 |
retry,
|
|
|
|
| 29 |
stop_after_attempt,
|
| 30 |
-
stop_after_delay,
|
| 31 |
wait_random_exponential,
|
| 32 |
)
|
| 33 |
from embodied_gen.utils.process_media import combine_images_to_grid
|
|
@@ -106,8 +107,9 @@ class GPTclient:
|
|
| 106 |
logger.info(f"Using GPT model: {self.model_name}.")
|
| 107 |
|
| 108 |
@retry(
|
| 109 |
-
|
| 110 |
-
|
|
|
|
| 111 |
)
|
| 112 |
def completion_with_backoff(self, **kwargs):
|
| 113 |
"""Performs a chat completion request with retry/backoff."""
|
|
@@ -246,3 +248,8 @@ GPT_CLIENT = GPTclient(
|
|
| 246 |
model_name=model_name,
|
| 247 |
check_connection=False,
|
| 248 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
from io import BytesIO
|
| 22 |
from typing import Optional
|
| 23 |
|
| 24 |
+
import openai
|
| 25 |
import yaml
|
| 26 |
from openai import AzureOpenAI, OpenAI # pip install openai
|
| 27 |
from PIL import Image
|
| 28 |
from tenacity import (
|
| 29 |
retry,
|
| 30 |
+
retry_if_not_exception_type,
|
| 31 |
stop_after_attempt,
|
|
|
|
| 32 |
wait_random_exponential,
|
| 33 |
)
|
| 34 |
from embodied_gen.utils.process_media import combine_images_to_grid
|
|
|
|
| 107 |
logger.info(f"Using GPT model: {self.model_name}.")
|
| 108 |
|
| 109 |
@retry(
|
| 110 |
+
retry=retry_if_not_exception_type(openai.BadRequestError),
|
| 111 |
+
wait=wait_random_exponential(min=1, max=10),
|
| 112 |
+
stop=stop_after_attempt(5),
|
| 113 |
)
|
| 114 |
def completion_with_backoff(self, **kwargs):
|
| 115 |
"""Performs a chat completion request with retry/backoff."""
|
|
|
|
| 248 |
model_name=model_name,
|
| 249 |
check_connection=False,
|
| 250 |
)
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
if __name__ == "__main__":
|
| 254 |
+
response = GPT_CLIENT.query("What is the capital of China?")
|
| 255 |
+
print(f"Response: {response}")
|
embodied_gen/utils/inference.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from embodied_gen.utils.monkey_patches import monkey_path_trellis
|
| 2 |
+
|
| 3 |
+
monkey_path_trellis()
|
| 4 |
+
import random
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from PIL import Image
|
| 8 |
+
from embodied_gen.data.utils import trellis_preprocess
|
| 9 |
+
from embodied_gen.models.sam3d import Sam3dInference
|
| 10 |
+
from embodied_gen.utils.trender import pack_state, unpack_state
|
| 11 |
+
from thirdparty.TRELLIS.trellis.pipelines import TrellisImageTo3DPipeline
|
| 12 |
+
|
| 13 |
+
__all__ = [
|
| 14 |
+
"image3d_model_infer",
|
| 15 |
+
]
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def image3d_model_infer(
|
| 19 |
+
pipe: TrellisImageTo3DPipeline | Sam3dInference,
|
| 20 |
+
seg_image: Image.Image,
|
| 21 |
+
seed: int = None,
|
| 22 |
+
**kwargs: dict,
|
| 23 |
+
) -> dict[str, any]:
|
| 24 |
+
if isinstance(pipe, TrellisImageTo3DPipeline):
|
| 25 |
+
pipe.cuda()
|
| 26 |
+
seg_image = trellis_preprocess(seg_image)
|
| 27 |
+
outputs = pipe.run(
|
| 28 |
+
seg_image,
|
| 29 |
+
preprocess_image=False,
|
| 30 |
+
seed=(random.randint(0, 100000) if seed is None else seed),
|
| 31 |
+
# Optional parameters
|
| 32 |
+
# sparse_structure_sampler_params={
|
| 33 |
+
# "steps": 12,
|
| 34 |
+
# "cfg_strength": 7.5,
|
| 35 |
+
# },
|
| 36 |
+
# slat_sampler_params={
|
| 37 |
+
# "steps": 12,
|
| 38 |
+
# "cfg_strength": 3,
|
| 39 |
+
# },
|
| 40 |
+
**kwargs,
|
| 41 |
+
)
|
| 42 |
+
pipe.cpu()
|
| 43 |
+
elif isinstance(pipe, Sam3dInference):
|
| 44 |
+
outputs = pipe.run(
|
| 45 |
+
seg_image,
|
| 46 |
+
seed=(random.randint(0, 100000) if seed is None else seed),
|
| 47 |
+
# stage1_inference_steps=25,
|
| 48 |
+
# stage2_inference_steps=25,
|
| 49 |
+
**kwargs,
|
| 50 |
+
)
|
| 51 |
+
state = pack_state(outputs["gaussian"][0], outputs["mesh"][0])
|
| 52 |
+
# Align GS3D from SAM3D with TRELLIS format.
|
| 53 |
+
outputs["gaussian"][0], _ = unpack_state(state, device="cuda")
|
| 54 |
+
else:
|
| 55 |
+
raise ValueError(f"Unsupported pipeline type: {type(pipe)}")
|
| 56 |
+
|
| 57 |
+
torch.cuda.empty_cache()
|
| 58 |
+
|
| 59 |
+
return outputs
|
embodied_gen/utils/monkey_patches.py
CHANGED
|
@@ -32,6 +32,67 @@ __all__ = [
|
|
| 32 |
]
|
| 33 |
|
| 34 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
def monkey_patch_pano2room():
|
| 36 |
current_file_path = os.path.abspath(__file__)
|
| 37 |
current_dir = os.path.dirname(current_file_path)
|
|
@@ -240,8 +301,6 @@ def monkey_patch_sam3d():
|
|
| 240 |
if sam3d_root not in sys.path:
|
| 241 |
sys.path.insert(0, sam3d_root)
|
| 242 |
|
| 243 |
-
print(f"[MonkeyPatch] Added to sys.path: {sam3d_root}")
|
| 244 |
-
|
| 245 |
def patch_pointmap_infer_pipeline():
|
| 246 |
from copy import deepcopy
|
| 247 |
|
|
@@ -317,9 +376,6 @@ def monkey_patch_sam3d():
|
|
| 317 |
)
|
| 318 |
)
|
| 319 |
|
| 320 |
-
logger.info(
|
| 321 |
-
f"Rescaling scale by {ss_return_dict['downsample_factor']} after downsampling"
|
| 322 |
-
)
|
| 323 |
ss_return_dict["scale"] = (
|
| 324 |
ss_return_dict["scale"]
|
| 325 |
* ss_return_dict["downsample_factor"]
|
|
@@ -471,11 +527,6 @@ def monkey_patch_sam3d():
|
|
| 471 |
self.rendering_engine = rendering_engine
|
| 472 |
self.device = torch.device(device)
|
| 473 |
self.compile_model = compile_model
|
| 474 |
-
logger.info(f"self.device: {self.device}")
|
| 475 |
-
logger.info(
|
| 476 |
-
f"CUDA_VISIBLE_DEVICES: {os.environ.get('CUDA_VISIBLE_DEVICES', None)}"
|
| 477 |
-
)
|
| 478 |
-
logger.info(f"Actually using GPU: {torch.cuda.current_device()}")
|
| 479 |
with self.device:
|
| 480 |
self.decode_formats = decode_formats
|
| 481 |
self.pad_size = pad_size
|
|
@@ -511,7 +562,6 @@ def monkey_patch_sam3d():
|
|
| 511 |
)
|
| 512 |
self.slat_preprocessor = slat_preprocessor
|
| 513 |
|
| 514 |
-
logger.info("Loading model weights...")
|
| 515 |
raw_device = self.device
|
| 516 |
self.device = torch.device("cpu")
|
| 517 |
ss_generator = self.init_ss_generator(
|
|
@@ -578,7 +628,7 @@ def monkey_patch_sam3d():
|
|
| 578 |
"slat_decoder_mesh": slat_decoder_mesh,
|
| 579 |
}
|
| 580 |
)
|
| 581 |
-
logger.info("Loading model weights completed
|
| 582 |
|
| 583 |
if self.compile_model:
|
| 584 |
logger.info("Compiling model...")
|
|
|
|
| 32 |
]
|
| 33 |
|
| 34 |
|
| 35 |
+
def monkey_path_trellis():
|
| 36 |
+
import torch.nn.functional as F
|
| 37 |
+
|
| 38 |
+
current_file_path = os.path.abspath(__file__)
|
| 39 |
+
current_dir = os.path.dirname(current_file_path)
|
| 40 |
+
sys.path.append(os.path.join(current_dir, "../.."))
|
| 41 |
+
|
| 42 |
+
from thirdparty.TRELLIS.trellis.representations import Gaussian
|
| 43 |
+
from thirdparty.TRELLIS.trellis.representations.gaussian.general_utils import (
|
| 44 |
+
build_scaling_rotation,
|
| 45 |
+
inverse_sigmoid,
|
| 46 |
+
strip_symmetric,
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
os.environ["TORCH_EXTENSIONS_DIR"] = os.path.expanduser(
|
| 50 |
+
"~/.cache/torch_extensions"
|
| 51 |
+
)
|
| 52 |
+
os.environ["SPCONV_ALGO"] = "auto" # Can be 'native' or 'auto'
|
| 53 |
+
os.environ['ATTN_BACKEND'] = (
|
| 54 |
+
"xformers" # Can be 'flash-attn' or 'xformers'
|
| 55 |
+
)
|
| 56 |
+
from thirdparty.TRELLIS.trellis.modules.sparse import set_attn
|
| 57 |
+
|
| 58 |
+
set_attn("xformers")
|
| 59 |
+
|
| 60 |
+
def patched_setup_functions(self):
|
| 61 |
+
def inverse_softplus(x):
|
| 62 |
+
return x + torch.log(-torch.expm1(-x))
|
| 63 |
+
|
| 64 |
+
def build_covariance_from_scaling_rotation(
|
| 65 |
+
scaling, scaling_modifier, rotation
|
| 66 |
+
):
|
| 67 |
+
L = build_scaling_rotation(scaling_modifier * scaling, rotation)
|
| 68 |
+
actual_covariance = L @ L.transpose(1, 2)
|
| 69 |
+
symm = strip_symmetric(actual_covariance)
|
| 70 |
+
return symm
|
| 71 |
+
|
| 72 |
+
if self.scaling_activation_type == "exp":
|
| 73 |
+
self.scaling_activation = torch.exp
|
| 74 |
+
self.inverse_scaling_activation = torch.log
|
| 75 |
+
elif self.scaling_activation_type == "softplus":
|
| 76 |
+
self.scaling_activation = F.softplus
|
| 77 |
+
self.inverse_scaling_activation = inverse_softplus
|
| 78 |
+
|
| 79 |
+
self.covariance_activation = build_covariance_from_scaling_rotation
|
| 80 |
+
self.opacity_activation = torch.sigmoid
|
| 81 |
+
self.inverse_opacity_activation = inverse_sigmoid
|
| 82 |
+
self.rotation_activation = F.normalize
|
| 83 |
+
|
| 84 |
+
self.scale_bias = self.inverse_scaling_activation(
|
| 85 |
+
torch.tensor(self.scaling_bias)
|
| 86 |
+
).to(self.device)
|
| 87 |
+
self.rots_bias = torch.zeros((4)).to(self.device)
|
| 88 |
+
self.rots_bias[0] = 1
|
| 89 |
+
self.opacity_bias = self.inverse_opacity_activation(
|
| 90 |
+
torch.tensor(self.opacity_bias)
|
| 91 |
+
).to(self.device)
|
| 92 |
+
|
| 93 |
+
Gaussian.setup_functions = patched_setup_functions
|
| 94 |
+
|
| 95 |
+
|
| 96 |
def monkey_patch_pano2room():
|
| 97 |
current_file_path = os.path.abspath(__file__)
|
| 98 |
current_dir = os.path.dirname(current_file_path)
|
|
|
|
| 301 |
if sam3d_root not in sys.path:
|
| 302 |
sys.path.insert(0, sam3d_root)
|
| 303 |
|
|
|
|
|
|
|
| 304 |
def patch_pointmap_infer_pipeline():
|
| 305 |
from copy import deepcopy
|
| 306 |
|
|
|
|
| 376 |
)
|
| 377 |
)
|
| 378 |
|
|
|
|
|
|
|
|
|
|
| 379 |
ss_return_dict["scale"] = (
|
| 380 |
ss_return_dict["scale"]
|
| 381 |
* ss_return_dict["downsample_factor"]
|
|
|
|
| 527 |
self.rendering_engine = rendering_engine
|
| 528 |
self.device = torch.device(device)
|
| 529 |
self.compile_model = compile_model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 530 |
with self.device:
|
| 531 |
self.decode_formats = decode_formats
|
| 532 |
self.pad_size = pad_size
|
|
|
|
| 562 |
)
|
| 563 |
self.slat_preprocessor = slat_preprocessor
|
| 564 |
|
|
|
|
| 565 |
raw_device = self.device
|
| 566 |
self.device = torch.device("cpu")
|
| 567 |
ss_generator = self.init_ss_generator(
|
|
|
|
| 628 |
"slat_decoder_mesh": slat_decoder_mesh,
|
| 629 |
}
|
| 630 |
)
|
| 631 |
+
logger.info("Loading SAM3D model weights completed.")
|
| 632 |
|
| 633 |
if self.compile_model:
|
| 634 |
logger.info("Compiling model...")
|
embodied_gen/utils/process_media.py
CHANGED
|
@@ -96,7 +96,7 @@ def render_asset3d(
|
|
| 96 |
image_paths = render_asset3d(
|
| 97 |
mesh_path="path_to_mesh.obj",
|
| 98 |
output_root="path_to_save_dir",
|
| 99 |
-
num_images=
|
| 100 |
elevation=(30, -30),
|
| 101 |
output_subdir="renders",
|
| 102 |
no_index_file=True,
|
|
|
|
| 96 |
image_paths = render_asset3d(
|
| 97 |
mesh_path="path_to_mesh.obj",
|
| 98 |
output_root="path_to_save_dir",
|
| 99 |
+
num_images=4,
|
| 100 |
elevation=(30, -30),
|
| 101 |
output_subdir="renders",
|
| 102 |
no_index_file=True,
|
embodied_gen/utils/tags.py
CHANGED
|
@@ -1 +1 @@
|
|
| 1 |
-
VERSION = "v0.1.
|
|
|
|
| 1 |
+
VERSION = "v0.1.7"
|
embodied_gen/utils/trender.py
CHANGED
|
@@ -21,18 +21,25 @@ from collections import defaultdict
|
|
| 21 |
import numpy as np
|
| 22 |
import spaces
|
| 23 |
import torch
|
|
|
|
| 24 |
from tqdm import tqdm
|
| 25 |
|
| 26 |
current_file_path = os.path.abspath(__file__)
|
| 27 |
current_dir = os.path.dirname(current_file_path)
|
| 28 |
sys.path.append(os.path.join(current_dir, "../.."))
|
| 29 |
from thirdparty.TRELLIS.trellis.renderers import GaussianRenderer, MeshRenderer
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
from thirdparty.TRELLIS.trellis.utils.render_utils import (
|
| 31 |
yaw_pitch_r_fov_to_extrinsics_intrinsics,
|
| 32 |
)
|
| 33 |
|
| 34 |
__all__ = [
|
| 35 |
"render_video",
|
|
|
|
|
|
|
| 36 |
]
|
| 37 |
|
| 38 |
|
|
@@ -140,3 +147,47 @@ def render_video(
|
|
| 140 |
)
|
| 141 |
|
| 142 |
return result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
import numpy as np
|
| 22 |
import spaces
|
| 23 |
import torch
|
| 24 |
+
from easydict import EasyDict as edict
|
| 25 |
from tqdm import tqdm
|
| 26 |
|
| 27 |
current_file_path = os.path.abspath(__file__)
|
| 28 |
current_dir = os.path.dirname(current_file_path)
|
| 29 |
sys.path.append(os.path.join(current_dir, "../.."))
|
| 30 |
from thirdparty.TRELLIS.trellis.renderers import GaussianRenderer, MeshRenderer
|
| 31 |
+
from thirdparty.TRELLIS.trellis.representations import (
|
| 32 |
+
Gaussian,
|
| 33 |
+
MeshExtractResult,
|
| 34 |
+
)
|
| 35 |
from thirdparty.TRELLIS.trellis.utils.render_utils import (
|
| 36 |
yaw_pitch_r_fov_to_extrinsics_intrinsics,
|
| 37 |
)
|
| 38 |
|
| 39 |
__all__ = [
|
| 40 |
"render_video",
|
| 41 |
+
"pack_state",
|
| 42 |
+
"unpack_state",
|
| 43 |
]
|
| 44 |
|
| 45 |
|
|
|
|
| 147 |
)
|
| 148 |
|
| 149 |
return result
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
@spaces.GPU
|
| 153 |
+
def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
|
| 154 |
+
return {
|
| 155 |
+
"gaussian": {
|
| 156 |
+
**gs.init_params,
|
| 157 |
+
"_xyz": gs._xyz.cpu().numpy(),
|
| 158 |
+
"_features_dc": gs._features_dc.cpu().numpy(),
|
| 159 |
+
"_scaling": gs._scaling.cpu().numpy(),
|
| 160 |
+
"_rotation": gs._rotation.cpu().numpy(),
|
| 161 |
+
"_opacity": gs._opacity.cpu().numpy(),
|
| 162 |
+
},
|
| 163 |
+
"mesh": {
|
| 164 |
+
"vertices": mesh.vertices.cpu().numpy(),
|
| 165 |
+
"faces": mesh.faces.cpu().numpy(),
|
| 166 |
+
},
|
| 167 |
+
}
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def unpack_state(state: dict, device: str = "cpu") -> tuple[Gaussian, dict]:
|
| 171 |
+
gs = Gaussian(
|
| 172 |
+
aabb=state["gaussian"]["aabb"],
|
| 173 |
+
sh_degree=state["gaussian"]["sh_degree"],
|
| 174 |
+
mininum_kernel_size=state["gaussian"]["mininum_kernel_size"],
|
| 175 |
+
scaling_bias=state["gaussian"]["scaling_bias"],
|
| 176 |
+
opacity_bias=state["gaussian"]["opacity_bias"],
|
| 177 |
+
scaling_activation=state["gaussian"]["scaling_activation"],
|
| 178 |
+
device=device,
|
| 179 |
+
)
|
| 180 |
+
gs._xyz = torch.tensor(state["gaussian"]["_xyz"], device=device)
|
| 181 |
+
gs._features_dc = torch.tensor(
|
| 182 |
+
state["gaussian"]["_features_dc"], device=device
|
| 183 |
+
)
|
| 184 |
+
gs._scaling = torch.tensor(state["gaussian"]["_scaling"], device=device)
|
| 185 |
+
gs._rotation = torch.tensor(state["gaussian"]["_rotation"], device=device)
|
| 186 |
+
gs._opacity = torch.tensor(state["gaussian"]["_opacity"], device=device)
|
| 187 |
+
|
| 188 |
+
mesh = edict(
|
| 189 |
+
vertices=torch.tensor(state["mesh"]["vertices"], device=device),
|
| 190 |
+
faces=torch.tensor(state["mesh"]["faces"], device=device),
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
return gs, mesh
|
embodied_gen/validators/aesthetic_predictor.py
CHANGED
|
@@ -125,7 +125,11 @@ class AestheticPredictor:
|
|
| 125 |
Returns:
|
| 126 |
float: Predicted aesthetic score.
|
| 127 |
"""
|
| 128 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
image = self.preprocess(pil_image).unsqueeze(0).to(self.device)
|
| 130 |
|
| 131 |
with torch.no_grad():
|
|
|
|
| 125 |
Returns:
|
| 126 |
float: Predicted aesthetic score.
|
| 127 |
"""
|
| 128 |
+
if isinstance(image_path, str):
|
| 129 |
+
pil_image = Image.open(image_path)
|
| 130 |
+
else:
|
| 131 |
+
pil_image = image_path
|
| 132 |
+
|
| 133 |
image = self.preprocess(pil_image).unsqueeze(0).to(self.device)
|
| 134 |
|
| 135 |
with torch.no_grad():
|
embodied_gen/validators/quality_checkers.py
CHANGED
|
@@ -126,6 +126,30 @@ class MeshGeoChecker(BaseChecker):
|
|
| 126 |
super().__init__(prompt, verbose)
|
| 127 |
self.gpt_client = gpt_client
|
| 128 |
if self.prompt is None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
self.prompt = """
|
| 130 |
You are an expert in evaluating the geometry quality of generated 3D asset.
|
| 131 |
You will be given rendered views of a generated 3D asset, type {}, with black background.
|
|
@@ -137,16 +161,13 @@ class MeshGeoChecker(BaseChecker):
|
|
| 137 |
- Minor flaws, asymmetries, or simplifications (e.g., less detail on sides or back,
|
| 138 |
soft edges) are acceptable if the object is structurally sound and recognizable.
|
| 139 |
- Only evaluate geometry. Do not assess texture quality.
|
| 140 |
-
- The asset should not contain any unrelated elements, such as
|
| 141 |
-
ground planes, platforms, or background props (e.g., paper, flooring).
|
| 142 |
|
| 143 |
-
If all the above criteria are met, return "YES". Otherwise, return
|
| 144 |
"NO" followed by a brief explanation (no more than 20 words).
|
| 145 |
|
| 146 |
Example:
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
Image shows a chair with simplified back legs and soft edges → YES
|
| 150 |
"""
|
| 151 |
|
| 152 |
def query(
|
|
|
|
| 126 |
super().__init__(prompt, verbose)
|
| 127 |
self.gpt_client = gpt_client
|
| 128 |
if self.prompt is None:
|
| 129 |
+
# Old version for TRELLIS.
|
| 130 |
+
# self.prompt = """
|
| 131 |
+
# You are an expert in evaluating the geometry quality of generated 3D asset.
|
| 132 |
+
# You will be given rendered views of a generated 3D asset, type {}, with black background.
|
| 133 |
+
# Your task is to evaluate the quality of the 3D asset generation,
|
| 134 |
+
# including geometry, structure, and appearance, based on the rendered views.
|
| 135 |
+
# Criteria:
|
| 136 |
+
# - Is the object in the image a single, complete, and well-formed instance,
|
| 137 |
+
# without truncation, missing parts, overlapping duplicates, or redundant geometry?
|
| 138 |
+
# - Minor flaws, asymmetries, or simplifications (e.g., less detail on sides or back,
|
| 139 |
+
# soft edges) are acceptable if the object is structurally sound and recognizable.
|
| 140 |
+
# - Only evaluate geometry. Do not assess texture quality.
|
| 141 |
+
# - The asset should not contain any unrelated elements, such as
|
| 142 |
+
# ground planes, platforms, or background props (e.g., paper, flooring).
|
| 143 |
+
|
| 144 |
+
# If all the above criteria are met, return "YES". Otherwise, return
|
| 145 |
+
# "NO" followed by a brief explanation (no more than 20 words).
|
| 146 |
+
|
| 147 |
+
# Example:
|
| 148 |
+
# Images show a yellow cup standing on a flat white plane -> NO
|
| 149 |
+
# -> Response: NO: extra white surface under the object.
|
| 150 |
+
# Image shows a chair with simplified back legs and soft edges -> YES
|
| 151 |
+
# """
|
| 152 |
+
|
| 153 |
self.prompt = """
|
| 154 |
You are an expert in evaluating the geometry quality of generated 3D asset.
|
| 155 |
You will be given rendered views of a generated 3D asset, type {}, with black background.
|
|
|
|
| 161 |
- Minor flaws, asymmetries, or simplifications (e.g., less detail on sides or back,
|
| 162 |
soft edges) are acceptable if the object is structurally sound and recognizable.
|
| 163 |
- Only evaluate geometry. Do not assess texture quality.
|
|
|
|
|
|
|
| 164 |
|
| 165 |
+
If all the above criteria are met, return "YES" only. Otherwise, return
|
| 166 |
"NO" followed by a brief explanation (no more than 20 words).
|
| 167 |
|
| 168 |
Example:
|
| 169 |
+
Image shows a chair with one leg missing -> NO: the chair missing leg.
|
| 170 |
+
Image shows a geometrically complete cup -> YES
|
|
|
|
| 171 |
"""
|
| 172 |
|
| 173 |
def query(
|
embodied_gen/validators/urdf_convertor.py
CHANGED
|
@@ -27,7 +27,10 @@ import trimesh
|
|
| 27 |
from scipy.spatial.transform import Rotation
|
| 28 |
from embodied_gen.data.convex_decomposer import decompose_convex_mesh
|
| 29 |
from embodied_gen.utils.gpt_clients import GPT_CLIENT, GPTclient
|
| 30 |
-
from embodied_gen.utils.process_media import
|
|
|
|
|
|
|
|
|
|
| 31 |
from embodied_gen.utils.tags import VERSION
|
| 32 |
|
| 33 |
logging.basicConfig(level=logging.INFO)
|
|
@@ -482,7 +485,7 @@ class URDFGenerator(object):
|
|
| 482 |
output_subdir=self.output_render_dir,
|
| 483 |
no_index_file=True,
|
| 484 |
)
|
| 485 |
-
|
| 486 |
response = self.gpt_client.query(text_prompt, image_path)
|
| 487 |
# logger.info(response)
|
| 488 |
if response is None:
|
|
|
|
| 27 |
from scipy.spatial.transform import Rotation
|
| 28 |
from embodied_gen.data.convex_decomposer import decompose_convex_mesh
|
| 29 |
from embodied_gen.utils.gpt_clients import GPT_CLIENT, GPTclient
|
| 30 |
+
from embodied_gen.utils.process_media import (
|
| 31 |
+
combine_images_to_grid,
|
| 32 |
+
render_asset3d,
|
| 33 |
+
)
|
| 34 |
from embodied_gen.utils.tags import VERSION
|
| 35 |
|
| 36 |
logging.basicConfig(level=logging.INFO)
|
|
|
|
| 485 |
output_subdir=self.output_render_dir,
|
| 486 |
no_index_file=True,
|
| 487 |
)
|
| 488 |
+
# image_path = combine_images_to_grid(image_path)
|
| 489 |
response = self.gpt_client.query(text_prompt, image_path)
|
| 490 |
# logger.info(response)
|
| 491 |
if response is None:
|