anfastech's picture
Updation: ML/AI logic is now in the AI engine service
74a089b
# diagnosis/ai_engine/features.py
"""
Feature extraction for IndicWav2Vec Hindi ASR
This module provides feature extraction capabilities using the IndicWav2Vec Hindi model.
Focused on ASR transcription features rather than hybrid acoustic+linguistic features.
"""
import torch
import numpy as np
import logging
from typing import Dict, Any, Tuple, Optional
from transformers import Wav2Vec2ForCTC, AutoProcessor
logger = logging.getLogger(__name__)
class ASRFeatureExtractor:
"""
Feature extractor using IndicWav2Vec Hindi for Automatic Speech Recognition.
This extractor focuses on:
- Audio feature extraction via IndicWav2Vec
- Transcription confidence scores
- Frame-level predictions and logits
- Word-level alignments (estimated)
Model: ai4bharat/indicwav2vec-hindi
"""
def __init__(self, model: Wav2Vec2ForCTC, processor: AutoProcessor, device: str = "cpu"):
"""
Initialize the ASR feature extractor.
Args:
model: Pre-loaded IndicWav2Vec Hindi model
processor: Pre-loaded processor for the model
device: Device to run inference on ('cpu' or 'cuda')
"""
self.model = model
self.processor = processor
self.device = device
self.model.eval()
logger.info(f"✅ ASRFeatureExtractor initialized on {device}")
def extract_audio_features(self, audio: np.ndarray, sample_rate: int = 16000) -> Dict[str, Any]:
"""
Extract features from audio using IndicWav2Vec Hindi.
Args:
audio: Audio waveform as numpy array
sample_rate: Sample rate of the audio (default: 16000)
Returns:
Dictionary containing:
- input_values: Processed audio features
- attention_mask: Attention mask (if available)
"""
try:
# Process audio through the processor
inputs = self.processor(
audio,
sampling_rate=sample_rate,
return_tensors="pt"
).to(self.device)
return {
'input_values': inputs.input_values,
'attention_mask': inputs.get('attention_mask', None)
}
except Exception as e:
logger.error(f"❌ Error extracting audio features: {e}")
raise
def get_transcription_features(
self,
audio: np.ndarray,
sample_rate: int = 16000
) -> Dict[str, Any]:
"""
Get transcription features including logits, predictions, and confidence.
Args:
audio: Audio waveform as numpy array
sample_rate: Sample rate of the audio (default: 16000)
Returns:
Dictionary containing:
- transcript: Transcribed text
- logits: Model logits (raw predictions)
- predicted_ids: Predicted token IDs
- probabilities: Softmax probabilities
- confidence: Average confidence score
- frame_confidence: Per-frame confidence scores
"""
try:
# Process audio
inputs = self.processor(
audio,
sampling_rate=sample_rate,
return_tensors="pt"
).to(self.device)
# Get model predictions
with torch.no_grad():
outputs = self.model(**inputs)
logits = outputs.logits
predicted_ids = torch.argmax(logits, dim=-1)
# Calculate probabilities and confidence
probs = torch.softmax(logits, dim=-1)
max_probs = torch.max(probs, dim=-1)[0] # Get max probability per frame
frame_confidence = max_probs[0].cpu().numpy()
avg_confidence = float(torch.mean(max_probs).item())
# Decode transcript
transcript = ""
try:
if hasattr(self.processor, 'tokenizer'):
transcript = self.processor.tokenizer.decode(
predicted_ids[0],
skip_special_tokens=True
)
elif hasattr(self.processor, 'batch_decode'):
transcript = self.processor.batch_decode(predicted_ids)[0]
# Clean up transcript
if transcript:
transcript = transcript.strip()
transcript = transcript.replace('<pad>', '').replace('<s>', '').replace('</s>', '').replace('|', ' ').strip()
transcript = ' '.join(transcript.split())
except Exception as e:
logger.warning(f"⚠️ Decode error: {e}")
transcript = ""
return {
'transcript': transcript,
'logits': logits.cpu().numpy(),
'predicted_ids': predicted_ids.cpu().numpy(),
'probabilities': probs.cpu().numpy(),
'confidence': avg_confidence,
'frame_confidence': frame_confidence,
'num_frames': logits.shape[1]
}
except Exception as e:
logger.error(f"❌ Error getting transcription features: {e}")
raise
def get_word_level_features(
self,
audio: np.ndarray,
sample_rate: int = 16000
) -> Dict[str, Any]:
"""
Get word-level features including timestamps and confidence.
Args:
audio: Audio waveform as numpy array
sample_rate: Sample rate of the audio (default: 16000)
Returns:
Dictionary containing:
- words: List of words
- word_timestamps: List of (start, end) timestamps for each word
- word_confidence: Confidence score for each word
"""
try:
# Get transcription features
features = self.get_transcription_features(audio, sample_rate)
transcript = features['transcript']
frame_confidence = features['frame_confidence']
num_frames = features['num_frames']
# Estimate word-level timestamps (simplified)
words = transcript.split() if transcript else []
audio_duration = len(audio) / sample_rate
time_per_word = audio_duration / max(len(words), 1) if words else 0
word_timestamps = []
word_confidence = []
for i, word in enumerate(words):
start_time = i * time_per_word
end_time = (i + 1) * time_per_word
# Estimate confidence for this word (average of corresponding frames)
start_frame = int((start_time / audio_duration) * num_frames)
end_frame = int((end_time / audio_duration) * num_frames)
word_conf = float(np.mean(frame_confidence[start_frame:end_frame])) if end_frame > start_frame else 0.5
word_timestamps.append({
'word': word,
'start': start_time,
'end': end_time
})
word_confidence.append(word_conf)
return {
'words': words,
'word_timestamps': word_timestamps,
'word_confidence': word_confidence,
'transcript': transcript
}
except Exception as e:
logger.error(f"❌ Error getting word-level features: {e}")
raise