"""Full duplex streaming mode for MiniCPM-o 4.5 MLX. Captures screen video + system audio, processes through the model in real-time, and outputs text analysis with optional TTS playback. Architecture: [Screen 1fps] + [Audio 16kHz] -> ChunkSynchronizer -> DuplexGenerator -> TTSPlayback """ import queue import threading import time from typing import Optional import mlx.core as mx import numpy as np class ScreenCapture: """Capture screen region at 1fps using mss. Produces (H, W, C) float32 frames resized to 448x448. """ def __init__( self, out_queue: queue.Queue, region: Optional[tuple] = None, fps: float = 1.0, target_size: int = 448, ): self.out_queue = out_queue self.region = region # (x, y, w, h) or None for primary monitor self.fps = fps self.target_size = target_size self._stop = threading.Event() self._thread: Optional[threading.Thread] = None def start(self): self._stop.clear() self._thread = threading.Thread(target=self._run, daemon=True) self._thread.start() def stop(self): self._stop.set() if self._thread: self._thread.join(timeout=2) def _run(self): import mss from PIL import Image with mss.mss() as sct: if self.region: x, y, w, h = self.region monitor = {"left": x, "top": y, "width": w, "height": h} else: monitor = sct.monitors[1] # Primary monitor while not self._stop.is_set(): t0 = time.time() screenshot = sct.grab(monitor) # Convert to PIL Image, resize, convert to float32 img = Image.frombytes("RGB", screenshot.size, screenshot.rgb) img = img.resize( (self.target_size, self.target_size), Image.BILINEAR ) frame = np.array(img, dtype=np.float32) / 255.0 # (H, W, 3) try: self.out_queue.put_nowait( {"type": "video", "frame": frame, "time": time.time()} ) except queue.Full: pass # Drop frame if queue full elapsed = time.time() - t0 sleep_time = max(0, (1.0 / self.fps) - elapsed) if sleep_time > 0: self._stop.wait(sleep_time) class AudioCapture: """Capture system audio at 16kHz using sounddevice. Uses BlackHole virtual audio device for system audio loopback on macOS. Produces 1-second mono float32 audio chunks. """ def __init__( self, out_queue: queue.Queue, device: Optional[str] = None, sample_rate: int = 16000, chunk_seconds: float = 1.0, ): self.out_queue = out_queue self.device = device # Device name or index self.sample_rate = sample_rate self.chunk_seconds = chunk_seconds self.chunk_samples = int(sample_rate * chunk_seconds) self._stop = threading.Event() self._thread: Optional[threading.Thread] = None def start(self): self._stop.clear() self._thread = threading.Thread(target=self._run, daemon=True) self._thread.start() def stop(self): self._stop.set() if self._thread: self._thread.join(timeout=2) def _find_device(self): """Find audio device by name.""" import sounddevice as sd if self.device is None: return None # Use default if isinstance(self.device, int): return self.device devices = sd.query_devices() for i, d in enumerate(devices): if self.device.lower() in d["name"].lower() and d["max_input_channels"] > 0: return i print(f"Warning: Audio device '{self.device}' not found, using default.") return None def _run(self): import sounddevice as sd device_id = self._find_device() buffer = np.array([], dtype=np.float32) def callback(indata, frames, time_info, status): nonlocal buffer if status: pass # Ignore overflow/underflow mono = indata.mean(axis=1) if indata.ndim > 1 else indata.flatten() buffer = np.concatenate([buffer, mono]) try: with sd.InputStream( device=device_id, channels=1, samplerate=self.sample_rate, blocksize=1024, callback=callback, ): while not self._stop.is_set(): if len(buffer) >= self.chunk_samples: chunk = buffer[: self.chunk_samples].copy() buffer = buffer[self.chunk_samples :] try: self.out_queue.put_nowait( { "type": "audio", "data": chunk, "time": time.time(), } ) except queue.Full: pass else: self._stop.wait(0.05) except Exception as e: print(f"Audio capture error: {e}") class ChunkSynchronizer: """Synchronize video frames and audio into 1-second chunks. Pairs the latest video frame with each 1-second audio chunk. Runs mel processing on the audio. """ def __init__( self, raw_queue: queue.Queue, sync_queue: queue.Queue, mel_processor, ): self.raw_queue = raw_queue self.sync_queue = sync_queue self.mel_processor = mel_processor self._stop = threading.Event() self._thread: Optional[threading.Thread] = None self._latest_frame: Optional[np.ndarray] = None def start(self): self._stop.clear() self._thread = threading.Thread(target=self._run, daemon=True) self._thread.start() def stop(self): self._stop.set() if self._thread: self._thread.join(timeout=2) def _run(self): while not self._stop.is_set(): try: item = self.raw_queue.get(timeout=0.1) except queue.Empty: continue if item["type"] == "video": self._latest_frame = item["frame"] elif item["type"] == "audio": self.mel_processor.add_audio(item["data"]) mel_chunk = self.mel_processor.get_mel_chunk() if mel_chunk is not None: try: self.sync_queue.put_nowait( { "video_frame": self._latest_frame, "mel_chunk": mel_chunk, "time": item["time"], } ) except queue.Full: pass # Drop if consumer is slow class DuplexGenerator: """Main processing loop for full duplex streaming. Dequeues synchronized chunks, runs model inference, generates text responses, and optionally queues TTS audio for playback. """ def __init__( self, model, processor, sync_queue: queue.Queue, tts_queue: Optional[queue.Queue] = None, temperature: float = 0.0, max_tokens_per_chunk: int = 50, enable_tts: bool = False, ): self.model = model self.processor = processor self.sync_queue = sync_queue self.tts_queue = tts_queue self.temperature = temperature self.max_tokens = max_tokens_per_chunk self.enable_tts = enable_tts self._stop = threading.Event() self._thread: Optional[threading.Thread] = None self.ctx = None self.chunk_count = 0 self.on_text = None # callback(text: str) self.on_status = None # callback(status: dict) def start(self): self._stop.clear() self._thread = threading.Thread(target=self._run, daemon=True) self._thread.start() def stop(self): self._stop.set() if self._thread: self._thread.join(timeout=5) def _build_chunk_prompt(self, has_video: bool, has_audio: bool): """Build prompt tokens for one streaming chunk. Returns: dict with input_ids, image_bound, audio_bound """ tokenizer = self.processor.tokenizer parts = [] parts.append("<|im_start|>user\n") image_bound = [] audio_bound = [] # Video placeholder if has_video: # 64 query tokens for resampled image n_img_tokens = self.model.config.query_num # 64 img_placeholder = "" + "" * n_img_tokens + "" parts.append(img_placeholder) # Audio placeholder if has_audio: # Approximate audio tokens: ~10 after pooling for 1 second n_audio_tokens = 10 audio_placeholder = ( "<|audio_start|>" + "" * n_audio_tokens + "<|audio_end|>" ) parts.append(audio_placeholder) parts.append("\nDescribe what you see and hear.<|im_end|>\n") parts.append("<|im_start|>assistant\n") text = "".join(parts) tokenized = tokenizer(text, return_tensors="np") input_ids = mx.array(tokenized["input_ids"]) # Find image_bound and audio_bound positions ids_list = tokenized["input_ids"][0].tolist() unk_id = tokenizer.convert_tokens_to_ids("") if has_video: img_start_id = tokenizer.convert_tokens_to_ids("") img_end_id = tokenizer.convert_tokens_to_ids("") in_img = False start_idx = None for i, tok in enumerate(ids_list): if tok == img_start_id: in_img = True start_idx = i + 1 elif tok == img_end_id and in_img: image_bound.append((start_idx, i)) in_img = False if has_audio: audio_start_id = tokenizer.convert_tokens_to_ids("<|audio_start|>") audio_end_id = tokenizer.convert_tokens_to_ids("<|audio_end|>") in_audio = False start_idx = None for i, tok in enumerate(ids_list): if tok == audio_start_id: in_audio = True start_idx = i + 1 elif tok == audio_end_id and in_audio: audio_bound.append((start_idx, i)) in_audio = False return { "input_ids": input_ids, "image_bound": image_bound if image_bound else None, "audio_bound": audio_bound if audio_bound else None, } def _prepare_video_frame(self, frame: np.ndarray): """Prepare a video frame for model input. Args: frame: (H, W, 3) float32 frame Returns: (pixel_values, tgt_sizes, patch_attention_mask) """ # Frame is already (448, 448, 3) float32 # Add batch dimension: (1, H, W, 3) pv = mx.array(frame[np.newaxis, ...]) # Compute patch sizes h_patches = frame.shape[0] // 14 # 32 w_patches = frame.shape[1] // 14 # 32 tgt_sizes = mx.array([[h_patches, w_patches]], dtype=mx.int32) total_patches = h_patches * w_patches patch_attention_mask = mx.ones((1, total_patches), dtype=mx.bool_) return pv, tgt_sizes, patch_attention_mask def _run(self): # Initialize streaming context self.ctx = self.model.init_streaming() self.chunk_count = 0 while not self._stop.is_set(): try: chunk = self.sync_queue.get(timeout=0.5) except queue.Empty: continue t0 = time.time() self.chunk_count += 1 video_frame = chunk.get("video_frame") mel_chunk = chunk.get("mel_chunk") has_video = video_frame is not None has_audio = mel_chunk is not None if not has_video and not has_audio: continue # Build prompt for this chunk prompt = self._build_chunk_prompt(has_video, has_audio) # Prepare video pixel_values = None tgt_sizes = None patch_attention_mask = None if has_video: pixel_values, tgt_sizes, patch_attention_mask = ( self._prepare_video_frame(video_frame) ) # Process chunk through model logits = self.model.process_streaming_chunk( ctx=self.ctx, video_frame=pixel_values, audio_chunk=mel_chunk, prompt_tokens=prompt["input_ids"], image_bound=prompt["image_bound"], audio_bound=prompt["audio_bound"], tgt_sizes=tgt_sizes, patch_attention_mask=patch_attention_mask, ) # Generate text response tokens = self.model.streaming_generate( ctx=self.ctx, logits=logits, tokenizer=self.processor.tokenizer, max_tokens=self.max_tokens, temperature=self.temperature, ) elapsed = time.time() - t0 if tokens: text = self.processor.tokenizer.decode( tokens, skip_special_tokens=True ) if self.on_text and text.strip(): self.on_text(text.strip()) # TTS if enabled if self.enable_tts and self.tts_queue and tokens: self.tts_queue.put_nowait( {"tokens": tokens, "text": text} ) if self.on_status: self.on_status( { "chunk": self.chunk_count, "mode": self.ctx.mode, "cache_tokens": self.ctx.total_tokens, "latency_ms": int(elapsed * 1000), "mem_gb": mx.get_peak_memory() / 1e9, } ) class TTSPlayback: """Dequeue TTS tokens, convert to audio, and play back. Uses Token2wav vocoder for audio synthesis and sounddevice for playback. """ def __init__(self, tts_queue: queue.Queue, sample_rate: int = 24000): self.tts_queue = tts_queue self.sample_rate = sample_rate self._stop = threading.Event() self._thread: Optional[threading.Thread] = None self._vocoder = None def start(self): self._stop.clear() self._thread = threading.Thread(target=self._run, daemon=True) self._thread.start() def stop(self): self._stop.set() if self._thread: self._thread.join(timeout=2) def _run(self): import sounddevice as sd # Try loading vocoder try: from stepaudio2 import Token2wav self._vocoder = Token2wav() except ImportError: print("TTSPlayback: Token2wav not available, TTS disabled.") return while not self._stop.is_set(): try: item = self.tts_queue.get(timeout=0.5) except queue.Empty: continue tokens = item.get("tokens", []) if not tokens: continue try: import io import soundfile as sf wav_bytes = self._vocoder(tokens, None) waveform, sr = sf.read(io.BytesIO(wav_bytes)) sd.play(waveform, sr, blocking=False) except Exception as e: print(f"TTS playback error: {e}") def run_live_mode(model, processor, args): """Run full duplex streaming mode. Args: model: loaded MiniCPM-o model processor: tokenizer/processor args: argparse namespace with capture_region, audio_device, tts options """ from mlx_vlm.models.minicpmo.audio import StreamingMelProcessor print("Starting live streaming mode...") print("Press Ctrl+C to stop.\n") # Create queues raw_queue = queue.Queue(maxsize=30) sync_queue = queue.Queue(maxsize=10) tts_queue = queue.Queue(maxsize=10) if args.tts else None # Create mel processor mel_processor = StreamingMelProcessor(sample_rate=16000) # Parse capture region region = None if hasattr(args, "capture_region") and args.capture_region: parts = args.capture_region.split(",") if len(parts) == 4: region = tuple(int(p) for p in parts) # Create threads screen = ScreenCapture(raw_queue, region=region, fps=1.0) audio_dev = getattr(args, "audio_device", "BlackHole") audio = AudioCapture(raw_queue, device=audio_dev, sample_rate=16000) sync = ChunkSynchronizer(raw_queue, sync_queue, mel_processor) generator = DuplexGenerator( model, processor, sync_queue, tts_queue=tts_queue, temperature=getattr(args, "temp", 0.0), max_tokens_per_chunk=getattr(args, "max_tokens", 50), enable_tts=getattr(args, "tts", False), ) tts_playback = None if tts_queue: tts_playback = TTSPlayback(tts_queue) # Set up callbacks def on_text(text): print(f"[{generator.chunk_count}] {text}") def on_status(status): print( f" >> chunk={status['chunk']} mode={status['mode']} " f"cache={status['cache_tokens']}tok " f"latency={status['latency_ms']}ms " f"mem={status['mem_gb']:.1f}GB", flush=True, ) generator.on_text = on_text generator.on_status = on_status # Start all threads screen.start() audio.start() sync.start() generator.start() if tts_playback: tts_playback.start() print("Live mode active. Capturing screen + audio...\n") try: while True: time.sleep(0.5) except KeyboardInterrupt: print("\nStopping live mode...") finally: screen.stop() audio.stop() sync.stop() generator.stop() if tts_playback: tts_playback.stop() print("Live mode stopped.")