Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions packaging/omlx_app/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
NSURL,
)

from .config import ServerConfig
from .config import ServerConfig, resolve_local_server_base_url
from .server_manager import PortConflict, ServerManager, ServerStatus

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -1702,7 +1702,9 @@ def _fetch_stats(self):
"""
try:
api_key = self.config.get_server_api_key()
base_url = f"http://127.0.0.1:{self.config.port}"
base_url = resolve_local_server_base_url(
self.config.get_server_bind_host(), self.config.port
)

if not api_key:
self._cached_stats = None
Expand Down Expand Up @@ -1899,7 +1901,9 @@ def _open_with_auto_login(self, redirect_path: str):
if self.server_manager.status != ServerStatus.RUNNING:
return

base_url = f"http://127.0.0.1:{self.config.port}"
base_url = resolve_local_server_base_url(
self.config.get_server_bind_host(), self.config.port
)
api_key = self.config.get_server_api_key()

if api_key:
Expand Down
74 changes: 70 additions & 4 deletions packaging/omlx_app/config.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,66 @@
"""Configuration management for oMLX menubar app."""

import ipaddress
import json
import logging
from dataclasses import dataclass, asdict
from pathlib import Path
from typing import Optional
from typing import List, Optional, Tuple

import requests

logger = logging.getLogger(__name__)


def resolve_local_server_base_url(bind_host: Optional[str], port: int) -> str:
"""HTTP base URL for reaching the server from this machine.

When the server binds to ``0.0.0.0`` or ``::``, loopback is used for health
and admin requests. When it binds to a specific address (LAN, Tailscale,
etc.), that address is used so local checks succeed.
"""
raw = (bind_host or "").strip()
if not raw or raw in ("0.0.0.0", "::"):
host_for_url = "127.0.0.1"
elif raw in ("127.0.0.1", "localhost"):
host_for_url = "127.0.0.1"
elif raw == "::1":
host_for_url = "[::1]"
else:
try:
ip = ipaddress.ip_address(raw)
if ip.version == 6:
host_for_url = f"[{raw}]"
else:
host_for_url = raw
except ValueError:
host_for_url = raw
return f"http://{host_for_url}:{port}"


def resolve_local_server_health_url(bind_host: Optional[str], port: int) -> str:
"""Full ``/health`` URL for local monitoring."""
return f"{resolve_local_server_base_url(bind_host, port)}/health"


def tcp_probe_connection_targets(
bind_host: Optional[str], port: int
) -> List[Tuple[str, int]]:
"""``(host, port)`` pairs to try for ``socket.create_connection`` port checks."""
raw = (bind_host or "").strip()
if not raw or raw in ("0.0.0.0", "::"):
return [("127.0.0.1", port)]
if raw in ("127.0.0.1", "localhost"):
return [("127.0.0.1", port)]
if raw == "::1":
return [("::1", port)]
try:
ip = ipaddress.ip_address(raw)
return [(str(ip), port)]
except ValueError:
return [(raw, port)]


def get_app_support_dir() -> Path:
"""Get the Application Support directory for oMLX."""
app_support = Path.home() / "Library" / "Application Support" / "oMLX"
Expand Down Expand Up @@ -116,7 +166,7 @@ def update_server_api_key_runtime(self, new_api_key: str) -> bool:

Returns True if successful, False if server is not reachable or auth fails.
"""
base_url = f"http://127.0.0.1:{self.port}"
base_url = resolve_local_server_base_url(self.get_server_bind_host(), self.port)
current_key = self.get_server_api_key()

try:
Expand Down Expand Up @@ -156,11 +206,26 @@ def update_server_api_key_runtime(self, new_api_key: str) -> bool:
logger.debug(f"Failed to update API key on running server: {e}")
return False

def get_server_bind_host(self) -> str:
"""Return ``server.host`` from ``{base_path}/settings.json`` (may be empty)."""
settings_file = Path(self.base_path).expanduser() / "settings.json"
if not settings_file.exists():
return ""
try:
with open(settings_file) as f:
data = json.load(f)
host = data.get("server", {}).get("host")
if host is None:
return ""
return str(host).strip()
except (json.JSONDecodeError, OSError):
return ""

