xinjie.wang commited on
Commit
c948747
·
1 Parent(s): 7f124e2
app.py CHANGED
@@ -17,8 +17,6 @@
17
 
18
  import os
19
 
20
- os.environ["GRADIO_APP"] = "textto3d"
21
-
22
  # GRADIO_APP == "textto3d_sam3d", sam3d object model, by default.
23
  # GRADIO_APP == "textto3d", TRELLIS model.
24
  os.environ["GRADIO_APP"] = "textto3d_sam3d"
@@ -35,19 +33,16 @@ from common import (
35
  get_cached_image,
36
  get_seed,
37
  get_selected_image,
 
38
  start_session,
39
  text2image_fn,
40
  )
41
 
42
  app_name = os.getenv("GRADIO_APP")
43
  if app_name == "textto3d_sam3d":
44
- from common import image_to_3d_sam3d as image_to_3d
45
-
46
  enable_pre_resize = False
47
  sample_step = 25
48
  elif app_name == "textto3d":
49
- from common import image_to_3d
50
-
51
  enable_pre_resize = True
52
  sample_step = 12
53
 
 
17
 
18
  import os
19
 
 
 
20
  # GRADIO_APP == "textto3d_sam3d", sam3d object model, by default.
21
  # GRADIO_APP == "textto3d", TRELLIS model.
22
  os.environ["GRADIO_APP"] = "textto3d_sam3d"
 
33
  get_cached_image,
34
  get_seed,
35
  get_selected_image,
36
+ image_to_3d,
37
  start_session,
38
  text2image_fn,
39
  )
40
 
41
  app_name = os.getenv("GRADIO_APP")
42
  if app_name == "textto3d_sam3d":
 
 
43
  enable_pre_resize = False
44
  sample_step = 25
45
  elif app_name == "textto3d":
 
 
46
  enable_pre_resize = True
47
  sample_step = 12
48
 
common.py CHANGED
@@ -14,6 +14,10 @@
14
  # implied. See the License for the specific language governing
15
  # permissions and limitations under the License.
16
 
 
 
 
 
17
  import gc
18
  import logging
19
  import os
@@ -27,9 +31,7 @@ import gradio as gr
27
  import numpy as np
28
  import spaces
29
  import torch
30
- import torch.nn.functional as F
31
  import trimesh
32
- from easydict import EasyDict as edict
33
  from PIL import Image
34
  from embodied_gen.data.backproject_v2 import entrypoint as backproject_api
35
  from embodied_gen.data.backproject_v3 import entrypoint as backproject_api_v3
@@ -37,6 +39,7 @@ from embodied_gen.data.differentiable_render import entrypoint as render_api
37
  from embodied_gen.data.utils import trellis_preprocess, zip_files
38
  from embodied_gen.models.delight_model import DelightingModel
39
  from embodied_gen.models.gs_model import GaussianOperator
 
