Fix: changes to app.py and detect_stuttering.py
Browse files- app.py +7 -3
- diagnosis/ai_engine/detect_stuttering.py +59 -7
app.py
CHANGED
|
@@ -3,7 +3,7 @@ import logging
|
|
| 3 |
import os
|
| 4 |
import sys
|
| 5 |
from pathlib import Path
|
| 6 |
-
from fastapi import FastAPI, UploadFile, File, HTTPException
|
| 7 |
from fastapi.responses import JSONResponse
|
| 8 |
from fastapi.middleware.cors import CORSMiddleware
|
| 9 |
|
|
@@ -70,7 +70,7 @@ async def health_check():
|
|
| 70 |
@app.post("/analyze")
|
| 71 |
async def analyze_audio(
|
| 72 |
audio: UploadFile = File(...),
|
| 73 |
-
transcript: str = ""
|
| 74 |
):
|
| 75 |
"""
|
| 76 |
Analyze audio file for stuttering
|
|
@@ -102,10 +102,14 @@ async def analyze_audio(
|
|
| 102 |
logger.info(f"π Saved to: {temp_file} ({len(content) / 1024 / 1024:.2f} MB)")
|
| 103 |
|
| 104 |
# Analyze
|
| 105 |
-
logger.info(f"π Analyzing audio with transcript: '{transcript[:50]}...'")
|
| 106 |
result = detector.analyze_audio(temp_file, transcript)
|
| 107 |
|
|
|
|
|
|
|
|
|
|
| 108 |
logger.info(f"β
Analysis complete: severity={result['severity']}, mismatch={result['mismatch_percentage']}%")
|
|
|
|
| 109 |
return result
|
| 110 |
|
| 111 |
except HTTPException:
|
|
|
|
| 3 |
import os
|
| 4 |
import sys
|
| 5 |
from pathlib import Path
|
| 6 |
+
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
|
| 7 |
from fastapi.responses import JSONResponse
|
| 8 |
from fastapi.middleware.cors import CORSMiddleware
|
| 9 |
|
|
|
|
| 70 |
@app.post("/analyze")
|
| 71 |
async def analyze_audio(
|
| 72 |
audio: UploadFile = File(...),
|
| 73 |
+
transcript: str = Form("")
|
| 74 |
):
|
| 75 |
"""
|
| 76 |
Analyze audio file for stuttering
|
|
|
|
| 102 |
logger.info(f"π Saved to: {temp_file} ({len(content) / 1024 / 1024:.2f} MB)")
|
| 103 |
|
| 104 |
# Analyze
|
| 105 |
+
logger.info(f"π Analyzing audio with transcript: '{transcript[:50] if transcript else '(empty)'}...'")
|
| 106 |
result = detector.analyze_audio(temp_file, transcript)
|
| 107 |
|
| 108 |
+
# Log transcript values from result
|
| 109 |
+
actual = result.get('actual_transcript', '')
|
| 110 |
+
target = result.get('target_transcript', '')
|
| 111 |
logger.info(f"β
Analysis complete: severity={result['severity']}, mismatch={result['mismatch_percentage']}%")
|
| 112 |
+
logger.info(f"π Result transcripts - Actual: '{actual[:100]}' (len: {len(actual)}), Target: '{target[:100]}' (len: {len(target)})")
|
| 113 |
return result
|
| 114 |
|
| 115 |
except HTTPException:
|
diagnosis/ai_engine/detect_stuttering.py
CHANGED
|
@@ -155,6 +155,14 @@ class AdvancedStutterDetector:
|
|
| 155 |
torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32
|
| 156 |
).to(DEVICE)
|
| 157 |
self.model.eval()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 158 |
self.loaded_adapters = set() # Keep for backward compatibility but not used with indicwav2vec
|
| 159 |
|
| 160 |
# Anomaly Detection Model (for outlier stutter events)
|
|
@@ -320,8 +328,45 @@ class AdvancedStutterDetector:
|
|
| 320 |
logits = outputs.logits
|
| 321 |
predicted_ids = torch.argmax(logits, dim=-1)
|
| 322 |
|
| 323 |
-
# Decode transcript
|
| 324 |
-
transcript =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 325 |
|
| 326 |
# Estimate word timestamps (simplified - frame-level alignment)
|
| 327 |
frame_duration = 0.02 # 20ms per frame
|
|
@@ -329,9 +374,9 @@ class AdvancedStutterDetector:
|
|
| 329 |
audio_duration = len(audio) / 16000
|
| 330 |
|
| 331 |
# Simple word-level timestamps (would need proper alignment for production)
|
| 332 |
-
words = transcript.split()
|
| 333 |
word_timestamps = []
|
| 334 |
-
time_per_word = audio_duration / max(len(words), 1)
|
| 335 |
|
| 336 |
for i, word in enumerate(words):
|
| 337 |
word_timestamps.append({
|
|
@@ -342,7 +387,7 @@ class AdvancedStutterDetector:
|
|
| 342 |
|
| 343 |
return transcript, word_timestamps, logits
|
| 344 |
except Exception as e:
|
| 345 |
-
logger.error(f"Transcription failed: {e}")
|
| 346 |
return "", [], torch.zeros((1, 100, 32)) # Dummy return
|
| 347 |
|
| 348 |
def _calculate_uncertainty(self, logits: torch.Tensor) -> Tuple[float, List[Dict]]:
|
|
@@ -686,6 +731,7 @@ class AdvancedStutterDetector:
|
|
| 686 |
|
| 687 |
# === STEP 4: Wav2Vec2 Transcription & Uncertainty ===
|
| 688 |
transcript, word_timestamps, logits = self._transcribe_with_timestamps(audio)
|
|
|
|
| 689 |
entropy_score, low_conf_regions = self._calculate_uncertainty(logits)
|
| 690 |
|
| 691 |
# === STEP 5: Speaking Rate Estimation ===
|
|
@@ -759,9 +805,15 @@ class AdvancedStutterDetector:
|
|
| 759 |
metrics['severity_score'] = max(metrics['severity_score'], 5.0)
|
| 760 |
|
| 761 |
# === STEP 9: Return Comprehensive Report ===
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 762 |
return {
|
| 763 |
-
'actual_transcript':
|
| 764 |
-
'target_transcript':
|
| 765 |
'mismatched_chars': [f"{r['time']}s" for r in low_conf_regions],
|
| 766 |
'mismatch_percentage': metrics['severity_score'],
|
| 767 |
'ctc_loss_score': round(entropy_score, 4),
|
|
|
|
| 155 |
torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32
|
| 156 |
).to(DEVICE)
|
| 157 |
self.model.eval()
|
| 158 |
+
|
| 159 |
+
# Debug: Log processor structure
|
| 160 |
+
logger.info(f"π Processor type: {type(self.processor)}")
|
| 161 |
+
logger.info(f"π Processor attributes: {[attr for attr in dir(self.processor) if not attr.startswith('_')]}")
|
| 162 |
+
if hasattr(self.processor, 'tokenizer'):
|
| 163 |
+
logger.info(f"π Tokenizer type: {type(self.processor.tokenizer)}")
|
| 164 |
+
if hasattr(self.processor, 'feature_extractor'):
|
| 165 |
+
logger.info(f"π Feature extractor type: {type(self.processor.feature_extractor)}")
|
| 166 |
self.loaded_adapters = set() # Keep for backward compatibility but not used with indicwav2vec
|
| 167 |
|
| 168 |
# Anomaly Detection Model (for outlier stutter events)
|
|
|
|
| 328 |
logits = outputs.logits
|
| 329 |
predicted_ids = torch.argmax(logits, dim=-1)
|
| 330 |
|
| 331 |
+
# Decode transcript - IndicWav2Vec uses tokenizer for decoding
|
| 332 |
+
transcript = ""
|
| 333 |
+
try:
|
| 334 |
+
# Method 1: Try using processor's tokenizer directly
|
| 335 |
+
if hasattr(self.processor, 'tokenizer'):
|
| 336 |
+
transcript = self.processor.tokenizer.decode(predicted_ids[0], skip_special_tokens=True)
|
| 337 |
+
logger.info(f"π Decoded via tokenizer: '{transcript}' (length: {len(transcript)})")
|
| 338 |
+
# Method 2: Try batch_decode if tokenizer not available
|
| 339 |
+
elif hasattr(self.processor, 'batch_decode'):
|
| 340 |
+
transcript = self.processor.batch_decode(predicted_ids)[0]
|
| 341 |
+
logger.info(f"π Decoded via batch_decode: '{transcript}' (length: {len(transcript)})")
|
| 342 |
+
# Method 3: Try accessing tokenizer through processor.feature_extractor or processor attributes
|
| 343 |
+
else:
|
| 344 |
+
# Check if processor wraps a tokenizer
|
| 345 |
+
for attr in ['tokenizer', '_tokenizer', 'decoder']:
|
| 346 |
+
if hasattr(self.processor, attr):
|
| 347 |
+
tokenizer = getattr(self.processor, attr)
|
| 348 |
+
if hasattr(tokenizer, 'decode'):
|
| 349 |
+
transcript = tokenizer.decode(predicted_ids[0], skip_special_tokens=True)
|
| 350 |
+
logger.info(f"π Decoded via {attr}: '{transcript}' (length: {len(transcript)})")
|
| 351 |
+
break
|
| 352 |
+
|
| 353 |
+
# Clean up transcript - remove special tokens and normalize
|
| 354 |
+
if transcript:
|
| 355 |
+
transcript = transcript.strip()
|
| 356 |
+
# Remove common special tokens if present
|
| 357 |
+
transcript = transcript.replace('<pad>', '').replace('<s>', '').replace('</s>', '').replace('|', ' ').strip()
|
| 358 |
+
# Normalize whitespace
|
| 359 |
+
transcript = ' '.join(transcript.split())
|
| 360 |
+
|
| 361 |
+
except Exception as decode_error:
|
| 362 |
+
logger.error(f"β οΈ Decode error: {decode_error}", exc_info=True)
|
| 363 |
+
transcript = ""
|
| 364 |
+
|
| 365 |
+
# Ensure transcript is not None
|
| 366 |
+
if not transcript:
|
| 367 |
+
transcript = ""
|
| 368 |
+
logger.warning("β οΈ Empty transcript generated - model may not have produced valid output")
|
| 369 |
+
logger.warning(f"β οΈ Predicted IDs shape: {predicted_ids.shape}, sample values: {predicted_ids[0][:10].tolist() if predicted_ids.numel() > 0 else 'empty'}")
|
| 370 |
|
| 371 |
# Estimate word timestamps (simplified - frame-level alignment)
|
| 372 |
frame_duration = 0.02 # 20ms per frame
|
|
|
|
| 374 |
audio_duration = len(audio) / 16000
|
| 375 |
|
| 376 |
# Simple word-level timestamps (would need proper alignment for production)
|
| 377 |
+
words = transcript.split() if transcript else []
|
| 378 |
word_timestamps = []
|
| 379 |
+
time_per_word = audio_duration / max(len(words), 1) if words else 0
|
| 380 |
|
| 381 |
for i, word in enumerate(words):
|
| 382 |
word_timestamps.append({
|
|
|
|
| 387 |
|
| 388 |
return transcript, word_timestamps, logits
|
| 389 |
except Exception as e:
|
| 390 |
+
logger.error(f"β Transcription failed: {e}", exc_info=True)
|
| 391 |
return "", [], torch.zeros((1, 100, 32)) # Dummy return
|
| 392 |
|
| 393 |
def _calculate_uncertainty(self, logits: torch.Tensor) -> Tuple[float, List[Dict]]:
|
|
|
|
| 731 |
|
| 732 |
# === STEP 4: Wav2Vec2 Transcription & Uncertainty ===
|
| 733 |
transcript, word_timestamps, logits = self._transcribe_with_timestamps(audio)
|
| 734 |
+
logger.info(f"π Main transcription result: '{transcript}' (length: {len(transcript)}, words: {len(word_timestamps)})")
|
| 735 |
entropy_score, low_conf_regions = self._calculate_uncertainty(logits)
|
| 736 |
|
| 737 |
# === STEP 5: Speaking Rate Estimation ===
|
|
|
|
| 805 |
metrics['severity_score'] = max(metrics['severity_score'], 5.0)
|
| 806 |
|
| 807 |
# === STEP 9: Return Comprehensive Report ===
|
| 808 |
+
# Ensure transcripts are not None
|
| 809 |
+
actual_transcript = transcript if transcript else ""
|
| 810 |
+
target_transcript = proper_transcript if proper_transcript else transcript if transcript else ""
|
| 811 |
+
|
| 812 |
+
logger.info(f"π Final return - Actual: '{actual_transcript}' (len: {len(actual_transcript)}), Target: '{target_transcript}' (len: {len(target_transcript)})")
|
| 813 |
+
|
| 814 |
return {
|
| 815 |
+
'actual_transcript': actual_transcript,
|
| 816 |
+
'target_transcript': target_transcript,
|
| 817 |
'mismatched_chars': [f"{r['time']}s" for r in low_conf_regions],
|
| 818 |
'mismatch_percentage': metrics['severity_score'],
|
| 819 |
'ctc_loss_score': round(entropy_score, 4),
|