def load_server_settings(self) -> dict:
"""Load model_dir and port from server's settings.json.
"""Load model_dir, port, and host from server's settings.json.

Returns:
{"model_dir": str, "port": int} or empty dict if not found
``model_dir``, ``port``, optional ``host``, or empty dict if not found
"""
settings_file = Path(self.base_path).expanduser() / "settings.json"
if not settings_file.exists():
Expand All @@ -172,6 +237,7 @@ def load_server_settings(self) -> dict:
return {
"model_dir": data.get("model", {}).get("model_dir"),
"port": data.get("server", {}).get("port", 8000),
"host": data.get("server", {}).get("host"),
}
except (json.JSONDecodeError, OSError) as e:
return {}
Expand Down
33 changes: 23 additions & 10 deletions packaging/omlx_app/server_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,13 @@

import requests

from .config import ServerConfig, get_log_path
from .config import (
ServerConfig,
get_log_path,
resolve_local_server_base_url,
resolve_local_server_health_url,
tcp_probe_connection_targets,
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -98,10 +104,14 @@ def _update_status(self, status: ServerStatus, error: Optional[str] = None) -> N
logger.error(f"Status callback error: {e}")

def _get_health_url(self) -> str:
return f"http://127.0.0.1:{self.config.port}/health"
return resolve_local_server_health_url(
self.config.get_server_bind_host(), self.config.port
)

def get_api_url(self) -> str:
return f"http://127.0.0.1:{self.config.port}"
return resolve_local_server_base_url(
self.config.get_server_bind_host(), self.config.port
)

def check_health(self) -> bool:
try:
Expand Down Expand Up @@ -223,13 +233,16 @@ def _cleanup_dead_process(self) -> None:

def _is_port_in_use(self) -> bool:
"""Check if the configured port is already in use."""
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.settimeout(1)
s.connect(("127.0.0.1", self.config.port))
return True
except (ConnectionRefusedError, OSError):
return False
if self._find_port_owner_pid() is not None:
return True
bind = self.config.get_server_bind_host()
for host, port in tcp_probe_connection_targets(bind, self.config.port):
try:
with socket.create_connection((host, port), timeout=1):
return True
except OSError:
continue
return False

def _is_omlx_server(self) -> bool:
"""Check if the process on the port is an oMLX server."""
Expand Down
7 changes: 5 additions & 2 deletions packaging/omlx_app/welcome.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
)
from Foundation import NSData, NSObject

from .config import resolve_local_server_base_url
from .server_manager import PortConflict
from .widgets import PastableSecureTextField

Expand Down Expand Up @@ -653,8 +654,10 @@ def toggleApiKeyVisibility_(self, sender):
@objc.IBAction
def openDashboard_(self, sender):
"""Open the admin dashboard and close welcome window."""
port = self.config.port
webbrowser.open(f"http://127.0.0.1:{port}/admin")
base = resolve_local_server_base_url(
self.config.get_server_bind_host(), self.config.port
)
webbrowser.open(f"{base}/admin")
self.window.close()

@objc.IBAction
Expand Down
82 changes: 81 additions & 1 deletion tests/test_omlx_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,15 @@

# Import the modules under test
sys.path.insert(0, str(Path(__file__).parent.parent / "packaging"))
from omlx_app.config import ServerConfig, get_app_support_dir, get_config_path, get_log_path
from omlx_app.config import (
ServerConfig,
get_app_support_dir,
get_config_path,
get_log_path,
resolve_local_server_base_url,
resolve_local_server_health_url,
tcp_probe_connection_targets,
)
from omlx_app.server_manager import PortConflict, ServerManager, ServerStatus


Expand Down Expand Up @@ -234,6 +242,29 @@ def test_load_server_settings(self, tmp_path: Path):
result = config.load_server_settings()
assert result["model_dir"] == "/server/models"
assert result["port"] == 9000
assert result.get("host") is None

def test_load_server_settings_includes_host(self, tmp_path: Path):
"""Test server host is loaded from settings.json."""
config = ServerConfig(base_path=str(tmp_path))
settings_file = tmp_path / "settings.json"
settings_file.write_text(json.dumps({
"server": {"port": 8000, "host": "100.64.0.2"},
}))

result = config.load_server_settings()
assert result["host"] == "100.64.0.2"

def test_get_server_bind_host(self, tmp_path: Path):
"""Test bind host reader."""
config = ServerConfig(base_path=str(tmp_path))
assert config.get_server_bind_host() == ""

settings_file = tmp_path / "settings.json"
settings_file.write_text(json.dumps({
"server": {"host": "100.64.0.2", "port": 8000},
}))
assert config.get_server_bind_host() == "100.64.0.2"

def test_load_server_settings_missing_file(self, tmp_path: Path):
"""Test load_server_settings returns empty dict when no file."""
Expand Down Expand Up @@ -326,6 +357,38 @@ def test_update_server_api_key_runtime_server_down(self, mock_session_cls, tmp_p
assert result is False


class TestResolveLocalServerUrls:
"""Tests for bind-host-aware local URL resolution."""

def test_unspecified_uses_loopback(self):
assert resolve_local_server_base_url("", 8000) == "http://127.0.0.1:8000"
assert resolve_local_server_base_url(None, 8000) == "http://127.0.0.1:8000"
assert resolve_local_server_base_url("0.0.0.0", 8000) == "http://127.0.0.1:8000"
assert resolve_local_server_base_url("::", 8000) == "http://127.0.0.1:8000"

def test_localhost_aliases(self):
assert resolve_local_server_base_url("127.0.0.1", 9000) == "http://127.0.0.1:9000"
assert resolve_local_server_base_url("localhost", 9000) == "http://127.0.0.1:9000"

def test_custom_ipv4(self):
assert resolve_local_server_base_url("100.64.0.5", 8765) == "http://100.64.0.5:8765"

def test_ipv6_brackets_in_url(self):
assert resolve_local_server_base_url(
"2001:db8::1", 8080
) == "http://[2001:db8::1]:8080"

def test_health_url(self):
assert resolve_local_server_health_url(
"100.64.0.5", 8765
) == "http://100.64.0.5:8765/health"

def test_tcp_probe_targets(self):
assert tcp_probe_connection_targets("", 8000) == [("127.0.0.1", 8000)]
assert tcp_probe_connection_targets("0.0.0.0", 8000) == [("127.0.0.1", 8000)]
assert tcp_probe_connection_targets("100.64.0.5", 8000) == [("100.64.0.5", 8000)]


class TestServerStatus:
"""Tests for ServerStatus enum."""

Expand Down Expand Up @@ -416,11 +479,28 @@ def test_get_health_url(self, manager: ServerManager):
url = manager._get_health_url()
assert url == "http://127.0.0.1:8765/health"

def test_get_health_url_custom_bind(self, manager: ServerManager):
"""Health URL follows server.host when not loopback / all-interfaces."""
base = Path(manager.config.base_path)
base.mkdir(parents=True, exist_ok=True)
(base / "settings.json").write_text(
json.dumps({"server": {"host": "100.64.0.5", "port": 8765}})
)
assert manager._get_health_url() == "http://100.64.0.5:8765/health"

def test_get_api_url(self, manager: ServerManager):
"""Test API URL generation."""
url = manager.get_api_url()
assert url == "http://127.0.0.1:8765"

def test_get_api_url_custom_bind(self, manager: ServerManager):
base = Path(manager.config.base_path)
base.mkdir(parents=True, exist_ok=True)
(base / "settings.json").write_text(
json.dumps({"server": {"host": "100.64.0.5"}})
)
assert manager.get_api_url() == "http://100.64.0.5:8765"

def test_update_config(self, manager: ServerManager):
"""Test config update."""
new_config = ServerConfig(base_path="/new/base", port=9999, model_dir="/new/path")
Expand Down