Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions docs/transports.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ curl -X POST http://localhost:8000/act \
### ZMQ

```bash
# Server (GPU machine)
tether serve ./my_export/ --transport zmq --port 5555
# Local dev server. Non-loopback ZMQ binds require Secure ZMQ below.
tether serve ./my_export/ --transport zmq --host 127.0.0.1 --port 5555
```

```python
Expand All @@ -52,6 +52,9 @@ robot.execute(actions[0]) # first action in the chunk

Production ZMQ deployments should use CURVE authentication/encryption and a
control token for operational endpoints such as `ping` and `kill`.
`tether serve --transport zmq` refuses non-loopback binds unless both are
configured. For isolated lab networks only, operators can pass
`--zmq-insecure-ok` to make that risk explicit.

Generate one server keypair and one client keypair:

Expand Down
25 changes: 24 additions & 1 deletion src/tether/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -1745,8 +1745,16 @@ def serve(
zmq_control_token: str = typer.Option(
"",
"--zmq-control-token",
envvar="TETHER_ZMQ_CONTROL_TOKEN",
help="Token required for ZMQ control endpoints such as ping and kill. "
"Pass the same value to ZmqRuntimeClient(auth_token=...).",
"Pass the same value to ZmqRuntimeClient(auth_token=...). "
"Can also be supplied via TETHER_ZMQ_CONTROL_TOKEN.",
),
zmq_insecure_ok: bool = typer.Option(
False,
"--zmq-insecure-ok",
help="Allow ZMQ to bind to a non-loopback host without CURVE and control "
"auth. Use only on isolated lab networks.",
),
device: str = typer.Option("cuda", help="Device: cuda or cpu"),
providers: str = typer.Option(
Expand Down Expand Up @@ -2763,6 +2771,19 @@ def _run_mcp_http():
if transport == "zmq":
console.print("[bold green]Starting ZMQ server...[/bold green]")
from tether.runtime.transports.zmq.factory import create_zmq_server
from tether.runtime.transports.zmq.security import validate_zmq_bind_security

try:
validate_zmq_bind_security(
host=host,
curve_enabled=bool(zmq_server_cert and zmq_client_cert_dir),
control_auth_enabled=bool(zmq_control_token),
allow_insecure=zmq_insecure_ok,
)
except ValueError as exc:
err_console.print(f"[red]{exc}[/red]", markup=False)
raise typer.Exit(1) from exc

zmq_server = create_zmq_server(
app_instance,
host=host,
Expand All @@ -2776,6 +2797,8 @@ def _run_mcp_http():
composed.append("[cyan]curve=on[/cyan]")
if zmq_control_token:
composed.append("[cyan]control-auth=on[/cyan]")
if zmq_insecure_ok:
composed.append("[yellow]zmq-insecure-ok[/yellow]")
console.print(f"[dim]Features: {' + '.join(composed)}[/dim]")
zmq_server.run()
elif transport == "http":
Expand Down
4 changes: 3 additions & 1 deletion src/tether/runtime/transports/zmq/policy_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"""
from __future__ import annotations

import hmac
import logging
import time
from dataclasses import dataclass
Expand Down Expand Up @@ -168,7 +169,8 @@ def _configure_curve(
def _authorize_control_request(self, request: dict[str, Any]) -> None:
if self._control_token is None:
return
if request.get("auth_token") != self._control_token:
token = request.get("auth_token")
if not isinstance(token, str) or not hmac.compare_digest(token, self._control_token):
raise PermissionError("ZMQ control endpoint requires a valid auth token")

def _handle_ping(self) -> dict:
Expand Down
56 changes: 52 additions & 4 deletions src/tether/runtime/transports/zmq/security.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Shared ZMQ transport security helpers."""
from __future__ import annotations

import ipaddress
from pathlib import Path

import zmq.auth
Expand All @@ -9,7 +10,7 @@
def load_curve_key(value: str | bytes | Path, *, secret: bool) -> bytes:
"""Load a Z85 CURVE key from a raw value or pyzmq certificate file."""
if isinstance(value, bytes):
return value
return _validate_curve_key(value, secret=secret)

raw_value = str(value)
path = Path(raw_value).expanduser()
Expand All @@ -19,9 +20,56 @@ def load_curve_key(value: str | bytes | Path, *, secret: bool) -> bytes:
if key is None:
kind = "secret" if secret else "public"
raise ValueError(f"CURVE certificate {path} does not contain a {kind} key")
return key
return _validate_curve_key(key, secret=secret)

return raw_value.encode("ascii")
return _validate_curve_key(raw_value.encode("ascii"), secret=secret)


__all__ = ["load_curve_key"]
def is_loopback_bind(host: str) -> bool:
"""Return True when a ZMQ bind host is local-only."""
candidate = host.strip().strip("[]")
if candidate in {"localhost"}:
return True
if candidate in {"", "*", "0.0.0.0", "::"}:
return False
try:
return ipaddress.ip_address(candidate).is_loopback
except ValueError:
return False


def validate_zmq_bind_security(
*,
host: str,
curve_enabled: bool,
control_auth_enabled: bool,
allow_insecure: bool = False,
) -> None:
"""Reject externally reachable ZMQ binds unless transport security is complete."""
if is_loopback_bind(host):
return
if curve_enabled and control_auth_enabled:
return
if allow_insecure:
return

missing: list[str] = []
if not curve_enabled:
missing.append("CURVE certificates")
if not control_auth_enabled:
missing.append("a ZMQ control token")
missing_text = " and ".join(missing)
raise ValueError(
f"Refusing insecure ZMQ bind on host {host!r}: configure {missing_text}, "
"bind to 127.0.0.1, or pass --zmq-insecure-ok for an isolated lab network."
)


def _validate_curve_key(key: bytes, *, secret: bool) -> bytes:
if len(key) != 40:
kind = "secret" if secret else "public"
raise ValueError(f"CURVE {kind} key must be 40 Z85 bytes, got {len(key)}")
return key


__all__ = ["is_loopback_bind", "load_curve_key", "validate_zmq_bind_security"]
3 changes: 1 addition & 2 deletions tests/test_zmq_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,9 @@

import threading
import time
from unittest.mock import MagicMock

import msgpack
import numpy as np
import pytest
import zmq

from tether.runtime.transports.zmq.factory import create_zmq_server
Expand Down Expand Up @@ -119,6 +117,7 @@ def test_cli_transport_flag_exists():
result = runner.invoke(app, ["serve", "--help"])
assert "--transport" in result.output
assert "zmq" in result.output
assert "--zmq-insecure-ok" in result.output


def test_cli_transport_invalid_rejected():
Expand Down
60 changes: 60 additions & 0 deletions tests/test_zmq_security.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
"""Tests for ZMQ transport security helpers."""
from __future__ import annotations

import pytest

from tether.runtime.transports.zmq.security import (
is_loopback_bind,
load_curve_key,
validate_zmq_bind_security,
)


@pytest.mark.parametrize("host", ["127.0.0.1", "::1", "[::1]", "localhost"])
def test_is_loopback_bind_accepts_local_only_hosts(host: str) -> None:
assert is_loopback_bind(host)


@pytest.mark.parametrize("host", ["0.0.0.0", "::", "*", "192.168.1.10"])
def test_is_loopback_bind_rejects_network_hosts(host: str) -> None:
assert not is_loopback_bind(host)


def test_validate_zmq_bind_security_allows_loopback_without_auth() -> None:
validate_zmq_bind_security(
host="127.0.0.1",
curve_enabled=False,
control_auth_enabled=False,
)


def test_validate_zmq_bind_security_requires_curve_and_control_token() -> None:
with pytest.raises(ValueError, match="CURVE certificates.*control token"):
validate_zmq_bind_security(
host="0.0.0.0",
curve_enabled=False,
control_auth_enabled=False,
)


def test_validate_zmq_bind_security_rejects_partial_security() -> None:
with pytest.raises(ValueError, match="control token"):
validate_zmq_bind_security(
host="192.168.1.10",
curve_enabled=True,
control_auth_enabled=False,
)


def test_validate_zmq_bind_security_allows_explicit_insecure_override() -> None:
validate_zmq_bind_security(
host="0.0.0.0",
curve_enabled=False,
control_auth_enabled=False,
allow_insecure=True,
)


def test_load_curve_key_rejects_invalid_raw_key_length() -> None:
with pytest.raises(ValueError, match="40 Z85 bytes"):
load_curve_key("too-short", secret=False)
Loading