| """Modified from https://github.com/guoyww/AnimateDiff/blob/main/app.py |
| """ |
| import base64 |
| import gc |
| import json |
| import os |
| import hashlib |
| import random |
| from datetime import datetime |
| from glob import glob |
|
|
| import cv2 |
| import gradio as gr |
| import numpy as np |
| import pkg_resources |
| import requests |
| import torch |
| from diffusers import (CogVideoXDDIMScheduler, DDIMScheduler, |
| DPMSolverMultistepScheduler, |
| EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, |
| FlowMatchEulerDiscreteScheduler, PNDMScheduler) |
| from omegaconf import OmegaConf |
| from PIL import Image |
| from safetensors import safe_open |
|
|
| from ..data.bucket_sampler import ASPECT_RATIO_512, get_closest_ratio |
| from ..utils.utils import save_videos_grid |
| from ..utils.fm_solvers import FlowDPMSolverMultistepScheduler |
| from ..utils.fm_solvers_unipc import FlowUniPCMultistepScheduler |
| from ..dist import set_multi_gpus_devices |
|
|
| gradio_version = pkg_resources.get_distribution("gradio").version |
| gradio_version_is_above_4 = True if int(gradio_version.split('.')[0]) >= 4 else False |
|
|
| css = """ |
| .toolbutton { |
| margin-buttom: 0em 0em 0em 0em; |
| max-width: 2.5em; |
| min-width: 2.5em !important; |
| height: 2.5em; |
| } |
| """ |
|
|
| ddpm_scheduler_dict = { |
| "Euler": EulerDiscreteScheduler, |
| "Euler A": EulerAncestralDiscreteScheduler, |
| "DPM++": DPMSolverMultistepScheduler, |
| "PNDM": PNDMScheduler, |
| "DDIM": DDIMScheduler, |
| "DDIM_Origin": DDIMScheduler, |
| "DDIM_Cog": CogVideoXDDIMScheduler, |
| } |
| flow_scheduler_dict = { |
| "Flow": FlowMatchEulerDiscreteScheduler, |
| "Flow_Unipc": FlowUniPCMultistepScheduler, |
| "Flow_DPM++": FlowDPMSolverMultistepScheduler, |
| } |
| all_cheduler_dict = {**ddpm_scheduler_dict, **flow_scheduler_dict} |
|
|
| class Fun_Controller: |
| def __init__( |
| self, GPU_memory_mode, scheduler_dict, model_name=None, model_type="Inpaint", |
| config_path=None, ulysses_degree=1, ring_degree=1, |
| fsdp_dit=False, fsdp_text_encoder=False, compile_dit=False, |
| weight_dtype=None, savedir_sample=None, |
| ): |
| |
| self.basedir = os.getcwd() |
| self.config_dir = os.path.join(self.basedir, "config") |
| self.diffusion_transformer_dir = os.path.join(self.basedir, "models", "Diffusion_Transformer") |
| self.motion_module_dir = os.path.join(self.basedir, "models", "Motion_Module") |
| self.personalized_model_dir = os.path.join(self.basedir, "models", "Personalized_Model") |
| if savedir_sample is None: |
| self.savedir_sample = os.path.join(self.basedir, "samples", datetime.now().strftime("Gradio-%Y-%m-%dT%H-%M-%S")) |
| else: |
| self.savedir_sample = savedir_sample |
| os.makedirs(self.savedir_sample, exist_ok=True) |
|
|
| self.GPU_memory_mode = GPU_memory_mode |
| self.model_name = model_name |
| self.diffusion_transformer_dropdown = model_name |
| self.scheduler_dict = scheduler_dict |
| self.model_type = model_type |
| if config_path is not None: |
| self.config_path = os.path.realpath(config_path) |
| self.config = OmegaConf.load(config_path) |
| else: |
| self.config_path = None |
| self.ulysses_degree = ulysses_degree |
| self.ring_degree = ring_degree |
| self.fsdp_dit = fsdp_dit |
| self.fsdp_text_encoder = fsdp_text_encoder |
| self.compile_dit = compile_dit |
| self.weight_dtype = weight_dtype |
| self.device = set_multi_gpus_devices(self.ulysses_degree, self.ring_degree) |
|
|
| self.diffusion_transformer_list = [] |
| self.motion_module_list = [] |
| self.personalized_model_list = [] |
| self.config_list = [] |
|
|
| |
| self.tokenizer = None |
| self.text_encoder = None |
| self.vae = None |
| self.transformer = None |
| self.transformer_2 = None |
| self.pipeline = None |
| self.base_model_path = "none" |
| self.base_model_2_path = "none" |
| self.lora_model_path = "none" |
| self.lora_model_2_path = "none" |
| |
| self.refresh_config() |
| self.refresh_diffusion_transformer() |
| self.refresh_personalized_model() |
| if model_name != None: |
| self.update_diffusion_transformer(model_name) |
|
|
| def refresh_config(self): |
| config_list = [] |
| for root, dirs, files in os.walk(self.config_dir): |
| for file in files: |
| if file.endswith(('.yaml', '.yml')): |
| full_path = os.path.join(root, file) |
| config_list.append(full_path) |
| self.config_list = config_list |
|
|
| def refresh_diffusion_transformer(self): |
| self.diffusion_transformer_list = sorted(glob(os.path.join(self.diffusion_transformer_dir, "*/"))) |
|
|
| def refresh_personalized_model(self): |
| personalized_model_list = sorted(glob(os.path.join(self.personalized_model_dir, "*.safetensors"))) |
| self.personalized_model_list = [os.path.basename(p) for p in personalized_model_list] |
|
|
| def update_model_type(self, model_type): |
| self.model_type = model_type |
|
|
| def update_config(self, config_dropdown): |
| self.config_path = config_dropdown |
| self.config = OmegaConf.load(config_dropdown) |
| print(f"Update config: {config_dropdown}") |
|
|
| def update_diffusion_transformer(self, diffusion_transformer_dropdown): |
| pass |
|
|
| def update_base_model(self, base_model_dropdown, is_checkpoint_2=False): |
| if not is_checkpoint_2: |
| self.base_model_path = base_model_dropdown |
| else: |
| self.base_model_2_path = base_model_dropdown |
| print(f"Update base model: {base_model_dropdown}") |
| if base_model_dropdown == "none": |
| return gr.update() |
| if self.transformer is None and not is_checkpoint_2: |
| gr.Info(f"Please select a pretrained model path.") |
| print(f"Please select a pretrained model path.") |
| return gr.update(value=None) |
| elif self.transformer_2 is None and is_checkpoint_2: |
| gr.Info(f"Please select a pretrained model path.") |
| print(f"Please select a pretrained model path.") |
| return gr.update(value=None) |
| else: |
| base_model_dropdown = os.path.join(self.personalized_model_dir, base_model_dropdown) |
| base_model_state_dict = {} |
| with safe_open(base_model_dropdown, framework="pt", device="cpu") as f: |
| for key in f.keys(): |
| base_model_state_dict[key] = f.get_tensor(key) |
| if not is_checkpoint_2: |
| self.transformer.load_state_dict(base_model_state_dict, strict=False) |
| else: |
| self.transformer_2.load_state_dict(base_model_state_dict, strict=False) |
| print("Update base model done") |
| return gr.update() |
|
|
| def update_lora_model(self, lora_model_dropdown, is_checkpoint_2=False): |
| print(f"Update lora model: {lora_model_dropdown}") |
| if lora_model_dropdown == "none": |
| self.lora_model_path = "none" |
| return gr.update() |
| lora_model_dropdown = os.path.join(self.personalized_model_dir, lora_model_dropdown) |
| if not is_checkpoint_2: |
| self.lora_model_path = lora_model_dropdown |
| else: |
| self.lora_model_2_path = lora_model_dropdown |
| return gr.update() |
|
|
| def clear_cache(self,): |
| gc.collect() |
| torch.cuda.empty_cache() |
| torch.cuda.ipc_collect() |
|
|
| def auto_model_clear_cache(self, model): |
| origin_device = model.device |
| model = model.to("cpu") |
| gc.collect() |
| torch.cuda.empty_cache() |
| torch.cuda.ipc_collect() |
| model = model.to(origin_device) |
| |
| def input_check(self, |
| resize_method, |
| generation_method, |
| start_image, |
| end_image, |
| validation_video, |
| control_video, |
| is_api = False, |
| ): |
| if self.transformer is None: |
| if is_api: |
| return "", f"Please select a pretrained model path." |
| else: |
| raise gr.Error(f"Please select a pretrained model path.") |
| |
| if control_video is not None and self.model_type == "Inpaint": |
| if is_api: |
| return "", f"If specifying the control video, please set the model_type == \"Control\". " |
| else: |
| raise gr.Error(f"If specifying the control video, please set the model_type == \"Control\". ") |
|
|
| if control_video is None and self.model_type == "Control": |
| if is_api: |
| return "", f"If set the model_type == \"Control\", please specifying the control video. " |
| else: |
| raise gr.Error(f"If set the model_type == \"Control\", please specifying the control video. ") |
|
|
| if resize_method == "Resize according to Reference": |
| if start_image is None and validation_video is None and control_video is None: |
| if is_api: |
| return "", f"Please upload an image when using \"Resize according to Reference\"." |
| else: |
| raise gr.Error(f"Please upload an image when using \"Resize according to Reference\".") |
|
|
| if self.transformer.config.in_channels == self.vae.config.latent_channels and start_image is not None: |
| if is_api: |
| return "", f"Please select an image to video pretrained model while using image to video." |
| else: |
| raise gr.Error(f"Please select an image to video pretrained model while using image to video.") |
|
|
| if self.transformer.config.in_channels == self.vae.config.latent_channels and generation_method == "Long Video Generation": |
| if is_api: |
| return "", f"Please select an image to video pretrained model while using long video generation." |
| else: |
| raise gr.Error(f"Please select an image to video pretrained model while using long video generation.") |
| |
| if start_image is None and end_image is not None: |
| if is_api: |
| return "", f"If specifying the ending image of the video, please specify a starting image of the video." |
| else: |
| raise gr.Error(f"If specifying the ending image of the video, please specify a starting image of the video.") |
| return "", "OK" |
|
|
| def get_height_width_from_reference( |
| self, |
| base_resolution, |
| start_image, |
| validation_video, |
| control_video, |
| ): |
| spatial_compression_ratio = self.vae.config.spatial_compression_ratio if hasattr(self.vae.config, "spatial_compression_ratio") else 8 |
| aspect_ratio_sample_size = {key : [x / 512 * base_resolution for x in ASPECT_RATIO_512[key]] for key in ASPECT_RATIO_512.keys()} |
| if self.model_type == "Inpaint": |
| if validation_video is not None: |
| original_width, original_height = Image.fromarray(cv2.VideoCapture(validation_video).read()[1]).size |
| else: |
| original_width, original_height = start_image[0].size if type(start_image) is list else Image.open(start_image).size |
| else: |
| original_width, original_height = Image.fromarray(cv2.VideoCapture(control_video).read()[1]).size |
| closest_size, closest_ratio = get_closest_ratio(original_height, original_width, ratios=aspect_ratio_sample_size) |
| height_slider, width_slider = [int(x / spatial_compression_ratio / 2) * spatial_compression_ratio * 2 for x in closest_size] |
| return height_slider, width_slider |
|
|
| def save_outputs(self, is_image, length_slider, sample, fps): |
| def save_results(): |
| if not os.path.exists(self.savedir_sample): |
| os.makedirs(self.savedir_sample, exist_ok=True) |
| index = len([path for path in os.listdir(self.savedir_sample)]) + 1 |
| prefix = str(index).zfill(8) |
|
|
| md5_hash = hashlib.md5(sample.cpu().numpy().tobytes()).hexdigest() |
|
|
| if is_image or length_slider == 1: |
| save_sample_path = os.path.join(self.savedir_sample, prefix + f"-{md5_hash}.png") |
| print(f"Saving to {save_sample_path}") |
| image = sample[0, :, 0] |
| image = image.transpose(0, 1).transpose(1, 2) |
| image = (image * 255).numpy().astype(np.uint8) |
| image = Image.fromarray(image) |
| image.save(save_sample_path) |
|
|
| else: |
| save_sample_path = os.path.join(self.savedir_sample, prefix + f"-{md5_hash}.mp4") |
| print(f"Saving to {save_sample_path}") |
| save_videos_grid(sample, save_sample_path, fps=fps) |
| return save_sample_path |
|
|
| if self.ulysses_degree * self.ring_degree > 1: |
| import torch.distributed as dist |
| if dist.get_rank() == 0: |
| save_sample_path = save_results() |
| else: |
| save_sample_path = None |
| else: |
| save_sample_path = save_results() |
| return save_sample_path |
|
|
| def generate( |
| self, |
| diffusion_transformer_dropdown, |
| base_model_dropdown, |
| lora_model_dropdown, |
| lora_alpha_slider, |
| prompt_textbox, |
| negative_prompt_textbox, |
| sampler_dropdown, |
| sample_step_slider, |
| resize_method, |
| width_slider, |
| height_slider, |
| base_resolution, |
| generation_method, |
| length_slider, |
| overlap_video_length, |
| partial_video_length, |
| cfg_scale_slider, |
| start_image, |
| end_image, |
| validation_video, |
| validation_video_mask, |
| control_video, |
| denoise_strength, |
| seed_textbox, |
| enable_teacache = None, |
| teacache_threshold = None, |
| num_skip_start_steps = None, |
| teacache_offload = None, |
| cfg_skip_ratio = None, |
| enable_riflex = None, |
| riflex_k = None, |
| is_api = False, |
| ): |
| pass |
|
|
| def post_to_host( |
| diffusion_transformer_dropdown, |
| base_model_dropdown, lora_model_dropdown, lora_alpha_slider, |
| prompt_textbox, negative_prompt_textbox, |
| sampler_dropdown, sample_step_slider, resize_method, width_slider, height_slider, |
| base_resolution, generation_method, length_slider, cfg_scale_slider, |
| start_image, end_image, validation_video, validation_video_mask, denoise_strength, seed_textbox, |
| ref_image = None, enable_teacache = None, teacache_threshold = None, num_skip_start_steps = None, |
| teacache_offload = None, cfg_skip_ratio = None,enable_riflex = None, riflex_k = None, |
| ): |
| if start_image is not None: |
| with open(start_image, 'rb') as file: |
| file_content = file.read() |
| start_image_encoded_content = base64.b64encode(file_content) |
| start_image = start_image_encoded_content.decode('utf-8') |
|
|
| if end_image is not None: |
| with open(end_image, 'rb') as file: |
| file_content = file.read() |
| end_image_encoded_content = base64.b64encode(file_content) |
| end_image = end_image_encoded_content.decode('utf-8') |
|
|
| if validation_video is not None: |
| with open(validation_video, 'rb') as file: |
| file_content = file.read() |
| validation_video_encoded_content = base64.b64encode(file_content) |
| validation_video = validation_video_encoded_content.decode('utf-8') |
|
|
| if validation_video_mask is not None: |
| with open(validation_video_mask, 'rb') as file: |
| file_content = file.read() |
| validation_video_mask_encoded_content = base64.b64encode(file_content) |
| validation_video_mask = validation_video_mask_encoded_content.decode('utf-8') |
|
|
| if ref_image is not None: |
| with open(ref_image, 'rb') as file: |
| file_content = file.read() |
| ref_image_encoded_content = base64.b64encode(file_content) |
| ref_image = ref_image_encoded_content.decode('utf-8') |
|
|
| datas = { |
| "base_model_path": base_model_dropdown, |
| "lora_model_path": lora_model_dropdown, |
| "lora_alpha_slider": lora_alpha_slider, |
| "prompt_textbox": prompt_textbox, |
| "negative_prompt_textbox": negative_prompt_textbox, |
| "sampler_dropdown": sampler_dropdown, |
| "sample_step_slider": sample_step_slider, |
| "resize_method": resize_method, |
| "width_slider": width_slider, |
| "height_slider": height_slider, |
| "base_resolution": base_resolution, |
| "generation_method": generation_method, |
| "length_slider": length_slider, |
| "cfg_scale_slider": cfg_scale_slider, |
| "start_image": start_image, |
| "end_image": end_image, |
| "validation_video": validation_video, |
| "validation_video_mask": validation_video_mask, |
| "denoise_strength": denoise_strength, |
| "seed_textbox": seed_textbox, |
|
|
| "ref_image": ref_image, |
| "enable_teacache": enable_teacache, |
| "teacache_threshold": teacache_threshold, |
| "num_skip_start_steps": num_skip_start_steps, |
| "teacache_offload": teacache_offload, |
| "cfg_skip_ratio": cfg_skip_ratio, |
| "enable_riflex": enable_riflex, |
| "riflex_k": riflex_k, |
| } |
|
|
| session = requests.session() |
| session.headers.update({"Authorization": os.environ.get("EAS_TOKEN")}) |
|
|
| response = session.post(url=f'{os.environ.get("EAS_URL")}/videox_fun/infer_forward', json=datas, timeout=300) |
|
|
| outputs = response.json() |
| return outputs |
|
|
|
|
| class Fun_Controller_Client: |
| def __init__(self, scheduler_dict, savedir_sample): |
| self.basedir = os.getcwd() |
| if savedir_sample is None: |
| self.savedir_sample = os.path.join(self.basedir, "samples", datetime.now().strftime("Gradio-%Y-%m-%dT%H-%M-%S")) |
| else: |
| self.savedir_sample = savedir_sample |
| os.makedirs(self.savedir_sample, exist_ok=True) |
| |
| self.scheduler_dict = scheduler_dict |
|
|
| def generate( |
| self, |
| diffusion_transformer_dropdown, |
| base_model_dropdown, |
| lora_model_dropdown, |
| lora_alpha_slider, |
| prompt_textbox, |
| negative_prompt_textbox, |
| sampler_dropdown, |
| sample_step_slider, |
| resize_method, |
| width_slider, |
| height_slider, |
| base_resolution, |
| generation_method, |
| length_slider, |
| cfg_scale_slider, |
| start_image, |
| end_image, |
| validation_video, |
| validation_video_mask, |
| denoise_strength, |
| seed_textbox, |
| ref_image = None, |
| enable_teacache = None, |
| teacache_threshold = None, |
| num_skip_start_steps = None, |
| teacache_offload = None, |
| cfg_skip_ratio = None, |
| enable_riflex = None, |
| riflex_k = None, |
| ): |
| is_image = True if generation_method == "Image Generation" else False |
|
|
| outputs = post_to_host( |
| diffusion_transformer_dropdown, |
| base_model_dropdown, lora_model_dropdown, lora_alpha_slider, |
| prompt_textbox, negative_prompt_textbox, |
| sampler_dropdown, sample_step_slider, resize_method, width_slider, height_slider, |
| base_resolution, generation_method, length_slider, cfg_scale_slider, |
| start_image, end_image, validation_video, validation_video_mask, denoise_strength, |
| seed_textbox, ref_image = ref_image, enable_teacache = enable_teacache, teacache_threshold = teacache_threshold, |
| num_skip_start_steps = num_skip_start_steps, teacache_offload = teacache_offload, |
| cfg_skip_ratio = cfg_skip_ratio, enable_riflex = enable_riflex, riflex_k = riflex_k, |
| ) |
|
|
| try: |
| base64_encoding = outputs["base64_encoding"] |
| except: |
| return gr.Image(visible=False, value=None), gr.Video(None, visible=True), outputs["message"] |
| |
| decoded_data = base64.b64decode(base64_encoding) |
|
|
| if not os.path.exists(self.savedir_sample): |
| os.makedirs(self.savedir_sample, exist_ok=True) |
| md5_hash = hashlib.md5(decoded_data).hexdigest() |
|
|
| index = len([path for path in os.listdir(self.savedir_sample)]) + 1 |
| prefix = str(index).zfill(8) |
| |
| if is_image or length_slider == 1: |
| save_sample_path = os.path.join(self.savedir_sample, prefix + f"-{md5_hash}.png") |
| print(f"Saving to {save_sample_path}") |
| with open(save_sample_path, "wb") as file: |
| file.write(decoded_data) |
| if gradio_version_is_above_4: |
| return gr.Image(value=save_sample_path, visible=True), gr.Video(value=None, visible=False), "Success" |
| else: |
| return gr.Image.update(value=save_sample_path, visible=True), gr.Video.update(value=None, visible=False), "Success" |
| else: |
| save_sample_path = os.path.join(self.savedir_sample, prefix + f"-{md5_hash}.mp4") |
| print(f"Saving to {save_sample_path}") |
| with open(save_sample_path, "wb") as file: |
| file.write(decoded_data) |
| if gradio_version_is_above_4: |
| return gr.Image(visible=False, value=None), gr.Video(value=save_sample_path, visible=True), "Success" |
| else: |
| return gr.Image.update(visible=False, value=None), gr.Video.update(value=save_sample_path, visible=True), "Success" |
|
|