import time from typing import Optional, Tuple from app.interfaces import IStreamingSpeechEngine import numpy as np import torch import gc from omegaconf import OmegaConf from nemo.collections.asr.models.aed_multitask_models import lens_to_mask from nemo.collections.asr.parts.submodules.aed_decoding import ( GreedyBatchedStreamingAEDComputer, return_decoder_input_ids, ) from nemo.collections.asr.parts.submodules.multitask_decoding import ( AEDStreamingDecodingConfig, MultiTaskDecodingConfig, ) # from nemo.collections.asr.parts.utils.eval_utils import cal_write_wer # Not used from nemo.collections.asr.parts.utils.streaming_utils import ( ContextSize, StreamingBatchedAudioBuffer, ) import nemo.collections.asr as nemo_asr from nemo.collections.asr.parts.utils.transcribe_utils import ( get_inference_device, get_inference_dtype, ) from app.logger_config import ( logger as logging, DEBUG ) from dataclasses import dataclass from typing import Optional, Literal @dataclass class CanaryConfig: chunk_secs: float = 1.0 left_context_secs: float = 20.0 right_context_secs: float = 0.5 cuda: Optional[bool] = None allow_mps: bool = True compute_dtype: Optional[str] = None matmul_precision: str = "high" batch_size= 1 decoding: dict = None streaming_policy: str = "alignatt" alignatt_thr: float = 8.0 waitk_lagging: int = 2 exclude_sink_frames: int = 8 xatt_scores_layer: int = -2 max_tokens_per_alignatt_step: int = 30 max_generation_length: int = 512 use_avgpool_for_alignatt: bool = False hallucinations_detector: bool = True prompt: dict = None pnc: str = "no" task: str = "asr" source_lang: str = "fr" target_lang: str = "fr" timestamps: bool = True def __post_init__(self): if self.decoding is None: self.decoding = { "streaming_policy": self.streaming_policy, "alignatt_thr": self.alignatt_thr, "waitk_lagging": self.waitk_lagging, "exclude_sink_frames": self.exclude_sink_frames, "xatt_scores_layer": self.xatt_scores_layer, "max_tokens_per_alignatt_step": self.max_tokens_per_alignatt_step, "max_generation_length": self.max_generation_length, "use_avgpool_for_alignatt": self.use_avgpool_for_alignatt, "hallucinations_detector": self.hallucinations_detector } if self.prompt is None: self.prompt = { "pnc": self.pnc, "task": self.task, "source_lang": self.source_lang, "target_lang": self.target_lang, "timestamps": self.timestamps } def toOmegaConf(self) -> OmegaConf: """Convert the config to OmegaConf format""" config_dict = { "chunk_secs": self.chunk_secs, "left_context_secs": self.left_context_secs, "right_context_secs": self.right_context_secs, "cuda": self.cuda, "allow_mps": self.allow_mps, "compute_dtype": self.compute_dtype, "matmul_precision": self.matmul_precision, "batch_size": self.batch_size, "decoding": self.decoding, "prompt": self.prompt } # Remove None values filtered_dict = {k: v for k, v in config_dict.items() if v is not None} return OmegaConf.create(filtered_dict) @classmethod def from_params( cls, task_type: str, source_lang: str, target_lang: str, chunk_secs: float = 1.0, left_context_secs: float = 20.0, right_context_secs: float = 0.5, streaming_policy: str = "alignatt", alignatt_thr: float = 8.0, waitk_lagging: int = 2, exclude_sink_frames: int = 8, xatt_scores_layer: int = -2, hallucinations_detector: bool = True ): """Create a CanaryConfig instance from parameters""" # Convert task type to model task task = "asr" if task_type == "Transcription" else "ast" target_lang = source_lang if task_type == "Transcription" else target_lang return cls( chunk_secs=chunk_secs, left_context_secs=left_context_secs, right_context_secs=right_context_secs, streaming_policy=streaming_policy, alignatt_thr=alignatt_thr, waitk_lagging=waitk_lagging, exclude_sink_frames=exclude_sink_frames, xatt_scores_layer=xatt_scores_layer, hallucinations_detector=hallucinations_detector, task=task, source_lang=source_lang, target_lang=target_lang ) def make_divisible_by(num: int, factor: int) -> int: """Make num divisible by factor""" return (num // factor) * factor class CanarySpeechEngine(IStreamingSpeechEngine): """ Encapsulates the state and logic for streaming audio transcription using an internally loaded Canary model. """ def __init__(self,asr_model, cfg: CanaryConfig): """ Initializes the speech engine and loads the ASR model. Args: cfg: An OmegaConf object containing 'model' and 'streaming' configs. """ logging.debug(f"Initializing CanarySpeechEngine with config: {cfg}") self.cfg = cfg.toOmegaConf() # Store the full config # Setup device and dtype from config self.map_location = get_inference_device(cuda=None, allow_mps=self.cfg.allow_mps) self.compute_dtype = get_inference_dtype(None, device=self.map_location) logging.info(f"Inference will be on device: {self.map_location} with dtype: {self.compute_dtype}") # Load the model internally asr_model, _ = self._setup_model(asr_model,self.cfg, self.map_location) self.asr_model = asr_model self.full_transcription = [] # Stores finalized segments self._setup_streaming_params() # The initial full reset (buffer + decoder) self.reset() logging.info("CanarySpeechEngine initialized and ready.") logging.info(f"Model-adjusted chunk size: {self.context_samples.chunk} samples.") def _setup_model(self,asr_model, model_cfg: OmegaConf, map_location: str): """Loads the pretrained ASR model and configures it for inference.""" logging.info(f"Loading model ...") start_time = time.time() try: asr_model = asr_model.to(map_location) asr_model.eval() # Change decoding strategy to greedy for streaming if hasattr(asr_model, 'change_decoding_strategy'): multitask_decoding = MultiTaskDecodingConfig() multitask_decoding.strategy = "greedy" asr_model.change_decoding_strategy(multitask_decoding) logging.info("Model decoding strategy set to 'greedy'.") if map_location == "cuda": torch.cuda.synchronize() end_time = time.time() logging.info("Model loaded successfully.") load_time = end_time - start_time logging.info("\n" + "="*30) logging.info(f"Total model load time: {load_time:.2f} seconds") logging.info("="*30) return asr_model, None except Exception as e: logging.error(f"Error loading model: {e}") logging.error("Ensure NeMo is installed (pip install nemo_toolkit['asr'])") return None, None def _setup_streaming_params(self): """Helper to calculate model-specific streaming parameters.""" model_cfg = self.asr_model.cfg audio_sample_rate = model_cfg.preprocessor['sample_rate'] self.feature_stride_sec = model_cfg.preprocessor['window_stride'] features_per_sec = 1.0 / self.feature_stride_sec self.encoder_subsampling_factor = self.asr_model.encoder.subsampling_factor self.features_frame2audio_samples = make_divisible_by( int(audio_sample_rate * self.feature_stride_sec ), factor=self.encoder_subsampling_factor ) encoder_frame2audio_samples = self.features_frame2audio_samples * self.encoder_subsampling_factor # Use self.cfg.streaming instead of self.streaming_cfg streaming_cfg = self.cfg self.context_encoder_frames = ContextSize( left=int(streaming_cfg.left_context_secs * features_per_sec / self.encoder_subsampling_factor), chunk=int(streaming_cfg.chunk_secs * features_per_sec / self.encoder_subsampling_factor), right=int(streaming_cfg.right_context_secs * features_per_sec / self.encoder_subsampling_factor), ) self.context_samples = ContextSize( left=self.context_encoder_frames.left * encoder_frame2audio_samples, chunk=self.context_encoder_frames.chunk * encoder_frame2audio_samples, right=self.context_encoder_frames.right * encoder_frame2audio_samples, ) def _reset_decoder_state(self): """ Resets ONLY the decoder state, preserving the audio buffer. This prevents slowdowns on long audio streams. """ start_time = time.perf_counter() logging.debug("--- Resetting decoder state (audio buffer preserved) ---") # Reset tracking for this segment self.last_transcription = "" self.chunk_count = 0 batch_size = 1 # Hardcoded for this script # Use self.cfg.streaming instead of self.streaming_cfg streaming_cfg = self.cfg # 1. Recreate the initial prompt for the decoder self.decoder_input_ids = return_decoder_input_ids(streaming_cfg, self.asr_model) # 2. Recreate the "computer" object that manages decoding self.decoding_computer = GreedyBatchedStreamingAEDComputer( self.asr_model, frame_chunk_size=self.context_encoder_frames.chunk, decoding_cfg=streaming_cfg.decoding, ) # 3. Recreate an EMPTY STATE object (model_state) self.model_state = GreedyBatchedStreamingAEDComputer.initialize_aed_model_state( asr_model=self.asr_model, decoder_input_ids=self.decoder_input_ids, batch_size=batch_size, context_encoder_frames=self.context_encoder_frames, chunk_secs=streaming_cfg.chunk_secs, right_context_secs=streaming_cfg.right_context_secs, ) # Clear CUDA cache if possible if torch.cuda.is_available(): gc.collect() torch.cuda.empty_cache() torch.cuda.synchronize() end_time = time.perf_counter() duration_ms = (end_time - start_time) * 1000 # Convert to milliseconds logging.debug(f"--- Decoder reset finished in {duration_ms:.2f} ms ---") def reset(self): """ Resets the transcriber's state completely (audio buffer + decoder state). Called only on initialization. """ start_time = time.perf_counter() logging.debug("--- FULL RESET (Audio Buffer + Decoder State) ---") # Operation 1: Reset the decoder (this now includes GC) self._reset_decoder_state() # Operation 2: Reset the audio buffer self.buffer = StreamingBatchedAudioBuffer( batch_size=1, # Hardcoded for this script context_samples=self.context_samples, dtype=torch.float32, device=self.map_location, ) end_time = time.perf_counter() duration_ms = (end_time * 1000) logging.debug(f"--- RESET Complete: took {duration_ms:.2f} ms ---") def transcribe_chunk(self, chunk: np.ndarray, is_last_chunk: bool = False) -> Tuple[str, str]: """ Processes a single audio chunk and returns the newly predicted text. Returns: Tuple[str, str]: (current_transcription: The full transcription for the current segment, new_text: The newly appended text since the last chunk) """ start_time = time.perf_counter() self.chunk_count += 1 # Preprocess audio signal = torch.from_numpy(chunk.astype(np.float32) / 32768.0) audio_batch = signal.unsqueeze(0).to(self.map_location) audio_batch_lengths = torch.tensor([signal.shape[0]], device=self.map_location) # 1. Add the chunk to the persistent buffer self.buffer.add_audio_batch_( audio_batch, audio_lengths=audio_batch_lengths, is_last_chunk=is_last_chunk, is_last_chunk_batch=torch.tensor([is_last_chunk], device=self.map_location) ) self.model_state.is_last_chunk_batch = torch.tensor([is_last_chunk], device=self.map_location) # 2. Pass the buffer to the encoder _, encoded_len, enc_states, _ = self.asr_model( input_signal=self.buffer.samples, input_signal_length=self.buffer.context_size_batch.total() ) encoder_context_batch = self.buffer.context_size_batch.subsample(factor=self.features_frame2audio_samples * self.encoder_subsampling_factor) encoded_len_no_rc = encoder_context_batch.left + encoder_context_batch.chunk encoded_length_corrected = torch.where(self.model_state.is_last_chunk_batch, encoded_len, encoded_len_no_rc) encoder_input_mask = lens_to_mask(encoded_length_corrected, enc_states.shape[1]).to(enc_states.dtype) # 3. Pass to the decoding computer self.model_state = self.decoding_computer( encoder_output=enc_states, encoder_output_len=encoded_length_corrected, encoder_input_mask=encoder_input_mask, prev_batched_state=self.model_state, ) # 4. Calculate the new text current_tokens = self.model_state.pred_tokens_ids[0, self.decoder_input_ids.size(-1): self.model_state.current_context_lengths[0]] # OPTIMIZATION: Move tokens to CPU before converting to list current_transcription = self.asr_model.tokenizer.ids_to_text(current_tokens.cpu().tolist()).strip() # Calculate the NEW text by "subtracting" the old history new_text = "" if current_transcription.startswith(self.last_transcription): new_text = current_transcription[len(self.last_transcription):] else: # Model corrected itself, send the full new transcription new_text = current_transcription # Memorize the FULL current transcription as the new history if new_text: self.last_transcription = current_transcription end_time = time.perf_counter() duration_ms = (end_time - start_time) * 1000 # logging.info(f"--- transcribe_chunk: took {duration_ms:.2f} ms ---") # Return both the full segment transcription and the new diff yield current_transcription, new_text def finalize_segment(self): """ Finalizes the current transcription segment (e.g., on silence) and adds it to the full history. """ if self.last_transcription: self.full_transcription.append(self.last_transcription) self.last_transcription = "" # We must reset the decoder state to start a new segment self._reset_decoder_state() def get_full_transcription(self) -> str: """ Returns the full accumulated transcription from all finalized segments. Does NOT include the currently active (unfinalized) segment. """ return " ".join(self.full_transcription) def get_current_segment_text(self) -> str: """Returns the text of the segment currently being transcribed.""" return self.last_transcription