# diagnosis/ai_engine/features.py """ Feature extraction for IndicWav2Vec Hindi ASR This module provides feature extraction capabilities using the IndicWav2Vec Hindi model. Focused on ASR transcription features rather than hybrid acoustic+linguistic features. """ import torch import numpy as np import logging from typing import Dict, Any, Tuple, Optional from transformers import Wav2Vec2ForCTC, AutoProcessor logger = logging.getLogger(__name__) class ASRFeatureExtractor: """ Feature extractor using IndicWav2Vec Hindi for Automatic Speech Recognition. This extractor focuses on: - Audio feature extraction via IndicWav2Vec - Transcription confidence scores - Frame-level predictions and logits - Word-level alignments (estimated) Model: ai4bharat/indicwav2vec-hindi """ def __init__(self, model: Wav2Vec2ForCTC, processor: AutoProcessor, device: str = "cpu"): """ Initialize the ASR feature extractor. Args: model: Pre-loaded IndicWav2Vec Hindi model processor: Pre-loaded processor for the model device: Device to run inference on ('cpu' or 'cuda') """ self.model = model self.processor = processor self.device = device self.model.eval() logger.info(f"✅ ASRFeatureExtractor initialized on {device}") def extract_audio_features(self, audio: np.ndarray, sample_rate: int = 16000) -> Dict[str, Any]: """ Extract features from audio using IndicWav2Vec Hindi. Args: audio: Audio waveform as numpy array sample_rate: Sample rate of the audio (default: 16000) Returns: Dictionary containing: - input_values: Processed audio features - attention_mask: Attention mask (if available) """ try: # Process audio through the processor inputs = self.processor( audio, sampling_rate=sample_rate, return_tensors="pt" ).to(self.device) return { 'input_values': inputs.input_values, 'attention_mask': inputs.get('attention_mask', None) } except Exception as e: logger.error(f"❌ Error extracting audio features: {e}") raise def get_transcription_features( self, audio: np.ndarray, sample_rate: int = 16000 ) -> Dict[str, Any]: """ Get transcription features including logits, predictions, and confidence. Args: audio: Audio waveform as numpy array sample_rate: Sample rate of the audio (default: 16000) Returns: Dictionary containing: - transcript: Transcribed text - logits: Model logits (raw predictions) - predicted_ids: Predicted token IDs - probabilities: Softmax probabilities - confidence: Average confidence score - frame_confidence: Per-frame confidence scores """ try: # Process audio inputs = self.processor( audio, sampling_rate=sample_rate, return_tensors="pt" ).to(self.device) # Get model predictions with torch.no_grad(): outputs = self.model(**inputs) logits = outputs.logits predicted_ids = torch.argmax(logits, dim=-1) # Calculate probabilities and confidence probs = torch.softmax(logits, dim=-1) max_probs = torch.max(probs, dim=-1)[0] # Get max probability per frame frame_confidence = max_probs[0].cpu().numpy() avg_confidence = float(torch.mean(max_probs).item()) # Decode transcript transcript = "" try: if hasattr(self.processor, 'tokenizer'): transcript = self.processor.tokenizer.decode( predicted_ids[0], skip_special_tokens=True ) elif hasattr(self.processor, 'batch_decode'): transcript = self.processor.batch_decode(predicted_ids)[0] # Clean up transcript if transcript: transcript = transcript.strip() transcript = transcript.replace('', '').replace('', '').replace('', '').replace('|', ' ').strip() transcript = ' '.join(transcript.split()) except Exception as e: logger.warning(f"⚠️ Decode error: {e}") transcript = "" return { 'transcript': transcript, 'logits': logits.cpu().numpy(), 'predicted_ids': predicted_ids.cpu().numpy(), 'probabilities': probs.cpu().numpy(), 'confidence': avg_confidence, 'frame_confidence': frame_confidence, 'num_frames': logits.shape[1] } except Exception as e: logger.error(f"❌ Error getting transcription features: {e}") raise def get_word_level_features( self, audio: np.ndarray, sample_rate: int = 16000 ) -> Dict[str, Any]: """ Get word-level features including timestamps and confidence. Args: audio: Audio waveform as numpy array sample_rate: Sample rate of the audio (default: 16000) Returns: Dictionary containing: - words: List of words - word_timestamps: List of (start, end) timestamps for each word - word_confidence: Confidence score for each word """ try: # Get transcription features features = self.get_transcription_features(audio, sample_rate) transcript = features['transcript'] frame_confidence = features['frame_confidence'] num_frames = features['num_frames'] # Estimate word-level timestamps (simplified) words = transcript.split() if transcript else [] audio_duration = len(audio) / sample_rate time_per_word = audio_duration / max(len(words), 1) if words else 0 word_timestamps = [] word_confidence = [] for i, word in enumerate(words): start_time = i * time_per_word end_time = (i + 1) * time_per_word # Estimate confidence for this word (average of corresponding frames) start_frame = int((start_time / audio_duration) * num_frames) end_frame = int((end_time / audio_duration) * num_frames) word_conf = float(np.mean(frame_confidence[start_frame:end_frame])) if end_frame > start_frame else 0.5 word_timestamps.append({ 'word': word, 'start': start_time, 'end': end_time }) word_confidence.append(word_conf) return { 'words': words, 'word_timestamps': word_timestamps, 'word_confidence': word_confidence, 'transcript': transcript } except Exception as e: logger.error(f"❌ Error getting word-level features: {e}") raise