From 988adcbdd9a95b132fcba8b0069c8bec7dc87956 Mon Sep 17 00:00:00 2001 From: future3000 <46719676+future3OOO@users.noreply.github.com> Date: Sat, 2 May 2026 14:38:15 +1200 Subject: [PATCH] fix(engine): Stabilize dictation capture and release paste Prevent mouse-hold dictation from transcribing overlapping VAD batches before release, and harden Windows microphone fallback so default devices are portable while explicit devices fail closed. Co-authored-by: Cursor --- README.md | 4 +- dictation_tool/engine.py | 94 ++++++++----- dictation_tool/io.py | 250 +++++++++++++++++++++++++++------- tests/test_engine_advanced.py | 53 ++++++- tests/test_io.py | 87 +++++++++--- 5 files changed, 382 insertions(+), 106 deletions(-) diff --git a/README.md b/README.md index 0b72a47..4169c0c 100644 --- a/README.md +++ b/README.md @@ -73,8 +73,8 @@ Latency β‰ˆ 200-500 ms on an RTX 3080. Designed for 20-30 s dictation bursts. #### πŸ’¨ Option B β€” Maximum speed (medium.en + prompt tricks) -medium.en delivers β‰ˆ 5-20 ms interface latency while staying surprisingly -accurate when paired with a good prompt and a larger beam. +medium.en delivers β‰ˆ 5-200 ms interface latency while staying surprisingly +accurate when paired with a preset and a larger beam. ## πŸ“§ Fast e-mail workflow β€” preset **email** diff --git a/dictation_tool/engine.py b/dictation_tool/engine.py index 189cfd4..fe151a6 100644 --- a/dictation_tool/engine.py +++ b/dictation_tool/engine.py @@ -13,6 +13,7 @@ β€’ Extended punctuation map (β€œat sign” β†’ @) β€’ JSONL profiler, adaptive batching, back-pressure, dynamic ring growth """ + from __future__ import annotations import asyncio @@ -44,8 +45,8 @@ if sys.platform == "win32": try: import win32clipboard as _wc # type: ignore - import win32con as _wcon # type: ignore - except ImportError: # pywin32 not installed + import win32con as _wcon # type: ignore + except ImportError: # pywin32 not installed _wc = _wcon = None else: _wc = _wcon = None @@ -82,6 +83,7 @@ def _paste_retry() -> None: time.sleep(0.02) LOGGER.warning("Auto-paste ultimately failed") + # ══════════════════════════════════ Audio ring buffer ═════════════════════════ class _Ring: """Lock-free power-of-two ring for int16 audio.""" @@ -103,7 +105,7 @@ def push(self, chunk: np.ndarray) -> None: if not n: return cap = self._view.shape[0] - if n >= cap: # keep only the last samples + if n >= cap: # keep only the last samples self._view[:] = chunk[-cap:] self._head = self._tail = 0 self._full = True @@ -134,7 +136,12 @@ def pop(self) -> np.ndarray: @property def size(self) -> int: - return self._view.shape[0] if self._full else (self._head - self._tail) & self._mask + return ( + self._view.shape[0] + if self._full + else (self._head - self._tail) & self._mask + ) + # ══════════════════════════════════ Punctuation map ═══════════════════════════ class _Punct: @@ -180,6 +187,7 @@ def __call__(self, text: str) -> str: out = self._spc_after.sub(r"\1", out) return out.strip() + # ═════════════════════════ Thread-safe context ════════════════════════════════ class _Context: """Rolling prompt history (mutex-guarded).""" @@ -198,6 +206,7 @@ def prompt(self) -> str: with self._lock: return " ".join(self._buf) + ". " if self._buf else "" + # ═══════════════════ Adaptive batch controller ════════════════════════════════ class _BatchCtl: def __init__(self, cfg: Config) -> None: @@ -215,10 +224,14 @@ def feed(self, samples: int) -> None: avg = self._seen // self._chunks self.min_samples = max(8_000, min(int(avg * 0.8) // 16 * 16, 24_000)) self.max_chunks = max(4, min(ceil(self.min_samples / avg), 10)) - LOGGER.info("Adaptive batch tuned β†’ %d samples | %d chunks", - self.min_samples, self.max_chunks) + LOGGER.info( + "Adaptive batch tuned β†’ %d samples | %d chunks", + self.min_samples, + self.max_chunks, + ) self._lock = True + # ═════════════════════ Clipboard wrapper class ════════════════════════════════ class _Clipboard: """Thread-pool clipboard copy; never blocks event loop.""" @@ -245,11 +258,16 @@ async def copy(self, text: str) -> None: await loop.run_in_executor(self._pool, pyperclip.copy, text) self._last = (text, time.time()) + # ───────────────────────── command cleanup ────────────────────────── _CMD_SUBS: tuple[tuple[re.Pattern, str], ...] = ( # single line break – eat optional punctuation / spaces after the cue - (re.compile(r"\b(?:new\s+line|line\s*break|newline)\b[ \t]*[.,!?;:]?[ \t]*", - re.I), "\n"), + ( + re.compile( + r"\b(?:new\s+line|line\s*break|newline)\b[ \t]*[.,!?;:]?[ \t]*", re.I + ), + "\n", + ), # blank line (paragraph) – same idea, but keep the double LF (re.compile(r"\bnew\s+paragraph\b[ \t]*[.,!?;:]?[ \t]*", re.I), "\n\n"), (re.compile(r"\bbullet\s+point\b", re.I), "\nβ€’ "), @@ -257,7 +275,9 @@ async def copy(self, text: str) -> None: (re.compile(r'[\u201C\u201D"`\uFFFD]+'), ""), ) -_SPACES_AROUND_DOT_AT = re.compile(r"[ \t\u00A0\u1680\u2000-\u200A\u202F\u205F\u3000]*([@.])[ \t\u00A0\u1680\u2000-\u200A\u202F\u205F\u3000]*") +_SPACES_AROUND_DOT_AT = re.compile( + r"[ \t\u00A0\u1680\u2000-\u200A\u202F\u205F\u3000]*([@.])[ \t\u00A0\u1680\u2000-\u200A\u202F\u205F\u3000]*" +) _SIGNOFFS = ("kind regards", "best regards", "regards", "cheers") SIGNOFF_PAT = re.compile( @@ -267,15 +287,17 @@ async def copy(self, text: str) -> None: # Greeting lines that should end with a comma before a blank line _GREETING_BREAK = re.compile( - r'(?i)(^|\n)(\s*(?:hi|hello|hey|kia(?:\s+ora)?|dear)\b[^\n]*?)\s*\n\n' + r"(?i)(^|\n)(\s*(?:hi|hello|hey|kia(?:\s+ora)?|dear)\b[^\n]*?)\s*\n\n" ) + # ══════════════════════════ Dictation Engine ══════════════════════════════════ class DictationEngine: """Microphone β†’ (VAD) β†’ Whisper β†’ clipboard.""" - _EMAIL_RE = re.compile(r"\b([\w.-]+)\s+at(?:\s+sign)?\s+([\w.-]+)\s+dot\s+com\b", - re.I) + _EMAIL_RE = re.compile( + r"\b([\w.-]+)\s+at(?:\s+sign)?\s+([\w.-]+)\s+dot\s+com\b", re.I + ) _URL_DOT = re.compile(r"\b([a-zA-Z0-9_-]+)\s+\.\s+([a-zA-Z0-9_-]+)") # ───────────────────────── init ────────────────────────── @@ -422,43 +444,43 @@ async def _transcribe(self, audio: np.ndarray) -> str: txt = self._EMAIL_RE.sub(r"\1@\2.com", self._URL_DOT.sub(r"\1.\2", txt)) for pat, rep in _CMD_SUBS: txt = pat.sub(rep, txt) - + # ── NEW: collapse duplicate commas (word "comma" + real comma) ───────── - txt = re.sub(r',\s*,+', ',', txt) - + txt = re.sub(r",\s*,+", ",", txt) + # ── NEW: if a comma sneaks in *before* the paragraph break, make it a '.' ─ - txt = re.sub(r',\s*\n\n', '.\n\n', txt) - + txt = re.sub(r",\s*\n\n", ".\n\n", txt) + # ── turn greeting + blank line into "Greeting," txt = _GREETING_BREAK.sub( lambda m: f"{m.group(1)}{m.group(2).rstrip(' ,.!?;:')},\n\n", txt, ) - + # ── NEW: add full stop before paragraph break when *no* punctuation spoken ─ - txt = re.sub(r'([^\s.,!?;:])\s*\n\n', r'\1.\n\n', txt) - + txt = re.sub(r"([^\s.,!?;:])\s*\n\n", r"\1.\n\n", txt) + # final space trim around @ and . - txt = _SPACES_AROUND_DOT_AT.sub(r'\1', txt) - + txt = _SPACES_AROUND_DOT_AT.sub(r"\1", txt) + # safety-pass: remove spaces or tabs (NOT new-lines) that may survive - txt = re.sub(r'@[ \t]+', '@', txt) # john @ gmail β†’ john@gmail - txt = re.sub(r'\.[ \t]+', '.', txt) # gmail . com β†’ gmail.com - + txt = re.sub(r"@[ \t]+", "@", txt) # john @ gmail β†’ john@gmail + txt = re.sub(r"\.[ \t]+", ".", txt) # gmail . com β†’ gmail.com + # ------------------------------------------------------------------------- # 8. sentence-/signature-polish (run *after* all previous tweaks) # ------------------------------------------------------------------------- - + # 8-a normalise common e-mail sign-offs txt = SIGNOFF_PAT.sub(lambda m: f"{m.group(1)}{m.group(2).title()},\n", txt) - + # 8-b capitalise first alphabetical char of every logical line txt = re.sub( r"(^|\n)([β€’ \t]*)([a-z])", lambda m: m.group(1) + m.group(2) + m.group(3).upper(), txt, ) - + return txt # ──────────────────── shadow helper ────────────────────── @@ -468,9 +490,8 @@ def _add_to_shadow(self, chunk: np.ndarray) -> None: # ──────────────────── trigger install ───────────────────── def _install_triggers(self) -> None: def ok() -> bool: - return ( - not self.cfg.dual_trigger_required - or (self._mouse_pressed and is_pressed(self.cfg.hotkey)) + return not self.cfg.dual_trigger_required or ( + self._mouse_pressed and is_pressed(self.cfg.hotkey) ) def toggle(src: str) -> None: @@ -498,7 +519,7 @@ def click(_x: int, _y: int, button: mouse.Button, down: bool) -> None: if down: self._mouse_pressed = True if self.cfg.mouse_hold_to_record: - self._raw_shadow.clear() # start fresh + self._raw_shadow.clear() # start fresh self._mouse_press = time.time() self._holding = False self._hold_timer = threading.Timer( @@ -547,7 +568,9 @@ async def _flush_hold(self) -> str: if self._raw_shadow: segs.append(concatenate(self._raw_shadow)) self._raw_shadow.clear() - if self._vad_gate: + if self._vad_gate: + self._vad_gate.force_flush() + elif self._vad_gate: tail = self._vad_gate.force_flush() if tail is not None and tail.size: segs.append(tail) @@ -590,10 +613,13 @@ async def _run(self) -> None: continue if self._clip_q.qsize() > 8: - await asyncio.sleep(0.02) # back-pressure clipboard + await asyncio.sleep(0.02) # back-pressure clipboard if self.cfg.use_vad: if chunk.size: + if self.cfg.mouse_hold_to_record and self._holding: + continue + batch.append(chunk) samples += chunk.size self._batch_ctl.feed(chunk.size) diff --git a/dictation_tool/io.py b/dictation_tool/io.py index 48bc3b2..89292c4 100644 --- a/dictation_tool/io.py +++ b/dictation_tool/io.py @@ -1,12 +1,12 @@ from __future__ import annotations import asyncio +import platform import queue import threading from collections import deque -from collections.abc import AsyncIterator, Iterable +from collections.abc import AsyncIterator, Callable, Iterable from enum import Enum -from typing import Callable import numpy as np import sounddevice as sd @@ -17,20 +17,22 @@ __all__ = [ "AudioStream", "VADGate", - "VADState", + "VADState", "concatenate", ] class VADState(Enum): """Voice Activity Detection state machine.""" + SILENCE = "silence" SPEECH_DETECTED = "speech_detected" SPEECH_ENDED = "speech_ended" - + class VADGate: """Advanced Voice Activity Detector with pre-buffering and state machine.""" + # ... (no changes to this class) ... def __init__( self, @@ -46,26 +48,26 @@ def __init__( self.sr = sample_rate self.frame_duration_ms = frame_duration_ms self.bytes_per_frame = self.sr * frame_duration_ms // 1000 * 2 # int16 - + # Pre-buffer configuration self.pre_buffer_chunks = pre_buffer_chunks self.post_buffer_chunks = post_buffer_chunks self.consecutive_speech_frames = consecutive_speech_frames self.consecutive_silence_frames = consecutive_silence_frames - + # State tracking self.state = VADState.SILENCE self.speech_frame_count = 0 self.silence_frame_count = 0 - + # Buffers - using deque for O(1) operations self.pre_buffer: deque[bytes] = deque(maxlen=pre_buffer_chunks) self.speech_buffer: list[bytes] = [] self.post_buffer: deque[bytes] = deque(maxlen=post_buffer_chunks) - + # Performance tracking self._frames_processed = 0 - + def reset(self) -> None: """Reset VAD state and buffers.""" self.state = VADState.SILENCE @@ -74,7 +76,7 @@ def reset(self) -> None: self.pre_buffer.clear() self.speech_buffer.clear() self.post_buffer.clear() - + def get_statistics(self) -> dict[str, int | str]: """Get processing statistics for monitoring.""" return { @@ -87,51 +89,51 @@ def get_statistics(self) -> dict[str, int | str]: def __call__(self, chunk: np.ndarray) -> Iterable[np.ndarray]: """Process audio chunk and yield voiced segments with optimal buffering. - + Returns segments only when speech has definitively ended, ensuring complete utterances are captured. """ pcm = chunk.tobytes() results: list[np.ndarray] = [] - + # Process each frame in the chunk for i in range(0, len(pcm), self.bytes_per_frame): frame = pcm[i : i + self.bytes_per_frame] if len(frame) < self.bytes_per_frame: continue - + self._frames_processed += 1 - + # VAD detection try: is_speech = self.vad.is_speech(frame, self.sr) except Exception: # Handle invalid frame gracefully is_speech = False - + # State machine processing segment = self._process_frame(frame, is_speech) if segment is not None and len(segment) > 0: results.append(segment) - + return results - + def _process_frame(self, frame: bytes, is_speech: bool) -> np.ndarray | None: """Process a single frame through the VAD state machine.""" if self.state == VADState.SILENCE: if is_speech: self.speech_frame_count += 1 self.silence_frame_count = 0 - + if self.speech_frame_count >= self.consecutive_speech_frames: # Speech detected! Transition to speech state self.state = VADState.SPEECH_DETECTED - + # Add pre-buffer and current frame to speech buffer self.speech_buffer.extend(self.pre_buffer) self.speech_buffer.append(frame) - - # Clear pre-buffer + + # Clear pre-buffer self.pre_buffer.clear() else: # Not enough consecutive speech frames yet @@ -140,40 +142,40 @@ def _process_frame(self, frame: bytes, is_speech: bool) -> np.ndarray | None: # Continue in silence self.speech_frame_count = 0 self.pre_buffer.append(frame) - + elif self.state == VADState.SPEECH_DETECTED: if is_speech: # Continue speech self.silence_frame_count = 0 self.speech_buffer.append(frame) - - # Clear any post-buffer + + # Clear any post-buffer self.post_buffer.clear() else: # Potential end of speech self.silence_frame_count += 1 self.post_buffer.append(frame) - + if self.silence_frame_count >= self.consecutive_silence_frames: # Speech definitely ended - create segment immediately complete_segment = self._create_segment() self._reset_for_next_utterance() return complete_segment - + return None - + def _create_segment(self) -> np.ndarray: """Create a complete speech segment from buffers.""" if not self.speech_buffer: return np.array([], dtype=np.int16) - + # Combine speech buffer with post-buffer all_frames = self.speech_buffer + list(self.post_buffer) - + # Convert to numpy arrays and concatenate frame_arrays = [np.frombuffer(frame, dtype=np.int16) for frame in all_frames] return np.concatenate(frame_arrays, dtype=np.int16) - + def _reset_for_next_utterance(self) -> None: """Reset state for detecting the next speech utterance.""" self.state = VADState.SILENCE @@ -181,7 +183,7 @@ def _reset_for_next_utterance(self) -> None: self.silence_frame_count = 0 self.speech_buffer.clear() self.post_buffer.clear() - + def force_flush(self) -> np.ndarray | None: """Force flush any pending speech buffer (e.g., on session end).""" if self.speech_buffer: @@ -190,21 +192,28 @@ def force_flush(self) -> np.ndarray | None: return segment return None + class AudioStream: - """Mic reader β†’ optional VAD β†’ async chunk generator.""" + """Mic reader -> optional VAD -> async chunk generator. + + Captures at the device's native sample rate and resamples to + ``sample_rate`` in the audio callback when the two differ. + Resampling uses numpy linear interpolation (~8 us per 10 ms chunk). + """ def __init__( self, sample_rate: int, chunk_ms: int = 10, vad_gate: VADGate | None = None, - input_device: str | None = None, + input_device: str | int | None = None, on_raw_chunk: Callable[[np.ndarray], None] | None = None, ) -> None: """ - `on_raw_chunk` receives every *raw* 10 ms frame (for shadow buffering). + `on_raw_chunk` receives every sample-rate-normalized frame for shadow buffering. """ self._sr = sample_rate + self._chunk_ms = chunk_ms self._frames = int(self._sr * chunk_ms / 1000) self._gate = vad_gate self._input_device = input_device @@ -212,23 +221,152 @@ def __init__( self._q: queue.Queue[np.ndarray] = queue.Queue(maxsize=64) self._stop = threading.Event() + # Set during _open_stream when native rate != target rate + self._native_sr: int = sample_rate + self._resample_idx: np.ndarray | None = None + self._native_arange: np.ndarray | None = None + async def __aenter__(self) -> AudioStream: - device_params = {} - if self._input_device is not None: - device_params['device'] = self._input_device - - self._stream = sd.InputStream( # type: ignore[attr-defined] - samplerate=self._sr, + self._stream = self._open_stream(self._input_device) + self._stream.start() + return self + + def _open_stream(self, device: str | int | None) -> sd.InputStream: + """Open an InputStream, falling back to native-rate + resample. + + Strategy: + 1. Try the requested device (or system default) at target rate. + 2. On Windows, also try every WASAPI input at target rate. + 3. If all fail, open the best candidate at its native rate + and resample each chunk via np.interp (~8 us / 10 ms). + """ + candidates = self._build_device_candidates(device) + + # --- Pass 1: try target sample rate directly --- + last_err: Exception | None = None + for dev in candidates: + try: + stream = self._try_open(dev, self._sr) + self._native_sr = self._sr + self._resample_idx = None + self._log_mic(dev, self._sr, resample=False) + return stream + except sd.PortAudioError as exc: + last_err = exc + LOGGER.debug("Device %s @ %d Hz failed: %s", dev, self._sr, exc) + + # --- Pass 2: open at native rate, resample in callback --- + for dev in candidates: + native_sr = self._device_native_sr(dev) + if native_sr is None or native_sr == self._sr: + continue + try: + native_frames = int(native_sr * self._chunk_ms / 1000) + stream = self._try_open(dev, native_sr, blocksize=native_frames) + self._setup_resampler(native_sr, native_frames) + self._log_mic(dev, native_sr, resample=True) + return stream + except sd.PortAudioError as exc: + last_err = exc + LOGGER.debug( + "Device %s @ %d Hz (native) failed: %s", dev, native_sr, exc + ) + + raise sd.PortAudioError( + f"No usable input device found (last error: {last_err})" + ) + + def _try_open( + self, + device: str | int | None, + sr: int, + blocksize: int | None = None, + ) -> sd.InputStream: + params: dict = {} + if device is not None: + params["device"] = device + return sd.InputStream( + samplerate=sr, channels=1, dtype="int16", - blocksize=self._frames, + blocksize=blocksize or self._frames, callback=self._callback, - **device_params + **params, ) - self._stream.start() - device_info = f" (device: {self._input_device})" if self._input_device else "" - LOGGER.info("πŸŽ™οΈ Mic @ %d Hz%s - press hot-key to toggle", self._sr, device_info) - return self + + def _build_device_candidates( + self, device: str | int | None + ) -> list[int | str | None]: + if device is not None: + return [device] + + candidates: list[int | str | None] = [] + candidates.append(None) + + if platform.system() == "Windows": + try: + input_devices = self._all_input_devices() + except sd.PortAudioError as exc: + LOGGER.debug("Unable to enumerate input devices: %s", exc) + input_devices = [] + for idx in input_devices: + if idx not in candidates: + candidates.append(idx) + return candidates + + def _setup_resampler(self, native_sr: int, native_frames: int) -> None: + target_frames = int(self._sr * self._chunk_ms / 1000) + self._native_sr = native_sr + self._resample_idx = np.linspace(0, native_frames - 1, target_frames) + self._native_arange = np.arange(native_frames, dtype=np.float64) + + def _log_mic(self, dev: str | int | None, sr: int, *, resample: bool) -> None: + label = dev if dev is not None else "default" + if resample: + LOGGER.info( + "Mic @ %d Hz (device: %s) -> resample to %d Hz - press hot-key to toggle", + sr, + label, + self._sr, + ) + else: + LOGGER.info( + "Mic @ %d Hz (device: %s) - press hot-key to toggle", + sr, + label, + ) + + @staticmethod + def _device_native_sr(device: str | int | None) -> int | None: + try: + if device is not None: + info = sd.query_devices(device, "input") + else: + info = sd.query_devices(kind="input") + return int(info["default_samplerate"]) + except Exception: + return None + + @staticmethod + def _all_input_devices() -> list[int]: + """Return all input device indices, WASAPI first, then others.""" + hostapis = sd.query_hostapis() + api_priority = {"WASAPI": 0, "DirectSound": 1, "MME": 2} + devices = sd.query_devices() + inputs = [ + (i, d) + for i, d in enumerate(devices) # type: ignore[arg-type] + if d["max_input_channels"] > 0 + ] + + def sort_key(pair: tuple) -> tuple: + _, d = pair + api_name = hostapis[d["hostapi"]]["name"] + pri = next((v for k, v in api_priority.items() if k in api_name), 10) + return (pri, d.get("index", 0)) + + inputs.sort(key=sort_key) + return [i for i, _ in inputs] async def __aexit__(self, exc_type, exc, tb) -> None: self._stop.set() @@ -238,12 +376,22 @@ async def __aexit__(self, exc_type, exc, tb) -> None: # ── internals ──────────────────────────────────────────────────────── def _callback(self, indata: np.ndarray, _frames: int, *_) -> None: try: + chunk = indata.copy() + if self._resample_idx is not None and self._native_arange is not None: + chunk = ( + np.interp( + self._resample_idx, + self._native_arange, + chunk[:, 0].astype(np.float64), + ) + .astype(np.int16) + .reshape(-1, 1) + ) if self._on_raw_chunk: - # Keep a copy of *every* frame before it goes to VAD - self._on_raw_chunk(indata.copy()) - self._q.put_nowait(indata.copy()) + self._on_raw_chunk(chunk.copy()) + self._q.put_nowait(chunk) except queue.Full: - pass # reader is behind, drop a frame + pass async def chunks(self) -> AsyncIterator[np.ndarray]: loop = asyncio.get_running_loop() @@ -258,4 +406,4 @@ async def chunks(self) -> AsyncIterator[np.ndarray]: def concatenate(chunks: Iterable[np.ndarray]) -> np.ndarray: with timed("concat"): - return np.concatenate(chunks, dtype=np.int16) \ No newline at end of file + return np.concatenate(chunks, dtype=np.int16) diff --git a/tests/test_engine_advanced.py b/tests/test_engine_advanced.py index c266aec..b326d8d 100644 --- a/tests/test_engine_advanced.py +++ b/tests/test_engine_advanced.py @@ -70,7 +70,9 @@ async def test_load_model_flash(self, mock_whisper_cls, mock_torch_compile): await eng._load_model() - mock_whisper_cls.assert_called_once_with("large-v3", device="cuda", compute_type="float16") + mock_whisper_cls.assert_called_once_with( + "large-v3", device="cuda", compute_type="float16" + ) mock_torch_compile.assert_called_once() # Check that torch.compile was called with the original model assert mock_torch_compile.call_args.args[0] is original_model @@ -99,7 +101,9 @@ def test_install_triggers(self, m_listener, m_hotkey): # ── adaptive transcription logic ──────────────────────────── @pytest.mark.asyncio async def test_adaptive_retry(self, mock_whisper_model): - cfg = Config(device="cpu", attention_backend="none", retry_temperatures=(0.0, 0.4)) + cfg = Config( + device="cpu", attention_backend="none", retry_temperatures=(0.0, 0.4) + ) eng = DictationEngine(cfg) eng._model = mock_whisper_model text = await eng._transcribe(np.zeros(16000, dtype=np.int16)) @@ -158,6 +162,51 @@ async def gen(): await eng._run() assert m_stream.call_args.kwargs["vad_gate"] is not None + @pytest.mark.asyncio + async def test_flush_hold_does_not_duplicate_vad_tail(self): + cfg = Config(device="cpu", attention_backend="none", use_vad=True) + eng = DictationEngine(cfg) + shadow = np.ones(1600, dtype=np.int16) + tail = np.ones(800, dtype=np.int16) + captured = {} + + async def transcribe(audio): + captured["samples"] = audio.size + return "hello" + + eng._raw_shadow.append(shadow) + eng._vad_gate = Mock(force_flush=Mock(return_value=tail)) + eng._transcribe = AsyncMock(side_effect=transcribe) + + assert await eng._flush_hold() == "hello" + assert captured["samples"] == shadow.size + eng._vad_gate.force_flush.assert_called_once() + + @patch("dictation_tool.engine.AudioStream") + @pytest.mark.asyncio + async def test_hold_mode_does_not_transcribe_vad_batches_before_release( + self, m_stream + ): + cfg = Config(device="cpu", attention_backend="none", use_vad=True) + eng = DictationEngine(cfg) + eng._holding = True + eng._recording.set() + eng._transcribe = AsyncMock(return_value="should not paste") + + stub = Mock() + stub.__aenter__ = AsyncMock(return_value=stub) + stub.__aexit__ = AsyncMock(return_value=None) + + async def gen(): + eng._terminate.set() + yield np.ones(1600, dtype=np.int16) + + stub.chunks.return_value = gen() + m_stream.return_value = stub + + await eng._run() + eng._transcribe.assert_not_called() + @patch("dictation_tool.engine.AudioStream") @pytest.mark.asyncio async def test_run_without_vad(self, m_stream): diff --git a/tests/test_io.py b/tests/test_io.py index b52a7e5..75ad0cf 100644 --- a/tests/test_io.py +++ b/tests/test_io.py @@ -16,29 +16,46 @@ class TestVADGate: def test_vadgate_initialization(self): """Test VADGate initialization with various parameters.""" - vad = VADGate(sample_rate=16000, aggressiveness=2, padding_ms=200) + vad = VADGate( + sample_rate=16000, + aggressiveness=2, + frame_duration_ms=30, + pre_buffer_chunks=10, + post_buffer_chunks=5, + ) assert vad.sr == 16000 - assert vad.frame_len == 30 + assert vad.frame_duration_ms == 30 assert vad.bytes_per_frame == 960 # 16000 * 30 / 1000 * 2 - assert vad.padding_frames == 6 # 200 / 30 - assert vad.ring.maxlen == 6 + assert vad.pre_buffer_chunks == 10 + assert vad.post_buffer_chunks == 5 + assert vad.pre_buffer.maxlen == 10 + assert vad.post_buffer.maxlen == 5 @given( sample_rate=st.integers(min_value=8000, max_value=48000), aggressiveness=st.integers(min_value=0, max_value=3), - padding_ms=st.integers(min_value=0, max_value=1000), + pre_buffer_chunks=st.integers(min_value=1, max_value=20), + post_buffer_chunks=st.integers(min_value=1, max_value=10), ) - def test_vadgate_parameter_validation(self, sample_rate, aggressiveness, padding_ms): + def test_vadgate_parameter_validation( + self, sample_rate, aggressiveness, pre_buffer_chunks, post_buffer_chunks + ): """Property-based test for VADGate parameter validation.""" - vad = VADGate(sample_rate=sample_rate, aggressiveness=aggressiveness, padding_ms=padding_ms) + vad = VADGate( + sample_rate=sample_rate, + aggressiveness=aggressiveness, + frame_duration_ms=30, + pre_buffer_chunks=pre_buffer_chunks, + post_buffer_chunks=post_buffer_chunks, + ) assert vad.sr == sample_rate - assert vad.frame_len == 30 # Fixed frame length + assert vad.frame_duration_ms == 30 expected_bytes = sample_rate * 30 // 1000 * 2 # int16 = 2 bytes assert vad.bytes_per_frame == expected_bytes - expected_frames = padding_ms // 30 - assert vad.padding_frames == expected_frames + assert vad.pre_buffer_chunks == pre_buffer_chunks + assert vad.post_buffer_chunks == post_buffer_chunks @patch("dictation_tool.io.webrtcvad.Vad") def test_vadgate_speech_detection(self, mock_vad_class): @@ -46,7 +63,7 @@ def test_vadgate_speech_detection(self, mock_vad_class): mock_vad = Mock() mock_vad_class.return_value = mock_vad - vad_gate = VADGate(sample_rate=16000, aggressiveness=2, padding_ms=200) + vad_gate = VADGate(sample_rate=16000, aggressiveness=2) # Test speech detection mock_vad.is_speech.return_value = True @@ -64,7 +81,7 @@ def test_vadgate_no_speech_buffering(self, mock_vad_class): mock_vad = Mock() mock_vad_class.return_value = mock_vad - vad_gate = VADGate(sample_rate=16000, aggressiveness=2, padding_ms=200) + vad_gate = VADGate(sample_rate=16000, aggressiveness=2) # Test no speech detection mock_vad.is_speech.return_value = False @@ -75,8 +92,8 @@ def test_vadgate_no_speech_buffering(self, mock_vad_class): # Should not yield when no speech detected assert len(result) == 0 - # Frames should be buffered in ring - assert len(vad_gate.ring) >= 0 + # Frames should be buffered before speech starts. + assert len(vad_gate.pre_buffer) >= 0 class TestAudioStream: @@ -145,6 +162,42 @@ def test_audio_stream_callback_queue_handling(self): stream._callback(indata, 2, Mock(), Mock()) # No exception should be raised + def test_explicit_input_device_does_not_fall_back_to_other_devices(self): + """Explicit device selection should fail closed rather than record elsewhere.""" + stream = AudioStream(16000, 20, input_device=7) + + with ( + patch("dictation_tool.io.platform.system", return_value="Windows"), + patch.object(stream, "_all_input_devices", return_value=[1, 2, 7]), + ): + assert stream._build_device_candidates(7) == [7] + + def test_default_windows_input_can_fall_back_to_available_devices(self): + """Default input may try other Windows devices for portability.""" + stream = AudioStream(16000, 20) + + with ( + patch("dictation_tool.io.platform.system", return_value="Windows"), + patch("dictation_tool.io.sd.default.device", (3, None)), + patch.object(stream, "_all_input_devices", return_value=[1, 3, 5]), + ): + assert stream._build_device_candidates(None) == [None, 1, 3, 5] + + def test_callback_resamples_to_target_frame_count(self): + """Native-rate fallback still emits target-rate mono int16 chunks.""" + raw_chunks = [] + stream = AudioStream(16000, 10, on_raw_chunk=raw_chunks.append) + stream._setup_resampler(native_sr=48000, native_frames=480) + native_chunk = np.arange(480, dtype=np.int16).reshape(-1, 1) + + stream._callback(native_chunk, 480, None, None) + + queued = stream._q.get_nowait() + assert queued.shape == (160, 1) + assert queued.dtype == np.int16 + assert len(raw_chunks) == 1 + np.testing.assert_array_equal(raw_chunks[0], queued) + @pytest.mark.asyncio async def test_chunks_without_vad(self): """Test chunks method without VAD.""" @@ -262,9 +315,9 @@ def test_concatenate_multiple_arrays(self): @given( arrays=st.lists( - st.lists(st.integers(min_value=-32768, max_value=32767), min_size=1, max_size=100).map( - lambda x: np.array(x, dtype=np.int16) - ), + st.lists( + st.integers(min_value=-32768, max_value=32767), min_size=1, max_size=100 + ).map(lambda x: np.array(x, dtype=np.int16)), min_size=1, max_size=10, )