40
  from embodied_gen.models.segment_model import (
41
  BMGG14Remover,
42
  RembgRemover,
@@ -57,7 +60,7 @@ from embodied_gen.utils.process_media import (
57
  merge_images_video,
58
  )
59
  from embodied_gen.utils.tags import VERSION
60
- from embodied_gen.utils.trender import render_video
61
  from embodied_gen.validators.quality_checkers import (
62
  BaseChecker,
63
  ImageAestheticChecker,
@@ -70,15 +73,6 @@ current_file_path = os.path.abspath(__file__)
70
  current_dir = os.path.dirname(current_file_path)
71
  sys.path.append(os.path.join(current_dir, ".."))
72
  from thirdparty.TRELLIS.trellis.pipelines import TrellisImageTo3DPipeline
73
- from thirdparty.TRELLIS.trellis.representations import (
74
- Gaussian,
75
- MeshExtractResult,
76
- )
77
- from thirdparty.TRELLIS.trellis.representations.gaussian.general_utils import (
78
- build_scaling_rotation,
79
- inverse_sigmoid,
80
- strip_symmetric,
81
- )
82
  from thirdparty.TRELLIS.trellis.utils import postprocessing_utils
83
 
84
  logging.basicConfig(
@@ -86,79 +80,24 @@ logging.basicConfig(
86
  )
87
  logger = logging.getLogger(__name__)
88
 
89
-
90
- os.environ["TORCH_EXTENSIONS_DIR"] = os.path.expanduser(
91
- "~/.cache/torch_extensions"
92
- )
93
  os.environ["GRADIO_ANALYTICS_ENABLED"] = "false"
94
- os.environ["SPCONV_ALGO"] = "native"
95
 
96
  MAX_SEED = 100000
97
 
98
-
99
- def patched_setup_functions(self):
100
- def inverse_softplus(x):
101
- return x + torch.log(-torch.expm1(-x))
102
-
103
- def build_covariance_from_scaling_rotation(
104
- scaling, scaling_modifier, rotation
105
- ):
106
- L = build_scaling_rotation(scaling_modifier * scaling, rotation)
107
- actual_covariance = L @ L.transpose(1, 2)
108
- symm = strip_symmetric(actual_covariance)
109
- return symm
110
-
111
- if self.scaling_activation_type == "exp":
112
- self.scaling_activation = torch.exp
113
- self.inverse_scaling_activation = torch.log
114
- elif self.scaling_activation_type == "softplus":
115
- self.scaling_activation = F.softplus
116
- self.inverse_scaling_activation = inverse_softplus
117
-
118
- self.covariance_activation = build_covariance_from_scaling_rotation
119
- self.opacity_activation = torch.sigmoid
120
- self.inverse_opacity_activation = inverse_sigmoid
121
- self.rotation_activation = F.normalize
122
-
123
- self.scale_bias = self.inverse_scaling_activation(
124
- torch.tensor(self.scaling_bias)
125
- ).to(self.device)
126
- self.rots_bias = torch.zeros((4)).to(self.device)
127
- self.rots_bias[0] = 1
128
- self.opacity_bias = self.inverse_opacity_activation(
129
- torch.tensor(self.opacity_bias)
130
- ).to(self.device)
131
-
132
-
133
- Gaussian.setup_functions = patched_setup_functions
134
-
135
-
136
  # DELIGHT = DelightingModel()
137
  # IMAGESR_MODEL = ImageRealESRGAN(outscale=4)
138
  # IMAGESR_MODEL = ImageStableSR()
139
- if os.getenv("GRADIO_APP") == "imageto3d":
140
- RBG_REMOVER = RembgRemover()
141
- RBG14_REMOVER = BMGG14Remover()
142
- SAM_PREDICTOR = SAMPredictor(model_type="vit_h", device="cpu")
143
- PIPELINE = TrellisImageTo3DPipeline.from_pretrained(
144
- "microsoft/TRELLIS-image-large"
145
- )
146
- # PIPELINE.cuda()
147
- SEG_CHECKER = ImageSegChecker(GPT_CLIENT)
148
- GEO_CHECKER = MeshGeoChecker(GPT_CLIENT)
149
- AESTHETIC_CHECKER = ImageAestheticChecker()
150
- CHECKERS = [GEO_CHECKER, SEG_CHECKER, AESTHETIC_CHECKER]
151
- TMP_DIR = os.path.join(
152
- os.path.dirname(os.path.abspath(__file__)), "sessions/imageto3d"
153
- )
154
- os.makedirs(TMP_DIR, exist_ok=True)
155
- elif os.getenv("GRADIO_APP") == "imageto3d_sam3d":
156
- from embodied_gen.models.sam3d import Sam3dInference
157
-
158
  RBG_REMOVER = RembgRemover()
159
  RBG14_REMOVER = BMGG14Remover()
160
  SAM_PREDICTOR = SAMPredictor(model_type="vit_h", device="cpu")
161
- PIPELINE = Sam3dInference()
 
 
 
 
 
 
162
  SEG_CHECKER = ImageSegChecker(GPT_CLIENT)
163
  GEO_CHECKER = MeshGeoChecker(GPT_CLIENT)
164
  AESTHETIC_CHECKER = ImageAestheticChecker()
@@ -167,30 +106,16 @@ elif os.getenv("GRADIO_APP") == "imageto3d_sam3d":
167
  os.path.dirname(os.path.abspath(__file__)), "sessions/imageto3d"
168
  )
169
  os.makedirs(TMP_DIR, exist_ok=True)
170
- elif os.getenv("GRADIO_APP") == "textto3d":
171
  RBG_REMOVER = RembgRemover()
172
  RBG14_REMOVER = BMGG14Remover()
173
- PIPELINE = TrellisImageTo3DPipeline.from_pretrained(
174
- "microsoft/TRELLIS-image-large"
175
- )
176
- # PIPELINE.cuda()
177
- text_model_dir = "weights/Kolors"
178
- PIPELINE_IMG_IP = build_text2img_ip_pipeline(text_model_dir, ref_scale=0.3)
179
- PIPELINE_IMG = build_text2img_pipeline(text_model_dir)
180
- SEG_CHECKER = ImageSegChecker(GPT_CLIENT)
181
- GEO_CHECKER = MeshGeoChecker(GPT_CLIENT)
182
- AESTHETIC_CHECKER = ImageAestheticChecker()
183
- CHECKERS = [GEO_CHECKER, SEG_CHECKER, AESTHETIC_CHECKER]
184
- TMP_DIR = os.path.join(
185
- os.path.dirname(os.path.abspath(__file__)), "sessions/textto3d"
186
- )
187
- os.makedirs(TMP_DIR, exist_ok=True)
188
- elif os.getenv("GRADIO_APP") == "textto3d_sam3d":
189
- from embodied_gen.models.sam3d import Sam3dInference
190
-
191
- RBG_REMOVER = RembgRemover()
192
- RBG14_REMOVER = BMGG14Remover()
193
- PIPELINE = Sam3dInference()
194
  text_model_dir = "weights/Kolors"
195
  PIPELINE_IMG_IP = build_text2img_ip_pipeline(text_model_dir, ref_scale=0.3)
196
  PIPELINE_IMG = build_text2img_pipeline(text_model_dir)
@@ -302,50 +227,6 @@ def get_cached_image(image_path: str) -> Image.Image:
302
  return Image.open(image_path).resize((512, 512))
303
 
304
 
305
- @spaces.GPU
306
- def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
307
- return {
308
- "gaussian": {
309
- **gs.init_params,
310
- "_xyz": gs._xyz.cpu().numpy(),
311
- "_features_dc": gs._features_dc.cpu().numpy(),
312
- "_scaling": gs._scaling.cpu().numpy(),
313
- "_rotation": gs._rotation.cpu().numpy(),
314
- "_opacity": gs._opacity.cpu().numpy(),
315
- },
316
- "mesh": {
317
- "vertices": mesh.vertices.cpu().numpy(),
318
- "faces": mesh.faces.cpu().numpy(),
319
- },
320
- }
321
-
322
-
323
- def unpack_state(state: dict, device: str = "cpu") -> tuple[Gaussian, dict]:
324
- gs = Gaussian(
325
- aabb=state["gaussian"]["aabb"],
326
- sh_degree=state["gaussian"]["sh_degree"],
327
- mininum_kernel_size=state["gaussian"]["mininum_kernel_size"],
328
- scaling_bias=state["gaussian"]["scaling_bias"],
329
- opacity_bias=state["gaussian"]["opacity_bias"],
330
- scaling_activation=state["gaussian"]["scaling_activation"],
331
- device=device,
332
- )
333
- gs._xyz = torch.tensor(state["gaussian"]["_xyz"], device=device)
334
- gs._features_dc = torch.tensor(
335
- state["gaussian"]["_features_dc"], device=device
336
- )
337
- gs._scaling = torch.tensor(state["gaussian"]["_scaling"], device=device)
338
- gs._rotation = torch.tensor(state["gaussian"]["_rotation"], device=device)
339
- gs._opacity = torch.tensor(state["gaussian"]["_opacity"], device=device)
340
-
341
- mesh = edict(
342
- vertices=torch.tensor(state["mesh"]["vertices"], device=device),
343
- faces=torch.tensor(state["mesh"]["faces"], device=device),
344
- )
345
-
346
- return gs, mesh
347
-
348
-
349
  def get_seed(randomize_seed: bool, seed: int, max_seed: int = MAX_SEED) -> int:
350
  return np.random.randint(0, max_seed) if randomize_seed else seed
351
 
@@ -399,87 +280,48 @@ def image_to_3d(
399
  if is_sam_image:
400
  seg_image = filter_image_small_connected_components(sam_image)
401
  seg_image = Image.fromarray(seg_image, mode="RGBA")
402
- seg_image = trellis_preprocess(seg_image)
403
  else:
404
  seg_image = image
405
 
406
  if isinstance(seg_image, np.ndarray):
407
  seg_image = Image.fromarray(seg_image)
408
 
409
- output_root = os.path.join(TMP_DIR, str(req.session_hash))
410
- os.makedirs(output_root, exist_ok=True)
411
- seg_image.save(f"{output_root}/seg_image.png")
412
- raw_image_cache.save(f"{output_root}/raw_image.png")
413
- PIPELINE.cuda()
414
- outputs = PIPELINE.run(
415
- seg_image,
416
- seed=seed,
417
- formats=["gaussian", "mesh"],
418
- preprocess_image=False,
419
- sparse_structure_sampler_params={
420
- "steps": ss_sampling_steps,
421
- "cfg_strength": ss_guidance_strength,
422
- },
423
- slat_sampler_params={
424
- "steps": slat_sampling_steps,
425
- "cfg_strength": slat_guidance_strength,
426
- },
427
- )
428
- # Set to cpu for memory saving.
429
- PIPELINE.cpu()
 
 
 
 
 
430
 
431
  gs_model = outputs["gaussian"][0]
432
  mesh_model = outputs["mesh"][0]
433
  color_images = render_video(gs_model, r=1.85)["color"]
434
  normal_images = render_video(mesh_model, r=1.85)["normal"]
435
 
436
- video_path = os.path.join(output_root, "gs_mesh.mp4")
437
- merge_images_video(color_images, normal_images, video_path)
438
- state = pack_state(gs_model, mesh_model)
439
-
440
- gc.collect()
441
- torch.cuda.empty_cache()
442
-
443
- return state, video_path
444
-
445
-
446
- @spaces.GPU
447
- def image_to_3d_sam3d(
448
- image: Image.Image,
449
- seed: int,
450
- ss_sampling_steps: int,
451
- slat_sampling_steps: int,
452
- raw_image_cache: Image.Image,
453
- ss_guidance_strength: float = None,
454
- slat_guidance_strength: float = None,
455
- sam_image: Image.Image = None,
456
- is_sam_image: bool = False,
457
- req: gr.Request = None,
458
- ) -> tuple[dict, str]:
459
- if is_sam_image:
460
- seg_image = filter_image_small_connected_components(sam_image)
461
- seg_image = Image.fromarray(seg_image, mode="RGBA")
462
- else:
463
- seg_image = image
464
-
465
- if isinstance(seg_image, np.ndarray):
466
- seg_image = Image.fromarray(seg_image)
467
-
468
  output_root = os.path.join(TMP_DIR, str(req.session_hash))
469
  os.makedirs(output_root, exist_ok=True)
470
  seg_image.save(f"{output_root}/seg_image.png")
471
  raw_image_cache.save(f"{output_root}/raw_image.png")
472
- outputs = PIPELINE.run(
473
- seg_image,
474
- seed=seed,
475
- stage1_inference_steps=ss_sampling_steps,
476
- stage2_inference_steps=slat_sampling_steps,
477
- )
478
-
479
- gs_model = outputs["gaussian"][0]
480
- mesh_model = outputs["mesh"][0]
481
- color_images = render_video(gs_model, r=1.85)["color"]
482
- normal_images = render_video(mesh_model, r=1.85)["normal"]
483
 
484
  video_path = os.path.join(output_root, "gs_mesh.mp4")
485
  merge_images_video(color_images, normal_images, video_path)
@@ -491,56 +333,13 @@ def image_to_3d_sam3d(
491
  return state, video_path
492
 
493
 
494
- @spaces.GPU
495
- def extract_3d_representations(
496
- state: dict, enable_delight: bool, texture_size: int, req: gr.Request
497
- ):
498
- output_root = TMP_DIR
499
- output_root = os.path.join(output_root, str(req.session_hash))
500
- gs_model, mesh_model = unpack_state(state, device="cuda")
501
-
502
- mesh = postprocessing_utils.to_glb(
503
- gs_model,
504
- mesh_model,
505
- simplify=0.9,
506
- texture_size=1024,
507
- verbose=True,
508
- )
509
- filename = "sample"
510
- gs_path = os.path.join(output_root, f"{filename}_gs.ply")
511
- gs_model.save_ply(gs_path)
512
-
513
- # Rotate mesh and GS by 90 degrees around Z-axis.
514
- rot_matrix = [[0, 0, -1], [0, 1, 0], [1, 0, 0]]
515
- # Addtional rotation for GS to align mesh.
516
- gs_rot = np.array([[1, 0, 0], [0, -1, 0], [0, 0, -1]]) @ np.array(
517
- rot_matrix
518
- )
519
- pose = GaussianOperator.trans_to_quatpose(gs_rot)
520
- aligned_gs_path = gs_path.replace(".ply", "_aligned.ply")
521
- GaussianOperator.resave_ply(
522
- in_ply=gs_path,
523
- out_ply=aligned_gs_path,
524
- instance_pose=pose,
525
- )
526
-
527
- mesh.vertices = mesh.vertices @ np.array(rot_matrix)
528
- mesh_obj_path = os.path.join(output_root, f"{filename}.obj")
529
- mesh.export(mesh_obj_path)
530
- mesh_glb_path = os.path.join(output_root, f"{filename}.glb")
531
- mesh.export(mesh_glb_path)
532
-
533
- torch.cuda.empty_cache()
534
-
535
- return mesh_glb_path, gs_path, mesh_obj_path, aligned_gs_path
536
-
537
-
538
  def extract_3d_representations_v2(
539
  state: dict,
540
  enable_delight: bool,
541
  texture_size: int,
542
  req: gr.Request,
543
  ):
 
544
  output_root = TMP_DIR
545
  user_dir = os.path.join(output_root, str(req.session_hash))
546
  gs_model, mesh_model = unpack_state(state, device="cpu")
@@ -607,6 +406,7 @@ def extract_3d_representations_v3(
607
  texture_size: int,
608
  req: gr.Request,
609
  ):
 
610
  output_root = TMP_DIR
611
  user_dir = os.path.join(output_root, str(req.session_hash))
612
  gs_model, mesh_model = unpack_state(state, device="cpu")
 
14
  # implied. See the License for the specific language governing
15
  # permissions and limitations under the License.
16
 
17
+ from embodied_gen.utils.monkey_patches import monkey_path_trellis
18
+
19
+ monkey_path_trellis()
20
+
21
  import gc
22
  import logging
23
  import os
 
31
  import numpy as np
32
  import spaces
33
  import torch
 
34
  import trimesh
 
35
  from PIL import Image
36
  from embodied_gen.data.backproject_v2 import entrypoint as backproject_api
37
  from embodied_gen.data.backproject_v3 import entrypoint as backproject_api_v3
 
39
  from embodied_gen.data.utils import trellis_preprocess, zip_files
40
  from embodied_gen.models.delight_model import DelightingModel
41
  from embodied_gen.models.gs_model import GaussianOperator
42
+ from embodied_gen.models.sam3d import Sam3dInference
43
  from embodied_gen.models.segment_model import (
44
  BMGG14Remover,
45
  RembgRemover,
 
60
  merge_images_video,
61
  )
62
  from embodied_gen.utils.tags import VERSION
63
+ from embodied_gen.utils.trender import pack_state, render_video, unpack_state
64
  from embodied_gen.validators.quality_checkers import (
65
  BaseChecker,
66
  ImageAestheticChecker,
 
73
  current_dir = os.path.dirname(current_file_path)
74
  sys.path.append(os.path.join(current_dir, ".."))
75
  from thirdparty.TRELLIS.trellis.pipelines import TrellisImageTo3DPipeline
 
 
 
 
 
 
 
 
 
76
  from thirdparty.TRELLIS.trellis.utils import postprocessing_utils
77
 
78
  logging.basicConfig(
 
80
  )
81
  logger = logging.getLogger(__name__)
82
 
 
 
 
 
83
  os.environ["GRADIO_ANALYTICS_ENABLED"] = "false"
 
84
 
85
  MAX_SEED = 100000
86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  # DELIGHT = DelightingModel()
88
  # IMAGESR_MODEL = ImageRealESRGAN(outscale=4)
89
  # IMAGESR_MODEL = ImageStableSR()
90
+ if os.getenv("GRADIO_APP").startswith("imageto3d"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  RBG_REMOVER = RembgRemover()
92
  RBG14_REMOVER = BMGG14Remover()
93
  SAM_PREDICTOR = SAMPredictor(model_type="vit_h", device="cpu")
94
+ if "sam3d" in os.getenv("GRADIO_APP"):
95
+ PIPELINE = Sam3dInference()
96
+ else:
97
+ PIPELINE = TrellisImageTo3DPipeline.from_pretrained(
98
+ "microsoft/TRELLIS-image-large"
99
+ )
100
+ # PIPELINE.cuda()
101
  SEG_CHECKER = ImageSegChecker(GPT_CLIENT)
102
  GEO_CHECKER = MeshGeoChecker(GPT_CLIENT)
103
  AESTHETIC_CHECKER = ImageAestheticChecker()
 
106
  os.path.dirname(os.path.abspath(__file__)), "sessions/imageto3d"
107
  )
108
  os.makedirs(TMP_DIR, exist_ok=True)
109
+ elif os.getenv("GRADIO_APP").startswith("textto3d"):
110
  RBG_REMOVER = RembgRemover()
111
  RBG14_REMOVER = BMGG14Remover()
112
+ if "sam3d" in os.getenv("GRADIO_APP"):
113
+ PIPELINE = Sam3dInference()
114
+ else:
115
+ PIPELINE = TrellisImageTo3DPipeline.from_pretrained(
116
+ "microsoft/TRELLIS-image-large"
117
+ )
118
+ # PIPELINE.cuda()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
  text_model_dir = "weights/Kolors"
120
  PIPELINE_IMG_IP = build_text2img_ip_pipeline(text_model_dir, ref_scale=0.3)
121
  PIPELINE_IMG = build_text2img_pipeline(text_model_dir)
 
227
  return Image.open(image_path).resize((512, 512))
228
 
229
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
230
  def get_seed(randomize_seed: bool, seed: int, max_seed: int = MAX_SEED) -> int:
231
  return np.random.randint(0, max_seed) if randomize_seed else seed
232
 
 
280
  if is_sam_image:
281
  seg_image = filter_image_small_connected_components(sam_image)
282
  seg_image = Image.fromarray(seg_image, mode="RGBA")
 
283
  else:
284
  seg_image = image
285
 
286
  if isinstance(seg_image, np.ndarray):
287
  seg_image = Image.fromarray(seg_image)
288
 
289
+ if isinstance(PIPELINE, Sam3dInference):
290
+ outputs = PIPELINE.run(
291
+ seg_image,
292
+ seed=seed,
293
+ stage1_inference_steps=ss_sampling_steps,
294
+ stage2_inference_steps=slat_sampling_steps,
295
+ )
296
+ else:
297
+ PIPELINE.cuda()
298
+ seg_image = trellis_preprocess(seg_image)
299
+ outputs = PIPELINE.run(
300
+ seg_image,
301
+ seed=seed,
302
+ formats=["gaussian", "mesh"],
303
+ preprocess_image=False,
304
+ sparse_structure_sampler_params={
305
+ "steps": ss_sampling_steps,
306
+ "cfg_strength": ss_guidance_strength,
307
+ },
308
+ slat_sampler_params={
309
+ "steps": slat_sampling_steps,
310
+ "cfg_strength": slat_guidance_strength,
311
+ },
312
+ )
313
+ # Set back to cpu for memory saving.
314
+ PIPELINE.cpu()
315
 
316
  gs_model = outputs["gaussian"][0]
317
  mesh_model = outputs["mesh"][0]
318
  color_images = render_video(gs_model, r=1.85)["color"]
319
  normal_images = render_video(mesh_model, r=1.85)["normal"]
320
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
321
  output_root = os.path.join(TMP_DIR, str(req.session_hash))
322
  os.makedirs(output_root, exist_ok=True)
323
  seg_image.save(f"{output_root}/seg_image.png")
324
  raw_image_cache.save(f"{output_root}/raw_image.png")
 
 
 
 
 
 
 
 
 
 
 
325
 
326
  video_path = os.path.join(output_root, "gs_mesh.mp4")
327
  merge_images_video(color_images, normal_images, video_path)
 
333
  return state, video_path
334
 
335
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
336
  def extract_3d_representations_v2(
337
  state: dict,
338
  enable_delight: bool,
339
  texture_size: int,
340
  req: gr.Request,
341
  ):
342
+ """Back-Projection Version of Texture Super-Resolution."""
343
  output_root = TMP_DIR
344
  user_dir = os.path.join(output_root, str(req.session_hash))
345
  gs_model, mesh_model = unpack_state(state, device="cpu")
 
406
  texture_size: int,
407
  req: gr.Request,
408
  ):
409
+ """Back-Projection Version with Optimization-Based."""
410
  output_root = TMP_DIR
411
  user_dir = os.path.join(output_root, str(req.session_hash))
412
  gs_model, mesh_model = unpack_state(state, device="cpu")
embodied_gen/data/asset_converter.py CHANGED
@@ -1,3 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from __future__ import annotations
2
 
3
  import logging
 
1
+ # Project EmbodiedGen
2
+ #
3
+ # Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
14
+ # implied. See the License for the specific language governing
15
+ # permissions and limitations under the License.
16
+
17
+
18
  from __future__ import annotations
19
 
20
  import logging
embodied_gen/models/sam3d.py CHANGED
@@ -19,7 +19,6 @@ from embodied_gen.utils.monkey_patches import monkey_patch_sam3d
19
  monkey_patch_sam3d()
20
  import os
21
  import sys
22
- from typing import Optional, Union
23
 
24
  import numpy as np
25
  from hydra.utils import instantiate
@@ -31,29 +30,38 @@ from PIL import Image
31
  current_file_path = os.path.abspath(__file__)
32
  current_dir = os.path.dirname(current_file_path)
33
  sys.path.append(os.path.join(current_dir, "../.."))
 
34
  from thirdparty.sam3d.sam3d_objects.pipeline.inference_pipeline_pointmap import (
35
  InferencePipelinePointMap,
36
  )
37
 
 
 
 
 
38
  __all__ = ["Sam3dInference"]
39
 
40
 
41
- def load_image(path: str) -> np.ndarray:
42
- image = Image.open(path)
43
- image = np.array(image)
44
- image = image.astype(np.uint8)
45
- return image
46
 
 
 
 
47
 
48
- def load_mask(path: str) -> np.ndarray:
49
- mask = load_image(path)
50
- mask = mask > 0
51
- if mask.ndim == 3:
52
- mask = mask[..., -1]
53
- return mask
54
 
 
 
 
 
 
 
 
 
55
 
56
- class Sam3dInference:
57
  def __init__(
58
  self, local_dir: str = "weights/sam-3d-objects", compile: bool = False
59
  ) -> None:
@@ -65,7 +73,7 @@ class Sam3dInference:
65
  config.rendering_engine = "nvdiffrast"
66
  config.compile_model = compile
67
  config.workspace_dir = os.path.dirname(config_file)
68
- # Generate 4 gs in each pixel.
69
  config["slat_decoder_gs_config_path"] = config.pop(
70
  "slat_decoder_gs_4_config_path", "slat_decoder_gs_4.yaml"
71
  )
@@ -118,25 +126,22 @@ class Sam3dInference:
118
  if __name__ == "__main__":
119
  pipeline = Sam3dInference()
120
 
121
- # load image
122
- image = load_image(
123
- "/home/users/xinjie.wang/xinjie/sam-3d-objects/notebook/images/shutterstock_stylish_kidsroom_1640806567/image.png"
124
- )
125
- mask = load_mask(
126
- "/home/users/xinjie.wang/xinjie/sam-3d-objects/notebook/images/shutterstock_stylish_kidsroom_1640806567/13.png"
127
- )
128
 
129
  import torch
 
 
 
 
 
 
130
 
131
  if torch.cuda.is_available():
132
  torch.cuda.reset_peak_memory_stats()
133
  torch.cuda.empty_cache()
134
 
135
- from time import time
136
-
137
  start = time()
138
-
139
- output = pipeline.run(image, mask, seed=42)
140
  print(f"Running cost: {round(time()-start, 1)}")
141
 
142
  if torch.cuda.is_available():
@@ -145,5 +150,5 @@ if __name__ == "__main__":
145
 
146
  print(f"End: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")
147
 
148
- output["gs"].save_ply(f"outputs/splat.ply")
149
- print("Your reconstruction has been saved to outputs/splat.ply")
 
19
  monkey_patch_sam3d()
20
  import os
21
  import sys
 
22
 
23
  import numpy as np
24
  from hydra.utils import instantiate
 
30
  current_file_path = os.path.abspath(__file__)
31
  current_dir = os.path.dirname(current_file_path)
32
  sys.path.append(os.path.join(current_dir, "../.."))
33
+ from loguru import logger
34
  from thirdparty.sam3d.sam3d_objects.pipeline.inference_pipeline_pointmap import (
35
  InferencePipelinePointMap,
36
  )
37
 
38
+ logger.remove()
39
+ logger.add(lambda _: None, level="ERROR")
40
+
41
+
42
  __all__ = ["Sam3dInference"]
43
 
44
 
45
+ class Sam3dInference:
46
+ """Wrapper for the SAM-3D-Objects inference pipeline.
 
 
 
47
 
48
+ This class handles loading the SAM-3D-Objects model, configuring it for inference,
49
+ and running the pipeline on input images (optionally with masks and pointmaps).
50
+ It supports distillation options and inference step customization.
51
 
52
+ Args:
53
+ local_dir (str): Directory to store or load model weights and configs.
54
+ compile (bool): Whether to compile the model for faster inference.
 
 
 
55
 
56
+ Methods:
57
+ merge_mask_to_rgba(image, mask):
58
+ Merges a binary mask into the alpha channel of an RGB image.
59
+
60
+ run(image, mask=None, seed=None, pointmap=None, use_stage1_distillation=False,
61
+ use_stage2_distillation=False, stage1_inference_steps=25, stage2_inference_steps=25):
62
+ Runs the inference pipeline and returns the output dictionary.
63
+ """
64
 
 
65
  def __init__(
66
  self, local_dir: str = "weights/sam-3d-objects", compile: bool = False
67
  ) -> None:
 
73
  config.rendering_engine = "nvdiffrast"
74
  config.compile_model = compile
75
  config.workspace_dir = os.path.dirname(config_file)
76
+ # Generate 4 instead of 32 gs in each pixel for efficient storage.
77
  config["slat_decoder_gs_config_path"] = config.pop(
78
  "slat_decoder_gs_4_config_path", "slat_decoder_gs_4.yaml"
79
  )
 
126
  if __name__ == "__main__":
127
  pipeline = Sam3dInference()
128
 
129
+ from time import time
 
 
 
 
 
 
130
 
131
  import torch
132
+ from embodied_gen.models.segment_model import RembgRemover
133
+
134
+ input_image = "apps/assets/example_image/sample_00.jpg"
135
+ output_gs = "outputs/splat.ply"
136
+ remover = RembgRemover()
137
+ clean_image = remover(input_image)
138
 
139
  if torch.cuda.is_available():
140
  torch.cuda.reset_peak_memory_stats()
141
  torch.cuda.empty_cache()
142
 
 
 
143
  start = time()
144
+ output = pipeline.run(clean_image, seed=42)
 
145
  print(f"Running cost: {round(time()-start, 1)}")
146
 
147
  if torch.cuda.is_available():
 
150
 
151
  print(f"End: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")
152
 
153
+ output["gs"].save_ply(output_gs)
154
+ print(f"Saved to {output_gs}")
embodied_gen/scripts/gen_scene3d.py CHANGED
@@ -1,3 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import logging
2
  import os
3
  import random
 
1
+ # Project EmbodiedGen
2
+ #
3
+ # Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
14
+ # implied. See the License for the specific language governing
15
+ # permissions and limitations under the License.
16
+
17
+
18
  import logging
19
  import os
20
  import random
embodied_gen/scripts/gen_texture.py CHANGED
@@ -1,3 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import shutil
3
  from dataclasses import dataclass
 
1
+ # Project EmbodiedGen
2
+ #
3
+ # Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
14
+ # implied. See the License for the specific language governing
15
+ # permissions and limitations under the License.
16
+
17
+
18
  import os
19
  import shutil
20
  from dataclasses import dataclass
embodied_gen/scripts/imageto3d.py CHANGED
@@ -14,30 +14,30 @@
14
  # implied. See the License for the specific language governing
15
  # permissions and limitations under the License.
16
 
17
-
18
  import argparse
19
  import os
20
  import random
21
- import sys
22
  from glob import glob
23
  from shutil import copy, copytree, rmtree
24
 
25
  import numpy as np
26
- import torch
27
  import trimesh
28
  from PIL import Image
29
  from embodied_gen.data.backproject_v3 import entrypoint as backproject_api
30
- from embodied_gen.data.utils import delete_dir, trellis_preprocess
31
 
 
32
  # from embodied_gen.models.delight_model import DelightingModel
33
  from embodied_gen.models.gs_model import GaussianOperator
34
  from embodied_gen.models.segment_model import RembgRemover
35
-
36
- # from embodied_gen.models.sr_model import ImageRealESRGAN
37
  from embodied_gen.scripts.render_gs import entrypoint as render_gs_api
38
  from embodied_gen.utils.gpt_clients import GPT_CLIENT
 
39
  from embodied_gen.utils.log import logger
40
- from embodied_gen.utils.process_media import merge_images_video
 
 
 
41
  from embodied_gen.utils.tags import VERSION
42
  from embodied_gen.utils.trender import render_video
43
  from embodied_gen.validators.quality_checkers import (
@@ -48,26 +48,24 @@ from embodied_gen.validators.quality_checkers import (
48
  )
49
  from embodied_gen.validators.urdf_convertor import URDFGenerator
50
 
51
- current_file_path = os.path.abspath(__file__)
52
- current_dir = os.path.dirname(current_file_path)
53
- sys.path.append(os.path.join(current_dir, "../.."))
54
- from thirdparty.TRELLIS.trellis.pipelines import TrellisImageTo3DPipeline
 
55
 
56
- os.environ["TORCH_EXTENSIONS_DIR"] = os.path.expanduser(
57
- "~/.cache/torch_extensions"
58
- )
59
- os.environ["GRADIO_ANALYTICS_ENABLED"] = "false"
60
- os.environ["SPCONV_ALGO"] = "native"
61
- random.seed(0)
 
 
62
 
63
- logger.info("Loading Image3D Models...")
64
  # DELIGHT = DelightingModel()
65
  # IMAGESR_MODEL = ImageRealESRGAN(outscale=4)
66
  RBG_REMOVER = RembgRemover()
67
- PIPELINE = TrellisImageTo3DPipeline.from_pretrained(
68
- "microsoft/TRELLIS-image-large"
69
- )
70
- # PIPELINE.cuda()
71
  SEG_CHECKER = ImageSegChecker(GPT_CLIENT)
72
  GEO_CHECKER = MeshGeoChecker(GPT_CLIENT)
73
  AESTHETIC_CHECKER = ImageAestheticChecker()
@@ -151,7 +149,6 @@ def entrypoint(**kwargs):
151
  # Segmentation: Get segmented image using Rembg.
152
  seg_path = f"{output_root}/{filename}_cond.png"
153
  seg_image = RBG_REMOVER(image) if image.mode != "RGBA" else image
154
- seg_image = trellis_preprocess(seg_image)
155
  seg_image.save(seg_path)
156
 
157
  seed = args.seed
@@ -162,27 +159,8 @@ def entrypoint(**kwargs):
162
  logger.info(
163
  f"Try: {try_idx + 1}/{args.n_retry}, Seed: {seed}, Prompt: {seg_path}"
164
  )
165
- # Run the pipeline
166
  try:
167
- PIPELINE.cuda()
168
- outputs = PIPELINE.run(
169
- seg_image,
170
- preprocess_image=False,
171
- seed=(
172
- random.randint(0, 100000) if seed is None else seed
173
- ),
174
- # Optional parameters
175
- # sparse_structure_sampler_params={
176
- # "steps": 12,
177
- # "cfg_strength": 7.5,
178
- # },
179
- # slat_sampler_params={
180
- # "steps": 12,
181
- # "cfg_strength": 3,
182
- # },
183
- )
184
- PIPELINE.cpu()
185
- torch.cuda.empty_cache()
186
  except Exception as e:
187
  logger.error(
188
  f"[Pipeline Failed] process {image_path}: {e}, skip."
@@ -215,14 +193,13 @@ def entrypoint(**kwargs):
215
  render_gs_api(
216
  input_gs=aligned_gs_path,
217
  output_path=color_path,
218
- elevation=[20, -10, 60, -50],
219
- num_images=12,
220
  )
221
-
222
  color_img = Image.open(color_path)
223
- keep_height = int(color_img.height * 2 / 3)
224
- crop_img = color_img.crop((0, 0, color_img.width, keep_height))
225
- geo_flag, geo_result = GEO_CHECKER([crop_img], text=asset_node)
226
  logger.warning(
227
  f"{GEO_CHECKER.__class__.__name__}: {geo_result} for {seg_path}"
228
  )
@@ -232,8 +209,8 @@ def entrypoint(**kwargs):
232
  seed = random.randint(0, 100000) if seed is not None else None
233
 
234
  # Render the video for generated 3D asset.
235
- color_images = render_video(gs_model)["color"]
236
- normal_images = render_video(mesh_model)["normal"]
237
  video_path = os.path.join(output_root, "gs_mesh.mp4")
238
  merge_images_video(color_images, normal_images, video_path)
239
 
@@ -312,7 +289,7 @@ def entrypoint(**kwargs):
312
  image_paths = glob(f"{image_dir}/*.png")
313
  images_list = []
314
  for checker in CHECKERS:
315
- images = image_paths
316
  if isinstance(checker, ImageSegChecker):
317
  images = [
318
  f"{output_root}/{filename}_raw.png",
@@ -334,9 +311,12 @@ def entrypoint(**kwargs):
334
  f"{result_dir}/{urdf_convertor.output_mesh_dir}",
335
  )
336
  copy(video_path, f"{result_dir}/video.mp4")
 
337
  if not args.keep_intermediate:
338
  delete_dir(output_root, keep_subs=["result"])
339
 
 
 
340
  except Exception as e:
341
  logger.error(f"Failed to process {image_path}: {e}, skip.")
342
  continue
 
14
  # implied. See the License for the specific language governing
15
  # permissions and limitations under the License.
16
 
 
17
  import argparse
18
  import os
19
  import random
 
20
  from glob import glob
21
  from shutil import copy, copytree, rmtree
22
 
23
  import numpy as np
 
24
  import trimesh
25
  from PIL import Image
26
  from embodied_gen.data.backproject_v3 import entrypoint as backproject_api
27
+ from embodied_gen.data.utils import delete_dir
28
 
29
+ # from embodied_gen.models.sr_model import ImageRealESRGAN
30
  # from embodied_gen.models.delight_model import DelightingModel
31
  from embodied_gen.models.gs_model import GaussianOperator
32
  from embodied_gen.models.segment_model import RembgRemover
 
 
33
  from embodied_gen.scripts.render_gs import entrypoint as render_gs_api
34
  from embodied_gen.utils.gpt_clients import GPT_CLIENT
35
+ from embodied_gen.utils.inference import image3d_model_infer
36
  from embodied_gen.utils.log import logger
37
+ from embodied_gen.utils.process_media import (
38
+ combine_images_to_grid,
39
+ merge_images_video,
40
+ )
41
  from embodied_gen.utils.tags import VERSION
42
  from embodied_gen.utils.trender import render_video
43
  from embodied_gen.validators.quality_checkers import (
 
48
  )
49
  from embodied_gen.validators.urdf_convertor import URDFGenerator
50
 
51
+ # random.seed(0)
52
+ IMAGE3D_MODEL = "SAM3D" # TRELLIS or SAM3D
53
+ logger.info(f"Loading {IMAGE3D_MODEL} as Image3D Models...")
54
+ if IMAGE3D_MODEL == "TRELLIS":
55
+ from thirdparty.TRELLIS.trellis.pipelines import TrellisImageTo3DPipeline
56
 
57
+ PIPELINE = TrellisImageTo3DPipeline.from_pretrained(
58
+ "microsoft/TRELLIS-image-large"
59
+ )
60
+ # PIPELINE.cuda()
61
+ elif IMAGE3D_MODEL == "SAM3D":
62
+ from embodied_gen.models.sam3d import Sam3dInference
63
+
64
+ PIPELINE = Sam3dInference()
65
 
 
66
  # DELIGHT = DelightingModel()
67
  # IMAGESR_MODEL = ImageRealESRGAN(outscale=4)
68
  RBG_REMOVER = RembgRemover()
 
 
 
 
69
  SEG_CHECKER = ImageSegChecker(GPT_CLIENT)
70
  GEO_CHECKER = MeshGeoChecker(GPT_CLIENT)
71
  AESTHETIC_CHECKER = ImageAestheticChecker()
 
149
  # Segmentation: Get segmented image using Rembg.
150
  seg_path = f"{output_root}/{filename}_cond.png"
151
  seg_image = RBG_REMOVER(image) if image.mode != "RGBA" else image
 
152
  seg_image.save(seg_path)
153
 
154
  seed = args.seed
 
159
  logger.info(
160
  f"Try: {try_idx + 1}/{args.n_retry}, Seed: {seed}, Prompt: {seg_path}"
161
  )
 
162
  try:
163
+ outputs = image3d_model_infer(PIPELINE, seg_image, seed)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
  except Exception as e:
165
  logger.error(
166
  f"[Pipeline Failed] process {image_path}: {e}, skip."
 
193
  render_gs_api(
194
  input_gs=aligned_gs_path,
195
  output_path=color_path,
196
+ elevation=[30, -30],
197
+ num_images=4,
198
  )
 
199
  color_img = Image.open(color_path)
200
+ geo_flag, geo_result = GEO_CHECKER(
201
+ [color_img], text=asset_node
202
+ )
203
  logger.warning(
204
  f"{GEO_CHECKER.__class__.__name__}: {geo_result} for {seg_path}"
205
  )
 
209
  seed = random.randint(0, 100000) if seed is not None else None
210
 
211
  # Render the video for generated 3D asset.
212
+ color_images = render_video(gs_model, r=1.85)["color"]
213
+ normal_images = render_video(mesh_model, r=1.85)["normal"]
214
  video_path = os.path.join(output_root, "gs_mesh.mp4")
215
  merge_images_video(color_images, normal_images, video_path)
216
 
 
289
  image_paths = glob(f"{image_dir}/*.png")
290
  images_list = []
291
  for checker in CHECKERS:
292
+ images = combine_images_to_grid(image_paths)
293
  if isinstance(checker, ImageSegChecker):
294
  images = [
295
  f"{output_root}/{filename}_raw.png",
 
311
  f"{result_dir}/{urdf_convertor.output_mesh_dir}",
312
  )
313
  copy(video_path, f"{result_dir}/video.mp4")
314
+
315
  if not args.keep_intermediate:
316
  delete_dir(output_root, keep_subs=["result"])
317
 
318
+ logger.info(f"Saved results for {image_path} in {result_dir}")
319
+
320
  except Exception as e:
321
  logger.error(f"Failed to process {image_path}: {e}, skip.")
322
  continue
embodied_gen/scripts/render_gs.py CHANGED
@@ -27,7 +27,6 @@ from tqdm import tqdm
27
  from embodied_gen.data.utils import (
28
  CameraSetting,
29
  init_kal_camera,
30
- normalize_vertices_array,
31
  )
32
  from embodied_gen.models.gs_model import load_gs_model
33
  from embodied_gen.utils.process_media import combine_images_to_grid
 
27
  from embodied_gen.data.utils import (
28
  CameraSetting,
29
  init_kal_camera,
 
30
  )
31
  from embodied_gen.models.gs_model import load_gs_model
32
  from embodied_gen.utils.process_media import combine_images_to_grid
embodied_gen/scripts/textto3d.py CHANGED
@@ -30,6 +30,7 @@ from embodied_gen.utils.gpt_clients import GPT_CLIENT
30
  from embodied_gen.utils.log import logger
31
  from embodied_gen.utils.process_media import (
32
  check_object_edge_truncated,
 
33
  render_asset3d,
34
  )
35
  from embodied_gen.validators.quality_checkers import (
@@ -51,7 +52,6 @@ BG_REMOVER = RembgRemover()
51
 
52
 
53
  __all__ = [
54
- "text_to_image",
55
  "text_to_3d",
56
  ]
57
 
@@ -176,12 +176,12 @@ def text_to_3d(**kwargs) -> dict:
176
  image_path = render_asset3d(
177
  mesh_path,
178
  output_root=f"{node_save_dir}/result",
179
- num_images=6,
180
  elevation=(30, -30),
181
  output_subdir="renders",
182
  no_index_file=True,
183
  )
184
-
185
  check_text = asset_type if asset_type is not None else prompt
186
  qa_flag, qa_result = TXTGEN_CHECKER(check_text, image_path)
187
  logger.warning(
 
30
  from embodied_gen.utils.log import logger
31
  from embodied_gen.utils.process_media import (
32
  check_object_edge_truncated,
33
+ combine_images_to_grid,
34
  render_asset3d,
35
  )
36
  from embodied_gen.validators.quality_checkers import (
 
52
 
53
 
54
  __all__ = [
 
55
  "text_to_3d",
56
  ]
57
 
 
176
  image_path = render_asset3d(
177
  mesh_path,
178
  output_root=f"{node_save_dir}/result",
179
+ num_images=4,
180
  elevation=(30, -30),
181
  output_subdir="renders",
182
  no_index_file=True,
183
  )
184
+ image_path = combine_images_to_grid(image_path)
185
  check_text = asset_type if asset_type is not None else prompt
186
  qa_flag, qa_result = TXTGEN_CHECKER(check_text, image_path)
187
  logger.warning(
embodied_gen/utils/gpt_clients.py CHANGED
@@ -21,13 +21,14 @@ import os
21
  from io import BytesIO
22
  from typing import Optional
23
 
 
24
  import yaml
25
  from openai import AzureOpenAI, OpenAI # pip install openai
26
  from PIL import Image
27
  from tenacity import (
28
  retry,
 
29
  stop_after_attempt,
30
- stop_after_delay,
31
  wait_random_exponential,
32
  )
33
  from embodied_gen.utils.process_media import combine_images_to_grid
@@ -106,8 +107,9 @@ class GPTclient:
106
  logger.info(f"Using GPT model: {self.model_name}.")
107
 
108
  @retry(
109
- wait=wait_random_exponential(min=1, max=20),
110
- stop=(stop_after_attempt(10) | stop_after_delay(30)),
 
111
  )
112
  def completion_with_backoff(self, **kwargs):
113
  """Performs a chat completion request with retry/backoff."""
@@ -246,3 +248,8 @@ GPT_CLIENT = GPTclient(
246
  model_name=model_name,
247
  check_connection=False,
248
  )
 
 
 
 
 
 
21
  from io import BytesIO
22
  from typing import Optional
23
 
24
+ import openai
25
  import yaml
26
  from openai import AzureOpenAI, OpenAI # pip install openai
27
  from PIL import Image
28
  from tenacity import (
29
  retry,
30
+ retry_if_not_exception_type,
31
  stop_after_attempt,
 
32
  wait_random_exponential,
33
  )
34
  from embodied_gen.utils.process_media import combine_images_to_grid
 
107
  logger.info(f"Using GPT model: {self.model_name}.")
108
 
109
  @retry(
110
+ retry=retry_if_not_exception_type(openai.BadRequestError),
111
+ wait=wait_random_exponential(min=1, max=10),
112
+ stop=stop_after_attempt(5),
113
  )
114
  def completion_with_backoff(self, **kwargs):
115
  """Performs a chat completion request with retry/backoff."""
 
248
  model_name=model_name,
249
  check_connection=False,
250
  )
251
+
252
+
253
+ if __name__ == "__main__":
254
+ response = GPT_CLIENT.query("What is the capital of China?")
255
+ print(f"Response: {response}")
embodied_gen/utils/inference.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from embodied_gen.utils.monkey_patches import monkey_path_trellis
2
+
3
+ monkey_path_trellis()
4
+ import random
5
+
6
+ import torch
7
+ from PIL import Image
8
+ from embodied_gen.data.utils import trellis_preprocess
9
+ from embodied_gen.models.sam3d import Sam3dInference
10
+ from embodied_gen.utils.trender import pack_state, unpack_state
11
+ from thirdparty.TRELLIS.trellis.pipelines import TrellisImageTo3DPipeline
12
+
13
+ __all__ = [
14
+ "image3d_model_infer",
15
+ ]
16
+
17
+
18
+ def image3d_model_infer(
19
+ pipe: TrellisImageTo3DPipeline | Sam3dInference,
20
+ seg_image: Image.Image,
21
+ seed: int = None,
22
+ **kwargs: dict,
23
+ ) -> dict[str, any]:
24
+ if isinstance(pipe, TrellisImageTo3DPipeline):
25
+ pipe.cuda()
26
+ seg_image = trellis_preprocess(seg_image)
27
+ outputs = pipe.run(
28
+ seg_image,
29
+ preprocess_image=False,
30
+ seed=(random.randint(0, 100000) if seed is None else seed),
31
+ # Optional parameters
32
+ # sparse_structure_sampler_params={
33
+ # "steps": 12,
34
+ # "cfg_strength": 7.5,
35
+ # },
36
+ # slat_sampler_params={
37
+ # "steps": 12,
38
+ # "cfg_strength": 3,
39
+ # },
40
+ **kwargs,
41
+ )
42
+ pipe.cpu()
43
+ elif isinstance(pipe, Sam3dInference):
44
+ outputs = pipe.run(
45
+ seg_image,
46
+ seed=(random.randint(0, 100000) if seed is None else seed),
47
+ # stage1_inference_steps=25,
48
+ # stage2_inference_steps=25,
49
+ **kwargs,
50
+ )
51
+ state = pack_state(outputs["gaussian"][0], outputs["mesh"][0])
52
+ # Align GS3D from SAM3D with TRELLIS format.
53
+ outputs["gaussian"][0], _ = unpack_state(state, device="cuda")
54
+ else:
55
+ raise ValueError(f"Unsupported pipeline type: {type(pipe)}")
56
+
57
+ torch.cuda.empty_cache()
58
+
59
+ return outputs
embodied_gen/utils/monkey_patches.py CHANGED
@@ -32,6 +32,67 @@ __all__ = [
32
  ]
33
 
34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  def monkey_patch_pano2room():
36
  current_file_path = os.path.abspath(__file__)
37
  current_dir = os.path.dirname(current_file_path)
@@ -240,8 +301,6 @@ def monkey_patch_sam3d():
240
  if sam3d_root not in sys.path:
241
  sys.path.insert(0, sam3d_root)
242
 
243
- print(f"[MonkeyPatch] Added to sys.path: {sam3d_root}")
244
-
245
  def patch_pointmap_infer_pipeline():
246
  from copy import deepcopy
247
 
@@ -317,9 +376,6 @@ def monkey_patch_sam3d():
317
  )
318
  )
319
 
320
- logger.info(
321
- f"Rescaling scale by {ss_return_dict['downsample_factor']} after downsampling"
322
- )
323
  ss_return_dict["scale"] = (
324
  ss_return_dict["scale"]
325
  * ss_return_dict["downsample_factor"]
@@ -471,11 +527,6 @@ def monkey_patch_sam3d():
471
  self.rendering_engine = rendering_engine
472
  self.device = torch.device(device)
473
  self.compile_model = compile_model
474
- logger.info(f"self.device: {self.device}")
475
- logger.info(
476
- f"CUDA_VISIBLE_DEVICES: {os.environ.get('CUDA_VISIBLE_DEVICES', None)}"
477
- )
478
- logger.info(f"Actually using GPU: {torch.cuda.current_device()}")
479
  with self.device:
480
  self.decode_formats = decode_formats
481
  self.pad_size = pad_size
@@ -511,7 +562,6 @@ def monkey_patch_sam3d():
511
  )
512
  self.slat_preprocessor = slat_preprocessor
513
 
514
- logger.info("Loading model weights...")
515
  raw_device = self.device
516
  self.device = torch.device("cpu")
517
  ss_generator = self.init_ss_generator(
@@ -578,7 +628,7 @@ def monkey_patch_sam3d():
578
  "slat_decoder_mesh": slat_decoder_mesh,
579
  }
580
  )
581
- logger.info("Loading model weights completed!")
582
 
583
  if self.compile_model:
584
  logger.info("Compiling model...")
 
32
  ]
33
 
34
 
35
+ def monkey_path_trellis():
36
+ import torch.nn.functional as F
37
+
38
+ current_file_path = os.path.abspath(__file__)
39
+ current_dir = os.path.dirname(current_file_path)
40
+ sys.path.append(os.path.join(current_dir, "../.."))
41
+
42
+ from thirdparty.TRELLIS.trellis.representations import Gaussian
43
+ from thirdparty.TRELLIS.trellis.representations.gaussian.general_utils import (
44
+ build_scaling_rotation,
45
+ inverse_sigmoid,
46
+ strip_symmetric,
47
+ )
48
+
49
+ os.environ["TORCH_EXTENSIONS_DIR"] = os.path.expanduser(
50
+ "~/.cache/torch_extensions"
51
+ )
52
+ os.environ["SPCONV_ALGO"] = "auto" # Can be 'native' or 'auto'
53
+ os.environ['ATTN_BACKEND'] = (
54
+ "xformers" # Can be 'flash-attn' or 'xformers'
55
+ )
56
+ from thirdparty.TRELLIS.trellis.modules.sparse import set_attn
57
+
58
+ set_attn("xformers")
59
+
60
+ def patched_setup_functions(self):
61
+ def inverse_softplus(x):
62
+ return x + torch.log(-torch.expm1(-x))
63
+
64
+ def build_covariance_from_scaling_rotation(
65
+ scaling, scaling_modifier, rotation
66
+ ):
67
+ L = build_scaling_rotation(scaling_modifier * scaling, rotation)
68
+ actual_covariance = L @ L.transpose(1, 2)
69
+ symm = strip_symmetric(actual_covariance)
70
+ return symm
71
+
72
+ if self.scaling_activation_type == "exp":
73
+ self.scaling_activation = torch.exp
74
+ self.inverse_scaling_activation = torch.log
75
+ elif self.scaling_activation_type == "softplus":
76
+ self.scaling_activation = F.softplus
77
+ self.inverse_scaling_activation = inverse_softplus
78
+
79
+ self.covariance_activation = build_covariance_from_scaling_rotation
80
+ self.opacity_activation = torch.sigmoid
81
+ self.inverse_opacity_activation = inverse_sigmoid
82
+ self.rotation_activation = F.normalize
83
+
84
+ self.scale_bias = self.inverse_scaling_activation(
85
+ torch.tensor(self.scaling_bias)
86
+ ).to(self.device)
87
+ self.rots_bias = torch.zeros((4)).to(self.device)
88
+ self.rots_bias[0] = 1
89
+ self.opacity_bias = self.inverse_opacity_activation(
90
+ torch.tensor(self.opacity_bias)
91
+ ).to(self.device)
92
+
93
+ Gaussian.setup_functions = patched_setup_functions
94
+
95
+
96
  def monkey_patch_pano2room():
97
  current_file_path = os.path.abspath(__file__)
98
  current_dir = os.path.dirname(current_file_path)
 
301
  if sam3d_root not in sys.path:
302
  sys.path.insert(0, sam3d_root)
303
 
 
 
304
  def patch_pointmap_infer_pipeline():
305
  from copy import deepcopy
306
 
 
376
  )
377
  )
378
 
 
 
 
379
  ss_return_dict["scale"] = (
380
  ss_return_dict["scale"]
381
  * ss_return_dict["downsample_factor"]
 
527
  self.rendering_engine = rendering_engine
528
  self.device = torch.device(device)
529
  self.compile_model = compile_model
 
 
 
 
 
530
  with self.device:
531
  self.decode_formats = decode_formats
532
  self.pad_size = pad_size
 
562
  )
563
  self.slat_preprocessor = slat_preprocessor
564
 
 
565
  raw_device = self.device
566
  self.device = torch.device("cpu")
567
  ss_generator = self.init_ss_generator(
 
628
  "slat_decoder_mesh": slat_decoder_mesh,
629
  }
630
  )
631
+ logger.info("Loading SAM3D model weights completed.")
632
 
633
  if self.compile_model:
634
  logger.info("Compiling model...")
embodied_gen/utils/process_media.py CHANGED
@@ -96,7 +96,7 @@ def render_asset3d(
96
  image_paths = render_asset3d(
97
  mesh_path="path_to_mesh.obj",
98
  output_root="path_to_save_dir",
99
- num_images=6,
100
  elevation=(30, -30),
101
  output_subdir="renders",
102
  no_index_file=True,
 
96
  image_paths = render_asset3d(
97
  mesh_path="path_to_mesh.obj",
98
  output_root="path_to_save_dir",
99
+ num_images=4,
100
  elevation=(30, -30),
101
  output_subdir="renders",
102
  no_index_file=True,
embodied_gen/utils/tags.py CHANGED
@@ -1 +1 @@
1
- VERSION = "v0.1.6"
 
1
+ VERSION = "v0.1.7"
embodied_gen/utils/trender.py CHANGED
@@ -21,18 +21,25 @@ from collections import defaultdict
21
  import numpy as np
22
  import spaces
23
  import torch
 
24
  from tqdm import tqdm
25
 
26
  current_file_path = os.path.abspath(__file__)
27
  current_dir = os.path.dirname(current_file_path)
28
  sys.path.append(os.path.join(current_dir, "../.."))
29
  from thirdparty.TRELLIS.trellis.renderers import GaussianRenderer, MeshRenderer
 
 
 
 
30
  from thirdparty.TRELLIS.trellis.utils.render_utils import (
31
  yaw_pitch_r_fov_to_extrinsics_intrinsics,
32
  )
33
 
34
  __all__ = [
35
  "render_video",
 
 
36
  ]
37
 
38
 
@@ -140,3 +147,47 @@ def render_video(
140
  )
141
 
142
  return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  import numpy as np
22
  import spaces
23
  import torch
24
+ from easydict import EasyDict as edict
25
  from tqdm import tqdm
26
 
27
  current_file_path = os.path.abspath(__file__)
28
  current_dir = os.path.dirname(current_file_path)
29
  sys.path.append(os.path.join(current_dir, "../.."))
30
  from thirdparty.TRELLIS.trellis.renderers import GaussianRenderer, MeshRenderer
31
+ from thirdparty.TRELLIS.trellis.representations import (
32
+ Gaussian,
33
+ MeshExtractResult,
34
+ )
35
  from thirdparty.TRELLIS.trellis.utils.render_utils import (
36
  yaw_pitch_r_fov_to_extrinsics_intrinsics,
37
  )
38
 
39
  __all__ = [
40
  "render_video",
41
+ "pack_state",
42
+ "unpack_state",
43
  ]
44
 
45
 
 
147
  )
148
 
149
  return result
150
+
151
+
152
+ @spaces.GPU
153
+ def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
154
+ return {
155
+ "gaussian": {
156
+ **gs.init_params,
157
+ "_xyz": gs._xyz.cpu().numpy(),
158
+ "_features_dc": gs._features_dc.cpu().numpy(),
159
+ "_scaling": gs._scaling.cpu().numpy(),
160
+ "_rotation": gs._rotation.cpu().numpy(),
161
+ "_opacity": gs._opacity.cpu().numpy(),
162
+ },
163
+ "mesh": {
164
+ "vertices": mesh.vertices.cpu().numpy(),
165
+ "faces": mesh.faces.cpu().numpy(),
166
+ },
167
+ }
168
+
169
+
170
+ def unpack_state(state: dict, device: str = "cpu") -> tuple[Gaussian, dict]:
171
+ gs = Gaussian(
172
+ aabb=state["gaussian"]["aabb"],
173
+ sh_degree=state["gaussian"]["sh_degree"],
174
+ mininum_kernel_size=state["gaussian"]["mininum_kernel_size"],
175
+ scaling_bias=state["gaussian"]["scaling_bias"],
176
+ opacity_bias=state["gaussian"]["opacity_bias"],
177
+ scaling_activation=state["gaussian"]["scaling_activation"],
178
+ device=device,
179
+ )
180
+ gs._xyz = torch.tensor(state["gaussian"]["_xyz"], device=device)
181
+ gs._features_dc = torch.tensor(
182
+ state["gaussian"]["_features_dc"], device=device
183
+ )
184
+ gs._scaling = torch.tensor(state["gaussian"]["_scaling"], device=device)
185
+ gs._rotation = torch.tensor(state["gaussian"]["_rotation"], device=device)
186
+ gs._opacity = torch.tensor(state["gaussian"]["_opacity"], device=device)
187
+
188
+ mesh = edict(
189
+ vertices=torch.tensor(state["mesh"]["vertices"], device=device),
190
+ faces=torch.tensor(state["mesh"]["faces"], device=device),
191
+ )
192
+
193
+ return gs, mesh
embodied_gen/validators/aesthetic_predictor.py CHANGED
@@ -125,7 +125,11 @@ class AestheticPredictor:
125
  Returns:
126
  float: Predicted aesthetic score.
127
  """
128
- pil_image = Image.open(image_path)
 
 
 
 
129
  image = self.preprocess(pil_image).unsqueeze(0).to(self.device)
130
 
131
  with torch.no_grad():
 
125
  Returns:
126
  float: Predicted aesthetic score.
127
  """
128
+ if isinstance(image_path, str):
129
+ pil_image = Image.open(image_path)
130
+ else:
131
+ pil_image = image_path
132
+
133
  image = self.preprocess(pil_image).unsqueeze(0).to(self.device)
134
 
135
  with torch.no_grad():
embodied_gen/validators/quality_checkers.py CHANGED
@@ -126,6 +126,30 @@ class MeshGeoChecker(BaseChecker):
126
  super().__init__(prompt, verbose)
127
  self.gpt_client = gpt_client
128
  if self.prompt is None:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
  self.prompt = """
130
  You are an expert in evaluating the geometry quality of generated 3D asset.
131
  You will be given rendered views of a generated 3D asset, type {}, with black background.
@@ -137,16 +161,13 @@ class MeshGeoChecker(BaseChecker):
137
  - Minor flaws, asymmetries, or simplifications (e.g., less detail on sides or back,
138
  soft edges) are acceptable if the object is structurally sound and recognizable.
139
  - Only evaluate geometry. Do not assess texture quality.
140
- - The asset should not contain any unrelated elements, such as
141
- ground planes, platforms, or background props (e.g., paper, flooring).
142
 
143
- If all the above criteria are met, return "YES". Otherwise, return
144
  "NO" followed by a brief explanation (no more than 20 words).
145
 
146
  Example:
147
- Images show a yellow cup standing on a flat white plane -> NO
148
- -> Response: NO: extra white surface under the object.
149
- Image shows a chair with simplified back legs and soft edges → YES
150
  """
151
 
152
  def query(
 
126
  super().__init__(prompt, verbose)
127
  self.gpt_client = gpt_client
128
  if self.prompt is None:
129
+ # Old version for TRELLIS.
130
+ # self.prompt = """
131
+ # You are an expert in evaluating the geometry quality of generated 3D asset.
132
+ # You will be given rendered views of a generated 3D asset, type {}, with black background.
133
+ # Your task is to evaluate the quality of the 3D asset generation,
134
+ # including geometry, structure, and appearance, based on the rendered views.
135
+ # Criteria:
136
+ # - Is the object in the image a single, complete, and well-formed instance,
137
+ # without truncation, missing parts, overlapping duplicates, or redundant geometry?
138
+ # - Minor flaws, asymmetries, or simplifications (e.g., less detail on sides or back,
139
+ # soft edges) are acceptable if the object is structurally sound and recognizable.
140
+ # - Only evaluate geometry. Do not assess texture quality.
141
+ # - The asset should not contain any unrelated elements, such as
142
+ # ground planes, platforms, or background props (e.g., paper, flooring).
143
+
144
+ # If all the above criteria are met, return "YES". Otherwise, return
145
+ # "NO" followed by a brief explanation (no more than 20 words).
146
+
147
+ # Example:
148
+ # Images show a yellow cup standing on a flat white plane -> NO
149
+ # -> Response: NO: extra white surface under the object.
150
+ # Image shows a chair with simplified back legs and soft edges -> YES
151
+ # """
152
+
153
  self.prompt = """
154
  You are an expert in evaluating the geometry quality of generated 3D asset.
155
  You will be given rendered views of a generated 3D asset, type {}, with black background.
 
161
  - Minor flaws, asymmetries, or simplifications (e.g., less detail on sides or back,
162
  soft edges) are acceptable if the object is structurally sound and recognizable.
163
  - Only evaluate geometry. Do not assess texture quality.
 
 
164
 
165
+ If all the above criteria are met, return "YES" only. Otherwise, return
166
  "NO" followed by a brief explanation (no more than 20 words).
167
 
168
  Example:
169
+ Image shows a chair with one leg missing -> NO: the chair missing leg.
170
+ Image shows a geometrically complete cup -> YES
 
171
  """
172
 
173
  def query(
embodied_gen/validators/urdf_convertor.py CHANGED
@@ -27,7 +27,10 @@ import trimesh
27
  from scipy.spatial.transform import Rotation
28
  from embodied_gen.data.convex_decomposer import decompose_convex_mesh
29
  from embodied_gen.utils.gpt_clients import GPT_CLIENT, GPTclient
30
- from embodied_gen.utils.process_media import render_asset3d
 
 
 
31
  from embodied_gen.utils.tags import VERSION
32
 
33
  logging.basicConfig(level=logging.INFO)
@@ -482,7 +485,7 @@ class URDFGenerator(object):
482
  output_subdir=self.output_render_dir,
483
  no_index_file=True,
484
  )
485
-
486
  response = self.gpt_client.query(text_prompt, image_path)
487
  # logger.info(response)
488
  if response is None:
 
27
  from scipy.spatial.transform import Rotation
28
  from embodied_gen.data.convex_decomposer import decompose_convex_mesh
29
  from embodied_gen.utils.gpt_clients import GPT_CLIENT, GPTclient
30
+ from embodied_gen.utils.process_media import (
31
+ combine_images_to_grid,
32
+ render_asset3d,
33
+ )
34
  from embodied_gen.utils.tags import VERSION
35
 
36
  logging.basicConfig(level=logging.INFO)
 
485
  output_subdir=self.output_render_dir,
486
  no_index_file=True,
487
  )
488
+ # image_path = combine_images_to_grid(image_path)
489
  response = self.gpt_client.query(text_prompt, image_path)
490
  # logger.info(response)
491
  if response is None: