Archime commited on
Commit
7b84154
·
1 Parent(s): 010aaff

impl CanaryConfig from ui

Browse files
Files changed (2) hide show
  1. app.py +31 -10
  2. app/canary_speech_engine.py +114 -5
app.py CHANGED
@@ -52,25 +52,39 @@ reset_all_active_session_hash_code()
52
 
53
  theme,css_style = get_custom_theme()
54
 
55
- from omegaconf import OmegaConf
56
- cfg = OmegaConf.load('app/config.yaml')
57
  # logger.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}')
58
- from app.canary_speech_engine import CanarySpeechEngine
59
  from app.silero_vad_engine import Silero_Vad_Engine
60
  from app.streaming_audio_processor import StreamingAudioProcessor,StreamingAudioProcessorConfig
61
 
62
 
63
- asr_model = nemo_asr.models.ASRModel.from_pretrained(cfg.pretrained_name)
64
- canary_speech_engine = CanarySpeechEngine(asr_model,cfg)
65
- silero_vad_engine = Silero_Vad_Engine()
66
  streaming_audio_processor_config = StreamingAudioProcessorConfig(
67
  read_size=4000,
68
  silence_threshold_chunks=1
69
  )
70
- streamer = StreamingAudioProcessor(speech_engine=canary_speech_engine,vad_engine=silero_vad_engine,cfg=streaming_audio_processor_config)
71
  @spaces.GPU
72
- def task(session_id: str):
 
 
 
 
 
73
  """Continuously read and delete .npz chunks while task is active."""
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  active_flag = get_active_task_flag_file(session_id)
75
  with open(active_flag, "w") as f:
76
  f.write("1")
@@ -319,6 +333,7 @@ with gr.Blocks(theme=theme, css=css_style) as demo:
319
  interactive=False,
320
  visible=False
321
  )
 
322
  stop_stream_button = gr.Button("Stop Streaming", visible=False)
323
 
324
  transcription_output = gr.Textbox(
@@ -365,9 +380,15 @@ with gr.Blocks(theme=theme, css=css_style) as demo:
365
 
366
  accumulated = ""
367
  yield f"Starting {task_type.lower()}...\n\n",gr.update(visible=False),gr.update(visible=True)
368
-
369
  # Boucle sur le générateur de `task()`
370
- for msg in task(session_hash_code):
 
 
 
 
 
 
371
  accumulated += msg
372
  yield accumulated,gr.update(visible=False),gr.update(visible=True)
373
 
 
52
 
53
  theme,css_style = get_custom_theme()
54
 
 
 
55
  # logger.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}')
56
+ from app.canary_speech_engine import CanarySpeechEngine,CanaryConfig
57
  from app.silero_vad_engine import Silero_Vad_Engine
58
  from app.streaming_audio_processor import StreamingAudioProcessor,StreamingAudioProcessorConfig
59
 
60
 
61
+ asr_model = nemo_asr.models.ASRModel.from_pretrained("nvidia/canary-1b-v2")
 
 
62
  streaming_audio_processor_config = StreamingAudioProcessorConfig(
63
  read_size=4000,
64
  silence_threshold_chunks=1
65
  )
66
+
67
  @spaces.GPU
68
+ def task(session_id: str,
69
+ task_type, lang_source, lang_target,
70
+ chunk_secs, left_context_secs, right_context_secs,
71
+ streaming_policy, alignatt_thr, waitk_lagging,
72
+ exclude_sink_frames, xatt_scores_layer, hallucinations_detector
73
+ ):
74
  """Continuously read and delete .npz chunks while task is active."""
75
+ yield f"initializing the CanarySpeechEngine and Silero_Vad_Engine\n\n"
76
+ # initialize the CanarySpeechEngine and Silero_Vad_Engine
77
+ conf = CanaryConfig.from_params(
78
+ task_type, lang_source, lang_target,
79
+ chunk_secs, left_context_secs, right_context_secs,
80
+ streaming_policy, alignatt_thr, waitk_lagging,
81
+ exclude_sink_frames, xatt_scores_layer, hallucinations_detector
82
+ )
83
+ canary_speech_engine = CanarySpeechEngine(asr_model,conf)
84
+ silero_vad_engine = Silero_Vad_Engine()
85
+ streamer = StreamingAudioProcessor(speech_engine=canary_speech_engine,vad_engine=silero_vad_engine,cfg=streaming_audio_processor_config)
86
+ yield f"initialized the CanarySpeechEngine and Silero_Vad_Engine\n\n"
87
+ yield f"Task started for session {session_id}\n\n"
88
  active_flag = get_active_task_flag_file(session_id)
89
  with open(active_flag, "w") as f:
90
  f.write("1")
 
333
  interactive=False,
334
  visible=False
335
  )
