Steveeeeeeen's picture
Steveeeeeeen HF Staff
Update app.py
ec1c52e verified
import gradio as gr
import torch
import librosa
import numpy as np
from transformers import AutoFeatureExtractor, AutoModel, AutoTokenizer
import sys
# Configuration
MODEL_ID = "abr-ai/asr-19m-v2-en-32b"
def load_components():
"""
Loads the model, feature extractor, and tokenizer with remote code trust.
"""
print(f"Loading components for: {MODEL_ID}...")
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Running on device: {device}")
try:
# Load Feature Extractor
feature_extractor = AutoFeatureExtractor.from_pretrained(
MODEL_ID, trust_remote_code=True
)
# Load Model
model = AutoModel.from_pretrained(MODEL_ID, trust_remote_code=True)
model = model.to(device)
model.eval() # Set to evaluation mode
# Load Tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
return feature_extractor, model, tokenizer, device
except Exception as e:
print(f"Error loading components: {e}")
sys.exit(1)
# Initialize components once at startup
feature_extractor, model, tokenizer, device = load_components()
def transcribe_audio(audio_filepath):
"""
Gradio callback function to transcribe audio using the custom inference loop.
"""
if audio_filepath is None:
return "Please provide an audio input."
try:
print(f"Processing audio file: {audio_filepath}")
# 1. Load and resample audio
# The model specifically requires 16000 Hz Mono-channel Audio.
target_sr = 16000
# librosa.load automatically handles loading, converting to mono, and resampling.
# We ensure it's float32 which is standard for PyTorch
audio_array, _ = librosa.load(audio_filepath, sr=target_sr, mono=True)
audio_array = audio_array.astype(np.float32)
# 2. Extract Features
# feature_extractor returns a dictionary. return_tensors="pt" creates PyTorch tensors with batch dim.
features = feature_extractor(audio_array, return_tensors="pt")
input_tensor = None
mask = None
# Check if features is directly a Tensor (fixes the "Tensor.__contains__" error)
if isinstance(features, torch.Tensor):
input_tensor = features.to(device)
elif isinstance(features, dict):
# Move input features to device
if "input_features" in features:
input_tensor = features["input_features"].to(device)
elif "input_values" in features:
input_tensor = features["input_values"].to(device)
else:
# Fallback: check keys for debugging
keys = list(features.keys())
raise ValueError(f"Unknown feature key. Available keys: {keys}")
# Handle mask
mask = features.get("mask")
if mask is None:
mask = features.get("attention_mask")
if mask is not None:
mask = mask.to(device)
else:
raise ValueError(f"Unexpected return type from feature_extractor: {type(features)}")
# 3. Model Inference
with torch.no_grad():
if mask is not None:
outputs = model(input_tensor, mask=mask)
else:
outputs = model(input_tensor)
# 4. Decode
# Extract logits
if isinstance(outputs, dict):
logits = outputs["logits"]
elif isinstance(outputs, (tuple, list)):
logits = outputs[0]
else:
logits = outputs
# Do NOT remove batch dimension. The decode_from_logits method expects 3D input (Batch, Time, Vocab).
# We also pass the mask to the decoder as per the usage documentation.
transcription = tokenizer.decode_from_logits(logits, mask=mask)
# The decoder likely returns a list of strings (one per batch item)
if isinstance(transcription, list):
return transcription[0]
return transcription
except Exception as e:
print(f"Error: {str(e)}")
# Return error to UI for easier debugging
return f"Error during transcription: {str(e)}"
# Define the CSS for the interface
custom_css = """
.container { max-width: 800px; margin: auto; }
.header-link { font-size: 0.9rem; color: #666; text-decoration: none; }
.header-link:hover { color: #ff7e5f; text-decoration: underline; }
"""
# Build the Gradio Interface
with gr.Blocks(theme=gr.themes.Soft(), css=custom_css) as demo:
gr.Markdown(f"# ๐ŸŽ™๏ธ ASR Demo: {MODEL_ID}")
gr.Markdown("Upload an audio file or record from your microphone to transcribe it using the ABR-AI model with custom inference.")
with gr.Row():
with gr.Column():
audio_input = gr.Audio(
sources=["microphone", "upload"],
type="filepath",
label="Audio Input"
)
submit_btn = gr.Button("Transcribe", variant="primary")
with gr.Column():
text_output = gr.Textbox(
label="Transcription",
show_copy_button=True,
lines=10,
max_lines=30
)
submit_btn.click(
fn=transcribe_audio,
inputs=audio_input,
outputs=text_output
)
if __name__ == "__main__":
demo.launch()