Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import numpy | |
| import json | |
| import zipfile | |
| import torch | |
| from PIL import Image | |
| # from transformers import CLIPImageProcessor | |
| from torch.utils.data import Dataset | |
| import io | |
| from omegaconf import OmegaConf | |
| import numpy as np | |
| # from torchvision import transforms | |
| # from einops import rearrange | |
| # import random | |
| # import os | |
| # from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler, DDIMScheduler | |
| # import time | |
| # import io | |
| # import array | |
| # import numpy as np | |
| # | |
| # from training.triplane import TriPlaneGenerator | |
| def to_rgb_image(maybe_rgba: Image.Image): | |
| if maybe_rgba.mode == 'RGB': | |
| return maybe_rgba | |
| elif maybe_rgba.mode == 'RGBA': | |
| rgba = maybe_rgba | |
| img = numpy.random.randint(127, 128, size=[rgba.size[1], rgba.size[0], 3], dtype=numpy.uint8) | |
| img = Image.fromarray(img, 'RGB') | |
| img.paste(rgba, mask=rgba.getchannel('A')) | |
| return img | |
| else: | |
| raise ValueError("Unsupported image type.", maybe_rgba.mode) | |
| # image(contain style),z,pose,text | |
| class TriplaneDataset(Dataset): | |
| # image, triplane, ref_feature | |
| def __init__(self, json_file, data_base_dir, model_names): | |
| super().__init__() | |
| self.dict_data_image = json.load(open(json_file)) # {'image_name': pose} | |
| self.data_base_dir = data_base_dir | |
| self.data_list = list(self.dict_data_image.keys()) | |
| self.zip_file_dict = {} | |
| config_gan_model = OmegaConf.load(model_names) | |
| all_models = config_gan_model['gan_models'].keys() | |
| for model_name in all_models: | |
| zipfile_path = os.path.join(self.data_base_dir, model_name+'.zip') | |
| zipfile_load = zipfile.ZipFile(zipfile_path) | |
| self.zip_file_dict[model_name] = zipfile_load | |
| def getdata(self, idx): | |
| # need z and expression and model name | |
| # image:"seed0035.png" | |
| # data_each_dict = { | |
| # 'vert_dir': vert_dir, | |
| # 'z_dir': z_dir, | |
| # 'pose_dir': pose_dir, | |
| # 'img_dir': img_dir, | |
| # 'model_name': model_name | |
| # } | |
| data_name = self.data_list[idx] | |
| data_model_name = self.dict_data_image[data_name]['model_name'] | |
| zipfile_loaded = self.zip_file_dict[data_model_name] | |
| # zipfile_path = os.path.join(self.data_base_dir, data_model_name) | |
| # zipfile_loaded = zipfile.ZipFile(zipfile_path) | |
| with zipfile_loaded.open(self.dict_data_image[data_name]['z_dir'], 'r') as f: | |
| buffer = io.BytesIO(f.read()) | |
| data_z = torch.load(buffer) | |
| buffer.close() | |
| f.close() | |
| with zipfile_loaded.open(self.dict_data_image[data_name]['vert_dir'], 'r') as ff: | |
| buffer_v = io.BytesIO(ff.read()) | |
| data_vert = torch.load(buffer_v) | |
| buffer_v.close() | |
| ff.close() | |
| # raw_image = to_rgb_image(Image.open(f)) | |
| # | |
| # data_model_name = self.dict_data_image[data_name]['model_name'] | |
| # data_z_dir = os.path.join(self.data_base_dir, data_model_name, self.dict_data_image[data_name]['z_dir']) | |
| # data_vert_dir = os.path.join(self.data_base_dir, data_model_name, self.dict_data_image[data_name]['vert_dir']) | |
| # data_z = torch.load(data_z_dir) | |
| # data_vert = torch.load(data_vert_dir) | |
| return { | |
| "data_z": data_z, | |
| "data_vert": data_vert, | |
| "data_model_name": data_model_name | |
| } | |
| def __getitem__(self, idx): | |
| for _ in range(20): | |
| try: | |
| return self.getdata(idx) | |
| except Exception as e: | |
| print(f"Error details: {str(e)}") | |
| idx = np.random.randint(len(self)) | |
| raise RuntimeError('Too many bad data.') | |
| def __len__(self): | |
| return len(self.data_list) | |
| # for zip files | |