Spaces:
Sleeping
Sleeping
| from ..custom_types import * | |
| from ..options import Options | |
| from ..utils import train_utils, mcubes_meshing, files_utils, mesh_utils | |
| from ..models.occ_gmm import Spaghetti | |
| from ..models import models_utils | |
| class Inference: | |
| def get_occ_fun(self, z: T): | |
| def forward(x: T) -> T: | |
| nonlocal z | |
| x = x.unsqueeze(0) | |
| out = self.model.occupancy_network(x, z)[0, :] | |
| out = 2 * out.sigmoid_() - 1 | |
| return out | |
| if z.dim() == 2: | |
| z = z.unsqueeze(0) | |
| return forward | |
| def get_mesh(self, z: T, res: int) -> Optional[T_Mesh]: | |
| mesh = self.meshing.occ_meshing(self.get_occ_fun(z), res=res) | |
| return mesh | |
| def plot_occ(self, z: Union[T, TS], z_base, gmms: Optional[TS], fixed_items: T, | |
| folder_name: str, res=200, verbose=False): | |
| for i in range(len(z)): | |
| mesh = self.get_mesh(z[i], res) | |
| name = f'{fixed_items[i]:04d}' | |
| if mesh is not None: | |
| files_utils.export_mesh(mesh, f'{self.opt.cp_folder}/{folder_name}/occ/{name}') | |
| files_utils.save_pickle(z_base[i].detach().cpu(), f'{self.opt.cp_folder}/{folder_name}/occ/{name}') | |
| if gmms is not None: | |
| files_utils.export_gmm(gmms, i, f'{self.opt.cp_folder}/{folder_name}/occ/{name}') | |
| if verbose: | |
| print(f'done {i + 1:d}/{len(z):d}') | |
| def load_file(self, info_path, disclude: Optional[List[int]] = None): | |
| info = files_utils.load_pickle(''.join(info_path)) | |
| keys = list(info['ids'].keys()) | |
| items = map(lambda x: int(x.split('_')[1]) if type(x) is str else x, keys) | |
| items = torch.tensor(list(items), dtype=torch.int64, device=self.device) | |
| zh, _, gmms_sanity, _ = self.model.get_embeddings(items) | |
| gmms = [item for item in info['gmm']] | |
| zh_ = [] | |
| split = [] | |
| gmm_mask = torch.ones(gmms[0].shape[2], dtype=torch.bool) | |
| counter = 0 | |
| # gmms_ = [[] for _ in range(len(gmms))] | |
| for i, key in enumerate(keys): | |
| gaussian_inds = info['ids'][key] | |
| if disclude is not None: | |
| for j in range(len(gaussian_inds)): | |
| gmm_mask[j + counter] = gaussian_inds[j] not in disclude | |
| counter += len(gaussian_inds) | |
| gaussian_inds = [ind for ind in gaussian_inds if ind not in disclude] | |
| info['ids'][key] = gaussian_inds | |
| gaussian_inds = torch.tensor(gaussian_inds, dtype=torch.int64) | |
| zh_.append(zh[i, gaussian_inds]) | |
| split.append(len(split) + torch.ones(len(info['ids'][key]), dtype=torch.int64, device=self.device)) | |
| zh_ = torch.cat(zh_, dim=0).unsqueeze(0).to(self.device) | |
| gmms = [item[:, :, gmm_mask].to(self.device) for item in info['gmm']] | |
| return zh_, gmms, split, info['ids'] | |
| def get_z_from_file(self, info_path): | |
| zh_, gmms, split, _ = self.load_file(info_path) | |
| zh_ = self.model.merge_zh_step_a(zh_, [gmms]) | |
| zh, _ = self.model.affine_transformer.forward_with_attention(zh_) | |
| # gmms_ = [torch.cat(item, dim=1).unsqueeze(0) for item in gmms_] | |
| # zh, _ = self.model.merge_zh(zh_, [gmms]) | |
| return zh, zh_, gmms, torch.cat(split) | |
| def plot_from_info(self, info_path, res): | |
| zh, zh_, gmms, split = self.get_z_from_file(info_path) | |
| mesh = self.get_mesh(zh[0], res, gmms) | |
| if mesh is not None: | |
| attention = self.get_attention_faces(mesh, zh, fixed_z=split) | |
| else: | |
| attention = None | |
| return mesh, attention | |
| def combine_and_pad(zh_a: T, zh_b: T) -> Tuple[T, TN]: | |
| if zh_a.shape[1] == zh_b.shape[1]: | |
| mask = None | |
| else: | |
| pad_length = max(zh_a.shape[1], zh_b.shape[1]) | |
| mask = torch.zeros(2, pad_length, device=zh_a.device, dtype=torch.bool) | |
| padding = torch.zeros(1, abs(zh_a.shape[1] - zh_b.shape[1]), zh_a.shape[-1], device=zh_a.device) | |
| if zh_a.shape[1] > zh_b.shape[1]: | |
| mask[1, zh_b.shape[1]:] = True | |
| zh_b = torch.cat((zh_b, padding), dim=1) | |
| else: | |
| mask[0, zh_a.shape[1]:] = True | |
| zh_a = torch.cat((zh_a, padding), dim=1) | |
| return torch.cat((zh_a, zh_b), dim=0), mask | |
| def get_intersection_z(z_a: T, z_b: T) -> T: | |
| diff = (z_a[0, :, None, :] - z_b[0, None]).abs().sum(-1) | |
| diff_a = diff.min(1)[0].lt(.1) | |
| diff_b = diff.min(0)[0].lt(.1) | |
| if diff_a.shape[0] != diff_b.shape[0]: | |
| padding = torch.zeros(abs(diff_a.shape[0] - diff_b.shape[0]), device=z_a.device, dtype=torch.bool) | |
| if diff_a.shape[0] > diff_b.shape[0]: | |
| diff_b = torch.cat((diff_b, padding)) | |
| else: | |
| diff_a = torch.cat((diff_a, padding)) | |
| return torch.cat((diff_a, diff_b)) | |
| def get_attention_points(self, vs: T, zh: T, mask: TN = None, alpha: TN = None): | |
| vs = vs.unsqueeze(0) | |
| attention = self.model.occupancy_network.forward_attention(vs, zh, mask=mask, alpha=alpha) | |
| attention = torch.stack(attention, 0).mean(0).mean(-1) | |
| attention = attention.permute(1, 0, 2).reshape(attention.shape[1], -1) | |
| attention_max = attention.argmax(-1) | |
| return attention_max | |
| def get_attention_faces(self, mesh: T_Mesh, zh: T, mask: TN = None, fixed_z: TN = None, alpha: TN = None): | |
| coords = mesh[0][mesh[1]].mean(1).to(zh.device) | |
| attention_max = self.get_attention_points(coords, zh, mask, alpha) | |
| if fixed_z is not None: | |
| attention_select = fixed_z[attention_max].cpu() | |
| else: | |
| attention_select = attention_max | |
| return attention_select | |
| def plot_folder(self, *folders, res: int = 256): | |
| logger = train_utils.Logger() | |
| for folder in folders: | |
| paths = files_utils.collect(folder, '.pkl') | |
| logger.start(len(paths)) | |
| for path in paths: | |
| name = path[1] | |
| out_path = f"{self.opt.cp_folder}/from_ui/{name}" | |
| mesh, colors = self.plot_from_info(path, res) | |
| if mesh is not None: | |
| files_utils.export_mesh(mesh, out_path) | |
| files_utils.export_list(colors.tolist(), f"{out_path}_faces") | |
| logger.reset_iter() | |
| logger.stop() | |
| def get_zh_from_idx(self, items: T): | |
| zh, _, gmms, __ = self.model.get_embeddings(items.to(self.device)) | |
| zh, attn_b = self.model.merge_zh(zh, gmms) | |
| return zh, gmms | |
| def device(self): | |
| return self.opt.device | |
| def get_new_ids(self, folder_name, nums_sample): | |
| names = [int(path[1]) for path in files_utils.collect(f'{self.opt.cp_folder}/{folder_name}/occ/', '.obj')] | |
| ids = torch.arange(nums_sample) | |
| if len(names) == 0: | |
| return ids + self.opt.dataset_size | |
| return ids + max(max(names) + 1, self.opt.dataset_size) | |
| def random_plot(self, folder_name: str, nums_sample, res=200, verbose=False): | |
| zh_base, gmms = self.model.random_samples(nums_sample) | |
| zh, attn_b = self.model.merge_zh(zh_base, gmms) | |
| numbers = self.get_new_ids(folder_name, nums_sample) | |
| self.plot_occ(zh, zh_base, gmms, numbers, folder_name, verbose=verbose, res=res) | |
| def plot(self, folder_name: str, nums_sample: int, verbose=False, res: int = 200): | |
| if self.model.opt.dataset_size < nums_sample: | |
| fixed_items = torch.arange(self.model.opt.dataset_size) | |
| else: | |
| fixed_items = torch.randint(low=0, high=self.opt.dataset_size, size=(nums_sample,)) | |
| zh_base, _, gmms = self.model.get_embeddings(fixed_items.to(self.device)) | |
| zh, attn_b = self.model.merge_zh(zh_base, gmms) | |
| self.plot_occ(zh, zh_base, gmms, fixed_items, folder_name, verbose=verbose, res=res) | |
| def get_mesh_from_mid(self, gmm, included: T, res: int) -> Optional[T_Mesh]: | |
| if self.mid is None: | |
| return None | |
| gmm = [elem.to(self.device) for elem in gmm] | |
| included = included.to(device=self.device) | |
| mid_ = self.mid[included[:, 0], included[:, 1]].unsqueeze(0) | |
| zh = self.model.merge_zh(mid_, gmm)[0] | |
| mesh = self.get_mesh(zh[0], res) | |
| return mesh | |
| def set_items(self, items: T): | |
| self.mid = items.to(self.device) | |
| def __init__(self, opt: Options): | |
| self.opt = opt | |
| model: Tuple[Spaghetti, Options] = train_utils.model_lc(opt) | |
| self.model, self.opt = model | |
| self.model.eval() | |
| self.mid: Optional[T] = None | |
| self.gmms: Optional[TN] = None | |
| self.meshing = mcubes_meshing.MarchingCubesMeshing(self.device, scale=1.) | |