# diagnosis/ai_engine/detect_stuttering.py import os import librosa import torch import logging import numpy as np from transformers import Wav2Vec2ForCTC, AutoProcessor import time from dataclasses import dataclass, field from typing import List, Dict, Any, Tuple # Simplified: Only using ASR transcription, removed complex signal processing libraries logger = logging.getLogger(__name__) # === CONFIGURATION === MODEL_ID = "ai4bharat/indicwav2vec-hindi" # Only model used - IndicWav2Vec Hindi for ASR DEVICE = "cuda" if torch.cuda.is_available() else "cpu" HF_TOKEN = os.getenv("HF_TOKEN") # Hugging Face token for authenticated model access INDIAN_LANGUAGES = { 'hindi': 'hin', 'english': 'eng', 'tamil': 'tam', 'telugu': 'tel', 'bengali': 'ben', 'marathi': 'mar', 'gujarati': 'guj', 'kannada': 'kan', 'malayalam': 'mal', 'punjabi': 'pan', 'urdu': 'urd', 'assamese': 'asm', 'odia': 'ory', 'bhojpuri': 'bho', 'maithili': 'mai' } # === RESEARCH-BASED THRESHOLDS (2024-2025 Literature) === # Prolongation Detection (Spectral Correlation + Duration) PROLONGATION_CORRELATION_THRESHOLD = 0.90 # >0.9 spectral similarity PROLONGATION_MIN_DURATION = 0.25 # >250ms (Revisiting Rule-Based, 2025) # Block Detection (Silence Analysis) BLOCK_SILENCE_THRESHOLD = 0.35 # >350ms silence mid-utterance BLOCK_ENERGY_PERCENTILE = 10 # Bottom 10% energy = silence # Repetition Detection (DTW + Text Matching) REPETITION_DTW_THRESHOLD = 0.15 # Normalized DTW distance REPETITION_MIN_SIMILARITY = 0.85 # Text-based similarity # Speaking Rate Norms (syllables/second) SPEECH_RATE_MIN = 2.0 SPEECH_RATE_MAX = 6.0 SPEECH_RATE_TYPICAL = 4.0 # Formant Analysis (Vowel Centralization - Research Finding) # People who stutter show reduced vowel space area VOWEL_SPACE_REDUCTION_THRESHOLD = 0.70 # 70% of typical area # Voice Quality (Jitter, Shimmer, HNR) JITTER_THRESHOLD = 0.01 # >1% jitter indicates instability SHIMMER_THRESHOLD = 0.03 # >3% shimmer HNR_THRESHOLD = 15.0 # <15 dB Harmonics-to-Noise Ratio # Zero-Crossing Rate (Voiced/Unvoiced Discrimination) ZCR_VOICED_THRESHOLD = 0.1 # Low ZCR = voiced ZCR_UNVOICED_THRESHOLD = 0.3 # High ZCR = unvoiced # Entropy-Based Uncertainty ENTROPY_HIGH_THRESHOLD = 3.5 # High confusion in model predictions CONFIDENCE_LOW_THRESHOLD = 0.40 # Low confidence frame threshold @dataclass class StutterEvent: """Enhanced stutter event with multi-modal features""" type: str # 'repetition', 'prolongation', 'block', 'dysfluency' start: float end: float text: str confidence: float acoustic_features: Dict[str, float] = field(default_factory=dict) voice_quality: Dict[str, float] = field(default_factory=dict) formant_data: Dict[str, Any] = field(default_factory=dict) class AdvancedStutterDetector: """ 🎤 IndicWav2Vec Hindi ASR Engine Simplified engine using ONLY ai4bharat/indicwav2vec-hindi for Automatic Speech Recognition. Features: - Speech-to-text transcription using IndicWav2Vec Hindi model - Text-based stutter analysis from transcription - Confidence scoring from model predictions - Basic dysfluency detection from transcript patterns Model: ai4bharat/indicwav2vec-hindi (Wav2Vec2ForCTC) Purpose: Automatic Speech Recognition (ASR) for Hindi and Indian languages """ def __init__(self): logger.info(f"🚀 Initializing Advanced AI Engine on {DEVICE}...") if HF_TOKEN: logger.info("✅ HF_TOKEN found - using authenticated model access") else: logger.warning("⚠️ HF_TOKEN not found - model access may fail if authentication is required") try: # Wav2Vec2 Model Loading - IndicWav2Vec Hindi Model self.processor = AutoProcessor.from_pretrained( MODEL_ID, token=HF_TOKEN ) self.model = Wav2Vec2ForCTC.from_pretrained( MODEL_ID, token=HF_TOKEN, torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32 ).to(DEVICE) self.model.eval() # Initialize feature extractor (clean architecture pattern) from .features import ASRFeatureExtractor self.feature_extractor = ASRFeatureExtractor( model=self.model, processor=self.processor, device=DEVICE ) # Debug: Log processor structure logger.info(f"📋 Processor type: {type(self.processor)}") if hasattr(self.processor, 'tokenizer'): logger.info(f"📋 Tokenizer type: {type(self.processor.tokenizer)}") if hasattr(self.processor, 'feature_extractor'): logger.info(f"📋 Feature extractor type: {type(self.processor.feature_extractor)}") logger.info("✅ IndicWav2Vec Hindi ASR Engine Loaded with Feature Extractor") except Exception as e: logger.error(f"🔥 Engine Failure: {e}") raise def _init_common_adapters(self): """Not applicable - IndicWav2Vec Hindi doesn't use adapters""" pass def _activate_adapter(self, lang_code: str): """Not applicable - IndicWav2Vec Hindi doesn't use adapters""" logger.info(f"Using IndicWav2Vec Hindi model (optimized for Hindi)") pass # ===== LEGACY METHODS (NOT USED IN ASR-ONLY MODE) ===== # These methods are kept for reference but not called in the simplified ASR pipeline # They require additional libraries (parselmouth, fastdtw, sklearn) that are not needed for ASR-only mode def _extract_comprehensive_features(self, audio: np.ndarray, sr: int, audio_path: str) -> Dict[str, Any]: """Extract multi-modal acoustic features""" features = {} # MFCC (20 coefficients) mfcc = librosa.feature.mfcc(y=audio, sr=sr, n_mfcc=20, hop_length=512) features['mfcc'] = mfcc.T # Transpose for time x features # Zero-Crossing Rate zcr = librosa.feature.zero_crossing_rate(audio, hop_length=512)[0] features['zcr'] = zcr # RMS Energy rms_energy = librosa.feature.rms(y=audio, hop_length=512)[0] features['rms_energy'] = rms_energy # Spectral Flux stft = librosa.stft(audio, hop_length=512) magnitude = np.abs(stft) spectral_flux = np.sum(np.diff(magnitude, axis=1) * (np.diff(magnitude, axis=1) > 0), axis=0) features['spectral_flux'] = spectral_flux # Energy Entropy frame_energy = np.sum(magnitude ** 2, axis=0) frame_energy = frame_energy + 1e-10 # Avoid log(0) energy_entropy = -np.sum((magnitude ** 2 / frame_energy) * np.log(magnitude ** 2 / frame_energy + 1e-10), axis=0) features['energy_entropy'] = energy_entropy # Formant Analysis using Parselmouth try: sound = parselmouth.Sound(audio_path) formant = sound.to_formant_burg(time_step=0.01) times = np.arange(0, sound.duration, 0.01) f1, f2, f3, f4 = [], [], [], [] for t in times: try: f1.append(formant.get_value_at_time(1, t) if formant.get_value_at_time(1, t) > 0 else np.nan) f2.append(formant.get_value_at_time(2, t) if formant.get_value_at_time(2, t) > 0 else np.nan) f3.append(formant.get_value_at_time(3, t) if formant.get_value_at_time(3, t) > 0 else np.nan) f4.append(formant.get_value_at_time(4, t) if formant.get_value_at_time(4, t) > 0 else np.nan) except: f1.append(np.nan) f2.append(np.nan) f3.append(np.nan) f4.append(np.nan) formants = np.array([f1, f2, f3, f4]).T features['formants'] = formants # Calculate vowel space area (F1-F2 plane) valid_f1f2 = formants[~np.isnan(formants[:, 0]) & ~np.isnan(formants[:, 1]), :2] if len(valid_f1f2) > 0: # Convex hull area approximation try: hull = ConvexHull(valid_f1f2) vowel_space_area = hull.volume except: vowel_space_area = np.nan else: vowel_space_area = np.nan features['formant_summary'] = { 'vowel_space_area': float(vowel_space_area) if not np.isnan(vowel_space_area) else 0.0, 'f1_mean': float(np.nanmean(f1)) if len(f1) > 0 else 0.0, 'f2_mean': float(np.nanmean(f2)) if len(f2) > 0 else 0.0, 'f1_std': float(np.nanstd(f1)) if len(f1) > 0 else 0.0, 'f2_std': float(np.nanstd(f2)) if len(f2) > 0 else 0.0 } except Exception as e: logger.warning(f"Formant analysis failed: {e}") features['formants'] = np.zeros((len(audio) // 100, 4)) features['formant_summary'] = { 'vowel_space_area': 0.0, 'f1_mean': 0.0, 'f2_mean': 0.0, 'f1_std': 0.0, 'f2_std': 0.0 } # Voice Quality Metrics (Jitter, Shimmer, HNR) try: sound = parselmouth.Sound(audio_path) pitch = sound.to_pitch() point_process = parselmouth.praat.call([sound, pitch], "To PointProcess") jitter = parselmouth.praat.call(point_process, "Get jitter (local)", 0.0, 0.0, 1.1, 1.6, 1.3, 1.6) shimmer = parselmouth.praat.call([sound, point_process], "Get shimmer (local)", 0.0, 0.0, 0.0001, 0.02, 1.3, 1.6) hnr = parselmouth.praat.call(sound, "Get harmonicity (cc)", 0.0, 0.0, 0.01, 1.5, 1.0, 0.1, 1.0) features['voice_quality'] = { 'jitter': float(jitter) if jitter is not None else 0.0, 'shimmer': float(shimmer) if shimmer is not None else 0.0, 'hnr_db': float(hnr) if hnr is not None else 20.0 } except Exception as e: logger.warning(f"Voice quality analysis failed: {e}") features['voice_quality'] = { 'jitter': 0.0, 'shimmer': 0.0, 'hnr_db': 20.0 } return features def _transcribe_with_timestamps(self, audio: np.ndarray) -> Tuple[str, List[Dict], torch.Tensor]: """ Transcribe audio and return word timestamps and logits. Uses the feature extractor for clean separation of concerns. """ try: # Use feature extractor for transcription (clean architecture) features = self.feature_extractor.get_transcription_features(audio, sample_rate=16000) transcript = features['transcript'] logits = torch.from_numpy(features['logits']) # Get word-level features for timestamps word_features = self.feature_extractor.get_word_level_features(audio, sample_rate=16000) word_timestamps = word_features['word_timestamps'] logger.info(f"📝 Transcription via feature extractor: '{transcript}' (length: {len(transcript)}, words: {len(word_timestamps)})") return transcript, word_timestamps, logits except Exception as e: logger.error(f"❌ Transcription failed: {e}", exc_info=True) return "", [], torch.zeros((1, 100, 32)) # Dummy return def _calculate_uncertainty(self, logits: torch.Tensor) -> Tuple[float, List[Dict]]: """Calculate entropy-based uncertainty and low-confidence regions""" try: probs = torch.softmax(logits, dim=-1) entropy = -torch.sum(probs * torch.log(probs + 1e-10), dim=-1) entropy_mean = float(torch.mean(entropy).item()) # Find low-confidence regions frame_duration = 0.02 low_conf_regions = [] confidence = torch.max(probs, dim=-1)[0] for i in range(confidence.shape[1]): conf = float(confidence[0, i].item()) if conf < CONFIDENCE_LOW_THRESHOLD: low_conf_regions.append({ 'time': i * frame_duration, 'confidence': conf }) return entropy_mean, low_conf_regions except Exception as e: logger.warning(f"Uncertainty calculation failed: {e}") return 0.0, [] def _estimate_speaking_rate(self, audio: np.ndarray, sr: int) -> float: """Estimate speaking rate in syllables per second""" try: # Simple syllable estimation using energy peaks rms = librosa.feature.rms(y=audio, hop_length=512)[0] peaks, _ = librosa.util.peak_pick(rms, pre_max=3, post_max=3, pre_avg=3, post_avg=5, delta=0.1, wait=10) duration = len(audio) / sr num_syllables = len(peaks) speaking_rate = num_syllables / duration if duration > 0 else SPEECH_RATE_TYPICAL return max(SPEECH_RATE_MIN, min(SPEECH_RATE_MAX, speaking_rate)) except Exception as e: logger.warning(f"Speaking rate estimation failed: {e}") return SPEECH_RATE_TYPICAL def _detect_prolongations_advanced(self, mfcc: np.ndarray, spectral_flux: np.ndarray, speaking_rate: float, word_timestamps: List[Dict]) -> List[StutterEvent]: """Detect prolongations using spectral correlation""" events = [] frame_duration = 0.02 # Adaptive threshold based on speaking rate min_duration = PROLONGATION_MIN_DURATION * (SPEECH_RATE_TYPICAL / max(speaking_rate, 0.1)) window_size = int(min_duration / frame_duration) if window_size < 2: return events for i in range(len(mfcc) - window_size): window = mfcc[i:i+window_size] # Calculate spectral correlation if len(window) > 1: corr_matrix = np.corrcoef(window.T) avg_correlation = np.mean(corr_matrix[np.triu_indices_from(corr_matrix, k=1)]) if avg_correlation > PROLONGATION_CORRELATION_THRESHOLD: start_time = i * frame_duration end_time = (i + window_size) * frame_duration # Check if within a word boundary for word_ts in word_timestamps: if word_ts['start'] <= start_time <= word_ts['end']: events.append(StutterEvent( type='prolongation', start=start_time, end=end_time, text=word_ts.get('word', ''), confidence=float(avg_correlation), acoustic_features={ 'spectral_correlation': float(avg_correlation), 'duration': end_time - start_time } )) break return events def _detect_blocks_enhanced(self, audio: np.ndarray, sr: int, rms_energy: np.ndarray, zcr: np.ndarray, word_timestamps: List[Dict], speaking_rate: float) -> List[StutterEvent]: """Detect blocks using silence analysis""" events = [] frame_duration = 0.02 # Adaptive threshold silence_threshold = BLOCK_SILENCE_THRESHOLD * (SPEECH_RATE_TYPICAL / max(speaking_rate, 0.1)) energy_threshold = np.percentile(rms_energy, BLOCK_ENERGY_PERCENTILE) in_silence = False silence_start = 0 for i, energy in enumerate(rms_energy): is_silent = energy < energy_threshold and zcr[i] < ZCR_VOICED_THRESHOLD if is_silent and not in_silence: silence_start = i * frame_duration in_silence = True elif not is_silent and in_silence: silence_duration = (i * frame_duration) - silence_start if silence_duration > silence_threshold: # Check if mid-utterance (not at start/end) audio_duration = len(audio) / sr if silence_start > 0.1 and silence_start < audio_duration - 0.1: events.append(StutterEvent( type='block', start=silence_start, end=i * frame_duration, text="", confidence=0.8, acoustic_features={ 'silence_duration': silence_duration, 'energy_level': float(energy) } )) in_silence = False return events def _detect_repetitions_advanced(self, mfcc: np.ndarray, formants: np.ndarray, word_timestamps: List[Dict], transcript: str, speaking_rate: float) -> List[StutterEvent]: """Detect repetitions using DTW and text matching""" events = [] if len(word_timestamps) < 2: return events # Text-based repetition detection words = transcript.lower().split() for i in range(len(words) - 1): if words[i] == words[i+1]: # Find corresponding timestamps if i < len(word_timestamps) and i+1 < len(word_timestamps): start = word_timestamps[i]['start'] end = word_timestamps[i+1]['end'] # DTW verification on MFCC start_frame = int(start / 0.02) mid_frame = int((start + end) / 2 / 0.02) end_frame = int(end / 0.02) if start_frame < len(mfcc) and end_frame < len(mfcc): segment1 = mfcc[start_frame:mid_frame] segment2 = mfcc[mid_frame:end_frame] if len(segment1) > 0 and len(segment2) > 0: try: distance, _ = fastdtw(segment1, segment2) normalized_distance = distance / max(len(segment1), len(segment2)) if normalized_distance < REPETITION_DTW_THRESHOLD: events.append(StutterEvent( type='repetition', start=start, end=end, text=words[i], confidence=1.0 - normalized_distance, acoustic_features={ 'dtw_distance': float(normalized_distance), 'repetition_count': 2 } )) except: pass return events def _detect_voice_quality_issues(self, audio_path: str, word_timestamps: List[Dict], voice_quality: Dict[str, float]) -> List[StutterEvent]: """Detect dysfluencies based on voice quality metrics""" events = [] # Global voice quality issues if voice_quality.get('jitter', 0) > JITTER_THRESHOLD or \ voice_quality.get('shimmer', 0) > SHIMMER_THRESHOLD or \ voice_quality.get('hnr_db', 20) < HNR_THRESHOLD: # Mark regions with poor voice quality for word_ts in word_timestamps: if word_ts.get('start', 0) > 0: # Skip first word events.append(StutterEvent( type='dysfluency', start=word_ts['start'], end=word_ts['end'], text=word_ts.get('word', ''), confidence=0.6, voice_quality=voice_quality.copy() )) break # Only mark first occurrence return events def _is_overlapping(self, time: float, events: List[StutterEvent], threshold: float = 0.1) -> bool: """Check if time overlaps with existing events""" for event in events: if event.start - threshold <= time <= event.end + threshold: return True return False def _detect_anomalies(self, events: List[StutterEvent], features: Dict[str, Any]) -> List[StutterEvent]: """Use Isolation Forest to filter anomalous events""" if len(events) == 0: return events try: # Extract features for anomaly detection X = [] for event in events: feat_vec = [ event.end - event.start, # Duration event.confidence, features.get('voice_quality', {}).get('jitter', 0), features.get('voice_quality', {}).get('shimmer', 0) ] X.append(feat_vec) X = np.array(X) if len(X) > 1: self.anomaly_detector.fit(X) predictions = self.anomaly_detector.predict(X) # Keep only non-anomalous events (predictions == 1) filtered_events = [events[i] for i, pred in enumerate(predictions) if pred == 1] return filtered_events except Exception as e: logger.warning(f"Anomaly detection failed: {e}") return events def _deduplicate_events_cascade(self, events: List[StutterEvent]) -> List[StutterEvent]: """Remove overlapping events with priority: Block > Repetition > Prolongation > Dysfluency""" if len(events) == 0: return events # Sort by priority and start time priority = {'block': 4, 'repetition': 3, 'prolongation': 2, 'dysfluency': 1} events.sort(key=lambda e: (priority.get(e.type, 0), e.start), reverse=True) cleaned = [] for event in events: overlap = False for existing in cleaned: # Check overlap if not (event.end < existing.start or event.start > existing.end): overlap = True break if not overlap: cleaned.append(event) # Sort by start time cleaned.sort(key=lambda e: e.start) return cleaned def _calculate_clinical_metrics(self, events: List[StutterEvent], duration: float, speaking_rate: float, features: Dict[str, Any]) -> Dict[str, Any]: """Calculate comprehensive clinical metrics""" total_duration = sum(e.end - e.start for e in events) frequency = (len(events) / duration * 60) if duration > 0 else 0 # Calculate severity score (0-100) stutter_percentage = (total_duration / duration * 100) if duration > 0 else 0 frequency_score = min(frequency / 10 * 100, 100) # Normalize to 100 severity_score = (stutter_percentage * 0.6 + frequency_score * 0.4) # Determine severity label if severity_score < 10: severity_label = 'none' elif severity_score < 25: severity_label = 'mild' elif severity_score < 50: severity_label = 'moderate' else: severity_label = 'severe' # Calculate confidence based on multiple factors voice_quality = features.get('voice_quality', {}) confidence = 0.8 # Base confidence # Adjust based on voice quality metrics if voice_quality.get('jitter', 0) > JITTER_THRESHOLD: confidence -= 0.1 if voice_quality.get('shimmer', 0) > SHIMMER_THRESHOLD: confidence -= 0.1 if voice_quality.get('hnr_db', 20) < HNR_THRESHOLD: confidence -= 0.1 confidence = max(0.3, min(1.0, confidence)) return { 'total_duration': round(total_duration, 2), 'frequency': round(frequency, 2), 'severity_score': round(severity_score, 2), 'severity_label': severity_label, 'confidence': round(confidence, 2) } def _event_to_dict(self, event: StutterEvent) -> Dict[str, Any]: """Convert StutterEvent to dictionary""" return { 'type': event.type, 'start': round(event.start, 2), 'end': round(event.end, 2), 'text': event.text, 'confidence': round(event.confidence, 2), 'acoustic_features': event.acoustic_features, 'voice_quality': event.voice_quality, 'formant_data': event.formant_data } def analyze_audio(self, audio_path: str, proper_transcript: str = "", language: str = 'hindi') -> dict: """ Main ASR analysis pipeline using IndicWav2Vec Hindi model Focus: Automatic Speech Recognition (ASR) transcription only """ start_time = time.time() # === STEP 1: Audio Loading & Preprocessing === audio, sr = librosa.load(audio_path, sr=16000) duration = librosa.get_duration(y=audio, sr=sr) # === STEP 2: ASR Transcription using IndicWav2Vec Hindi === transcript, word_timestamps, logits = self._transcribe_with_timestamps(audio) logger.info(f"📝 ASR Transcription: '{transcript}' (length: {len(transcript)}, words: {len(word_timestamps)})") # === STEP 3: Calculate Confidence from Model Predictions === entropy_score, low_conf_regions = self._calculate_uncertainty(logits) avg_confidence = 1.0 - (entropy_score / 10.0) if entropy_score > 0 else 0.8 avg_confidence = max(0.0, min(1.0, avg_confidence)) # === STEP 4: Basic Text-based Analysis === # Simple text-based stutter detection (repetitions, hesitations) events = [] if transcript: words = transcript.split() # Detect word repetitions for i in range(len(words) - 1): if words[i] == words[i+1] and i < len(word_timestamps) - 1: events.append(StutterEvent( type='repetition', start=word_timestamps[i]['start'] if i < len(word_timestamps) else 0, end=word_timestamps[i+1]['end'] if i+1 < len(word_timestamps) else 0, text=words[i], confidence=0.7 )) # Add low confidence regions as potential dysfluencies for region in low_conf_regions[:5]: # Limit to first 5 events.append(StutterEvent( type='dysfluency', start=region['time'], end=region['time'] + 0.3, text="", confidence=0.4, acoustic_features={'entropy': entropy_score} )) # === STEP 5: Calculate Basic Metrics === total_duration = sum(e.end - e.start for e in events) frequency = (len(events) / duration * 60) if duration > 0 else 0 stutter_percentage = (total_duration / duration * 100) if duration > 0 else 0 # Simple severity assessment if stutter_percentage < 5: severity = 'none' elif stutter_percentage < 15: severity = 'mild' elif stutter_percentage < 30: severity = 'moderate' else: severity = 'severe' # === STEP 6: Return ASR Results === actual_transcript = transcript if transcript else "" target_transcript = proper_transcript if proper_transcript else "" logger.info(f"📝 Final ASR result - Actual: '{actual_transcript}' (len: {len(actual_transcript)}), Target: '{target_transcript}' (len: {len(target_transcript)})") return { 'actual_transcript': actual_transcript, 'target_transcript': target_transcript, 'mismatched_chars': [f"{r['time']:.2f}s" for r in low_conf_regions[:10]], 'mismatch_percentage': round(stutter_percentage, 2), 'ctc_loss_score': round(entropy_score, 4), 'stutter_timestamps': [self._event_to_dict(e) for e in events], 'total_stutter_duration': round(total_duration, 2), 'stutter_frequency': round(frequency, 2), 'severity': severity, 'confidence_score': round(avg_confidence, 2), 'speaking_rate_sps': round(len(word_timestamps) / duration if duration > 0 else 0, 2), 'analysis_duration_seconds': round(time.time() - start_time, 2), 'model_version': 'indicwav2vec-hindi-asr-v1' } # Legacy methods - kept for backward compatibility but may not work without additional model initialization # These methods reference models (xlsr, base, large) that are not initialized in __init__ # The main analyze_audio() method uses the IndicWav2Vec Hindi model instead def generate_target_transcript(self, audio_file: str) -> str: """Generate expected transcript - Legacy method (uses IndicWav2Vec Hindi model)""" try: audio, sr = librosa.load(audio_file, sr=16000) transcript, _, _ = self._transcribe_with_timestamps(audio) return transcript except Exception as e: logger.error(f"Target transcript generation failed: {e}") return "" def transcribe_and_detect(self, audio_file: str, proper_transcript: str) -> Dict: """Transcribe audio and detect stuttering patterns - Legacy method""" try: audio, _ = librosa.load(audio_file, sr=16000) transcript, _, _ = self._transcribe_with_timestamps(audio) # Find stuttered sequences stuttered_chars = self.find_sequences_not_in_common(transcript, proper_transcript) # Calculate mismatch percentage total_mismatched = sum(len(segment) for segment in stuttered_chars) mismatch_percentage = (total_mismatched / len(proper_transcript)) * 100 if len(proper_transcript) > 0 else 0 mismatch_percentage = min(round(mismatch_percentage), 100) return { 'transcription': transcript, 'stuttered_chars': stuttered_chars, 'mismatch_percentage': mismatch_percentage } except Exception as e: logger.error(f"Transcription failed: {e}") return { 'transcription': '', 'stuttered_chars': [], 'mismatch_percentage': 0 } def calculate_stutter_timestamps(self, audio_file: str, proper_transcript: str) -> Tuple[float, List[Tuple[float, float]]]: """Calculate stutter timestamps - Legacy method (uses analyze_audio instead)""" try: # Use main analyze_audio method result = self.analyze_audio(audio_file, proper_transcript) # Extract timestamps from result timestamps = [] for event in result.get('stutter_timestamps', []): timestamps.append((event['start'], event['end'])) ctc_score = result.get('ctc_loss_score', 0.0) return float(ctc_score), timestamps except Exception as e: logger.error(f"Timestamp calculation failed: {e}") return 0.0, [] def find_max_common_characters(self, transcription1: str, transcript2: str) -> str: """Longest Common Subsequence algorithm""" m, n = len(transcription1), len(transcript2) lcs_matrix = [[0] * (n + 1) for _ in range(m + 1)] for i in range(1, m + 1): for j in range(1, n + 1): if transcription1[i - 1] == transcript2[j - 1]: lcs_matrix[i][j] = lcs_matrix[i - 1][j - 1] + 1 else: lcs_matrix[i][j] = max(lcs_matrix[i - 1][j], lcs_matrix[i][j - 1]) # Backtrack to find LCS lcs_characters = [] i, j = m, n while i > 0 and j > 0: if transcription1[i - 1] == transcript2[j - 1]: lcs_characters.append(transcription1[i - 1]) i -= 1 j -= 1 elif lcs_matrix[i - 1][j] > lcs_matrix[i][j - 1]: i -= 1 else: j -= 1 lcs_characters.reverse() return ''.join(lcs_characters) def find_sequences_not_in_common(self, transcription1: str, proper_transcript: str) -> List[str]: """Find stuttered character sequences""" common_characters = self.find_max_common_characters(transcription1, proper_transcript) sequences = [] sequence = "" i, j = 0, 0 while i < len(transcription1) and j < len(common_characters): if transcription1[i] == common_characters[j]: if sequence: sequences.append(sequence) sequence = "" i += 1 j += 1 else: sequence += transcription1[i] i += 1 if sequence: sequences.append(sequence) return sequences def _calculate_total_duration(self, timestamps: List[Tuple[float, float]]) -> float: """Calculate total stuttering duration""" return sum(end - start for start, end in timestamps) def _calculate_frequency(self, timestamps: List[Tuple[float, float]], audio_file: str) -> float: """Calculate stutters per minute""" try: audio_duration = librosa.get_duration(path=audio_file) if audio_duration > 0: return (len(timestamps) / audio_duration) * 60 return 0.0 except: return 0.0 def _determine_severity(self, mismatch_percentage: float) -> str: """Determine severity level""" if mismatch_percentage < 10: return 'none' elif mismatch_percentage < 25: return 'mild' elif mismatch_percentage < 50: return 'moderate' else: return 'severe' def _calculate_confidence(self, transcription_result: Dict, ctc_loss: float) -> float: """Calculate confidence score for the analysis""" # Lower mismatch and lower CTC loss = higher confidence mismatch_factor = 1 - (transcription_result['mismatch_percentage'] / 100) loss_factor = max(0, 1 - (ctc_loss / 10)) # Normalize loss confidence = (mismatch_factor + loss_factor) / 2 return round(min(max(confidence, 0.0), 1.0), 2) # Model loader is now in a separate module: model_loader.py # This follows clean architecture principles - separation of concerns # Import using: from diagnosis.ai_engine.model_loader import get_stutter_detector