File size: 7,747 Bytes
74a089b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 |
# diagnosis/ai_engine/features.py
"""
Feature extraction for IndicWav2Vec Hindi ASR
This module provides feature extraction capabilities using the IndicWav2Vec Hindi model.
Focused on ASR transcription features rather than hybrid acoustic+linguistic features.
"""
import torch
import numpy as np
import logging
from typing import Dict, Any, Tuple, Optional
from transformers import Wav2Vec2ForCTC, AutoProcessor
logger = logging.getLogger(__name__)
class ASRFeatureExtractor:
"""
Feature extractor using IndicWav2Vec Hindi for Automatic Speech Recognition.
This extractor focuses on:
- Audio feature extraction via IndicWav2Vec
- Transcription confidence scores
- Frame-level predictions and logits
- Word-level alignments (estimated)
Model: ai4bharat/indicwav2vec-hindi
"""
def __init__(self, model: Wav2Vec2ForCTC, processor: AutoProcessor, device: str = "cpu"):
"""
Initialize the ASR feature extractor.
Args:
model: Pre-loaded IndicWav2Vec Hindi model
processor: Pre-loaded processor for the model
device: Device to run inference on ('cpu' or 'cuda')
"""
self.model = model
self.processor = processor
self.device = device
self.model.eval()
logger.info(f"✅ ASRFeatureExtractor initialized on {device}")
def extract_audio_features(self, audio: np.ndarray, sample_rate: int = 16000) -> Dict[str, Any]:
"""
Extract features from audio using IndicWav2Vec Hindi.
Args:
audio: Audio waveform as numpy array
sample_rate: Sample rate of the audio (default: 16000)
Returns:
Dictionary containing:
- input_values: Processed audio features
- attention_mask: Attention mask (if available)
"""
try:
# Process audio through the processor
inputs = self.processor(
audio,
sampling_rate=sample_rate,
return_tensors="pt"
).to(self.device)
return {
'input_values': inputs.input_values,
'attention_mask': inputs.get('attention_mask', None)
}
except Exception as e:
logger.error(f"❌ Error extracting audio features: {e}")
raise
def get_transcription_features(
self,
audio: np.ndarray,
sample_rate: int = 16000
) -> Dict[str, Any]:
"""
Get transcription features including logits, predictions, and confidence.
Args:
audio: Audio waveform as numpy array
sample_rate: Sample rate of the audio (default: 16000)
Returns:
Dictionary containing:
- transcript: Transcribed text
- logits: Model logits (raw predictions)
- predicted_ids: Predicted token IDs
- probabilities: Softmax probabilities
- confidence: Average confidence score
- frame_confidence: Per-frame confidence scores
"""
try:
# Process audio
inputs = self.processor(
audio,
sampling_rate=sample_rate,
return_tensors="pt"
).to(self.device)
# Get model predictions
with torch.no_grad():
outputs = self.model(**inputs)
logits = outputs.logits
predicted_ids = torch.argmax(logits, dim=-1)
# Calculate probabilities and confidence
probs = torch.softmax(logits, dim=-1)
max_probs = torch.max(probs, dim=-1)[0] # Get max probability per frame
frame_confidence = max_probs[0].cpu().numpy()
avg_confidence = float(torch.mean(max_probs).item())
# Decode transcript
transcript = ""
try:
if hasattr(self.processor, 'tokenizer'):
transcript = self.processor.tokenizer.decode(
predicted_ids[0],
skip_special_tokens=True
)
elif hasattr(self.processor, 'batch_decode'):
transcript = self.processor.batch_decode(predicted_ids)[0]
# Clean up transcript
if transcript:
transcript = transcript.strip()
transcript = transcript.replace('<pad>', '').replace('<s>', '').replace('</s>', '').replace('|', ' ').strip()
transcript = ' '.join(transcript.split())
except Exception as e:
logger.warning(f"⚠️ Decode error: {e}")
transcript = ""
return {
'transcript': transcript,
'logits': logits.cpu().numpy(),
'predicted_ids': predicted_ids.cpu().numpy(),
'probabilities': probs.cpu().numpy(),
'confidence': avg_confidence,
'frame_confidence': frame_confidence,
'num_frames': logits.shape[1]
}
except Exception as e:
logger.error(f"❌ Error getting transcription features: {e}")
raise
def get_word_level_features(
self,
audio: np.ndarray,
sample_rate: int = 16000
) -> Dict[str, Any]:
"""
Get word-level features including timestamps and confidence.
Args:
audio: Audio waveform as numpy array
sample_rate: Sample rate of the audio (default: 16000)
Returns:
Dictionary containing:
- words: List of words
- word_timestamps: List of (start, end) timestamps for each word
- word_confidence: Confidence score for each word
"""
try:
# Get transcription features
features = self.get_transcription_features(audio, sample_rate)
transcript = features['transcript']
frame_confidence = features['frame_confidence']
num_frames = features['num_frames']
# Estimate word-level timestamps (simplified)
words = transcript.split() if transcript else []
audio_duration = len(audio) / sample_rate
time_per_word = audio_duration / max(len(words), 1) if words else 0
word_timestamps = []
word_confidence = []
for i, word in enumerate(words):
start_time = i * time_per_word
end_time = (i + 1) * time_per_word
# Estimate confidence for this word (average of corresponding frames)
start_frame = int((start_time / audio_duration) * num_frames)
end_frame = int((end_time / audio_duration) * num_frames)
word_conf = float(np.mean(frame_confidence[start_frame:end_frame])) if end_frame > start_frame else 0.5
word_timestamps.append({
'word': word,
'start': start_time,
'end': end_time
})
word_confidence.append(word_conf)
return {
'words': words,
'word_timestamps': word_timestamps,
'word_confidence': word_confidence,
'transcript': transcript
}
except Exception as e:
logger.error(f"❌ Error getting word-level features: {e}")
raise
|