Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import librosa | |
| import numpy as np | |
| from transformers import AutoFeatureExtractor, AutoModel, AutoTokenizer | |
| import sys | |
| # Configuration | |
| MODEL_ID = "abr-ai/asr-19m-v2-en-32b" | |
| def load_components(): | |
| """ | |
| Loads the model, feature extractor, and tokenizer with remote code trust. | |
| """ | |
| print(f"Loading components for: {MODEL_ID}...") | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"Running on device: {device}") | |
| try: | |
| # Load Feature Extractor | |
| feature_extractor = AutoFeatureExtractor.from_pretrained( | |
| MODEL_ID, trust_remote_code=True | |
| ) | |
| # Load Model | |
| model = AutoModel.from_pretrained(MODEL_ID, trust_remote_code=True) | |
| model = model.to(device) | |
| model.eval() # Set to evaluation mode | |
| # Load Tokenizer | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) | |
| return feature_extractor, model, tokenizer, device | |
| except Exception as e: | |
| print(f"Error loading components: {e}") | |
| sys.exit(1) | |
| # Initialize components once at startup | |
| feature_extractor, model, tokenizer, device = load_components() | |
| def transcribe_audio(audio_filepath): | |
| """ | |
| Gradio callback function to transcribe audio using the custom inference loop. | |
| """ | |
| if audio_filepath is None: | |
| return "Please provide an audio input." | |
| try: | |
| print(f"Processing audio file: {audio_filepath}") | |
| # 1. Load and resample audio | |
| # The model specifically requires 16000 Hz Mono-channel Audio. | |
| target_sr = 16000 | |
| # librosa.load automatically handles loading, converting to mono, and resampling. | |
| # We ensure it's float32 which is standard for PyTorch | |
| audio_array, _ = librosa.load(audio_filepath, sr=target_sr, mono=True) | |
| audio_array = audio_array.astype(np.float32) | |
| # 2. Extract Features | |
| # feature_extractor returns a dictionary. return_tensors="pt" creates PyTorch tensors with batch dim. | |
| features = feature_extractor(audio_array, return_tensors="pt") | |
| input_tensor = None | |
| mask = None | |
| # Check if features is directly a Tensor (fixes the "Tensor.__contains__" error) | |
| if isinstance(features, torch.Tensor): | |
| input_tensor = features.to(device) | |
| elif isinstance(features, dict): | |
| # Move input features to device | |
| if "input_features" in features: | |
| input_tensor = features["input_features"].to(device) | |
| elif "input_values" in features: | |
| input_tensor = features["input_values"].to(device) | |
| else: | |
| # Fallback: check keys for debugging | |
| keys = list(features.keys()) | |
| raise ValueError(f"Unknown feature key. Available keys: {keys}") | |
| # Handle mask | |
| mask = features.get("mask") | |
| if mask is None: | |
| mask = features.get("attention_mask") | |
| if mask is not None: | |
| mask = mask.to(device) | |
| else: | |
| raise ValueError(f"Unexpected return type from feature_extractor: {type(features)}") | |
| # 3. Model Inference | |
| with torch.no_grad(): | |
| if mask is not None: | |
| outputs = model(input_tensor, mask=mask) | |
| else: | |
| outputs = model(input_tensor) | |
| # 4. Decode | |
| # Extract logits | |
| if isinstance(outputs, dict): | |
| logits = outputs["logits"] | |
| elif isinstance(outputs, (tuple, list)): | |
| logits = outputs[0] | |
| else: | |
| logits = outputs | |
| # Do NOT remove batch dimension. The decode_from_logits method expects 3D input (Batch, Time, Vocab). | |
| # We also pass the mask to the decoder as per the usage documentation. | |
| transcription = tokenizer.decode_from_logits(logits, mask=mask) | |
| # The decoder likely returns a list of strings (one per batch item) | |
| if isinstance(transcription, list): | |
| return transcription[0] | |
| return transcription | |
| except Exception as e: | |
| print(f"Error: {str(e)}") | |
| # Return error to UI for easier debugging | |
| return f"Error during transcription: {str(e)}" | |
| # Define the CSS for the interface | |
| custom_css = """ | |
| .container { max-width: 800px; margin: auto; } | |
| .header-link { font-size: 0.9rem; color: #666; text-decoration: none; } | |
| .header-link:hover { color: #ff7e5f; text-decoration: underline; } | |
| """ | |
| # Build the Gradio Interface | |
| with gr.Blocks(theme=gr.themes.Soft(), css=custom_css) as demo: | |
| gr.Markdown(f"# ๐๏ธ ASR Demo: {MODEL_ID}") | |
| gr.Markdown("Upload an audio file or record from your microphone to transcribe it using the ABR-AI model with custom inference.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| audio_input = gr.Audio( | |
| sources=["microphone", "upload"], | |
| type="filepath", | |
| label="Audio Input" | |
| ) | |
| submit_btn = gr.Button("Transcribe", variant="primary") | |
| with gr.Column(): | |
| text_output = gr.Textbox( | |
| label="Transcription", | |
| show_copy_button=True, | |
| lines=10, | |
| max_lines=30 | |
| ) | |
| submit_btn.click( | |
| fn=transcribe_audio, | |
| inputs=audio_input, | |
| outputs=text_output | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |