import gradio as gr import torch import librosa import numpy as np from transformers import AutoFeatureExtractor, AutoModel, AutoTokenizer import sys # Configuration MODEL_ID = "abr-ai/asr-19m-v2-en-32b" def load_components(): """ Loads the model, feature extractor, and tokenizer with remote code trust. """ print(f"Loading components for: {MODEL_ID}...") device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Running on device: {device}") try: # Load Feature Extractor feature_extractor = AutoFeatureExtractor.from_pretrained( MODEL_ID, trust_remote_code=True ) # Load Model model = AutoModel.from_pretrained(MODEL_ID, trust_remote_code=True) model = model.to(device) model.eval() # Set to evaluation mode # Load Tokenizer tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) return feature_extractor, model, tokenizer, device except Exception as e: print(f"Error loading components: {e}") sys.exit(1) # Initialize components once at startup feature_extractor, model, tokenizer, device = load_components() def transcribe_audio(audio_filepath): """ Gradio callback function to transcribe audio using the custom inference loop. """ if audio_filepath is None: return "Please provide an audio input." try: print(f"Processing audio file: {audio_filepath}") # 1. Load and resample audio # The model specifically requires 16000 Hz Mono-channel Audio. target_sr = 16000 # librosa.load automatically handles loading, converting to mono, and resampling. # We ensure it's float32 which is standard for PyTorch audio_array, _ = librosa.load(audio_filepath, sr=target_sr, mono=True) audio_array = audio_array.astype(np.float32) # 2. Extract Features # feature_extractor returns a dictionary. return_tensors="pt" creates PyTorch tensors with batch dim. features = feature_extractor(audio_array, return_tensors="pt") input_tensor = None mask = None # Check if features is directly a Tensor (fixes the "Tensor.__contains__" error) if isinstance(features, torch.Tensor): input_tensor = features.to(device) elif isinstance(features, dict): # Move input features to device if "input_features" in features: input_tensor = features["input_features"].to(device) elif "input_values" in features: input_tensor = features["input_values"].to(device) else: # Fallback: check keys for debugging keys = list(features.keys()) raise ValueError(f"Unknown feature key. Available keys: {keys}") # Handle mask mask = features.get("mask") if mask is None: mask = features.get("attention_mask") if mask is not None: mask = mask.to(device) else: raise ValueError(f"Unexpected return type from feature_extractor: {type(features)}") # 3. Model Inference with torch.no_grad(): if mask is not None: outputs = model(input_tensor, mask=mask) else: outputs = model(input_tensor) # 4. Decode # Extract logits if isinstance(outputs, dict): logits = outputs["logits"] elif isinstance(outputs, (tuple, list)): logits = outputs[0] else: logits = outputs # Do NOT remove batch dimension. The decode_from_logits method expects 3D input (Batch, Time, Vocab). # We also pass the mask to the decoder as per the usage documentation. transcription = tokenizer.decode_from_logits(logits, mask=mask) # The decoder likely returns a list of strings (one per batch item) if isinstance(transcription, list): return transcription[0] return transcription except Exception as e: print(f"Error: {str(e)}") # Return error to UI for easier debugging return f"Error during transcription: {str(e)}" # Define the CSS for the interface custom_css = """ .container { max-width: 800px; margin: auto; } .header-link { font-size: 0.9rem; color: #666; text-decoration: none; } .header-link:hover { color: #ff7e5f; text-decoration: underline; } """ # Build the Gradio Interface with gr.Blocks(theme=gr.themes.Soft(), css=custom_css) as demo: gr.Markdown(f"# 🎙️ ASR Demo: {MODEL_ID}") gr.Markdown("Upload an audio file or record from your microphone to transcribe it using the ABR-AI model with custom inference.") with gr.Row(): with gr.Column(): audio_input = gr.Audio( sources=["microphone", "upload"], type="filepath", label="Audio Input" ) submit_btn = gr.Button("Transcribe", variant="primary") with gr.Column(): text_output = gr.Textbox( label="Transcription", show_copy_button=True, lines=10, max_lines=30 ) submit_btn.click( fn=transcribe_audio, inputs=audio_input, outputs=text_output ) if __name__ == "__main__": demo.launch()