Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import os | |
| import pickle | |
| import salad.spaghetti.constants as const | |
| from salad.spaghetti.custom_types import * | |
| class Options: | |
| def load(self): | |
| device = self.device | |
| if os.path.isfile(self.save_path): | |
| print(f'loading opitons from {self.save_path}') | |
| with open(self.save_path, 'rb') as f: | |
| options = pickle.load(f) | |
| options.device = device | |
| return options | |
| return self | |
| def save(self): | |
| if os.path.isdir(self.cp_folder): | |
| # self.already_saved = True | |
| with open(self.save_path, 'wb') as f: | |
| pickle.dump(self, f, pickle.HIGHEST_PROTOCOL) | |
| def info(self) -> str: | |
| return f'{self.model_name}_{self.tag}' | |
| def cp_folder(self): | |
| return f'{const.CHECKPOINTS_ROOT}{self.info}' | |
| def save_path(self): | |
| return f'{const.CHECKPOINTS_ROOT}{self.info}/options.pkl' | |
| def fill_args(self, args): | |
| for arg in args: | |
| if hasattr(self, arg): | |
| setattr(self, arg, args[arg]) | |
| def __init__(self, **kwargs): | |
| self.device = CUDA(0) | |
| self.tag = 'airplanes' | |
| self.dataset_name = 'shapenet_airplanes_wm_sphere_sym_train' | |
| self.epochs = 2000 | |
| self.model_name = 'spaghetti' | |
| self.dim_z = 256 | |
| self.pos_dim = 256 - 3 | |
| self.dim_h = 512 | |
| self.dim_zh = 512 | |
| self.num_gaussians = 16 | |
| self.min_split = 4 | |
| self.max_split = 12 | |
| self.gmm_weight = 1 | |
| self.decomposition_network = 'transformer' | |
| self.decomposition_num_layers = 4 | |
| self.num_layers = 4 | |
| self.num_heads = 4 | |
| self.num_layers_head = 6 | |
| self.num_heads_head = 8 | |
| self.head_occ_size = 5 | |
| self.head_occ_type = 'skip' | |
| self.batch_size = 18 | |
| self.num_samples = 2000 | |
| self.dataset_size = -1 | |
| self.symmetric = (True, False, False) | |
| self.data_symmetric = (True, False, False) | |
| self.lr_decay = .9 | |
| self.lr_decay_every = 500 | |
| self.warm_up = 2000 | |
| self.reg_weight = 1e-4 | |
| self.disentanglement = True | |
| self.use_encoder = True | |
| self.disentanglement_weight = 1 | |
| self.augmentation_rotation = 0.3 | |
| self.augmentation_scale = .2 | |
| self.augmentation_translation = .3 | |
| self.fill_args(kwargs) | |