Changing the model to [facebook/mms-1b-all] to [ai4bharat/indicwav2vec-hindi]
Browse files
diagnosis/ai_engine/detect_stuttering.py
CHANGED
|
@@ -22,7 +22,7 @@ from sklearn.ensemble import IsolationForest
|
|
| 22 |
logger = logging.getLogger(__name__)
|
| 23 |
|
| 24 |
# === CONFIGURATION ===
|
| 25 |
-
MODEL_ID = "
|
| 26 |
LID_MODEL_ID = "facebook/mms-lid-126"
|
| 27 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 28 |
|
|
@@ -138,17 +138,14 @@ class AdvancedStutterDetector:
|
|
| 138 |
def __init__(self):
|
| 139 |
logger.info(f"🚀 Initializing Advanced AI Engine on {DEVICE}...")
|
| 140 |
try:
|
| 141 |
-
# Wav2Vec2 Model Loading
|
| 142 |
self.processor = AutoProcessor.from_pretrained(MODEL_ID)
|
| 143 |
self.model = Wav2Vec2ForCTC.from_pretrained(
|
| 144 |
MODEL_ID,
|
| 145 |
-
torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32
|
| 146 |
-
target_lang="eng",
|
| 147 |
-
ignore_mismatched_sizes=True
|
| 148 |
).to(DEVICE)
|
| 149 |
self.model.eval()
|
| 150 |
-
self.loaded_adapters = set()
|
| 151 |
-
self._init_common_adapters()
|
| 152 |
|
| 153 |
# Anomaly Detection Model (for outlier stutter events)
|
| 154 |
self.anomaly_detector = IsolationForest(
|
|
@@ -162,12 +159,9 @@ class AdvancedStutterDetector:
|
|
| 162 |
raise
|
| 163 |
|
| 164 |
def _init_common_adapters(self):
|
| 165 |
-
"""Preload common language adapters"""
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
self.model.load_adapter(code)
|
| 169 |
-
self.loaded_adapters.add(code)
|
| 170 |
-
except: pass
|
| 171 |
|
| 172 |
def _detect_language_robust(self, audio_path: str) -> str:
|
| 173 |
"""Detect language using MMS LID model"""
|
|
@@ -190,18 +184,12 @@ class AdvancedStutterDetector:
|
|
| 190 |
return 'eng'
|
| 191 |
|
| 192 |
def _activate_adapter(self, lang_code: str):
|
| 193 |
-
"""Activate language adapter
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
logger.warning(f"Failed to load adapter {lang_code}: {e}")
|
| 200 |
-
|
| 201 |
-
try:
|
| 202 |
-
self.model.set_adapter(lang_code)
|
| 203 |
-
except Exception as e:
|
| 204 |
-
logger.warning(f"Failed to activate adapter {lang_code}: {e}")
|
| 205 |
|
| 206 |
def _extract_comprehensive_features(self, audio: np.ndarray, sr: int, audio_path: str) -> Dict[str, Any]:
|
| 207 |
"""Extract multi-modal acoustic features"""
|
|
@@ -666,10 +654,11 @@ class AdvancedStutterDetector:
|
|
| 666 |
start_time = time.time()
|
| 667 |
|
| 668 |
# === STEP 1: Language Detection & Setup ===
|
|
|
|
| 669 |
if language == 'auto':
|
| 670 |
lang_code = self._detect_language_robust(audio_path)
|
| 671 |
else:
|
| 672 |
-
lang_code = INDIAN_LANGUAGES.get(language.lower(), '
|
| 673 |
self._activate_adapter(lang_code)
|
| 674 |
|
| 675 |
# === STEP 2: Audio Loading & Preprocessing ===
|
|
@@ -775,16 +764,16 @@ class AdvancedStutterDetector:
|
|
| 775 |
'energy_entropy': float(np.mean(features['energy_entropy']))
|
| 776 |
},
|
| 777 |
'analysis_duration_seconds': round(time.time() - start_time, 2),
|
| 778 |
-
'model_version': f'
|
| 779 |
}
|
| 780 |
|
| 781 |
|
| 782 |
# Legacy methods - kept for backward compatibility but may not work without additional model initialization
|
| 783 |
# These methods reference models (xlsr, base, large) that are not initialized in __init__
|
| 784 |
-
# The main analyze_audio() method uses the
|
| 785 |
|
| 786 |
def generate_target_transcript(self, audio_file: str) -> str:
|
| 787 |
-
"""Generate expected transcript - Legacy method (uses
|
| 788 |
try:
|
| 789 |
audio, sr = librosa.load(audio_file, sr=16000)
|
| 790 |
transcript, _, _ = self._transcribe_with_timestamps(audio)
|
|
|
|
| 22 |
logger = logging.getLogger(__name__)
|
| 23 |
|
| 24 |
# === CONFIGURATION ===
|
| 25 |
+
MODEL_ID = "ai4bharat/indicwav2vec-hindi"
|
| 26 |
LID_MODEL_ID = "facebook/mms-lid-126"
|
| 27 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 28 |
|
|
|
|
| 138 |
def __init__(self):
|
| 139 |
logger.info(f"🚀 Initializing Advanced AI Engine on {DEVICE}...")
|
| 140 |
try:
|
| 141 |
+
# Wav2Vec2 Model Loading - IndicWav2Vec Hindi Model
|
| 142 |
self.processor = AutoProcessor.from_pretrained(MODEL_ID)
|
| 143 |
self.model = Wav2Vec2ForCTC.from_pretrained(
|
| 144 |
MODEL_ID,
|
| 145 |
+
torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32
|
|
|
|
|
|
|
| 146 |
).to(DEVICE)
|
| 147 |
self.model.eval()
|
| 148 |
+
self.loaded_adapters = set() # Keep for backward compatibility but not used with indicwav2vec
|
|
|
|
| 149 |
|
| 150 |
# Anomaly Detection Model (for outlier stutter events)
|
| 151 |
self.anomaly_detector = IsolationForest(
|
|
|
|
| 159 |
raise
|
| 160 |
|
| 161 |
def _init_common_adapters(self):
|
| 162 |
+
"""Preload common language adapters - Not applicable for indicwav2vec-hindi"""
|
| 163 |
+
# IndicWav2Vec Hindi model is pre-trained for Hindi, no adapters needed
|
| 164 |
+
pass
|
|
|
|
|
|
|
|
|
|
| 165 |
|
| 166 |
def _detect_language_robust(self, audio_path: str) -> str:
|
| 167 |
"""Detect language using MMS LID model"""
|
|
|
|
| 184 |
return 'eng'
|
| 185 |
|
| 186 |
def _activate_adapter(self, lang_code: str):
|
| 187 |
+
"""Activate language adapter - Not applicable for indicwav2vec-hindi"""
|
| 188 |
+
# IndicWav2Vec Hindi model is pre-trained for Hindi, no adapter switching needed
|
| 189 |
+
# Log for debugging but no action required
|
| 190 |
+
if lang_code != 'hin':
|
| 191 |
+
logger.info(f"Note: Using Hindi-specific model (indicwav2vec-hindi), language code '{lang_code}' requested but model is optimized for Hindi")
|
| 192 |
+
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 193 |
|
| 194 |
def _extract_comprehensive_features(self, audio: np.ndarray, sr: int, audio_path: str) -> Dict[str, Any]:
|
| 195 |
"""Extract multi-modal acoustic features"""
|
|
|
|
| 654 |
start_time = time.time()
|
| 655 |
|
| 656 |
# === STEP 1: Language Detection & Setup ===
|
| 657 |
+
# Note: indicwav2vec-hindi is optimized for Hindi, but can handle other languages
|
| 658 |
if language == 'auto':
|
| 659 |
lang_code = self._detect_language_robust(audio_path)
|
| 660 |
else:
|
| 661 |
+
lang_code = INDIAN_LANGUAGES.get(language.lower(), 'hin') # Default to Hindi for indicwav2vec
|
| 662 |
self._activate_adapter(lang_code)
|
| 663 |
|
| 664 |
# === STEP 2: Audio Loading & Preprocessing ===
|
|
|
|
| 764 |
'energy_entropy': float(np.mean(features['energy_entropy']))
|
| 765 |
},
|
| 766 |
'analysis_duration_seconds': round(time.time() - start_time, 2),
|
| 767 |
+
'model_version': f'indicwav2vec-hindi-v1-{lang_code}'
|
| 768 |
}
|
| 769 |
|
| 770 |
|
| 771 |
# Legacy methods - kept for backward compatibility but may not work without additional model initialization
|
| 772 |
# These methods reference models (xlsr, base, large) that are not initialized in __init__
|
| 773 |
+
# The main analyze_audio() method uses the IndicWav2Vec Hindi model instead
|
| 774 |
|
| 775 |
def generate_target_transcript(self, audio_file: str) -> str:
|
| 776 |
+
"""Generate expected transcript - Legacy method (uses IndicWav2Vec Hindi model)"""
|
| 777 |
try:
|
| 778 |
audio, sr = librosa.load(audio_file, sr=16000)
|
| 779 |
transcript, _, _ = self._transcribe_with_timestamps(audio)
|