StoryTelling / app.py
zamia123's picture
Update app.py
6730d6d verified
import os
import spaces
import gradio as gr
import tempfile
from typing import Tuple
import torch
import io
import re
import numpy as np
import imageio
from diffusers import StableDiffusionPipeline
from PIL import Image
from pydub import AudioSegment
from scipy.io.wavfile import write as wavwrite
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
SpeechT5Processor,
SpeechT5ForTextToSpeech,
SpeechT5HifiGan,
pipeline
)
import imageio_ffmpeg
os.environ["IMAGEIO_FFMPEG_EXE"] = imageio_ffmpeg.get_ffmpeg_exe()
AVAILABLE_BGS = {
"none": None,
"rain": "assets/Rain-Drips",
"calm": "assets/calming-piano-tune",
"calmmusic": "assets/calmimg-piano-tune1",
"forest": "assets/forest-rain-lullaby-music",
"piano": "assets/nature-calming",
}
kids_day_options = [
"I played with my toys",
"I went to the park",
"I read a storybook",
"I had a fun snack",
"I helped at home"
]
DISALLOWED_KEYWORDS = [
"nude", "naked", "sexual", "porn", "vagina", "penis", "nsfw",
"blood", "gore", "violence",
"political", "president", "prime minister", "election",
"religion", "terrorist", "extremist"
]
def is_prompt_safe(prompt: str) -> bool:
prompt_lower = prompt.lower()
return not any(keyword in prompt_lower for keyword in DISALLOWED_KEYWORDS)
# ------------------------
# DEVICE SETUP
# ------------------------
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)
ffmpeg_path = imageio_ffmpeg.get_ffmpeg_exe()
print("Using ffmpeg:", ffmpeg_path)
# Install system deps (ffmpeg, espeak-ng). Put in startup script if preferred.
os.system("apt-get update && apt-get install -y espeak-ng ffmpeg --no-install-recommends")
os.environ["ESPEAK_PATH"] = "/usr/bin/espeak-ng"
speaker_options = {"Loading...": None}
# ------------------------
# TEXT GENERATION MODEL
# ------------------------
tokenizer = AutoTokenizer.from_pretrained("tiiuae/falcon-7b-instruct")
model = AutoModelForCausalLM.from_pretrained(
"tiiuae/falcon-7b-instruct",
device_map="auto",
torch_dtype="auto"
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
local_text_gen = pipeline("text-generation", model=model, tokenizer=tokenizer)
# ------------------------
# Story memory
# ------------------------
story_memory = ""
print("🔄 Loading Stable Diffusion v1.5 with safety checker…")
sd_pipe = StableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
).to(device)
sd_pipe.enable_attention_slicing()
def generate_story_images(story_text, num_images=3):
"""
Generate images illustrating the story.
Splits story into parts and generates one image per part.
"""
if not is_prompt_safe(story_text):
raise ValueError("Prompt contains disallowed content per Stable Diffusion license policy.")
paragraphs = re.split(r'\n\n', story_text)
if len(paragraphs) < num_images:
paragraphs = paragraphs + [paragraphs[-1]] * (num_images - len(paragraphs))
images = []
for para in paragraphs[:num_images]:
prompt = f"Illustrate this scene: {para} in soft, magical, calming style, suitable for bedtime story."
# img = sd_pipe(prompt).images[0]
img = sd_pipe(prompt).images[0].convert("RGB")
images.append(img)
return images
def load_background(bg_key: str):
"""
Loads background audio. Supports both WAV and MP3 automatically.
Returns an AudioSegment or None.
"""
base_path = AVAILABLE_BGS.get(bg_key)
if base_path is None:
return None
# Try WAV and MP3 versions
candidates = [base_path + ".wav", base_path + ".mp3"]
for path in candidates:
if os.path.exists(path):
try:
ext = os.path.splitext(path)[1].lower()
if ext == ".wav":
return AudioSegment.from_wav(path)
elif ext == ".mp3":
return AudioSegment.from_mp3(path)
except Exception as e:
print(f"Error loading background audio {path}: {e}")
print(f"⚠️ No valid background audio found for key '{bg_key}'.")
return None
def images_to_video(images, fps=1):
frames = [
np.array(img.convert("RGB").resize((576, 320)))
for img in images
]
tmp_video = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
tmp_video.close()
# Ensure imageio uses correct ffmpeg
import imageio_ffmpeg
os.environ["IMAGEIO_FFMPEG_EXE"] = imageio_ffmpeg.get_ffmpeg_exe()
imageio.mimwrite(
tmp_video.name,
frames,
fps=fps,
codec="libx264"
)
return tmp_video.name
def call_text_model(prompt: str, max_tokens: int = 700):
# If local model not available, raise an informative error
if local_text_gen is None:
raise RuntimeError("Local text generation model not available in this environment.")
outputs = local_text_gen(
prompt,
max_new_tokens=max_tokens,
temperature=0.7,
do_sample=True,
top_p=0.9,
return_full_text=False,
)
return outputs[0]["generated_text"].strip()
# ------------------------
# STORY MEMORY (NEW)
# ------------------------
story_memory = ""
# ------------------------
# TTS SETUP (unchanged)
# ------------------------
use_speecht5 = True
try:
print("🔊 Loading SpeechT5...")
processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts")
tts_model = SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts").to(device)
vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan").to(device)
from datasets import load_dataset
embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
speaker_options.clear()
speaker_options.update({
"Female (Slt)": torch.tensor(embeddings_dataset[0]["xvector"]).unsqueeze(0).to(device),
"Male (Bdl)": torch.tensor(embeddings_dataset[116]["xvector"]).unsqueeze(0).to(device),
"Male (Rms)": torch.tensor(embeddings_dataset[227]["xvector"]).unsqueeze(0).to(device),
"Female (Clb)": torch.tensor(embeddings_dataset[420]["xvector"]).unsqueeze(0).to(device),
"Female (Aew)": torch.tensor(embeddings_dataset[550]["xvector"]).unsqueeze(0).to(device),
"Male (Ksp)": torch.tensor(embeddings_dataset[670]["xvector"]).unsqueeze(0).to(device),
"Gentle Narrator": torch.randn((1,512)).to(device),
})
except Exception as e:
print("⚠️ SpeechT5 failed to load, switching to lightweight TTS:", e)
use_speecht5 = False
text_speech = pipeline("text-to-speech", model="kakao-enterprise/vits-ljs")
speaker_options = {"Default Voice": None}
def split_text_into_chunks(text, max_chars=300):
sentences = re.split(r'(?<=[.!?]) +', text)
chunks, current_chunk = [], ""
for sent in sentences:
if len(current_chunk) + len(sent) > max_chars:
chunks.append(current_chunk.strip())
current_chunk = sent
else:
current_chunk += " " + sent
if current_chunk:
chunks.append(current_chunk.strip())
return chunks
def call_tts_model(text: str, voice: str):
chunks = split_text_into_chunks(text)
audio_segments = []
if use_speecht5:
speaker_emb = speaker_options.get(voice, list(speaker_options.values())[0])
if speaker_emb is None:
speaker_emb = torch.randn((1, 512)).to(device)
for chunk in chunks:
inputs = processor(text=chunk, return_tensors="pt").to(device)
with torch.no_grad():
speech = tts_model.generate_speech(
input_ids=inputs["input_ids"],
speaker_embeddings=speaker_emb,
vocoder=vocoder
)
segment = speech.cpu().numpy().flatten()
audio_segments.append(segment)
full_audio = np.concatenate(audio_segments, axis=0)
sr = 16000
else:
# lightweight pipeline returns dicts with 'audio' and 'sampling_rate'
for chunk in chunks:
narrated = text_speech(chunk)
segment = narrated["audio"].flatten()
audio_segments.append(segment)
full_audio = np.concatenate(audio_segments, axis=0)
sr = narrated["sampling_rate"]
# Normalize and convert to 16-bit PCM
full_audio = full_audio / np.max(np.abs(full_audio))
full_audio = full_audio * 0.7
buffer = io.BytesIO()
wavwrite(buffer, sr, (full_audio * 32767).astype(np.int16))
buffer.seek(0)
return buffer.read()
# ------------------------
# Prompt builder
# ------------------------
def build_story_prompt(day_text: str, mood: str, theme: str, for_kids: bool, story_length: str) -> str:
persona = "Write a soothing bedtime story"
if for_kids:
persona += " for a young child (age 3–8)."
else:
persona += " for an adult who wants to fall asleep."
extras = f"Tone: {mood}. Theme: {theme}."
if "Short" in story_length:
length_instruction = "Length: about 300 words."
elif "Long" in story_length:
length_instruction = "Length: about 700 words."
elif "Extra" in story_length:
length_instruction = "Length: about 1500 words."
elif "Epic" in story_length:
length_instruction = "Length: about 3000 words."
else:
length_instruction = "Length: about 6000 words."
instructions = (
f"{length_instruction} Avoid tension. "
"Feature a calming guide character. End with: "
"'And soon, everything drifted into quiet dreams.'"
)
return (
f"{persona}"
f"\nSummary of the day: {day_text}\n"
f"{extras}\n"
f"{instructions}\n\n"
"The story must remain wholesome, family-friendly, safe, and non-political. "
"No adult content, no violence, no sexual themes, and no harmful scenarios.\n\n"
"Begin the story with:\n"
"Once upon a time,"
)
def apply_fade(audio: AudioSegment, fade_ms: int = 200):
if fade_ms > 0:
return audio.fade_in(fade_ms).fade_out(fade_ms)
return audio
# ------------------------
# MAIN FUNCTION
# ------------------------
@spaces.GPU
def generate_and_render(day_text, kids_day_choice, mood, theme, for_kids, voice, story_length, continue_story, bg_music_on, bg_choice, bg_volume):
"""
Generates story text, TTS audio (with optional background), images, and video slideshow of the story.
"""
global story_memory
# RESET MEMORY if user does NOT want continuation
if not continue_story:
story_memory = ""
# SAFETY CHECK ONLY ON USER INPUT
if for_kids:
safe_check_source = kids_day_choice
day_summary = kids_day_choice
else:
safe_check_source = day_text
day_summary = day_text
if not is_prompt_safe(safe_check_source):
return "❌ Your input contains disallowed content. Please change it.", None, None
# 1️⃣ Build prompt
if continue_story and story_memory.strip():
prompt = story_memory + " Continue the story softly:"
else:
prompt = build_story_prompt(day_summary, mood, theme, for_kids, story_length)
# 2️⃣ Generate story
max_tokens_map = {
"Short (~300 words)": 700,
"Long (~700 words)": 1400,
"Extra Long (~1500 words)": 2500,
"Epic (~3000 words)": 4096,
"Mega (~6000 words)": 5500
}
max_tokens = max_tokens_map.get(story_length, 700)
try:
story = call_text_model(prompt, max_tokens=max_tokens)
except Exception as e:
return f"Error generating story: {e}", None, None
# SAVE story to memory ONLY if continuation is allowed
if continue_story:
story_memory += " " + story
# story_memory += story # append to memory
# 3️⃣ TTS generation
try:
audio_bytes = call_tts_model(story, voice)
except Exception as e:
return f"Story generated, but TTS failed: {e}{story}", None, None
tmp_voice = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
tmp_voice.write(audio_bytes)
tmp_voice.close()
# 4️⃣ Background music mixing
if bg_music_on:
try:
voice_audio = AudioSegment.from_file(tmp_voice.name)
bg_music = load_background(bg_choice)
if bg_music:
bg_music = bg_music.set_frame_rate(voice_audio.frame_rate).set_channels(voice_audio.channels)
if len(bg_music) < len(voice_audio):
loops = int(len(voice_audio)/len(bg_music)) + 1
bg_music = (bg_music*loops)[:len(voice_audio)]
else:
bg_music = bg_music[:len(voice_audio)]
bg_music = apply_fade(bg_music, fade_ms=500)
voice_audio = apply_fade(voice_audio, fade_ms=80)
vol_db = -10 + (bg_volume / 100) * 16
bg_music = bg_music + vol_db
mixed = bg_music.overlay(voice_audio)
final_audio = tempfile.NamedTemporaryFile(delete=False, suffix=".mp3")
mixed.export(final_audio.name, format="mp3")
else:
final_audio = tmp_voice
except Exception as e:
print("Background mixing failed:", e)
final_audio = tmp_voice
else:
final_audio = tmp_voice
# 5️⃣ Generate story images
try:
images = generate_story_images(story, num_images=5)
video_path = images_to_video(images, fps=2)
except Exception as e:
print("Story image/video generation failed:", e)
video_path = None
return story, final_audio.name, video_path
# ------------------------
# Memory reset
# ------------------------
def reset_memory():
global story_memory
story_memory = ""
return "Story memory cleared."
# ------------------------
# GRADIO UI
# ------------------------
with gr.Blocks(title="AI Sleep Story Generator") as demo:
gr.Markdown("# 🌙 AI Sleep Story Generator")
gr.Markdown("⚠️ This app follows Stable Diffusion’s OpenRAIL-M license and blocks NSFW, political, and harmful content.")
with gr.Row():
with gr.Column(scale=2):
day_text = gr.Textbox(label="Tell me about your day", lines=4)
# Replace or add to the text input for day summary
# Dropdown for kids who cannot type
kids_day_dropdown = gr.Dropdown(
choices=kids_day_options,
value=kids_day_options[0],
label="Select your day (for kids who cannot write)"
)
mood = gr.Dropdown(
["calm", "soothing", "nostalgic", "drowsy", "dreamlike"],
value="calm",
)
theme = gr.Dropdown(
["gentle forest", "ocean breeze", "starlit sky", "quiet village", "magical garden"],
value="gentle forest",
)
voice_dropdown = gr.Dropdown(
choices=list(speaker_options.keys()),
value=list(speaker_options.keys())[0],
label="Choose Voice"
)
story_length = gr.Dropdown(
[
"Short (~300 words)",
"Long (~700 words)",
"Extra Long (~1500 words)",
"Epic (~3000 words)",
"Mega (~6000 words)"
],
value="Short (~300 words)",
label="Story Length",
)
bg_music_on = gr.Checkbox(label="Play soothing music", value=True)
bg_choice = gr.Dropdown(list(AVAILABLE_BGS.keys()), value="forest", label="Background sound")
# bg_volume = gr.Slider(minimum=0, maximum=100, step=1, value=30, label="Background volume")
bg_volume = gr.Slider(0, 100, 30, step=1, label="Background volume")
for_kids = gr.Checkbox(label="Child-friendly (3–8 yrs)", value=False)
continue_story = gr.Checkbox(label="Continue previous story", value=False)
generate_btn = gr.Button("Generate Story")
reset_btn = gr.Button("Reset Story Memory")
with gr.Column(scale=3):
story_output = gr.Textbox(label="Generated story", lines=20)
audio_output = gr.Audio(label="Play story audio", type="filepath")
video_output = gr.Video(label="Story video", format="mp4")
generate_btn.click(
fn=generate_and_render,
inputs=[day_text, kids_day_dropdown, mood, theme, for_kids, voice_dropdown, story_length, continue_story, bg_music_on, bg_choice, bg_volume],
outputs=[story_output, audio_output, video_output],
)
reset_btn.click(fn=reset_memory, inputs=None, outputs=[story_output])
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)))