anfastech commited on
Commit
13a1b12
·
1 Parent(s): 900bd70

Changing the model to [facebook/mms-1b-all] to [ai4bharat/indicwav2vec-hindi]

Browse files
diagnosis/ai_engine/detect_stuttering.py CHANGED
@@ -22,7 +22,7 @@ from sklearn.ensemble import IsolationForest
22
  logger = logging.getLogger(__name__)
23
 
24
  # === CONFIGURATION ===
25
- MODEL_ID = "facebook/mms-1b-all"
26
  LID_MODEL_ID = "facebook/mms-lid-126"
27
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
28
 
@@ -138,17 +138,14 @@ class AdvancedStutterDetector:
138
  def __init__(self):
139
  logger.info(f"🚀 Initializing Advanced AI Engine on {DEVICE}...")
140
  try:
141
- # Wav2Vec2 Model Loading
142
  self.processor = AutoProcessor.from_pretrained(MODEL_ID)
143
  self.model = Wav2Vec2ForCTC.from_pretrained(
144
  MODEL_ID,
145
- torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
146
- target_lang="eng",
147
- ignore_mismatched_sizes=True
148
  ).to(DEVICE)
149
  self.model.eval()
150
- self.loaded_adapters = set()
151
- self._init_common_adapters()
152
 
153
  # Anomaly Detection Model (for outlier stutter events)
154
  self.anomaly_detector = IsolationForest(
@@ -162,12 +159,9 @@ class AdvancedStutterDetector:
162
  raise
163
 
164
  def _init_common_adapters(self):
165
- """Preload common language adapters"""
166
- for code in ['eng', 'hin']:
167
- try:
168
- self.model.load_adapter(code)
169
- self.loaded_adapters.add(code)
170
- except: pass
171
 
172
  def _detect_language_robust(self, audio_path: str) -> str:
173
  """Detect language using MMS LID model"""
@@ -190,18 +184,12 @@ class AdvancedStutterDetector:
190
  return 'eng'
191
 
192
  def _activate_adapter(self, lang_code: str):
193
- """Activate language adapter for MMS model"""
194
- if lang_code not in self.loaded_adapters:
195
- try:
196
- self.model.load_adapter(lang_code)
197
- self.loaded_adapters.add(lang_code)
198
- except Exception as e:
199
- logger.warning(f"Failed to load adapter {lang_code}: {e}")
200
-
201
- try:
202
- self.model.set_adapter(lang_code)
203
- except Exception as e:
204
- logger.warning(f"Failed to activate adapter {lang_code}: {e}")
205
 
206
  def _extract_comprehensive_features(self, audio: np.ndarray, sr: int, audio_path: str) -> Dict[str, Any]:
207
  """Extract multi-modal acoustic features"""
@@ -666,10 +654,11 @@ class AdvancedStutterDetector:
666
  start_time = time.time()
667
 
668
  # === STEP 1: Language Detection & Setup ===
 
669
  if language == 'auto':
670
  lang_code = self._detect_language_robust(audio_path)
671
  else:
672
- lang_code = INDIAN_LANGUAGES.get(language.lower(), 'eng')
673
  self._activate_adapter(lang_code)
674
 
675
  # === STEP 2: Audio Loading & Preprocessing ===
@@ -775,16 +764,16 @@ class AdvancedStutterDetector:
775
  'energy_entropy': float(np.mean(features['energy_entropy']))
776
  },
777
  'analysis_duration_seconds': round(time.time() - start_time, 2),
778
- 'model_version': f'advanced-research-v2-{lang_code}'
779
  }
780
 
781
 
782
  # Legacy methods - kept for backward compatibility but may not work without additional model initialization
783
  # These methods reference models (xlsr, base, large) that are not initialized in __init__
784
- # The main analyze_audio() method uses the MMS model instead
785
 
786
  def generate_target_transcript(self, audio_file: str) -> str:
787
- """Generate expected transcript - Legacy method (uses main MMS model)"""
788
  try:
789
  audio, sr = librosa.load(audio_file, sr=16000)
790
  transcript, _, _ = self._transcribe_with_timestamps(audio)
 
22
  logger = logging.getLogger(__name__)
23
 
24
  # === CONFIGURATION ===
25
+ MODEL_ID = "ai4bharat/indicwav2vec-hindi"
26
  LID_MODEL_ID = "facebook/mms-lid-126"
27
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
28
 
 
138
  def __init__(self):
139
  logger.info(f"🚀 Initializing Advanced AI Engine on {DEVICE}...")
140
  try:
141
+ # Wav2Vec2 Model Loading - IndicWav2Vec Hindi Model
142
  self.processor = AutoProcessor.from_pretrained(MODEL_ID)
143
  self.model = Wav2Vec2ForCTC.from_pretrained(
144
  MODEL_ID,
145
+ torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32
 
 
146
  ).to(DEVICE)
147
  self.model.eval()
148
+ self.loaded_adapters = set() # Keep for backward compatibility but not used with indicwav2vec
 
149
 
150
  # Anomaly Detection Model (for outlier stutter events)
151
  self.anomaly_detector = IsolationForest(
 
159
  raise
160
 
161
  def _init_common_adapters(self):
162
+ """Preload common language adapters - Not applicable for indicwav2vec-hindi"""
163
+ # IndicWav2Vec Hindi model is pre-trained for Hindi, no adapters needed
164
+ pass
 
 
 
165
 
166
  def _detect_language_robust(self, audio_path: str) -> str:
167
  """Detect language using MMS LID model"""
 
184
  return 'eng'
185
 
186
  def _activate_adapter(self, lang_code: str):
187
+ """Activate language adapter - Not applicable for indicwav2vec-hindi"""
188
+ # IndicWav2Vec Hindi model is pre-trained for Hindi, no adapter switching needed
189
+ # Log for debugging but no action required
190
+ if lang_code != 'hin':
191
+ logger.info(f"Note: Using Hindi-specific model (indicwav2vec-hindi), language code '{lang_code}' requested but model is optimized for Hindi")
192
+ pass
 
 
 
 
 
 
193
 
194
  def _extract_comprehensive_features(self, audio: np.ndarray, sr: int, audio_path: str) -> Dict[str, Any]:
195
  """Extract multi-modal acoustic features"""
 
654
  start_time = time.time()
655
 
656
  # === STEP 1: Language Detection & Setup ===
657
+ # Note: indicwav2vec-hindi is optimized for Hindi, but can handle other languages
658
  if language == 'auto':
659
  lang_code = self._detect_language_robust(audio_path)
660
  else:
661
+ lang_code = INDIAN_LANGUAGES.get(language.lower(), 'hin') # Default to Hindi for indicwav2vec
662
  self._activate_adapter(lang_code)
663
 
664
  # === STEP 2: Audio Loading & Preprocessing ===
 
764
  'energy_entropy': float(np.mean(features['energy_entropy']))
765
  },
766
  'analysis_duration_seconds': round(time.time() - start_time, 2),
767
+ 'model_version': f'indicwav2vec-hindi-v1-{lang_code}'
768
  }
769
 
770
 
771
  # Legacy methods - kept for backward compatibility but may not work without additional model initialization
772
  # These methods reference models (xlsr, base, large) that are not initialized in __init__
773
+ # The main analyze_audio() method uses the IndicWav2Vec Hindi model instead
774
 
775
  def generate_target_transcript(self, audio_file: str) -> str:
776
+ """Generate expected transcript - Legacy method (uses IndicWav2Vec Hindi model)"""
777
  try:
778
  audio, sr = librosa.load(audio_file, sr=16000)
779
  transcript, _, _ = self._transcribe_with_timestamps(audio)