Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions shuvoice/asr_sherpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,11 @@ def _looks_like_parakeet_model(cls, config: Config) -> bool:
model_name = str(getattr(config, "sherpa_model_name", "")).strip().lower()
model_dir = str(getattr(config, "sherpa_model_dir", "") or "").strip().lower()
candidates = [model_name, model_dir]
return any(marker in candidate for candidate in candidates for marker in cls._PARAKEET_MODEL_MARKERS)
return any(
marker in candidate
for candidate in candidates
for marker in cls._PARAKEET_MODEL_MARKERS
)

@staticmethod
def _cuda_provider_available() -> tuple[bool, str]:
Expand Down Expand Up @@ -523,7 +527,9 @@ def _parakeet_streaming_model_compatible(cls, config: Config) -> tuple[bool, str

configured = getattr(config, "sherpa_model_dir", None)
model_name = str(getattr(config, "sherpa_model_name", "") or "").strip() or None
model_dir = Path(configured).expanduser() if configured else cls._default_model_dir(model_name)
model_dir = (
Path(configured).expanduser() if configured else cls._default_model_dir(model_name)
)

if not model_dir.is_dir():
return True, f"model directory not present yet ({model_dir})"
Expand Down
3 changes: 1 addition & 2 deletions shuvoice/cli/commands/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,7 @@ def config_set(key: str, value: str) -> int:
if value_norm not in _ALLOWED_FINAL_INJECTION_MODES:
allowed = ", ".join(sorted(_ALLOWED_FINAL_INJECTION_MODES))
print(
"ERROR: typing_final_injection_mode must be one of: "
f"{allowed}",
f"ERROR: typing_final_injection_mode must be one of: {allowed}",
file=sys.stderr,
)
return 1
Expand Down
8 changes: 6 additions & 2 deletions shuvoice/cli/commands/preflight.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,9 @@ def check_tts_backend_stack() -> str:
errors = backend_cls.dependency_errors()
if config.tts_backend == "elevenlabs":
# API key presence is validated via check_tts_api_key using the configured env name.
errors = [err for err in errors if "API key" not in err and "ELEVENLABS_API_KEY" not in err]
errors = [
err for err in errors if "API key" not in err and "ELEVENLABS_API_KEY" not in err
]
if errors:
raise RuntimeError("; ".join(errors))
return f"{config.tts_backend} deps OK"
Expand Down Expand Up @@ -176,7 +178,9 @@ def check_asr_stack() -> str:
if config.preserve_clipboard or config.tts_enabled:
add_check("wl-paste binary", check_binary("wl-paste"))
else:
checks.append(("wl-paste binary", True, "skipped (preserve_clipboard=false, tts_enabled=false)"))
checks.append(
("wl-paste binary", True, "skipped (preserve_clipboard=false, tts_enabled=false)")
)
add_check("gtk4-layer-shell library", check_layer_shell)
add_check("Output mode", check_output_mode)

Expand Down
9 changes: 2 additions & 7 deletions shuvoice/cli/commands/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,7 @@ def run_setup(

# Evaluate startup diagnostics on a copy so we can report effective runtime
# values (for example provider fallback) without mutating caller config.
cfg_for_checks = Config(
**{name: getattr(config, name) for name in Config.config_field_names()}
)
cfg_for_checks = Config(**{name: getattr(config, name) for name in Config.config_field_names()})

startup_warnings = backend_cls.startup_warnings(cfg_for_checks, apply_fixes=True)
if startup_warnings:
Expand All @@ -141,10 +139,7 @@ def run_setup(
looks_like_parakeet = bool(detector(cfg_for_checks))
print(f"[INFO] Sherpa Parakeet model: {'yes' if looks_like_parakeet else 'no'}")
if looks_like_parakeet:
print(
"[INFO] Sherpa Parakeet runnable: "
f"{'yes' if not startup_errors else 'no'}"
)
print(f"[INFO] Sherpa Parakeet runnable: {'yes' if not startup_errors else 'no'}")

if startup_errors:
print("\n[FAIL] Backend runtime compatibility")
Expand Down
10 changes: 3 additions & 7 deletions shuvoice/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,9 +298,7 @@ def __post_init__(self):

self.sherpa_decode_mode = str(self.sherpa_decode_mode).strip().lower()
if self.sherpa_decode_mode not in {"auto", "streaming", "offline_instant"}:
raise ValueError(
"sherpa_decode_mode must be one of: auto, streaming, offline_instant"
)
raise ValueError("sherpa_decode_mode must be one of: auto, streaming, offline_instant")

if not isinstance(self.sherpa_enable_parakeet_streaming, bool):
raise ValueError("sherpa_enable_parakeet_streaming must be true or false")
Expand Down Expand Up @@ -560,8 +558,7 @@ def _apply_instant_mode_profile(self) -> None:
# audio is accumulated and decoded in one shot on key release.
# Log the resolved mode for diagnostics.
log.info(
"instant_mode enabled with Sherpa offline_instant mode "
"(model: %s)",
"instant_mode enabled with Sherpa offline_instant mode (model: %s)",
self.sherpa_model_name,
)
else:
Expand Down Expand Up @@ -696,8 +693,7 @@ def load(cls) -> "Config":
)
if derived_mode_from_legacy:
log.info(
"Migrated legacy use_clipboard_for_final to "
"typing_final_injection_mode=%s",
"Migrated legacy use_clipboard_for_final to typing_final_injection_mode=%s",
flat.get("typing_final_injection_mode"),
)
except Exception: # noqa: BLE001
Expand Down
4 changes: 3 additions & 1 deletion shuvoice/tts_elevenlabs.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,9 @@ def list_voices(self) -> list[VoiceInfo]:
except TimeoutError as exc:
raise RuntimeError("ElevenLabs voice list request timed out") from exc
except OSError as exc:
raise RuntimeError(f"ElevenLabs voice list request failed: {type(exc).__name__}") from exc
raise RuntimeError(
f"ElevenLabs voice list request failed: {type(exc).__name__}"
) from exc
except json.JSONDecodeError as exc:
raise RuntimeError("Invalid ElevenLabs voice list response") from exc

Expand Down
8 changes: 6 additions & 2 deletions shuvoice/tts_overlay.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,9 @@ def _render(self) -> None:
self._preview_label.set_text(self._preview_text)

if self._pause_btn is not None:
self._pause_btn.set_label("▶ Resume" if self._state == TTS_OVERLAY_PAUSED else "⏸ Pause")
self._pause_btn.set_label(
"▶ Resume" if self._state == TTS_OVERLAY_PAUSED else "⏸ Pause"
)

def _on_pause_clicked(self, _button: Gtk.Button) -> None:
if self._state == TTS_OVERLAY_PAUSED:
Expand Down Expand Up @@ -229,7 +231,9 @@ def show(self) -> None:
def hide(self) -> None:
GLib.idle_add(self._do_hide)

def set_state(self, state: str, *, preview_text: str = "", error_message: str | None = None) -> None:
def set_state(
self, state: str, *, preview_text: str = "", error_message: str | None = None
) -> None:
GLib.idle_add(self._do_set_state, state, preview_text, error_message)

def set_voices(self, voices: list[VoiceInfo], selected_voice_id: str | None = None) -> None:
Expand Down
5 changes: 1 addition & 4 deletions shuvoice/typer.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,7 @@ def _run(self, args: list[str], op: str, attempts: int | None = None) -> bool:

@staticmethod
def _backspace_args(count: int) -> list[str]:
args = ["wtype"]
for _ in range(count):
args.extend(["-k", "BackSpace"])
return args
return ["wtype"] + ["-k", "BackSpace"] * count

def _send_backspaces(self, count: int, op: str) -> bool:
if count <= 0:
Expand Down
8 changes: 2 additions & 6 deletions shuvoice/wizard/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,9 +212,7 @@ def _build_asr_page(self) -> Gtk.Widget:
self._sherpa_profile_help.set_margin_bottom(4)
page.append(self._sherpa_profile_help)

self._sherpa_streaming_radio = Gtk.CheckButton(
label="Streaming (Zipformer Kroko model)"
)
self._sherpa_streaming_radio = Gtk.CheckButton(label="Streaming (Zipformer Kroko model)")
self._sherpa_streaming_radio.add_css_class("wizard-radio")
self._sherpa_streaming_radio.connect(
"toggled",
Expand Down Expand Up @@ -658,9 +656,7 @@ def _on_finish(self, button):
),
}
if self._asr_backend == "sherpa":
write_kwargs["sherpa_enable_parakeet_streaming"] = (
sherpa_enable_parakeet_streaming
)
write_kwargs["sherpa_enable_parakeet_streaming"] = sherpa_enable_parakeet_streaming

write_config(
self._asr_backend,
Expand Down
7 changes: 4 additions & 3 deletions shuvoice/wizard/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,9 @@ def _check_parakeet_streaming_compatibility(

if not backend_cls.capabilities.supports_model_download:
_emit(1.0, "Model download skipped (lazy backend)")
return "skipped", _with_provider_note("Selected backend downloads models lazily at runtime.")
return "skipped", _with_provider_note(
"Selected backend downloads models lazily at runtime."
)

missing = backend_cls.dependency_errors()
if missing:
Expand Down Expand Up @@ -269,8 +271,7 @@ def finish_setup(
typing_final_injection_mode=typing_final_injection_mode,
)
model_message = (
f"{model_message} "
"Applied fallback profile: Streaming (Zipformer default model)."
f"{model_message} Applied fallback profile: Streaming (Zipformer default model)."
)

write_marker()
Expand Down
6 changes: 2 additions & 4 deletions shuvoice/wizard_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
def _is_parakeet_sherpa_model_name(model_name: str) -> bool:
return "parakeet" in str(model_name).strip().lower()


# Keybind presets for push-to-talk setup.
# (id, display_label, hyprland_bind_key_spec, description)
# hyprland_bind_key_spec is the "MODS, KEY" portion for bind/bindr lines.
Expand Down Expand Up @@ -611,10 +612,7 @@ def write_config(
injection_mode = str(typing_final_injection_mode).strip().lower()
if injection_mode not in _FINAL_INJECTION_MODE_SET:
allowed = ", ".join(sorted(_FINAL_INJECTION_MODE_SET))
raise ValueError(
"typing_final_injection_mode must be one of: "
f"{allowed}"
)
raise ValueError(f"typing_final_injection_mode must be one of: {allowed}")

if asr_backend == "sherpa":
sherpa_cuda_available = _detect_sherpa_cuda_provider()
Expand Down
4 changes: 3 additions & 1 deletion tests/test_asr.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,9 @@ def fake_download_model(cls, model_name=None, model_dir=None, **_):


def test_sherpa_startup_errors_block_parakeet_streaming_by_default():
cfg = Config(asr_backend="sherpa", sherpa_model_name="sherpa-onnx-nemo-parakeet-tdt-0.6b-v3-int8")
cfg = Config(
asr_backend="sherpa", sherpa_model_name="sherpa-onnx-nemo-parakeet-tdt-0.6b-v3-int8"
)

sherpa_cls = get_backend_class("sherpa")
errors = sherpa_cls.startup_errors(cfg)
Expand Down
49 changes: 26 additions & 23 deletions tests/test_asr_sherpa_offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,27 +454,30 @@ def capture_from_transducer(**kwargs):
def test_online_parakeet_streaming_fails_fast_without_window_size_metadata(
self, tmp_path: Path
):
model_dir = tmp_path / "sherpa-model"
model_dir.mkdir(parents=True, exist_ok=True)
(model_dir / "tokens.txt").write_text("<blk>\na\n")
(model_dir / "encoder.onnx").write_bytes(b"onnx")
for name in ("decoder.onnx", "joiner.onnx"):
(model_dir / name).write_bytes(b"onnx")
from unittest.mock import MagicMock, patch

with patch.dict("sys.modules", {"sherpa_onnx": MagicMock()}):
model_dir = tmp_path / "sherpa-model"
model_dir.mkdir(parents=True, exist_ok=True)
(model_dir / "tokens.txt").write_text("<blk>\na\n")
(model_dir / "encoder.onnx").write_bytes(b"onnx")
for name in ("decoder.onnx", "joiner.onnx"):
(model_dir / name).write_bytes(b"onnx")

cfg = Config(
asr_backend="sherpa",
sherpa_model_name="sherpa-onnx-nemo-parakeet-tdt-0.6b-v3-int8",
sherpa_model_dir=str(model_dir),
sherpa_decode_mode="streaming",
sherpa_enable_parakeet_streaming=True,
)
backend = create_backend("sherpa", cfg)
backend._model_files = {
"tokens": model_dir / "tokens.txt",
"encoder": model_dir / "encoder.onnx",
"decoder": model_dir / "decoder.onnx",
"joiner": model_dir / "joiner.onnx",
}

cfg = Config(
asr_backend="sherpa",
sherpa_model_name="sherpa-onnx-nemo-parakeet-tdt-0.6b-v3-int8",
sherpa_model_dir=str(model_dir),
sherpa_decode_mode="streaming",
sherpa_enable_parakeet_streaming=True,
)
backend = create_backend("sherpa", cfg)
backend._model_files = {
"tokens": model_dir / "tokens.txt",
"encoder": model_dir / "encoder.onnx",
"decoder": model_dir / "decoder.onnx",
"joiner": model_dir / "joiner.onnx",
}

with pytest.raises(RuntimeError, match="window_size"):
backend._load_online_recognizer()
with pytest.raises(RuntimeError, match="window_size"):
backend._load_online_recognizer()
4 changes: 3 additions & 1 deletion tests/test_backend_capabilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,9 @@ def test_sherpa_startup_guard_allows_non_parakeet_streaming(tmp_path: Path):
assert sherpa_cls.startup_errors(cfg) == []


def test_sherpa_startup_warning_cuda_fallback_applies_in_streaming_mode(monkeypatch, tmp_path: Path):
def test_sherpa_startup_warning_cuda_fallback_applies_in_streaming_mode(
monkeypatch, tmp_path: Path
):
model_dir = _make_sherpa_model_dir(tmp_path)
cfg = Config(
asr_backend="sherpa",
Expand Down
10 changes: 7 additions & 3 deletions tests/test_cli_commands_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@ def test_list_audio_devices_success_filters_input_devices(monkeypatch, capsys):


def test_list_audio_devices_error(monkeypatch, capsys):
fake_sd = types.SimpleNamespace(query_devices=lambda: (_ for _ in ()).throw(RuntimeError("boom")))
fake_sd = types.SimpleNamespace(
query_devices=lambda: (_ for _ in ()).throw(RuntimeError("boom"))
)
monkeypatch.setitem(sys.modules, "sounddevice", fake_sd)

assert audio_cmd.list_audio_devices() == 1
Expand Down Expand Up @@ -99,7 +101,9 @@ def test_config_validate_success(monkeypatch, capsys):


def test_config_validate_error(monkeypatch, capsys):
monkeypatch.setattr(config_cmd, "load_raw", lambda _path: (_ for _ in ()).throw(RuntimeError("bad")))
monkeypatch.setattr(
config_cmd, "load_raw", lambda _path: (_ for _ in ()).throw(RuntimeError("bad"))
)

assert config_cmd.config_validate() == 1
err = capsys.readouterr().err
Expand Down Expand Up @@ -129,7 +133,7 @@ def test_config_effective_error(monkeypatch, capsys):

def test_config_set_updates_typing_final_injection_mode(monkeypatch, tmp_path, capsys):
config_file = tmp_path / "config.toml"
config_file.write_text("[typing]\ntyping_final_injection_mode = \"auto\"\n", encoding="utf-8")
config_file.write_text('[typing]\ntyping_final_injection_mode = "auto"\n', encoding="utf-8")

monkeypatch.setattr(config_cmd.Config, "config_path", classmethod(lambda cls: config_file))

Expand Down
4 changes: 3 additions & 1 deletion tests/test_typer.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,9 @@ def test_commit_final_auto_mode_prefers_direct_when_watchers_detected(monkeypatc

monkeypatch.setattr(typer, "_detect_clipboard_watchers", lambda: True)
monkeypatch.setattr(typer, "update_partial", lambda text: events.append(("update", text)))
monkeypatch.setattr(typer, "_capture_clipboard", lambda: events.append("capture") or (True, "x"))
monkeypatch.setattr(
typer, "_capture_clipboard", lambda: events.append("capture") or (True, "x")
)
monkeypatch.setattr(typer, "_backspace_partial", lambda: events.append("backspace") or True)
monkeypatch.setattr(
typer,
Expand Down
4 changes: 1 addition & 3 deletions tests/test_wizard_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,9 +149,7 @@ def test_on_finish_passes_parakeet_streaming_profile_to_write_config():

with (
patch("shuvoice.wizard.write_config") as write_config,
patch(
"shuvoice.wizard.maybe_download_model", return_value=("skipped", "noop")
),
patch("shuvoice.wizard.maybe_download_model", return_value=("skipped", "noop")),
patch("shuvoice.wizard.write_marker"),
):
WelcomeWizard._on_finish(wizard, None)
Expand Down