Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import spaces | |
| import gradio as gr | |
| import tempfile | |
| from typing import Tuple | |
| import torch | |
| import io | |
| import re | |
| import numpy as np | |
| import imageio | |
| from diffusers import StableDiffusionPipeline | |
| from PIL import Image | |
| from pydub import AudioSegment | |
| from scipy.io.wavfile import write as wavwrite | |
| from transformers import ( | |
| AutoModelForCausalLM, | |
| AutoTokenizer, | |
| SpeechT5Processor, | |
| SpeechT5ForTextToSpeech, | |
| SpeechT5HifiGan, | |
| pipeline | |
| ) | |
| import imageio_ffmpeg | |
| os.environ["IMAGEIO_FFMPEG_EXE"] = imageio_ffmpeg.get_ffmpeg_exe() | |
| AVAILABLE_BGS = { | |
| "none": None, | |
| "rain": "assets/Rain-Drips", | |
| "calm": "assets/calming-piano-tune", | |
| "calmmusic": "assets/calmimg-piano-tune1", | |
| "forest": "assets/forest-rain-lullaby-music", | |
| "piano": "assets/nature-calming", | |
| } | |
| kids_day_options = [ | |
| "I played with my toys", | |
| "I went to the park", | |
| "I read a storybook", | |
| "I had a fun snack", | |
| "I helped at home" | |
| ] | |
| DISALLOWED_KEYWORDS = [ | |
| "nude", "naked", "sexual", "porn", "vagina", "penis", "nsfw", | |
| "blood", "gore", "violence", | |
| "political", "president", "prime minister", "election", | |
| "religion", "terrorist", "extremist" | |
| ] | |
| def is_prompt_safe(prompt: str) -> bool: | |
| prompt_lower = prompt.lower() | |
| return not any(keyword in prompt_lower for keyword in DISALLOWED_KEYWORDS) | |
| # ------------------------ | |
| # DEVICE SETUP | |
| # ------------------------ | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print("Device:", device) | |
| ffmpeg_path = imageio_ffmpeg.get_ffmpeg_exe() | |
| print("Using ffmpeg:", ffmpeg_path) | |
| # Install system deps (ffmpeg, espeak-ng). Put in startup script if preferred. | |
| os.system("apt-get update && apt-get install -y espeak-ng ffmpeg --no-install-recommends") | |
| os.environ["ESPEAK_PATH"] = "/usr/bin/espeak-ng" | |
| speaker_options = {"Loading...": None} | |
| # ------------------------ | |
| # TEXT GENERATION MODEL | |
| # ------------------------ | |
| tokenizer = AutoTokenizer.from_pretrained("tiiuae/falcon-7b-instruct") | |
| model = AutoModelForCausalLM.from_pretrained( | |
| "tiiuae/falcon-7b-instruct", | |
| device_map="auto", | |
| torch_dtype="auto" | |
| ) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| local_text_gen = pipeline("text-generation", model=model, tokenizer=tokenizer) | |
| # ------------------------ | |
| # Story memory | |
| # ------------------------ | |
| story_memory = "" | |
| print("🔄 Loading Stable Diffusion v1.5 with safety checker…") | |
| sd_pipe = StableDiffusionPipeline.from_pretrained( | |
| "runwayml/stable-diffusion-v1-5", | |
| torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
| ).to(device) | |
| sd_pipe.enable_attention_slicing() | |
| def generate_story_images(story_text, num_images=3): | |
| """ | |
| Generate images illustrating the story. | |
| Splits story into parts and generates one image per part. | |
| """ | |
| if not is_prompt_safe(story_text): | |
| raise ValueError("Prompt contains disallowed content per Stable Diffusion license policy.") | |
| paragraphs = re.split(r'\n\n', story_text) | |
| if len(paragraphs) < num_images: | |
| paragraphs = paragraphs + [paragraphs[-1]] * (num_images - len(paragraphs)) | |
| images = [] | |
| for para in paragraphs[:num_images]: | |
| prompt = f"Illustrate this scene: {para} in soft, magical, calming style, suitable for bedtime story." | |
| # img = sd_pipe(prompt).images[0] | |
| img = sd_pipe(prompt).images[0].convert("RGB") | |
| images.append(img) | |
| return images | |
| def load_background(bg_key: str): | |
| """ | |
| Loads background audio. Supports both WAV and MP3 automatically. | |
| Returns an AudioSegment or None. | |
| """ | |
| base_path = AVAILABLE_BGS.get(bg_key) | |
| if base_path is None: | |
| return None | |
| # Try WAV and MP3 versions | |
| candidates = [base_path + ".wav", base_path + ".mp3"] | |
| for path in candidates: | |
| if os.path.exists(path): | |
| try: | |
| ext = os.path.splitext(path)[1].lower() | |
| if ext == ".wav": | |
| return AudioSegment.from_wav(path) | |
| elif ext == ".mp3": | |
| return AudioSegment.from_mp3(path) | |
| except Exception as e: | |
| print(f"Error loading background audio {path}: {e}") | |
| print(f"⚠️ No valid background audio found for key '{bg_key}'.") | |
| return None | |
| def images_to_video(images, fps=1): | |
| frames = [ | |
| np.array(img.convert("RGB").resize((576, 320))) | |
| for img in images | |
| ] | |
| tmp_video = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") | |
| tmp_video.close() | |
| # Ensure imageio uses correct ffmpeg | |
| import imageio_ffmpeg | |
| os.environ["IMAGEIO_FFMPEG_EXE"] = imageio_ffmpeg.get_ffmpeg_exe() | |
| imageio.mimwrite( | |
| tmp_video.name, | |
| frames, | |
| fps=fps, | |
| codec="libx264" | |
| ) | |
| return tmp_video.name | |
| def call_text_model(prompt: str, max_tokens: int = 700): | |
| # If local model not available, raise an informative error | |
| if local_text_gen is None: | |
| raise RuntimeError("Local text generation model not available in this environment.") | |
| outputs = local_text_gen( | |
| prompt, | |
| max_new_tokens=max_tokens, | |
| temperature=0.7, | |
| do_sample=True, | |
| top_p=0.9, | |
| return_full_text=False, | |
| ) | |
| return outputs[0]["generated_text"].strip() | |
| # ------------------------ | |
| # STORY MEMORY (NEW) | |
| # ------------------------ | |
| story_memory = "" | |
| # ------------------------ | |
| # TTS SETUP (unchanged) | |
| # ------------------------ | |
| use_speecht5 = True | |
| try: | |
| print("🔊 Loading SpeechT5...") | |
| processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts") | |
| tts_model = SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts").to(device) | |
| vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan").to(device) | |
| from datasets import load_dataset | |
| embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation") | |
| speaker_options.clear() | |
| speaker_options.update({ | |
| "Female (Slt)": torch.tensor(embeddings_dataset[0]["xvector"]).unsqueeze(0).to(device), | |
| "Male (Bdl)": torch.tensor(embeddings_dataset[116]["xvector"]).unsqueeze(0).to(device), | |
| "Male (Rms)": torch.tensor(embeddings_dataset[227]["xvector"]).unsqueeze(0).to(device), | |
| "Female (Clb)": torch.tensor(embeddings_dataset[420]["xvector"]).unsqueeze(0).to(device), | |
| "Female (Aew)": torch.tensor(embeddings_dataset[550]["xvector"]).unsqueeze(0).to(device), | |
| "Male (Ksp)": torch.tensor(embeddings_dataset[670]["xvector"]).unsqueeze(0).to(device), | |
| "Gentle Narrator": torch.randn((1,512)).to(device), | |
| }) | |
| except Exception as e: | |
| print("⚠️ SpeechT5 failed to load, switching to lightweight TTS:", e) | |
| use_speecht5 = False | |
| text_speech = pipeline("text-to-speech", model="kakao-enterprise/vits-ljs") | |
| speaker_options = {"Default Voice": None} | |
| def split_text_into_chunks(text, max_chars=300): | |
| sentences = re.split(r'(?<=[.!?]) +', text) | |
| chunks, current_chunk = [], "" | |
| for sent in sentences: | |
| if len(current_chunk) + len(sent) > max_chars: | |
| chunks.append(current_chunk.strip()) | |
| current_chunk = sent | |
| else: | |
| current_chunk += " " + sent | |
| if current_chunk: | |
| chunks.append(current_chunk.strip()) | |
| return chunks | |
| def call_tts_model(text: str, voice: str): | |
| chunks = split_text_into_chunks(text) | |
| audio_segments = [] | |
| if use_speecht5: | |
| speaker_emb = speaker_options.get(voice, list(speaker_options.values())[0]) | |
| if speaker_emb is None: | |
| speaker_emb = torch.randn((1, 512)).to(device) | |
| for chunk in chunks: | |
| inputs = processor(text=chunk, return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| speech = tts_model.generate_speech( | |
| input_ids=inputs["input_ids"], | |
| speaker_embeddings=speaker_emb, | |
| vocoder=vocoder | |
| ) | |
| segment = speech.cpu().numpy().flatten() | |
| audio_segments.append(segment) | |
| full_audio = np.concatenate(audio_segments, axis=0) | |
| sr = 16000 | |
| else: | |
| # lightweight pipeline returns dicts with 'audio' and 'sampling_rate' | |
| for chunk in chunks: | |
| narrated = text_speech(chunk) | |
| segment = narrated["audio"].flatten() | |
| audio_segments.append(segment) | |
| full_audio = np.concatenate(audio_segments, axis=0) | |
| sr = narrated["sampling_rate"] | |
| # Normalize and convert to 16-bit PCM | |
| full_audio = full_audio / np.max(np.abs(full_audio)) | |
| full_audio = full_audio * 0.7 | |
| buffer = io.BytesIO() | |
| wavwrite(buffer, sr, (full_audio * 32767).astype(np.int16)) | |
| buffer.seek(0) | |
| return buffer.read() | |
| # ------------------------ | |
| # Prompt builder | |
| # ------------------------ | |
| def build_story_prompt(day_text: str, mood: str, theme: str, for_kids: bool, story_length: str) -> str: | |
| persona = "Write a soothing bedtime story" | |
| if for_kids: | |
| persona += " for a young child (age 3–8)." | |
| else: | |
| persona += " for an adult who wants to fall asleep." | |
| extras = f"Tone: {mood}. Theme: {theme}." | |
| if "Short" in story_length: | |
| length_instruction = "Length: about 300 words." | |
| elif "Long" in story_length: | |
| length_instruction = "Length: about 700 words." | |
| elif "Extra" in story_length: | |
| length_instruction = "Length: about 1500 words." | |
| elif "Epic" in story_length: | |
| length_instruction = "Length: about 3000 words." | |
| else: | |
| length_instruction = "Length: about 6000 words." | |
| instructions = ( | |
| f"{length_instruction} Avoid tension. " | |
| "Feature a calming guide character. End with: " | |
| "'And soon, everything drifted into quiet dreams.'" | |
| ) | |
| return ( | |
| f"{persona}" | |
| f"\nSummary of the day: {day_text}\n" | |
| f"{extras}\n" | |
| f"{instructions}\n\n" | |
| "The story must remain wholesome, family-friendly, safe, and non-political. " | |
| "No adult content, no violence, no sexual themes, and no harmful scenarios.\n\n" | |
| "Begin the story with:\n" | |
| "Once upon a time," | |
| ) | |
| def apply_fade(audio: AudioSegment, fade_ms: int = 200): | |
| if fade_ms > 0: | |
| return audio.fade_in(fade_ms).fade_out(fade_ms) | |
| return audio | |
| # ------------------------ | |
| # MAIN FUNCTION | |
| # ------------------------ | |
| def generate_and_render(day_text, kids_day_choice, mood, theme, for_kids, voice, story_length, continue_story, bg_music_on, bg_choice, bg_volume): | |
| """ | |
| Generates story text, TTS audio (with optional background), images, and video slideshow of the story. | |
| """ | |
| global story_memory | |
| # RESET MEMORY if user does NOT want continuation | |
| if not continue_story: | |
| story_memory = "" | |
| # SAFETY CHECK ONLY ON USER INPUT | |
| if for_kids: | |
| safe_check_source = kids_day_choice | |
| day_summary = kids_day_choice | |
| else: | |
| safe_check_source = day_text | |
| day_summary = day_text | |
| if not is_prompt_safe(safe_check_source): | |
| return "❌ Your input contains disallowed content. Please change it.", None, None | |
| # 1️⃣ Build prompt | |
| if continue_story and story_memory.strip(): | |
| prompt = story_memory + " Continue the story softly:" | |
| else: | |
| prompt = build_story_prompt(day_summary, mood, theme, for_kids, story_length) | |
| # 2️⃣ Generate story | |
| max_tokens_map = { | |
| "Short (~300 words)": 700, | |
| "Long (~700 words)": 1400, | |
| "Extra Long (~1500 words)": 2500, | |
| "Epic (~3000 words)": 4096, | |
| "Mega (~6000 words)": 5500 | |
| } | |
| max_tokens = max_tokens_map.get(story_length, 700) | |
| try: | |
| story = call_text_model(prompt, max_tokens=max_tokens) | |
| except Exception as e: | |
| return f"Error generating story: {e}", None, None | |
| # SAVE story to memory ONLY if continuation is allowed | |
| if continue_story: | |
| story_memory += " " + story | |
| # story_memory += story # append to memory | |
| # 3️⃣ TTS generation | |
| try: | |
| audio_bytes = call_tts_model(story, voice) | |
| except Exception as e: | |
| return f"Story generated, but TTS failed: {e}{story}", None, None | |
| tmp_voice = tempfile.NamedTemporaryFile(delete=False, suffix=".wav") | |
| tmp_voice.write(audio_bytes) | |
| tmp_voice.close() | |
| # 4️⃣ Background music mixing | |
| if bg_music_on: | |
| try: | |
| voice_audio = AudioSegment.from_file(tmp_voice.name) | |
| bg_music = load_background(bg_choice) | |
| if bg_music: | |
| bg_music = bg_music.set_frame_rate(voice_audio.frame_rate).set_channels(voice_audio.channels) | |
| if len(bg_music) < len(voice_audio): | |
| loops = int(len(voice_audio)/len(bg_music)) + 1 | |
| bg_music = (bg_music*loops)[:len(voice_audio)] | |
| else: | |
| bg_music = bg_music[:len(voice_audio)] | |
| bg_music = apply_fade(bg_music, fade_ms=500) | |
| voice_audio = apply_fade(voice_audio, fade_ms=80) | |
| vol_db = -10 + (bg_volume / 100) * 16 | |
| bg_music = bg_music + vol_db | |
| mixed = bg_music.overlay(voice_audio) | |
| final_audio = tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") | |
| mixed.export(final_audio.name, format="mp3") | |
| else: | |
| final_audio = tmp_voice | |
| except Exception as e: | |
| print("Background mixing failed:", e) | |
| final_audio = tmp_voice | |
| else: | |
| final_audio = tmp_voice | |
| # 5️⃣ Generate story images | |
| try: | |
| images = generate_story_images(story, num_images=5) | |
| video_path = images_to_video(images, fps=2) | |
| except Exception as e: | |
| print("Story image/video generation failed:", e) | |
| video_path = None | |
| return story, final_audio.name, video_path | |
| # ------------------------ | |
| # Memory reset | |
| # ------------------------ | |
| def reset_memory(): | |
| global story_memory | |
| story_memory = "" | |
| return "Story memory cleared." | |
| # ------------------------ | |
| # GRADIO UI | |
| # ------------------------ | |
| with gr.Blocks(title="AI Sleep Story Generator") as demo: | |
| gr.Markdown("# 🌙 AI Sleep Story Generator") | |
| gr.Markdown("⚠️ This app follows Stable Diffusion’s OpenRAIL-M license and blocks NSFW, political, and harmful content.") | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| day_text = gr.Textbox(label="Tell me about your day", lines=4) | |
| # Replace or add to the text input for day summary | |
| # Dropdown for kids who cannot type | |
| kids_day_dropdown = gr.Dropdown( | |
| choices=kids_day_options, | |
| value=kids_day_options[0], | |
| label="Select your day (for kids who cannot write)" | |
| ) | |
| mood = gr.Dropdown( | |
| ["calm", "soothing", "nostalgic", "drowsy", "dreamlike"], | |
| value="calm", | |
| ) | |
| theme = gr.Dropdown( | |
| ["gentle forest", "ocean breeze", "starlit sky", "quiet village", "magical garden"], | |
| value="gentle forest", | |
| ) | |
| voice_dropdown = gr.Dropdown( | |
| choices=list(speaker_options.keys()), | |
| value=list(speaker_options.keys())[0], | |
| label="Choose Voice" | |
| ) | |
| story_length = gr.Dropdown( | |
| [ | |
| "Short (~300 words)", | |
| "Long (~700 words)", | |
| "Extra Long (~1500 words)", | |
| "Epic (~3000 words)", | |
| "Mega (~6000 words)" | |
| ], | |
| value="Short (~300 words)", | |
| label="Story Length", | |
| ) | |
| bg_music_on = gr.Checkbox(label="Play soothing music", value=True) | |
| bg_choice = gr.Dropdown(list(AVAILABLE_BGS.keys()), value="forest", label="Background sound") | |
| # bg_volume = gr.Slider(minimum=0, maximum=100, step=1, value=30, label="Background volume") | |
| bg_volume = gr.Slider(0, 100, 30, step=1, label="Background volume") | |
| for_kids = gr.Checkbox(label="Child-friendly (3–8 yrs)", value=False) | |
| continue_story = gr.Checkbox(label="Continue previous story", value=False) | |
| generate_btn = gr.Button("Generate Story") | |
| reset_btn = gr.Button("Reset Story Memory") | |
| with gr.Column(scale=3): | |
| story_output = gr.Textbox(label="Generated story", lines=20) | |
| audio_output = gr.Audio(label="Play story audio", type="filepath") | |
| video_output = gr.Video(label="Story video", format="mp4") | |
| generate_btn.click( | |
| fn=generate_and_render, | |
| inputs=[day_text, kids_day_dropdown, mood, theme, for_kids, voice_dropdown, story_length, continue_story, bg_music_on, bg_choice, bg_volume], | |
| outputs=[story_output, audio_output, video_output], | |
| ) | |
| reset_btn.click(fn=reset_memory, inputs=None, outputs=[story_output]) | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860))) | |