import os import spaces import gradio as gr import tempfile from typing import Tuple import torch import io import re import numpy as np import imageio from diffusers import StableDiffusionPipeline from PIL import Image from pydub import AudioSegment from scipy.io.wavfile import write as wavwrite from transformers import ( AutoModelForCausalLM, AutoTokenizer, SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan, pipeline ) import imageio_ffmpeg os.environ["IMAGEIO_FFMPEG_EXE"] = imageio_ffmpeg.get_ffmpeg_exe() AVAILABLE_BGS = { "none": None, "rain": "assets/Rain-Drips", "calm": "assets/calming-piano-tune", "calmmusic": "assets/calmimg-piano-tune1", "forest": "assets/forest-rain-lullaby-music", "piano": "assets/nature-calming", } kids_day_options = [ "I played with my toys", "I went to the park", "I read a storybook", "I had a fun snack", "I helped at home" ] DISALLOWED_KEYWORDS = [ "nude", "naked", "sexual", "porn", "vagina", "penis", "nsfw", "blood", "gore", "violence", "political", "president", "prime minister", "election", "religion", "terrorist", "extremist" ] def is_prompt_safe(prompt: str) -> bool: prompt_lower = prompt.lower() return not any(keyword in prompt_lower for keyword in DISALLOWED_KEYWORDS) # ------------------------ # DEVICE SETUP # ------------------------ device = "cuda" if torch.cuda.is_available() else "cpu" print("Device:", device) ffmpeg_path = imageio_ffmpeg.get_ffmpeg_exe() print("Using ffmpeg:", ffmpeg_path) # Install system deps (ffmpeg, espeak-ng). Put in startup script if preferred. os.system("apt-get update && apt-get install -y espeak-ng ffmpeg --no-install-recommends") os.environ["ESPEAK_PATH"] = "/usr/bin/espeak-ng" speaker_options = {"Loading...": None} # ------------------------ # TEXT GENERATION MODEL # ------------------------ tokenizer = AutoTokenizer.from_pretrained("tiiuae/falcon-7b-instruct") model = AutoModelForCausalLM.from_pretrained( "tiiuae/falcon-7b-instruct", device_map="auto", torch_dtype="auto" ) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token local_text_gen = pipeline("text-generation", model=model, tokenizer=tokenizer) # ------------------------ # Story memory # ------------------------ story_memory = "" print("🔄 Loading Stable Diffusion v1.5 with safety checker…") sd_pipe = StableDiffusionPipeline.from_pretrained( "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, ).to(device) sd_pipe.enable_attention_slicing() def generate_story_images(story_text, num_images=3): """ Generate images illustrating the story. Splits story into parts and generates one image per part. """ if not is_prompt_safe(story_text): raise ValueError("Prompt contains disallowed content per Stable Diffusion license policy.") paragraphs = re.split(r'\n\n', story_text) if len(paragraphs) < num_images: paragraphs = paragraphs + [paragraphs[-1]] * (num_images - len(paragraphs)) images = [] for para in paragraphs[:num_images]: prompt = f"Illustrate this scene: {para} in soft, magical, calming style, suitable for bedtime story." # img = sd_pipe(prompt).images[0] img = sd_pipe(prompt).images[0].convert("RGB") images.append(img) return images def load_background(bg_key: str): """ Loads background audio. Supports both WAV and MP3 automatically. Returns an AudioSegment or None. """ base_path = AVAILABLE_BGS.get(bg_key) if base_path is None: return None # Try WAV and MP3 versions candidates = [base_path + ".wav", base_path + ".mp3"] for path in candidates: if os.path.exists(path): try: ext = os.path.splitext(path)[1].lower() if ext == ".wav": return AudioSegment.from_wav(path) elif ext == ".mp3": return AudioSegment.from_mp3(path) except Exception as e: print(f"Error loading background audio {path}: {e}") print(f"⚠️ No valid background audio found for key '{bg_key}'.") return None def images_to_video(images, fps=1): frames = [ np.array(img.convert("RGB").resize((576, 320))) for img in images ] tmp_video = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") tmp_video.close() # Ensure imageio uses correct ffmpeg import imageio_ffmpeg os.environ["IMAGEIO_FFMPEG_EXE"] = imageio_ffmpeg.get_ffmpeg_exe() imageio.mimwrite( tmp_video.name, frames, fps=fps, codec="libx264" ) return tmp_video.name def call_text_model(prompt: str, max_tokens: int = 700): # If local model not available, raise an informative error if local_text_gen is None: raise RuntimeError("Local text generation model not available in this environment.") outputs = local_text_gen( prompt, max_new_tokens=max_tokens, temperature=0.7, do_sample=True, top_p=0.9, return_full_text=False, ) return outputs[0]["generated_text"].strip() # ------------------------ # STORY MEMORY (NEW) # ------------------------ story_memory = "" # ------------------------ # TTS SETUP (unchanged) # ------------------------ use_speecht5 = True try: print("🔊 Loading SpeechT5...") processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts") tts_model = SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts").to(device) vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan").to(device) from datasets import load_dataset embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation") speaker_options.clear() speaker_options.update({ "Female (Slt)": torch.tensor(embeddings_dataset[0]["xvector"]).unsqueeze(0).to(device), "Male (Bdl)": torch.tensor(embeddings_dataset[116]["xvector"]).unsqueeze(0).to(device), "Male (Rms)": torch.tensor(embeddings_dataset[227]["xvector"]).unsqueeze(0).to(device), "Female (Clb)": torch.tensor(embeddings_dataset[420]["xvector"]).unsqueeze(0).to(device), "Female (Aew)": torch.tensor(embeddings_dataset[550]["xvector"]).unsqueeze(0).to(device), "Male (Ksp)": torch.tensor(embeddings_dataset[670]["xvector"]).unsqueeze(0).to(device), "Gentle Narrator": torch.randn((1,512)).to(device), }) except Exception as e: print("⚠️ SpeechT5 failed to load, switching to lightweight TTS:", e) use_speecht5 = False text_speech = pipeline("text-to-speech", model="kakao-enterprise/vits-ljs") speaker_options = {"Default Voice": None} def split_text_into_chunks(text, max_chars=300): sentences = re.split(r'(?<=[.!?]) +', text) chunks, current_chunk = [], "" for sent in sentences: if len(current_chunk) + len(sent) > max_chars: chunks.append(current_chunk.strip()) current_chunk = sent else: current_chunk += " " + sent if current_chunk: chunks.append(current_chunk.strip()) return chunks def call_tts_model(text: str, voice: str): chunks = split_text_into_chunks(text) audio_segments = [] if use_speecht5: speaker_emb = speaker_options.get(voice, list(speaker_options.values())[0]) if speaker_emb is None: speaker_emb = torch.randn((1, 512)).to(device) for chunk in chunks: inputs = processor(text=chunk, return_tensors="pt").to(device) with torch.no_grad(): speech = tts_model.generate_speech( input_ids=inputs["input_ids"], speaker_embeddings=speaker_emb, vocoder=vocoder ) segment = speech.cpu().numpy().flatten() audio_segments.append(segment) full_audio = np.concatenate(audio_segments, axis=0) sr = 16000 else: # lightweight pipeline returns dicts with 'audio' and 'sampling_rate' for chunk in chunks: narrated = text_speech(chunk) segment = narrated["audio"].flatten() audio_segments.append(segment) full_audio = np.concatenate(audio_segments, axis=0) sr = narrated["sampling_rate"] # Normalize and convert to 16-bit PCM full_audio = full_audio / np.max(np.abs(full_audio)) full_audio = full_audio * 0.7 buffer = io.BytesIO() wavwrite(buffer, sr, (full_audio * 32767).astype(np.int16)) buffer.seek(0) return buffer.read() # ------------------------ # Prompt builder # ------------------------ def build_story_prompt(day_text: str, mood: str, theme: str, for_kids: bool, story_length: str) -> str: persona = "Write a soothing bedtime story" if for_kids: persona += " for a young child (age 3–8)." else: persona += " for an adult who wants to fall asleep." extras = f"Tone: {mood}. Theme: {theme}." if "Short" in story_length: length_instruction = "Length: about 300 words." elif "Long" in story_length: length_instruction = "Length: about 700 words." elif "Extra" in story_length: length_instruction = "Length: about 1500 words." elif "Epic" in story_length: length_instruction = "Length: about 3000 words." else: length_instruction = "Length: about 6000 words." instructions = ( f"{length_instruction} Avoid tension. " "Feature a calming guide character. End with: " "'And soon, everything drifted into quiet dreams.'" ) return ( f"{persona}" f"\nSummary of the day: {day_text}\n" f"{extras}\n" f"{instructions}\n\n" "The story must remain wholesome, family-friendly, safe, and non-political. " "No adult content, no violence, no sexual themes, and no harmful scenarios.\n\n" "Begin the story with:\n" "Once upon a time," ) def apply_fade(audio: AudioSegment, fade_ms: int = 200): if fade_ms > 0: return audio.fade_in(fade_ms).fade_out(fade_ms) return audio # ------------------------ # MAIN FUNCTION # ------------------------ @spaces.GPU def generate_and_render(day_text, kids_day_choice, mood, theme, for_kids, voice, story_length, continue_story, bg_music_on, bg_choice, bg_volume): """ Generates story text, TTS audio (with optional background), images, and video slideshow of the story. """ global story_memory # RESET MEMORY if user does NOT want continuation if not continue_story: story_memory = "" # SAFETY CHECK ONLY ON USER INPUT if for_kids: safe_check_source = kids_day_choice day_summary = kids_day_choice else: safe_check_source = day_text day_summary = day_text if not is_prompt_safe(safe_check_source): return "❌ Your input contains disallowed content. Please change it.", None, None # 1️⃣ Build prompt if continue_story and story_memory.strip(): prompt = story_memory + " Continue the story softly:" else: prompt = build_story_prompt(day_summary, mood, theme, for_kids, story_length) # 2️⃣ Generate story max_tokens_map = { "Short (~300 words)": 700, "Long (~700 words)": 1400, "Extra Long (~1500 words)": 2500, "Epic (~3000 words)": 4096, "Mega (~6000 words)": 5500 } max_tokens = max_tokens_map.get(story_length, 700) try: story = call_text_model(prompt, max_tokens=max_tokens) except Exception as e: return f"Error generating story: {e}", None, None # SAVE story to memory ONLY if continuation is allowed if continue_story: story_memory += " " + story # story_memory += story # append to memory # 3️⃣ TTS generation try: audio_bytes = call_tts_model(story, voice) except Exception as e: return f"Story generated, but TTS failed: {e}{story}", None, None tmp_voice = tempfile.NamedTemporaryFile(delete=False, suffix=".wav") tmp_voice.write(audio_bytes) tmp_voice.close() # 4️⃣ Background music mixing if bg_music_on: try: voice_audio = AudioSegment.from_file(tmp_voice.name) bg_music = load_background(bg_choice) if bg_music: bg_music = bg_music.set_frame_rate(voice_audio.frame_rate).set_channels(voice_audio.channels) if len(bg_music) < len(voice_audio): loops = int(len(voice_audio)/len(bg_music)) + 1 bg_music = (bg_music*loops)[:len(voice_audio)] else: bg_music = bg_music[:len(voice_audio)] bg_music = apply_fade(bg_music, fade_ms=500) voice_audio = apply_fade(voice_audio, fade_ms=80) vol_db = -10 + (bg_volume / 100) * 16 bg_music = bg_music + vol_db mixed = bg_music.overlay(voice_audio) final_audio = tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") mixed.export(final_audio.name, format="mp3") else: final_audio = tmp_voice except Exception as e: print("Background mixing failed:", e) final_audio = tmp_voice else: final_audio = tmp_voice # 5️⃣ Generate story images try: images = generate_story_images(story, num_images=5) video_path = images_to_video(images, fps=2) except Exception as e: print("Story image/video generation failed:", e) video_path = None return story, final_audio.name, video_path # ------------------------ # Memory reset # ------------------------ def reset_memory(): global story_memory story_memory = "" return "Story memory cleared." # ------------------------ # GRADIO UI # ------------------------ with gr.Blocks(title="AI Sleep Story Generator") as demo: gr.Markdown("# 🌙 AI Sleep Story Generator") gr.Markdown("⚠️ This app follows Stable Diffusion’s OpenRAIL-M license and blocks NSFW, political, and harmful content.") with gr.Row(): with gr.Column(scale=2): day_text = gr.Textbox(label="Tell me about your day", lines=4) # Replace or add to the text input for day summary # Dropdown for kids who cannot type kids_day_dropdown = gr.Dropdown( choices=kids_day_options, value=kids_day_options[0], label="Select your day (for kids who cannot write)" ) mood = gr.Dropdown( ["calm", "soothing", "nostalgic", "drowsy", "dreamlike"], value="calm", ) theme = gr.Dropdown( ["gentle forest", "ocean breeze", "starlit sky", "quiet village", "magical garden"], value="gentle forest", ) voice_dropdown = gr.Dropdown( choices=list(speaker_options.keys()), value=list(speaker_options.keys())[0], label="Choose Voice" ) story_length = gr.Dropdown( [ "Short (~300 words)", "Long (~700 words)", "Extra Long (~1500 words)", "Epic (~3000 words)", "Mega (~6000 words)" ], value="Short (~300 words)", label="Story Length", ) bg_music_on = gr.Checkbox(label="Play soothing music", value=True) bg_choice = gr.Dropdown(list(AVAILABLE_BGS.keys()), value="forest", label="Background sound") # bg_volume = gr.Slider(minimum=0, maximum=100, step=1, value=30, label="Background volume") bg_volume = gr.Slider(0, 100, 30, step=1, label="Background volume") for_kids = gr.Checkbox(label="Child-friendly (3–8 yrs)", value=False) continue_story = gr.Checkbox(label="Continue previous story", value=False) generate_btn = gr.Button("Generate Story") reset_btn = gr.Button("Reset Story Memory") with gr.Column(scale=3): story_output = gr.Textbox(label="Generated story", lines=20) audio_output = gr.Audio(label="Play story audio", type="filepath") video_output = gr.Video(label="Story video", format="mp4") generate_btn.click( fn=generate_and_render, inputs=[day_text, kids_day_dropdown, mood, theme, for_kids, voice_dropdown, story_length, continue_story, bg_music_on, bg_choice, bg_volume], outputs=[story_output, audio_output, video_output], ) reset_btn.click(fn=reset_memory, inputs=None, outputs=[story_output]) if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)))