anfastech commited on
Commit
439ae4d
Β·
1 Parent(s): 220224d

Fix: changes to app.py and detect_stuttering.py

Browse files
Files changed (2) hide show
  1. app.py +7 -3
  2. diagnosis/ai_engine/detect_stuttering.py +59 -7
app.py CHANGED
@@ -3,7 +3,7 @@ import logging
3
  import os
4
  import sys
5
  from pathlib import Path
6
- from fastapi import FastAPI, UploadFile, File, HTTPException
7
  from fastapi.responses import JSONResponse
8
  from fastapi.middleware.cors import CORSMiddleware
9
 
@@ -70,7 +70,7 @@ async def health_check():
70
  @app.post("/analyze")
71
  async def analyze_audio(
72
  audio: UploadFile = File(...),
73
- transcript: str = ""
74
  ):
75
  """
76
  Analyze audio file for stuttering
@@ -102,10 +102,14 @@ async def analyze_audio(
102
  logger.info(f"πŸ“‚ Saved to: {temp_file} ({len(content) / 1024 / 1024:.2f} MB)")
103
 
104
  # Analyze
105
- logger.info(f"πŸ”„ Analyzing audio with transcript: '{transcript[:50]}...'")
106
  result = detector.analyze_audio(temp_file, transcript)
107
 
 
 
 
108
  logger.info(f"βœ… Analysis complete: severity={result['severity']}, mismatch={result['mismatch_percentage']}%")
 
109
  return result
110
 
111
  except HTTPException:
 
3
  import os
4
  import sys
5
  from pathlib import Path
6
+ from fastapi import FastAPI, UploadFile, File, Form, HTTPException
7
  from fastapi.responses import JSONResponse
8
  from fastapi.middleware.cors import CORSMiddleware
9
 
 
70
  @app.post("/analyze")
71
  async def analyze_audio(
72
  audio: UploadFile = File(...),
73
+ transcript: str = Form("")
74
  ):
75
  """
76
  Analyze audio file for stuttering
 
102
  logger.info(f"πŸ“‚ Saved to: {temp_file} ({len(content) / 1024 / 1024:.2f} MB)")
103
 
104
  # Analyze
105
+ logger.info(f"πŸ”„ Analyzing audio with transcript: '{transcript[:50] if transcript else '(empty)'}...'")
106
  result = detector.analyze_audio(temp_file, transcript)
107
 
108
+ # Log transcript values from result
109
+ actual = result.get('actual_transcript', '')
110
+ target = result.get('target_transcript', '')
111
  logger.info(f"βœ… Analysis complete: severity={result['severity']}, mismatch={result['mismatch_percentage']}%")
112
+ logger.info(f"πŸ“ Result transcripts - Actual: '{actual[:100]}' (len: {len(actual)}), Target: '{target[:100]}' (len: {len(target)})")
113
  return result
114
 
115
  except HTTPException:
diagnosis/ai_engine/detect_stuttering.py CHANGED
@@ -155,6 +155,14 @@ class AdvancedStutterDetector:
155
  torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32
156
  ).to(DEVICE)
157
  self.model.eval()
 
 
 
 
 
 
 
 
158
  self.loaded_adapters = set() # Keep for backward compatibility but not used with indicwav2vec
159
 
160
  # Anomaly Detection Model (for outlier stutter events)
@@ -320,8 +328,45 @@ class AdvancedStutterDetector:
320
  logits = outputs.logits
321
  predicted_ids = torch.argmax(logits, dim=-1)
322
 
