|
|
|
|
|
""" |
|
|
Feature extraction for IndicWav2Vec Hindi ASR |
|
|
|
|
|
This module provides feature extraction capabilities using the IndicWav2Vec Hindi model. |
|
|
Focused on ASR transcription features rather than hybrid acoustic+linguistic features. |
|
|
""" |
|
|
import torch |
|
|
import numpy as np |
|
|
import logging |
|
|
from typing import Dict, Any, Tuple, Optional |
|
|
from transformers import Wav2Vec2ForCTC, AutoProcessor |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class ASRFeatureExtractor: |
|
|
""" |
|
|
Feature extractor using IndicWav2Vec Hindi for Automatic Speech Recognition. |
|
|
|
|
|
This extractor focuses on: |
|
|
- Audio feature extraction via IndicWav2Vec |
|
|
- Transcription confidence scores |
|
|
- Frame-level predictions and logits |
|
|
- Word-level alignments (estimated) |
|
|
|
|
|
Model: ai4bharat/indicwav2vec-hindi |
|
|
""" |
|
|
|
|
|
def __init__(self, model: Wav2Vec2ForCTC, processor: AutoProcessor, device: str = "cpu"): |
|
|
""" |
|
|
Initialize the ASR feature extractor. |
|
|
|
|
|
Args: |
|
|
model: Pre-loaded IndicWav2Vec Hindi model |
|
|
processor: Pre-loaded processor for the model |
|
|
device: Device to run inference on ('cpu' or 'cuda') |
|
|
""" |
|
|
self.model = model |
|
|
self.processor = processor |
|
|
self.device = device |
|
|
self.model.eval() |
|
|
logger.info(f"✅ ASRFeatureExtractor initialized on {device}") |
|
|
|
|
|
def extract_audio_features(self, audio: np.ndarray, sample_rate: int = 16000) -> Dict[str, Any]: |
|
|
""" |
|
|
Extract features from audio using IndicWav2Vec Hindi. |
|
|
|
|
|
Args: |
|
|
audio: Audio waveform as numpy array |
|
|
sample_rate: Sample rate of the audio (default: 16000) |
|
|
|
|
|
Returns: |
|
|
Dictionary containing: |
|
|
- input_values: Processed audio features |
|
|
- attention_mask: Attention mask (if available) |
|
|
""" |
|
|
try: |
|
|
|
|
|
inputs = self.processor( |
|
|
audio, |
|
|
sampling_rate=sample_rate, |
|
|
return_tensors="pt" |
|
|
).to(self.device) |
|
|
|
|
|
return { |
|
|
'input_values': inputs.input_values, |
|
|
'attention_mask': inputs.get('attention_mask', None) |
|
|
} |
|
|
except Exception as e: |
|
|
logger.error(f"❌ Error extracting audio features: {e}") |
|
|
raise |
|
|
|
|
|
def get_transcription_features( |
|
|
self, |
|
|
audio: np.ndarray, |
|
|
sample_rate: int = 16000 |
|
|
) -> Dict[str, Any]: |
|
|
""" |
|
|
Get transcription features including logits, predictions, and confidence. |
|
|
|
|
|
Args: |
|
|
audio: Audio waveform as numpy array |
|
|
sample_rate: Sample rate of the audio (default: 16000) |
|
|
|
|
|
Returns: |
|
|
Dictionary containing: |
|
|
- transcript: Transcribed text |
|
|
- logits: Model logits (raw predictions) |
|
|
- predicted_ids: Predicted token IDs |
|
|
- probabilities: Softmax probabilities |
|
|
- confidence: Average confidence score |
|
|
- frame_confidence: Per-frame confidence scores |
|
|
""" |
|
|
try: |
|
|
|
|
|
inputs = self.processor( |
|
|
audio, |
|
|
sampling_rate=sample_rate, |
|
|
return_tensors="pt" |
|
|
).to(self.device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = self.model(**inputs) |
|
|
logits = outputs.logits |
|
|
predicted_ids = torch.argmax(logits, dim=-1) |
|
|
|
|
|
|
|
|
probs = torch.softmax(logits, dim=-1) |
|
|
max_probs = torch.max(probs, dim=-1)[0] |
|
|
frame_confidence = max_probs[0].cpu().numpy() |
|
|
avg_confidence = float(torch.mean(max_probs).item()) |
|
|
|
|
|
|
|
|
transcript = "" |
|
|
try: |
|
|
if hasattr(self.processor, 'tokenizer'): |
|
|
transcript = self.processor.tokenizer.decode( |
|
|
predicted_ids[0], |
|
|
skip_special_tokens=True |
|
|
) |
|
|
elif hasattr(self.processor, 'batch_decode'): |
|
|
transcript = self.processor.batch_decode(predicted_ids)[0] |
|
|
|
|
|
|
|
|
if transcript: |
|
|
transcript = transcript.strip() |
|
|
transcript = transcript.replace('<pad>', '').replace('<s>', '').replace('</s>', '').replace('|', ' ').strip() |
|
|
transcript = ' '.join(transcript.split()) |
|
|
except Exception as e: |
|
|
logger.warning(f"⚠️ Decode error: {e}") |
|
|
transcript = "" |
|
|
|
|
|
return { |
|
|
'transcript': transcript, |
|
|
'logits': logits.cpu().numpy(), |
|
|
'predicted_ids': predicted_ids.cpu().numpy(), |
|
|
'probabilities': probs.cpu().numpy(), |
|
|
'confidence': avg_confidence, |
|
|
'frame_confidence': frame_confidence, |
|
|
'num_frames': logits.shape[1] |
|
|
} |
|
|
except Exception as e: |
|
|
logger.error(f"❌ Error getting transcription features: {e}") |
|
|
raise |
|
|
|
|
|
def get_word_level_features( |
|
|
self, |
|
|
audio: np.ndarray, |
|
|
sample_rate: int = 16000 |
|
|
) -> Dict[str, Any]: |
|
|
""" |
|
|
Get word-level features including timestamps and confidence. |
|
|
|
|
|
Args: |
|
|
audio: Audio waveform as numpy array |
|
|
sample_rate: Sample rate of the audio (default: 16000) |
|
|
|
|
|
Returns: |
|
|
Dictionary containing: |
|
|
- words: List of words |
|
|
- word_timestamps: List of (start, end) timestamps for each word |
|
|
- word_confidence: Confidence score for each word |
|
|
""" |
|
|
try: |
|
|
|
|
|
features = self.get_transcription_features(audio, sample_rate) |
|
|
transcript = features['transcript'] |
|
|
frame_confidence = features['frame_confidence'] |
|
|
num_frames = features['num_frames'] |
|
|
|
|
|
|
|
|
words = transcript.split() if transcript else [] |
|
|
audio_duration = len(audio) / sample_rate |
|
|
time_per_word = audio_duration / max(len(words), 1) if words else 0 |
|
|
|
|
|
word_timestamps = [] |
|
|
word_confidence = [] |
|
|
|
|
|
for i, word in enumerate(words): |
|
|
start_time = i * time_per_word |
|
|
end_time = (i + 1) * time_per_word |
|
|
|
|
|
|
|
|
start_frame = int((start_time / audio_duration) * num_frames) |
|
|
end_frame = int((end_time / audio_duration) * num_frames) |
|
|
word_conf = float(np.mean(frame_confidence[start_frame:end_frame])) if end_frame > start_frame else 0.5 |
|
|
|
|
|
word_timestamps.append({ |
|
|
'word': word, |
|
|
'start': start_time, |
|
|
'end': end_time |
|
|
}) |
|
|
word_confidence.append(word_conf) |
|
|
|
|
|
return { |
|
|
'words': words, |
|
|
'word_timestamps': word_timestamps, |
|
|
'word_confidence': word_confidence, |
|
|
'transcript': transcript |
|
|
} |
|
|
except Exception as e: |
|
|
logger.error(f"❌ Error getting word-level features: {e}") |
|
|
raise |
|
|
|
|
|
|