import spaces import gradio as gr import os import random import re import numpy as np import torch from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer # Import the custom tiling pipeline and its components from pipeline_z_image_mod import ZImageMoDTilingPipeline from diffusers import ( AutoencoderKL, FlowMatchEulerDiscreteScheduler, GGUFQuantizationConfig, ) from diffusers.models import ZImageTransformer2DModel from huggingface_hub import hf_hub_download # Import the safety checker function try: from prompt_check import is_unsafe_prompt UNSAFE_CHECK_AVAILABLE = True except ImportError: print("Warning: 'prompt_check.py' not found. NSFW prompt check will be disabled.") UNSAFE_CHECK_AVAILABLE = False def is_unsafe_prompt(*args, **kwargs): return False # 1. Environment Variables for Model Paths # Base model for VAE, text_encoder, tokenizer BASE_MODEL_ID = os.getenv("BASE_MODEL_ID", "Tongyi-MAI/Z-Image-Turbo") # GGUF model for the transformer GGUF_REPO_ID = os.getenv("GGUF_REPO_ID", "jayn7/Z-Image-Turbo-GGUF") GGUF_FILENAME = os.getenv("GGUF_FILENAME", "z_image_turbo-Q4_K_M.gguf") GGUF_LOCAL_DIR = os.getenv( "GGUF_LOCAL_DIR", None ) # Set this path for local use, e.g., "F:\\models\\Z-Image-Turbo" USE_SPACES_ENV = os.getenv("USE_SPACES", "false").lower() USE_SPACES = USE_SPACES_ENV not in ("false", "0", "no", "none") # System prompt for the safety checker UNSAFE_PROMPT_CHECK = os.getenv("UNSAFE_PROMPT_CHECK") MAX_SEED = np.iinfo(np.int32).max # 2. Load Models (GGUF + Standard Components) print("--- Loading Models ---") device = "cuda" if torch.cuda.is_available() else "cpu" print("Loading VAE...") if USE_SPACES: vae = AutoencoderKL.from_pretrained( BASE_MODEL_ID, subfolder="vae", torch_dtype=torch.bfloat16 ).to(device) else: vae = AutoencoderKL.from_pretrained( BASE_MODEL_ID, subfolder="vae", torch_dtype=torch.bfloat16 ) print("Loading Text Encoder and Tokenizer...") if USE_SPACES: text_encoder = AutoModelForCausalLM.from_pretrained( BASE_MODEL_ID, subfolder="text_encoder", dtype=torch.bfloat16 ).to(device) tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID, subfolder="tokenizer") else: text_encoder = AutoModelForCausalLM.from_pretrained( BASE_MODEL_ID, subfolder="text_encoder", dtype=torch.bfloat16 ) tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID, subfolder="tokenizer") print(f"Loading Transformer from GGUF: {GGUF_REPO_ID}/{GGUF_FILENAME}...") transformer_path = ( hf_hub_download( GGUF_REPO_ID, GGUF_FILENAME, local_dir=GGUF_LOCAL_DIR, local_dir_use_symlinks=False, ) if GGUF_LOCAL_DIR else hf_hub_download(GGUF_REPO_ID, GGUF_FILENAME) ) if USE_SPACES: transformer = ZImageTransformer2DModel.from_single_file( transformer_path, quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16), dtype=torch.bfloat16, ).to(device) else: transformer = ZImageTransformer2DModel.from_single_file( transformer_path, quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16), dtype=torch.bfloat16, ) # 3. Assemble the Tiling Pipeline print("\nAssembling the ZImageMoDTilingPipeline...") scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=3.0) if USE_SPACES: pipe = ZImageMoDTilingPipeline( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, scheduler=scheduler, transformer=transformer, ).to(device) else: pipe = ZImageMoDTilingPipeline( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, scheduler=scheduler, transformer=transformer, ) print("Enabling model CPU offload...") pipe.enable_model_cpu_offload() # Load Translation Models print("Loading translation models...") try: ko_en_translator = pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en") zh_en_translator = pipeline("translation", model="Helsinki-NLP/opus-mt-zh-en") except Exception as e: ko_en_translator, zh_en_translator = None, None print(f"Warning: Could not load translation models: {e}") print("Pipeline loaded and ready.") # Helper Functions def translate_prompt(text: str, language: str) -> str: """Translates text to English if the selected language is not English.""" if language == "English" or not text.strip(): return text translated = text if ( language == "Korean" and ko_en_translator and any("\uac00" <= char <= "\ud7a3" for char in text) ): translated = ko_en_translator(text)[0]["translation_text"] elif ( language == "Chinese" and zh_en_translator and any("\u4e00" <= char <= "\u9fff" for char in text) ): translated = zh_en_translator(text)[0]["translation_text"] return translated def create_hdr_effect(image, hdr_strength): if hdr_strength == 0: return image from PIL import ImageEnhance, Image if isinstance(image, Image.Image): image = np.array(image) from scipy.ndimage import gaussian_filter blurred = gaussian_filter(image, sigma=5) sharpened = np.clip(image + hdr_strength * (image - blurred), 0, 255).astype( np.uint8 ) pil_img = Image.fromarray(sharpened) converter = ImageEnhance.Color(pil_img) return converter.enhance(1 + hdr_strength) @spaces.GPU(duration=120) def generate_z_image_panorama( left_prompt, center_prompt, right_prompt, left_gs, center_gs, right_gs, overlap_pixels, steps, shift, generation_seed, tile_weighting_method, prompt_language, target_height, target_width, hdr, randomize_seed, progress=gr.Progress(track_tqdm=True), ): """ Generate a panoramic image using the Z-Image Turbo model with tiling and composition. Args: left_prompt (str): Text prompt for the left section of the panorama. center_prompt (str): Text prompt for the center section of the panorama. right_prompt (str): Text prompt for the right section of the panorama. left_gs (float): Guidance scale for the left tile. center_gs (float): Guidance scale for the center tile. right_gs (float): Guidance scale for the right tile. overlap_pixels (int): Number of pixels to overlap between tiles. steps (int): Number of inference steps for generation. shift (float): Time Shift. generation_seed (int): Random seed for reproducibility. tile_weighting_method (str): Method for weighting overlapping tile regions. prompt_language (str): Language code for prompt translation. target_height (int): Height of the generated panorama in pixels. target_width (int): Width of the generated panorama in pixels. hdr (float): HDR effect intensity. randomize_seed (boolean): Not used. progress (gr.Progress): Gradio progress tracker. Returns: PIL.Image: The generated panoramic image with optional HDR effect applied. """ if not left_prompt or not center_prompt or not right_prompt: gr.Info("⚡️ Prompts must be provided!") return gr.skip() # Safety Check prompts_to_check = [left_prompt, center_prompt, right_prompt] for p in prompts_to_check: if UNSAFE_CHECK_AVAILABLE and is_unsafe_prompt( pipe.text_encoder, device, pipe.tokenizer, system_prompt=UNSAFE_PROMPT_CHECK, user_prompt=p, ): raise gr.Error( f"Unsafe prompt detected. Please modify your prompt and try again." ) generator = torch.Generator("cuda").manual_seed(generation_seed) pipe.scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=shift) final_height, final_width = int(target_height), int(target_width) translated_left = translate_prompt(left_prompt, prompt_language) translated_center = translate_prompt(center_prompt, prompt_language) translated_right = translate_prompt(right_prompt, prompt_language) image = pipe( prompt=[[translated_left, translated_center, translated_right]], height=final_height, width=final_width, num_inference_steps=steps, guidance_scale_tiles=[[left_gs, center_gs, right_gs]], tile_overlap=overlap_pixels, tile_weighting_method=tile_weighting_method, generator=generator, ).images[0] return create_hdr_effect(image, hdr) def calculate_tile_size(target_height, target_width, overlap_pixels): """ Calculate tile dimensions for panoramic image generation. Args: target_height (int): The target height of the final panoramic image in pixels. target_width (int): The target width of the final panoramic image in pixels. overlap_pixels (int): The number of overlapping pixels between adjacent tiles. Returns: tuple: A tuple of 2 gr.update objects containing: - final_height: Final panorama height after tiling - final_width: Final panorama width after tiling """ num_cols = 3 num_rows = 1 tile_width = (target_width + (num_cols - 1) * overlap_pixels) // num_cols tile_height = (target_height + (num_rows - 1) * overlap_pixels) // num_rows tile_width -= tile_width % 16 tile_height -= tile_height % 16 final_width = tile_width * num_cols - (num_cols - 1) * overlap_pixels final_height = tile_height * num_rows - (num_rows - 1) * overlap_pixels return ( gr.update(value=final_height), gr.update(value=final_width), ) def randomize_seed_fn(generation_seed: int, randomize_seed: bool) -> int: if randomize_seed: return random.randint(0, MAX_SEED) return generation_seed def run_for_examples( left_prompt, center_prompt, right_prompt, left_gs, center_gs, right_gs, overlap_pixels, steps, shift, generation_seed, tile_weighting_method, prompt_language, target_height, target_width, hdr, randomize_seed ): return generate_z_image_panorama( left_prompt, center_prompt, right_prompt, left_gs, center_gs, right_gs, overlap_pixels, steps, shift, generation_seed, tile_weighting_method, prompt_language, target_height, target_width, hdr, randomize_seed ) def clear_result(): return gr.update(value=None) def update_dimensions_from_preset(resolution_preset): """Updates the width and height sliders based on a preset string.""" match = re.search(r"(\d+)\s*[x]\s*(\d+)", resolution_preset) if match: width, height = int(match.group(1)), int(match.group(2)) return gr.update(value=width), gr.update(value=height) return gr.update(), gr.update() # UI Layout theme = gr.themes.Default( primary_hue="red", secondary_hue="orange", neutral_hue="gray" ).set( body_background_fill="*neutral_100", body_background_fill_dark="*neutral_900", body_text_color="*neutral_900", body_text_color_dark="*neutral_100", panel_background_fill="*neutral_800", panel_background_fill_dark="*neutral_900", input_background_fill="white", input_background_fill_dark="*neutral_800", button_primary_background_fill="*primary_500", button_primary_background_fill_dark="*primary_700", button_primary_text_color="white", button_primary_text_color_dark="white", button_secondary_background_fill="*secondary_500", button_secondary_background_fill_dark="*secondary_700", button_secondary_text_color="white", button_secondary_text_color_dark="white", ) css_code = "" try: with open("./style.css", "r", encoding="utf-8") as f: css_code += f.read() + "\n" except FileNotFoundError: pass title = """