323
- # Decode transcript
324
- transcript = self.processor.batch_decode(predicted_ids)[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
325
 
326
  # Estimate word timestamps (simplified - frame-level alignment)
327
  frame_duration = 0.02 # 20ms per frame
@@ -329,9 +374,9 @@ class AdvancedStutterDetector:
329
  audio_duration = len(audio) / 16000
330
 
331
  # Simple word-level timestamps (would need proper alignment for production)
332
- words = transcript.split()
333
  word_timestamps = []
334
- time_per_word = audio_duration / max(len(words), 1)
335
 
336
  for i, word in enumerate(words):
337
  word_timestamps.append({
@@ -342,7 +387,7 @@ class AdvancedStutterDetector:
342
 
343
  return transcript, word_timestamps, logits
344
  except Exception as e:
345
- logger.error(f"Transcription failed: {e}")
346
  return "", [], torch.zeros((1, 100, 32)) # Dummy return
347
 
348
  def _calculate_uncertainty(self, logits: torch.Tensor) -> Tuple[float, List[Dict]]:
@@ -686,6 +731,7 @@ class AdvancedStutterDetector:
686
 
687
  # === STEP 4: Wav2Vec2 Transcription & Uncertainty ===
688
  transcript, word_timestamps, logits = self._transcribe_with_timestamps(audio)
 
689
  entropy_score, low_conf_regions = self._calculate_uncertainty(logits)
690
 
691
  # === STEP 5: Speaking Rate Estimation ===
@@ -759,9 +805,15 @@ class AdvancedStutterDetector:
759
  metrics['severity_score'] = max(metrics['severity_score'], 5.0)
760
 
761
  # === STEP 9: Return Comprehensive Report ===
 
 
 
 
 
 
762
  return {
763
- 'actual_transcript': transcript,
764
- 'target_transcript': transcript,
765
  'mismatched_chars': [f"{r['time']}s" for r in low_conf_regions],
766
  'mismatch_percentage': metrics['severity_score'],
767
  'ctc_loss_score': round(entropy_score, 4),
 
155
  torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32
156
  ).to(DEVICE)
157
  self.model.eval()
158
+
159
+ # Debug: Log processor structure
160
+ logger.info(f"πŸ“‹ Processor type: {type(self.processor)}")
161
+ logger.info(f"πŸ“‹ Processor attributes: {[attr for attr in dir(self.processor) if not attr.startswith('_')]}")
162
+ if hasattr(self.processor, 'tokenizer'):
163
+ logger.info(f"πŸ“‹ Tokenizer type: {type(self.processor.tokenizer)}")
164
+ if hasattr(self.processor, 'feature_extractor'):
165
+ logger.info(f"πŸ“‹ Feature extractor type: {type(self.processor.feature_extractor)}")
166
  self.loaded_adapters = set() # Keep for backward compatibility but not used with indicwav2vec
167
 
168
  # Anomaly Detection Model (for outlier stutter events)
 
328
  logits = outputs.logits
329
  predicted_ids = torch.argmax(logits, dim=-1)
330
 
331
+ # Decode transcript - IndicWav2Vec uses tokenizer for decoding
332
+ transcript = ""
333
+ try:
334
+ # Method 1: Try using processor's tokenizer directly
335
+ if hasattr(self.processor, 'tokenizer'):
336
+ transcript = self.processor.tokenizer.decode(predicted_ids[0], skip_special_tokens=True)
337
+ logger.info(f"πŸ“ Decoded via tokenizer: '{transcript}' (length: {len(transcript)})")
338
+ # Method 2: Try batch_decode if tokenizer not available
339
+ elif hasattr(self.processor, 'batch_decode'):
340
+ transcript = self.processor.batch_decode(predicted_ids)[0]
341
+ logger.info(f"πŸ“ Decoded via batch_decode: '{transcript}' (length: {len(transcript)})")
342
+ # Method 3: Try accessing tokenizer through processor.feature_extractor or processor attributes
343
+ else:
344
+ # Check if processor wraps a tokenizer
345
+ for attr in ['tokenizer', '_tokenizer', 'decoder']:
346
+ if hasattr(self.processor, attr):
347
+ tokenizer = getattr(self.processor, attr)
348
+ if hasattr(tokenizer, 'decode'):
349
+ transcript = tokenizer.decode(predicted_ids[0], skip_special_tokens=True)
350
+ logger.info(f"πŸ“ Decoded via {attr}: '{transcript}' (length: {len(transcript)})")
351
+ break
352
+
353
+ # Clean up transcript - remove special tokens and normalize
354
+ if transcript:
355
+ transcript = transcript.strip()
356
+ # Remove common special tokens if present
357
+ transcript = transcript.replace('<pad>', '').replace('<s>', '').replace('</s>', '').replace('|', ' ').strip()
358
+ # Normalize whitespace
359
+ transcript = ' '.join(transcript.split())
360
+
361
+ except Exception as decode_error:
362
+ logger.error(f"⚠️ Decode error: {decode_error}", exc_info=True)
363
+ transcript = ""
364
+
365
+ # Ensure transcript is not None
366
+ if not transcript:
367
+ transcript = ""
368
+ logger.warning("⚠️ Empty transcript generated - model may not have produced valid output")
369
+ logger.warning(f"⚠️ Predicted IDs shape: {predicted_ids.shape}, sample values: {predicted_ids[0][:10].tolist() if predicted_ids.numel() > 0 else 'empty'}")
370
 
371
  # Estimate word timestamps (simplified - frame-level alignment)
372
  frame_duration = 0.02 # 20ms per frame
 
374
  audio_duration = len(audio) / 16000
375
 
376
  # Simple word-level timestamps (would need proper alignment for production)
377
+ words = transcript.split() if transcript else []
378
  word_timestamps = []
379
+ time_per_word = audio_duration / max(len(words), 1) if words else 0
380
 
381
  for i, word in enumerate(words):
382
  word_timestamps.append({
 
387
 
388
  return transcript, word_timestamps, logits
389
  except Exception as e:
390
+ logger.error(f"❌ Transcription failed: {e}", exc_info=True)
391
  return "", [], torch.zeros((1, 100, 32)) # Dummy return
392
 
393
  def _calculate_uncertainty(self, logits: torch.Tensor) -> Tuple[float, List[Dict]]:
 
731
 
732
  # === STEP 4: Wav2Vec2 Transcription & Uncertainty ===
733
  transcript, word_timestamps, logits = self._transcribe_with_timestamps(audio)
734
+ logger.info(f"πŸ“ Main transcription result: '{transcript}' (length: {len(transcript)}, words: {len(word_timestamps)})")
735
  entropy_score, low_conf_regions = self._calculate_uncertainty(logits)
736
 
737
  # === STEP 5: Speaking Rate Estimation ===
 
805
  metrics['severity_score'] = max(metrics['severity_score'], 5.0)
806
 
807
  # === STEP 9: Return Comprehensive Report ===
808
+ # Ensure transcripts are not None
809
+ actual_transcript = transcript if transcript else ""
810
+ target_transcript = proper_transcript if proper_transcript else transcript if transcript else ""
811
+
812
+ logger.info(f"πŸ“ Final return - Actual: '{actual_transcript}' (len: {len(actual_transcript)}), Target: '{target_transcript}' (len: {len(target_transcript)})")
813
+
814
  return {
815
+ 'actual_transcript': actual_transcript,
816
+ 'target_transcript': target_transcript,
817
  'mismatched_chars': [f"{r['time']}s" for r in low_conf_regions],
818
  'mismatch_percentage': metrics['severity_score'],
819
  'ctc_loss_score': round(entropy_score, 4),