Spaces:
Running
Running
Curinha
Refactor sound generation functions to remove user_id parameter and adjust GPU duration
671b217
| import time | |
| import spaces | |
| import torch | |
| from audiocraft.data.audio import audio_write | |
| from audiocraft.models import AudioGen, MusicGen | |
| # Load the pretrained models and move them to GPU if available | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print("Using device:", device) | |
| sound_model = AudioGen.get_pretrained('facebook/audiogen-medium') | |
| music_model = MusicGen.get_pretrained('facebook/musicgen-small') | |
| # Set generation parameters (for example, audio duration of 5 seconds) | |
| sound_model.set_generation_params(duration=5) | |
| music_model.set_generation_params(duration=5) | |
| def generate_sound(prompt: str): | |
| """ | |
| Generate sound using Audiocraft based on the given prompt. | |
| Args: | |
| - prompt (str): The description of the sound/music to generate. | |
| Returns: | |
| - str: The path to the saved audio file. | |
| """ | |
| descriptions = [prompt] | |
| timestamp = str(time.time()).replace(".", "") | |
| wav = sound_model.generate(descriptions) # Generate audio | |
| output_path = f'{prompt}_{timestamp}' | |
| audio_write(output_path, wav[0].cpu(), sound_model.sample_rate, strategy="loudness") | |
| return f"{output_path}.wav" | |
| def generate_music(prompt: str): | |
| """ | |
| Generate music using Audiocraft based on the given prompt. | |
| Args: | |
| - prompt (str): The description of the music to generate. | |
| Returns: | |
| - str: The path to the saved audio file. | |
| """ | |
| descriptions = [prompt] | |
| timestamp = str(time.time()).replace(".", "") | |
| wav = music_model.generate(descriptions) # Generate music | |
| output_path = f'{prompt}_{timestamp}' | |
| audio_write(output_path, wav[0].cpu(), music_model.sample_rate, strategy="loudness") | |
| return f"{output_path}.wav" | |