336
+
337
  stop_stream_button = gr.Button("Stop Streaming", visible=False)
338
 
339
  transcription_output = gr.Textbox(
 
380
 
381
  accumulated = ""
382
  yield f"Starting {task_type.lower()}...\n\n",gr.update(visible=False),gr.update(visible=True)
383
+
384
  # Boucle sur le générateur de `task()`
385
+ for msg in task(
386
+ session_hash_code,
387
+ task_type, lang_source, lang_target,
388
+ chunk_secs, left_context_secs, right_context_secs,
389
+ streaming_policy, alignatt_thr, waitk_lagging,
390
+ exclude_sink_frames, xatt_scores_layer, hallucinations_detector
391
+ ):
392
  accumulated += msg
393
  yield accumulated,gr.update(visible=False),gr.update(visible=True)
394
 
app/canary_speech_engine.py CHANGED
@@ -32,6 +32,115 @@ from app.logger_config import (
32
  )
33
 
34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  def make_divisible_by(num: int, factor: int) -> int:
36
  """Make num divisible by factor"""
37
  return (num // factor) * factor
@@ -42,18 +151,18 @@ class CanarySpeechEngine(IStreamingSpeechEngine):
42
  Encapsulates the state and logic for streaming audio transcription
43
  using an internally loaded Canary model.
44
  """
45
- def __init__(self,asr_model, cfg: OmegaConf):
46
  """
47
  Initializes the speech engine and loads the ASR model.
48
 
49
  Args:
50
  cfg: An OmegaConf object containing 'model' and 'streaming' configs.
51
  """
52
- self.cfg = cfg # Store the full config
53
 
54
  # Setup device and dtype from config
55
- self.map_location = get_inference_device(cuda=self.cfg.cuda, allow_mps=self.cfg.allow_mps)
56
- self.compute_dtype = get_inference_dtype(self.cfg.compute_dtype, device=self.map_location)
57
  logging.info(f"Inference will be on device: {self.map_location} with dtype: {self.compute_dtype}")
58
 
59
  # Load the model internally
@@ -71,7 +180,7 @@ class CanarySpeechEngine(IStreamingSpeechEngine):
71
 
72
  def _setup_model(self,asr_model, model_cfg: OmegaConf, map_location: str):
73
  """Loads the pretrained ASR model and configures it for inference."""
74
- logging.info(f"Loading model {model_cfg.pretrained_name}...")
75
  start_time = time.time()
76
  try:
77
  asr_model = asr_model.to(map_location)
 
32
  )
33
 
34
 
35
+ from dataclasses import dataclass
36
+ from typing import Optional, Literal
37
+
38
+ @dataclass
39
+ class CanaryConfig:
40
+ chunk_secs: float = 1.0
41
+ left_context_secs: float = 20.0
42
+ right_context_secs: float = 0.5
43
+ cuda: Optional[bool] = None
44
+ allow_mps: bool = True
45
+ compute_dtype: Optional[str] = None
46
+ matmul_precision: str = "high"
47
+ batch_size= 1
48
+ decoding: dict = None
49
+ streaming_policy: str = "alignatt"
50
+ alignatt_thr: float = 8.0
51
+ waitk_lagging: int = 2
52
+ exclude_sink_frames: int = 8
53
+ xatt_scores_layer: int = -2
54
+ max_tokens_per_alignatt_step: int = 30
55
+ max_generation_length: int = 512
56
+ use_avgpool_for_alignatt: bool = False
57
+ hallucinations_detector: bool = True
58
+
59
+ prompt: dict = None
60
+ pnc: str = "no"
61
+ task: str = "asr"
62
+ source_lang: str = "fr"
63
+ target_lang: str = "fr"
64
+ timestamps: bool = True
65
+
66
+ def __post_init__(self):
67
+ if self.decoding is None:
68
+ self.decoding = {
69
+ "streaming_policy": self.streaming_policy,
70
+ "alignatt_thr": self.alignatt_thr,
71
+ "waitk_lagging": self.waitk_lagging,
72
+ "exclude_sink_frames": self.exclude_sink_frames,
73
+ "xatt_scores_layer": self.xatt_scores_layer,
74
+ "max_tokens_per_alignatt_step": self.max_tokens_per_alignatt_step,
75
+ "max_generation_length": self.max_generation_length,
76
+ "use_avgpool_for_alignatt": self.use_avgpool_for_alignatt,
77
+ "hallucinations_detector": self.hallucinations_detector
78
+ }
79
+
80
+ if self.prompt is None:
81
+ self.prompt = {
82
+ "pnc": self.pnc,
83
+ "task": self.task,
84
+ "source_lang": self.source_lang,
85
+ "target_lang": self.target_lang,
86
+ "timestamps": self.timestamps
87
+ }
88
+
89
+ def toOmegaConf(self) -> OmegaConf:
90
+ """Convert the config to OmegaConf format"""
91
+ config_dict = {
92
+ "chunk_secs": self.chunk_secs,
93
+ "left_context_secs": self.left_context_secs,
94
+ "right_context_secs": self.right_context_secs,
95
+ "cuda": self.cuda,
96
+ "allow_mps": self.allow_mps,
97
+ "compute_dtype": self.compute_dtype,
98
+ "matmul_precision": self.matmul_precision,
99
+ "batch_size": self.batch_size,
100
+ "decoding": self.decoding,
101
+ "prompt": self.prompt
102
+ }
103
+
104
+ # Remove None values
105
+ filtered_dict = {k: v for k, v in config_dict.items() if v is not None}
106
+
107
+ return OmegaConf.create(filtered_dict)
108
+
109
+ @classmethod
110
+ def from_params(
111
+ cls,
112
+ task_type: str,
113
+ source_lang: str,
114
+ target_lang: str,
115
+ chunk_secs: float = 1.0,
116
+ left_context_secs: float = 20.0,
117
+ right_context_secs: float = 0.5,
118
+ streaming_policy: str = "alignatt",
119
+ alignatt_thr: float = 8.0,
120
+ waitk_lagging: int = 2,
121
+ exclude_sink_frames: int = 8,
122
+ xatt_scores_layer: int = -2,
123
+ hallucinations_detector: bool = True
124
+ ):
125
+ """Create a CanaryConfig instance from parameters"""
126
+ # Convert task type to model task
127
+ task = "asr" if task_type == "Transcription" else "ast"
128
+
129
+ return cls(
130
+ chunk_secs=chunk_secs,
131
+ left_context_secs=left_context_secs,
132
+ right_context_secs=right_context_secs,
133
+ streaming_policy=streaming_policy,
134
+ alignatt_thr=alignatt_thr,
135
+ waitk_lagging=waitk_lagging,
136
+ exclude_sink_frames=exclude_sink_frames,
137
+ xatt_scores_layer=xatt_scores_layer,
138
+ hallucinations_detector=hallucinations_detector,
139
+ task=task,
140
+ source_lang=source_lang,
141
+ target_lang=target_lang
142
+ )
143
+
144
  def make_divisible_by(num: int, factor: int) -> int:
145
  """Make num divisible by factor"""
146
  return (num // factor) * factor
 
151
  Encapsulates the state and logic for streaming audio transcription
152
  using an internally loaded Canary model.
153
  """
154
+ def __init__(self,asr_model, cfg: CanaryConfig):
155
  """
156
  Initializes the speech engine and loads the ASR model.
157
 
158
  Args:
159
  cfg: An OmegaConf object containing 'model' and 'streaming' configs.
160
  """
161
+ self.cfg = cfg.toOmegaConf() # Store the full config
162
 
163
  # Setup device and dtype from config
164
+ self.map_location = get_inference_device(cuda=None, allow_mps=self.cfg.allow_mps)
165
+ self.compute_dtype = get_inference_dtype(None, device=self.map_location)
166
  logging.info(f"Inference will be on device: {self.map_location} with dtype: {self.compute_dtype}")
167
 
168
  # Load the model internally
 
180
 
181
  def _setup_model(self,asr_model, model_cfg: OmegaConf, map_location: str):
182
  """Loads the pretrained ASR model and configures it for inference."""
183
+ logging.info(f"Loading model ...")
184
  start_time = time.time()
185
  try:
186
  asr_model = asr_model.to(map_location)