diff --git a/alembic/versions/c8d3e5f1a204_add_allowed_subnets_to_devices.py b/alembic/versions/c8d3e5f1a204_add_allowed_subnets_to_devices.py new file mode 100644 index 0000000..91d64e7 --- /dev/null +++ b/alembic/versions/c8d3e5f1a204_add_allowed_subnets_to_devices.py @@ -0,0 +1,27 @@ +"""add allowed_subnets to devices + +Revision ID: c8d3e5f1a204 +Revises: d8b3a1f06e57 +Create Date: 2026-06-05 10:30:00.000000 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + + +# revision identifiers, used by Alembic. +revision: str = 'c8d3e5f1a204' +down_revision: Union[str, None] = 'd8b3a1f06e57' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.add_column('devices', sa.Column('allowed_subnets', postgresql.JSON(astext_type=sa.Text()), nullable=False, server_default='[]')) + + +def downgrade() -> None: + op.drop_column('devices', 'allowed_subnets') diff --git a/tests/e2e/test_admin_devices.py b/tests/e2e/test_admin_devices.py index b44a262..57be43e 100644 --- a/tests/e2e/test_admin_devices.py +++ b/tests/e2e/test_admin_devices.py @@ -236,4 +236,60 @@ async def test_config_dialog_shows_wg_config(page: Page, test_user): await expect(page.get_by_role("button", name="Download .conf")).to_be_visible() # QR code should be rendered - await expect(page.locator(".q-dialog img")).to_be_visible(timeout=5_000) \ No newline at end of file + await expect(page.locator(".q-dialog img")).to_be_visible(timeout=5_000) + + +async def test_create_device_with_relay_subnets(page: Page, test_user): + """Admin creates a device with relay subnets for site-to-site VPN.""" + await _go_to_admin_devices(page) + await page.get_by_role("button", name="Add Device").click() + await expect(page.get_by_text("New Device")).to_be_visible(timeout=5_000) + + await page.locator("input[aria-label='Device Name']").fill("site-gateway") + await page.locator("input[aria-label='Description (optional)']").fill("Site-to-site gateway") + + # Scroll to relay configuration section + await page.evaluate("window.scrollTo(0, document.body.scrollHeight)") + + # Fill in relay subnets + await page.locator(".q-dialog input[aria-label='Routed Subnets (optional)']").fill("192.168.1.0/24, 10.20.0.0/16") + + await page.get_by_role("button", name="Create").click() + + # Should see config dialog + await expect(page.get_by_text("Config for site-gateway")).to_be_visible(timeout=10_000) + await page.get_by_role("button", name="Close").click() + await page.wait_for_timeout(500) + + # Verify device was created with relay subnets in DB + async with async_session() as session: + result = await session.execute( + select(Device).where(Device.name == "site-gateway") + .order_by(Device.inserted_at.desc()).limit(1) + ) + device = result.scalar_one() + assert device.allowed_subnets == ["192.168.1.0/24", "10.20.0.0/16"] + assert device.description == "Site-to-site gateway" + + +async def test_create_device_with_invalid_subnet_rejected(page: Page, test_user): + """An invalid relay subnet must be rejected — no device is created.""" + await _go_to_admin_devices(page) + await page.get_by_role("button", name="Add Device").click() + await expect(page.get_by_text("New Device")).to_be_visible(timeout=5_000) + + await page.locator("input[aria-label='Device Name']").fill("bad-subnet-gw") + await page.evaluate("window.scrollTo(0, document.body.scrollHeight)") + # Not a valid CIDR — must be rejected before any nft/route command runs. + await page.locator(".q-dialog input[aria-label='Routed Subnets (optional)']").fill("not-a-subnet") + + await page.get_by_role("button", name="Create").click() + + # No config dialog appears and the create dialog stays open (validation failed). + await expect(page.get_by_text("Config for bad-subnet-gw")).not_to_be_visible(timeout=3_000) + await expect(page.get_by_text("New Device")).to_be_visible() + + # And nothing was persisted. + async with async_session() as session: + result = await session.execute(select(Device).where(Device.name == "bad-subnet-gw")) + assert result.first() is None \ No newline at end of file diff --git a/tests/e2e/test_devices.py b/tests/e2e/test_devices.py index 805910a..b5c85f6 100644 --- a/tests/e2e/test_devices.py +++ b/tests/e2e/test_devices.py @@ -30,3 +30,20 @@ async def test_add_device_requires_name(page: Page, test_user: UserModel): await expect(page.get_by_text("New Device")).to_be_visible(timeout=5_000) await page.get_by_role("button", name="Create").click() await expect(page.get_by_text("Device name is required")).to_be_visible(timeout=5_000) + + +async def test_user_device_dialog_has_no_relay_field(page: Page, test_user: UserModel): + """Relay/site-to-site subnets are admin-only — the field must not appear on + the end-user device dialog.""" + await login(page) + await expect(page.get_by_text("My Devices")).to_be_visible(timeout=10_000) + + await page.get_by_role("button", name="Add Device").click() + await expect(page.get_by_text("New Device")).to_be_visible(timeout=5_000) + + # The admin-only relay subnet input must be absent here. + await expect(page.locator("input[aria-label='Routed Subnets (optional)']")).to_have_count(0) + + +# Note: relay subnet acceptance coverage (create + invalid rejection) lives in +# tests/e2e/test_admin_devices.py, since the capability is admin-only. diff --git a/tests/test_firewall.py b/tests/test_firewall.py index 24b04ce..5a193d4 100644 --- a/tests/test_firewall.py +++ b/tests/test_firewall.py @@ -38,3 +38,90 @@ def test_build_rule_expr_single_port(): def test_build_rule_expr_no_port(): expr = _build_rule_expr("0.0.0.0/0", "accept", port_type=None, port_range=None) assert expr == "ip daddr 0.0.0.0/0 accept" + + +# --- Tests for relay subnet support --- + + +async def test_add_device_jump_rule_with_allowed_subnets(): + """Test that add_device_jump_rule creates rules for tunnel IPs and relay subnets.""" + from unittest.mock import AsyncMock, patch + from wiregui.services.firewall import add_device_jump_rule + + with patch("wiregui.services.firewall._nft_batch") as mock_nft: + mock_nft.return_value = None + + await add_device_jump_rule( + user_id="test-user-id", + device_ipv4="10.3.2.5", + device_ipv6="fd00::3:2:5", + allowed_subnets=["192.168.1.0/24", "10.20.0.0/16", "fd00:1::/64"] + ) + + # Verify nft_batch was called with correct commands + mock_nft.assert_called_once() + commands = mock_nft.call_args[0][0] + + # Should have 5 rules: 2 tunnel IPs + 3 subnets + assert len(commands) == 5 + assert any("ip saddr 10.3.2.5 jump" in cmd for cmd in commands) + assert any("ip6 saddr fd00::3:2:5 jump" in cmd for cmd in commands) + assert any("ip saddr 192.168.1.0/24 jump" in cmd for cmd in commands) + assert any("ip saddr 10.20.0.0/16 jump" in cmd for cmd in commands) + assert any("ip6 saddr fd00:1::/64 jump" in cmd for cmd in commands) + + +async def test_add_device_jump_rule_ipv4_subnet_only(): + """Test add_device_jump_rule with only IPv4 relay subnet.""" + from unittest.mock import AsyncMock, patch + from wiregui.services.firewall import add_device_jump_rule + + with patch("wiregui.services.firewall._nft_batch") as mock_nft: + mock_nft.return_value = None + + await add_device_jump_rule( + user_id="test-user-id", + device_ipv4="10.3.2.5", + device_ipv6=None, + allowed_subnets=["192.168.1.0/24"] + ) + + commands = mock_nft.call_args[0][0] + assert len(commands) == 2 + assert any("ip saddr 10.3.2.5 jump" in cmd for cmd in commands) + assert any("ip saddr 192.168.1.0/24 jump" in cmd for cmd in commands) + + +async def test_rebuild_all_rules_with_allowed_subnets(): + """Test that rebuild_all_rules includes relay subnets in jump rules.""" + from unittest.mock import patch + from wiregui.services.firewall import rebuild_all_rules + + with patch("wiregui.services.firewall._nft_batch") as mock_nft, \ + patch("wiregui.services.firewall._list_user_chains") as mock_list: + mock_nft.return_value = None + mock_list.return_value = set() + + await rebuild_all_rules([{ + "user_id": "user-123", + "devices": [ + { + "ipv4": "10.3.2.5", + "ipv6": "fd00::3:2:5", + "allowed_subnets": ["192.168.1.0/24", "10.20.0.0/16"] + } + ], + "rules": [] + }]) + + # Verify nft_batch was called + mock_nft.assert_called_once() + commands = mock_nft.call_args[0][0] + + # Check that jump rules include both tunnel IPs and relay subnets + forward_rules = [cmd for cmd in commands if "forward" in cmd and "jump" in cmd] + assert len(forward_rules) == 4 # 2 tunnel IPs + 2 subnets + assert any("ip saddr 10.3.2.5 jump" in cmd for cmd in forward_rules) + assert any("ip6 saddr fd00::3:2:5 jump" in cmd for cmd in forward_rules) + assert any("ip saddr 192.168.1.0/24 jump" in cmd for cmd in forward_rules) + assert any("ip saddr 10.20.0.0/16 jump" in cmd for cmd in forward_rules) diff --git a/tests/test_services.py b/tests/test_services.py index c93ecab..d942ff1 100644 --- a/tests/test_services.py +++ b/tests/test_services.py @@ -4,7 +4,7 @@ from wiregui.models.device import Device from wiregui.models.rule import Rule -from wiregui.services.events import on_device_created, on_device_deleted, on_device_updated, on_rule_created +from wiregui.services.events import on_device_created, on_device_deleted, on_device_updated, on_rule_created, _device_allowed_ips def _make_device(**kwargs) -> Device: @@ -20,6 +20,51 @@ def _make_device(**kwargs) -> Device: return Device(**defaults) +# --- _device_allowed_ips tests --- + + +def test_device_allowed_ips_basic(): + """Test _device_allowed_ips returns tunnel IPs with /32 and /128.""" + device = _make_device() + ips = _device_allowed_ips(device) + assert ips == ["10.3.2.5/32", "fd00::3:2:5/128"] + + +def test_device_allowed_ips_with_relay_subnets(): + """Test _device_allowed_ips includes relay subnets.""" + device = _make_device(allowed_subnets=["192.168.1.0/24", "10.20.0.0/16"]) + ips = _device_allowed_ips(device) + assert ips == ["10.3.2.5/32", "fd00::3:2:5/128", "192.168.1.0/24", "10.20.0.0/16"] + + +def test_device_allowed_ips_ipv4_only(): + """Test _device_allowed_ips with only IPv4.""" + device = _make_device(ipv6=None) + ips = _device_allowed_ips(device) + assert ips == ["10.3.2.5/32"] + + +def test_device_allowed_ips_ipv6_only(): + """Test _device_allowed_ips with only IPv6.""" + device = _make_device(ipv4=None) + ips = _device_allowed_ips(device) + assert ips == ["fd00::3:2:5/128"] + + +def test_device_allowed_ips_relay_only(): + """Test _device_allowed_ips with only relay subnets (no tunnel IPs).""" + device = _make_device(ipv4=None, ipv6=None, allowed_subnets=["192.168.1.0/24"]) + ips = _device_allowed_ips(device) + assert ips == ["192.168.1.0/24"] + + +def test_device_allowed_ips_empty(): + """Test _device_allowed_ips with no IPs or subnets.""" + device = _make_device(ipv4=None, ipv6=None, allowed_subnets=[]) + ips = _device_allowed_ips(device) + assert ips == [] + + # --- Events (WG disabled) --- @@ -55,6 +100,63 @@ async def test_on_device_created_handles_wg_error(mock_wg, mock_fw, mock_setting await on_device_created(device) +@patch("wiregui.services.events.get_settings") +@patch("wiregui.services.events.firewall") +@patch("wiregui.services.events.wireguard") +async def test_on_device_created_with_relay_subnets(mock_wg, mock_fw, mock_settings): + """Test that device creation with relay subnets passes correct allowed_ips to WireGuard, adds routes, and configures firewall.""" + mock_settings.return_value.wg_enabled = True + mock_wg.add_peer = AsyncMock() + mock_wg.add_routes = AsyncMock() + mock_fw.add_user_chain = AsyncMock() + mock_fw.add_device_jump_rule = AsyncMock() + + device = _make_device(allowed_subnets=["192.168.1.0/24", "10.20.0.0/16"]) + await on_device_created(device) + + # Verify WireGuard peer was added with tunnel IPs + relay subnets + mock_wg.add_peer.assert_awaited_once_with( + public_key="pk-test", + allowed_ips=["10.3.2.5/32", "fd00::3:2:5/128", "192.168.1.0/24", "10.20.0.0/16"], + preshared_key="psk-test", + ) + + # Verify routes were added for relay subnets + mock_wg.add_routes.assert_awaited_once_with(["192.168.1.0/24", "10.20.0.0/16"]) + + # Verify firewall jump rule was added with relay subnets + mock_fw.add_device_jump_rule.assert_awaited_once_with( + "00000000-0000-0000-0000-000000000000", + "10.3.2.5", + "fd00::3:2:5", + ["192.168.1.0/24", "10.20.0.0/16"], + ) + + +@patch("wiregui.services.events.get_settings") +@patch("wiregui.services.events.wireguard") +async def test_on_device_deleted_prunes_orphaned_routes(mock_wg, mock_settings, monkeypatch): + """Deleting a device reconciles routes against the remaining DB devices, so its + orphaned subnets are pruned while subnets still used elsewhere are kept.""" + mock_settings.return_value.wg_enabled = True + mock_wg.remove_peer = AsyncMock() + mock_wg.sync_routes = AsyncMock() + + # Remaining devices (after this delete) still route 10.20.0.0/16 but not the + # deleted device's 192.168.1.0/24. + async def fake_remaining(): + return {"10.20.0.0/16"} + + monkeypatch.setattr("wiregui.services.events._all_relay_subnets", fake_remaining) + + device = _make_device(allowed_subnets=["192.168.1.0/24", "10.20.0.0/16"]) + await on_device_deleted(device) + + # sync_routes is called with the *remaining* expected set — 192.168.1.0/24 is + # therefore pruned, 10.20.0.0/16 is preserved. + mock_wg.sync_routes.assert_awaited_once_with({"10.20.0.0/16"}) + + # --- Rule events --- diff --git a/tests/test_utils.py b/tests/test_utils.py index dbb4b2b..7cdd0c2 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -8,11 +8,45 @@ from wiregui.models.device import Device from wiregui.models.rule import Rule from wiregui.models.user import User -from wiregui.utils.network import allocate_ipv4, allocate_ipv6 +from wiregui.utils.network import allocate_ipv4, allocate_ipv6, parse_subnet_list from wiregui.utils.ordering import assign_priorities from wiregui.utils.wg_conf import build_client_config +# --- Relay subnet parsing/validation --- + + +def test_parse_subnet_list_valid_v4_and_v6(): + assert parse_subnet_list("192.168.1.0/24, 10.20.0.0/16, fd00:1::/64") == [ + "192.168.1.0/24", "10.20.0.0/16", "fd00:1::/64", + ] + + +def test_parse_subnet_list_empty_and_whitespace(): + assert parse_subnet_list("") == [] + assert parse_subnet_list(" , ,") == [] + + +def test_parse_subnet_list_normalizes_host_bits(): + # Host bits are masked off so the value is a clean network. + assert parse_subnet_list("192.168.1.5/24") == ["192.168.1.0/24"] + + +def test_parse_subnet_list_bare_ip_becomes_host_route(): + assert parse_subnet_list("10.0.0.1") == ["10.0.0.1/32"] + + +def test_parse_subnet_list_rejects_invalid_cidr(): + with pytest.raises(ValueError): + parse_subnet_list("not-a-subnet") + + +def test_parse_subnet_list_rejects_injection_attempt(): + # A value crafted to inject an extra nft command must be rejected, not passed through. + with pytest.raises(ValueError): + parse_subnet_list("10.0.0.0/24 jump evil") + + # --- Rule priority ordering --- diff --git a/tests/test_wireguard_extended.py b/tests/test_wireguard_extended.py index ab848df..e64abf9 100644 --- a/tests/test_wireguard_extended.py +++ b/tests/test_wireguard_extended.py @@ -7,6 +7,10 @@ set_private_key, set_listen_port, configure_interface, + add_routes, + remove_routes, + get_interface_routes, + sync_routes, ) @@ -111,4 +115,149 @@ async def test_configure_interface_sets_key_and_port(mock_session_cls, mock_run) args = mock_run.call_args[0][0] assert args[0:3] == ["wg", "set", "wg-test"] assert "private-key" in args - assert "listen-port" in args \ No newline at end of file + assert "listen-port" in args + + +# ========== add_routes ========== + + +@patch("wiregui.services.wireguard._run", new_callable=AsyncMock) +async def test_add_routes_ipv4(mock_run): + """add_routes should call ip route add for IPv4 subnets.""" + mock_run.return_value = "" + await add_routes(["192.168.1.0/24", "10.20.0.0/16"], iface="wg-test") + + assert mock_run.await_count == 2 + calls = [c[0][0] for c in mock_run.call_args_list] + assert calls[0] == ["ip", "-4", "route", "add", "192.168.1.0/24", "dev", "wg-test"] + assert calls[1] == ["ip", "-4", "route", "add", "10.20.0.0/16", "dev", "wg-test"] + + +@patch("wiregui.services.wireguard._run", new_callable=AsyncMock) +async def test_add_routes_ipv6(mock_run): + """add_routes should call ip -6 route add for IPv6 subnets.""" + mock_run.return_value = "" + await add_routes(["fd00:1::/64", "fd00:2::/48"], iface="wg-test") + + assert mock_run.await_count == 2 + calls = [c[0][0] for c in mock_run.call_args_list] + assert calls[0] == ["ip", "-6", "route", "add", "fd00:1::/64", "dev", "wg-test"] + assert calls[1] == ["ip", "-6", "route", "add", "fd00:2::/48", "dev", "wg-test"] + + +@patch("wiregui.services.wireguard._run", new_callable=AsyncMock) +async def test_add_routes_mixed(mock_run): + """add_routes should handle mixed IPv4 and IPv6.""" + mock_run.return_value = "" + await add_routes(["192.168.1.0/24", "fd00:1::/64"], iface="wg-test") + + assert mock_run.await_count == 2 + calls = [c[0][0] for c in mock_run.call_args_list] + assert calls[0] == ["ip", "-4", "route", "add", "192.168.1.0/24", "dev", "wg-test"] + assert calls[1] == ["ip", "-6", "route", "add", "fd00:1::/64", "dev", "wg-test"] + + +@patch("wiregui.services.wireguard._run", new_callable=AsyncMock) +async def test_add_routes_empty_list(mock_run): + """add_routes with empty list should not call ip route.""" + await add_routes([], iface="wg-test") + mock_run.assert_not_awaited() + + +@patch("wiregui.services.wireguard._run", new_callable=AsyncMock) +async def test_add_routes_already_exists(mock_run): + """add_routes should not fail if route already exists.""" + mock_run.side_effect = RuntimeError("RTNETLINK answers: File exists") + # Should not raise + await add_routes(["192.168.1.0/24"], iface="wg-test") + mock_run.assert_awaited_once() + + +# ========== remove_routes ========== + + +@patch("wiregui.services.wireguard._run", new_callable=AsyncMock) +async def test_remove_routes_ipv4(mock_run): + """remove_routes should call ip route del for IPv4 subnets.""" + mock_run.return_value = "" + await remove_routes(["192.168.1.0/24", "10.20.0.0/16"], iface="wg-test") + + assert mock_run.await_count == 2 + calls = [c[0][0] for c in mock_run.call_args_list] + assert calls[0] == ["ip", "-4", "route", "del", "192.168.1.0/24", "dev", "wg-test"] + assert calls[1] == ["ip", "-4", "route", "del", "10.20.0.0/16", "dev", "wg-test"] + + +@patch("wiregui.services.wireguard._run", new_callable=AsyncMock) +async def test_remove_routes_ipv6(mock_run): + """remove_routes should call ip -6 route del for IPv6 subnets.""" + mock_run.return_value = "" + await remove_routes(["fd00:1::/64"], iface="wg-test") + + mock_run.assert_awaited_once_with(["ip", "-6", "route", "del", "fd00:1::/64", "dev", "wg-test"]) + + +@patch("wiregui.services.wireguard._run", new_callable=AsyncMock) +async def test_remove_routes_not_found(mock_run): + """remove_routes should not fail if route doesn't exist.""" + mock_run.side_effect = RuntimeError("RTNETLINK answers: No such process") + # Should not raise + await remove_routes(["192.168.1.0/24"], iface="wg-test") + mock_run.assert_awaited_once() + + +# ========== get_interface_routes ========== + + +@patch("wiregui.services.wireguard.get_settings") +@patch("wiregui.services.wireguard._run", new_callable=AsyncMock) +async def test_get_interface_routes_parses_and_normalizes(mock_run, mock_settings): + """Parses `ip route show dev` output, skips default, normalizes host routes.""" + mock_settings.return_value.wg_interface = "wg-test" + mock_run.side_effect = [ + "10.60.1.0/24 proto kernel scope link src 10.60.1.1\n192.168.1.0/24 scope link\n10.0.0.5", + "fd00::/106 proto kernel scope link", + ] + routes = await get_interface_routes(iface="wg-test") + assert routes == {"10.60.1.0/24", "192.168.1.0/24", "10.0.0.5/32", "fd00::/106"} + + +# ========== sync_routes ========== + + +@patch("wiregui.services.wireguard.remove_routes", new_callable=AsyncMock) +@patch("wiregui.services.wireguard.add_routes", new_callable=AsyncMock) +@patch("wiregui.services.wireguard.get_interface_routes", new_callable=AsyncMock) +@patch("wiregui.services.wireguard.get_settings") +async def test_sync_routes_adds_missing_and_prunes_orphans(mock_settings, mock_get, mock_add, mock_remove): + """sync_routes adds expected-but-missing, removes orphans, and never touches tunnel nets.""" + mock_settings.return_value.wg_interface = "wg-test" + mock_settings.return_value.wg_ipv4_network = "10.60.1.0/24" + mock_settings.return_value.wg_ipv6_network = "fd00::/106" + # Interface currently has: tunnel net, one expected subnet, one orphan. + mock_get.return_value = {"10.60.1.0/24", "192.168.1.0/24", "10.99.0.0/16"} + + await sync_routes(["192.168.1.0/24", "10.20.0.0/16"], iface="wg-test") + + mock_add.assert_awaited_once() + assert mock_add.call_args[0][0] == ["10.20.0.0/16"] + mock_remove.assert_awaited_once() + # Orphan removed; tunnel network preserved. + assert mock_remove.call_args[0][0] == ["10.99.0.0/16"] + + +@patch("wiregui.services.wireguard.remove_routes", new_callable=AsyncMock) +@patch("wiregui.services.wireguard.add_routes", new_callable=AsyncMock) +@patch("wiregui.services.wireguard.get_interface_routes", new_callable=AsyncMock) +@patch("wiregui.services.wireguard.get_settings") +async def test_sync_routes_noop_when_in_sync(mock_settings, mock_get, mock_add, mock_remove): + """No add/remove calls when the interface already matches the expected set.""" + mock_settings.return_value.wg_interface = "wg-test" + mock_settings.return_value.wg_ipv4_network = "10.60.1.0/24" + mock_settings.return_value.wg_ipv6_network = "fd00::/106" + mock_get.return_value = {"10.60.1.0/24", "192.168.1.0/24"} + + await sync_routes(["192.168.1.0/24"], iface="wg-test") + + mock_add.assert_not_awaited() + mock_remove.assert_not_awaited() \ No newline at end of file diff --git a/wiregui/models/device.py b/wiregui/models/device.py index a7e09da..7d527d7 100644 --- a/wiregui/models/device.py +++ b/wiregui/models/device.py @@ -34,6 +34,9 @@ class Device(SQLModel, table=True): # Assigned tunnel addresses ipv4: str | None = Field(default=None, unique=True) ipv6: str | None = Field(default=None, unique=True) + + # Additional subnets this peer routes (for site-to-site / relay configuration) + allowed_subnets: list[str] = Field(default_factory=list, sa_column=Column(JSON, default=[])) # Peer stats (updated periodically from WireGuard) remote_ip: str | None = None diff --git a/wiregui/pages/admin/devices.py b/wiregui/pages/admin/devices.py index 34a1d1c..aae78f0 100644 --- a/wiregui/pages/admin/devices.py +++ b/wiregui/pages/admin/devices.py @@ -17,7 +17,7 @@ from wiregui.pages.layout import layout from wiregui.services.events import on_device_created, on_device_deleted, on_device_updated from wiregui.utils.crypto import generate_keypair, generate_preshared_key -from wiregui.utils.network import allocate_ipv4, allocate_ipv6 +from wiregui.utils.network import allocate_ipv4, allocate_ipv6, parse_subnet_list from wiregui.utils.server_key import get_server_public_key from wiregui.utils.wg_conf import build_client_config @@ -132,6 +132,7 @@ async def create_device(): if not create_use_default_keepalive.value and create_keepalive.value else None), allowed_ips=([s.strip() for s in create_allowed_ips.value.split(",") if s.strip()] if not create_use_default_ips.value and create_allowed_ips.value else []), + allowed_subnets=parse_subnet_list(create_allowed_subnets.value or ""), ) session.add(device) await session.commit() @@ -175,6 +176,7 @@ def _reset_create_form(): create_endpoint.value = _defaults["endpoint"] create_mtu.value = _defaults["mtu"] create_keepalive.value = _defaults["keepalive"] + create_allowed_subnets.value = "" # --- Edit device --- edit_device_id = {"value": None} @@ -198,6 +200,7 @@ async def open_edit(device_id: str): edit_mtu.value = str(device.mtu) if device.mtu else "" edit_keepalive.value = str(device.persistent_keepalive) if device.persistent_keepalive else "" edit_allowed_ips.value = ", ".join(device.allowed_ips) if device.allowed_ips else "" + edit_allowed_subnets.value = ", ".join(device.allowed_subnets) if device.allowed_subnets else "" edit_dialog.open() async def save_edit(): @@ -228,6 +231,8 @@ async def save_edit(): device.persistent_keepalive = int(edit_keepalive.value) if edit_keepalive.value else None if not device.use_default_allowed_ips: device.allowed_ips = [s.strip() for s in edit_allowed_ips.value.split(",") if s.strip()] + + device.allowed_subnets = parse_subnet_list(edit_allowed_subnets.value or "") session.add(device) await session.commit() @@ -337,6 +342,11 @@ def on_admin_row_click(e): create_use_default_keepalive = ui.switch("Use default Keepalive", value=True) create_keepalive = ui.input("Persistent Keepalive", value=_defaults["keepalive"]).props("outlined dense").classes("w-full").bind_enabled_from(create_use_default_keepalive, "value", backward=lambda v: not v) + ui.separator().classes("q-my-sm") + ui.label("Relay / Site-to-Site Configuration").classes("text-subtitle2") + ui.label("Additional subnets this device routes (comma-separated CIDRs, e.g., 192.168.1.0/24)").classes("text-caption text-grey") + create_allowed_subnets = ui.input("Routed Subnets (optional)").props("outlined dense").classes("w-full") + with ui.row().classes("w-full justify-end q-mt-sm"): ui.button("Cancel", on_click=create_dialog.close).props("flat") ui.button("Create", on_click=create_device).props("color=primary") @@ -368,6 +378,11 @@ def on_admin_row_click(e): edit_use_default_keepalive = ui.switch("Use default Keepalive", value=True) edit_keepalive = ui.input("Persistent Keepalive").props("outlined dense").classes("w-full").bind_enabled_from(edit_use_default_keepalive, "value", backward=lambda v: not v) + ui.separator().classes("q-my-sm") + ui.label("Relay / Site-to-Site Configuration").classes("text-subtitle2") + ui.label("Additional subnets this device routes (comma-separated CIDRs, e.g., 192.168.1.0/24)").classes("text-caption text-grey") + edit_allowed_subnets = ui.input("Routed Subnets (optional)").props("outlined dense").classes("w-full") + with ui.row().classes("w-full justify-end q-mt-sm"): ui.button("Cancel", on_click=edit_dialog.close).props("flat") ui.button("Save", on_click=save_edit).props("color=primary") diff --git a/wiregui/services/events.py b/wiregui/services/events.py index 59a28c6..8b5778f 100644 --- a/wiregui/services/events.py +++ b/wiregui/services/events.py @@ -11,13 +11,27 @@ from wiregui.services import firewall, wireguard +async def _all_relay_subnets() -> set[str]: + """Union of relay subnets across every device currently in the database.""" + from sqlmodel import select as sel + + async with async_session() as session: + devices = (await session.execute(sel(Device))).scalars().all() + subnets: set[str] = set() + for d in devices: + subnets.update(d.allowed_subnets or []) + return subnets + + def _device_allowed_ips(device: Device) -> list[str]: - """Build the allowed-ips list for a device peer (its tunnel addresses).""" + """Build the allowed-ips list for a device peer (its tunnel addresses + relay subnets).""" ips = [] if device.ipv4: ips.append(f"{device.ipv4}/32") if device.ipv6: ips.append(f"{device.ipv6}/128") + if device.allowed_subnets: + ips.extend(device.allowed_subnets) return ips @@ -25,7 +39,7 @@ def _device_allowed_ips(device: Device) -> list[str]: async def on_device_created(device: Device) -> None: - """Configure WireGuard peer and firewall after a new device is created.""" + """Configure WireGuard peer, routes, and firewall after a new device is created.""" settings = get_settings() if not settings.wg_enabled: return @@ -37,30 +51,48 @@ async def on_device_created(device: Device) -> None: ) except Exception as e: logger.error("Failed to add WG peer for device {}: {}", device.name, e) + + # Add routes for relay subnets + if device.allowed_subnets: + try: + await wireguard.add_routes(device.allowed_subnets) + except Exception as e: + logger.error("Failed to add routes for device {}: {}", device.name, e) try: # Ensure user chain exists before adding jump rules await firewall.add_user_chain(str(device.user_id)) await firewall.add_device_jump_rule( - str(device.user_id), device.ipv4, device.ipv6, + str(device.user_id), + device.ipv4, + device.ipv6, + device.allowed_subnets, ) except Exception as e: logger.error("Failed to add firewall jump rule for device {}: {}", device.name, e) async def on_device_deleted(device: Device) -> None: - """Remove WireGuard peer after a device is deleted.""" + """Remove WireGuard peer and routes after a device is deleted.""" if not get_settings().wg_enabled: return try: await wireguard.remove_peer(public_key=device.public_key) except Exception as e: logger.error("Failed to remove WG peer for device {}: {}", device.name, e) + + # Prune routes for subnets this device used that no remaining device routes. + # sync_routes reconciles against the current DB, so shared subnets are kept. + if device.allowed_subnets: + try: + await wireguard.sync_routes(await _all_relay_subnets()) + except Exception as e: + logger.error("Failed to prune routes for device {}: {}", device.name, e) # Firewall jump rules are cleaned up on next rebuild async def on_device_updated(device: Device) -> None: - """Update WireGuard peer after a device is modified.""" + """Update WireGuard peer, routes, and firewall after a device is modified.""" if not get_settings().wg_enabled: return try: @@ -71,6 +103,19 @@ async def on_device_updated(device: Device) -> None: ) except Exception as e: logger.error("Failed to update WG peer for device {}: {}", device.name, e) + + # Reconcile routes against the full DB so subnets removed in this edit are + # pruned and newly added ones are created (sync_routes handles both directions). + try: + await wireguard.sync_routes(await _all_relay_subnets()) + except Exception as e: + logger.error("Failed to sync routes for device {}: {}", device.name, e) + + # Rebuild firewall rules for this user to update allowed_subnets + try: + await _rebuild_user_chain(str(device.user_id)) + except Exception as e: + logger.error("Failed to rebuild firewall rules for device {}: {}", device.name, e) # --- Rule events --- @@ -131,7 +176,10 @@ async def _rebuild_user_chain(user_id: str) -> None: await firewall.rebuild_all_rules([{ "user_id": user_id, - "devices": [{"ipv4": d.ipv4, "ipv6": d.ipv6} for d in devices], + "devices": [ + {"ipv4": d.ipv4, "ipv6": d.ipv6, "allowed_subnets": d.allowed_subnets} + for d in devices + ], "rules": [ {"destination": r.destination, "action": r.action, "port_type": r.port_type, "port_range": r.port_range, diff --git a/wiregui/services/firewall.py b/wiregui/services/firewall.py index d5d0166..03c7f9d 100644 --- a/wiregui/services/firewall.py +++ b/wiregui/services/firewall.py @@ -101,10 +101,24 @@ async def remove_user_chain(user_id: str) -> None: logger.debug("Remove user chain {}: {}", chain, e) -async def add_device_jump_rule(user_id: str, device_ipv4: str | None, device_ipv6: str | None) -> None: - """Add jump rules in the forward chain to route device traffic to the user chain.""" +async def add_device_jump_rule( + user_id: str, + device_ipv4: str | None, + device_ipv6: str | None, + allowed_subnets: list[str] | None = None, +) -> None: + """Add jump rules in the forward chain to route device traffic to the user chain. + + Args: + user_id: User ID for the chain + device_ipv4: Device tunnel IPv4 address + device_ipv6: Device tunnel IPv6 address + allowed_subnets: Additional relay subnets this device routes + """ chain = _user_chain_name(user_id) commands = [] + + # Add jump rules for tunnel IPs if device_ipv4: commands.append( f"add rule inet {TABLE_NAME} forward ip saddr {device_ipv4} jump {chain}" @@ -113,9 +127,25 @@ async def add_device_jump_rule(user_id: str, device_ipv4: str | None, device_ipv commands.append( f"add rule inet {TABLE_NAME} forward ip6 saddr {device_ipv6} jump {chain}" ) + + # Add jump rules for relay subnets + if allowed_subnets: + for subnet in allowed_subnets: + if ":" in subnet: + # IPv6 + commands.append( + f"add rule inet {TABLE_NAME} forward ip6 saddr {subnet} jump {chain}" + ) + else: + # IPv4 + commands.append( + f"add rule inet {TABLE_NAME} forward ip saddr {subnet} jump {chain}" + ) + if commands: await _nft_batch(commands) - logger.debug("Jump rules added for device {}/{} -> {}", device_ipv4, device_ipv6, chain) + logger.debug("Jump rules added for device {}/{} + {} subnets -> {}", + device_ipv4, device_ipv6, len(allowed_subnets or []), chain) async def apply_rule(user_id: str, destination: str, action: str, port_type: str | None = None, port_range: str | None = None) -> None: @@ -133,7 +163,7 @@ async def rebuild_all_rules(users_devices_rules: list[dict]) -> None: Args: users_devices_rules: list of dicts with keys: - user_id, devices (list of {ipv4, ipv6}), rules (list of {destination, action, port_type, port_range}) + user_id, devices (list of {ipv4, ipv6, allowed_subnets}), rules (list of {destination, action, port_type, port_range}) """ # Discover existing user_ chains so we can remove orphans existing_user_chains = await _list_user_chains() @@ -167,10 +197,20 @@ async def rebuild_all_rules(users_devices_rules: list[dict]) -> None: user_id = entry["user_id"] chain = _user_chain_name(user_id) for dev in entry.get("devices", []): + # Add jump rules for tunnel IPs if dev.get("ipv4"): commands.append(f"add rule inet {TABLE_NAME} forward ip saddr {dev['ipv4']} jump {chain}") if dev.get("ipv6"): commands.append(f"add rule inet {TABLE_NAME} forward ip6 saddr {dev['ipv6']} jump {chain}") + + # Add jump rules for relay subnets + for subnet in dev.get("allowed_subnets", []): + if ":" in subnet: + # IPv6 + commands.append(f"add rule inet {TABLE_NAME} forward ip6 saddr {subnet} jump {chain}") + else: + # IPv4 + commands.append(f"add rule inet {TABLE_NAME} forward ip saddr {subnet} jump {chain}") # Remove orphaned user chains (must happen after forward chain is flushed # so there are no remaining jump references to these chains) diff --git a/wiregui/services/wireguard.py b/wiregui/services/wireguard.py index 6d29160..36f30cf 100644 --- a/wiregui/services/wireguard.py +++ b/wiregui/services/wireguard.py @@ -3,12 +3,18 @@ import asyncio from dataclasses import dataclass, field from datetime import datetime +from ipaddress import ip_network from loguru import logger from wiregui.config import get_settings +def _normalize_cidr(value: str) -> str: + """Normalize a CIDR/host into canonical network form (e.g. 10.0.0.1 -> 10.0.0.1/32).""" + return str(ip_network(value, strict=False)) + + @dataclass class PeerInfo: public_key: str @@ -186,3 +192,118 @@ async def get_peers(iface: str | None = None) -> list[PeerInfo]: tx_bytes=tx_bytes, )) return peers + + +async def add_routes(subnets: list[str], iface: str | None = None) -> None: + """Add IP routes for relay subnets through the WireGuard interface. + + Args: + subnets: List of CIDR subnets to route through WireGuard + iface: WireGuard interface name (defaults to config value) + """ + if not subnets: + return + + settings = get_settings() + iface = iface or settings.wg_interface + + for subnet in subnets: + try: + if ":" in subnet: + # IPv6 + await _run(["ip", "-6", "route", "add", subnet, "dev", iface]) + else: + # IPv4 + await _run(["ip", "-4", "route", "add", subnet, "dev", iface]) + logger.debug("Route added: {} via {}", subnet, iface) + except RuntimeError as e: + # Route might already exist, log but don't fail + if "File exists" not in str(e): + logger.warning("Failed to add route for {}: {}", subnet, e) + + +async def remove_routes(subnets: list[str], iface: str | None = None) -> None: + """Remove IP routes for relay subnets. + + Args: + subnets: List of CIDR subnets to remove routes for + iface: WireGuard interface name (defaults to config value) + """ + if not subnets: + return + + settings = get_settings() + iface = iface or settings.wg_interface + + for subnet in subnets: + try: + if ":" in subnet: + # IPv6 + await _run(["ip", "-6", "route", "del", subnet, "dev", iface]) + else: + # IPv4 + await _run(["ip", "-4", "route", "del", subnet, "dev", iface]) + logger.debug("Route removed: {} via {}", subnet, iface) + except RuntimeError as e: + # Route might not exist, log but don't fail + logger.debug("Failed to remove route for {}: {}", subnet, e) + + +async def get_interface_routes(iface: str | None = None) -> set[str]: + """Return the set of normalized CIDR destinations currently routed via the interface.""" + settings = get_settings() + iface = iface or settings.wg_interface + + routes: set[str] = set() + for family in ("-4", "-6"): + try: + output = await _run(["ip", family, "route", "show", "dev", iface]) + except RuntimeError: + continue + for line in output.splitlines(): + line = line.strip() + if not line or line.startswith("default"): + continue + dest = line.split()[0] + try: + routes.add(_normalize_cidr(dest)) + except ValueError: + continue + return routes + + +async def sync_routes(expected_subnets, iface: str | None = None) -> None: + """Make the interface's relay routes exactly match ``expected_subnets``. + + Adds missing routes and removes orphaned ones (e.g. left behind when a device + or one of its subnets is removed), so the kernel routing table converges to the + database. The WireGuard tunnel networks are never touched. + """ + settings = get_settings() + iface = iface or settings.wg_interface + + # Never remove the tunnel network routes managed by interface setup. + protected: set[str] = set() + for net in (settings.wg_ipv4_network, settings.wg_ipv6_network): + try: + protected.add(_normalize_cidr(net)) + except ValueError: + pass + + expected: set[str] = set() + for subnet in expected_subnets: + try: + expected.add(_normalize_cidr(subnet)) + except ValueError: + logger.warning("Skipping invalid relay subnet during route sync: {}", subnet) + + actual = await get_interface_routes(iface) + to_add = sorted(expected - actual) + to_remove = sorted((actual - expected) - protected) + + if to_add: + await add_routes(to_add, iface) + if to_remove: + await remove_routes(to_remove, iface) + if to_add or to_remove: + logger.info("Synced relay routes on {}: +{} -{}", iface, len(to_add), len(to_remove)) diff --git a/wiregui/tasks/reconcile.py b/wiregui/tasks/reconcile.py index 347cdaa..9a56fed 100644 --- a/wiregui/tasks/reconcile.py +++ b/wiregui/tasks/reconcile.py @@ -36,6 +36,8 @@ async def reconcile() -> None: ips.append(f"{device.ipv4}/32") if device.ipv6: ips.append(f"{device.ipv6}/128") + if device.allowed_subnets: + ips.extend(device.allowed_subnets) try: await wireguard.add_peer( public_key=device.public_key, @@ -60,6 +62,9 @@ async def reconcile() -> None: # Rebuild all firewall rules from DB await _reconcile_firewall(devices, rules) + + # Reconcile routes for relay subnets + await _reconcile_routes(devices) async def _reconcile_firewall(devices: list[Device], rules: list[Rule]) -> None: @@ -77,7 +82,10 @@ async def _reconcile_firewall(devices: list[Device], rules: list[Rule]) -> None: entries.append({ "user_id": uid, - "devices": [{"ipv4": d.ipv4, "ipv6": d.ipv6} for d in user_devices], + "devices": [ + {"ipv4": d.ipv4, "ipv6": d.ipv6, "allowed_subnets": d.allowed_subnets} + for d in user_devices + ], "rules": [ {"destination": r.destination, "action": r.action, "port_type": r.port_type, "port_range": r.port_range, @@ -90,3 +98,20 @@ async def _reconcile_firewall(devices: list[Device], rules: list[Rule]) -> None: await firewall.rebuild_all_rules(entries) except Exception as e: logger.error("Reconcile: firewall rebuild failed: {}", e) + + +async def _reconcile_routes(devices: list[Device]) -> None: + """Make relay subnet routes exactly match the database. + + Adds missing routes and removes orphans (e.g. subnets left over from devices + or subnets that were removed while the app was down), converging the kernel + routing table to DB state. + """ + expected: set[str] = set() + for device in devices: + expected.update(device.allowed_subnets or []) + + try: + await wireguard.sync_routes(expected) + except Exception as e: + logger.error("Reconcile: route sync failed: {}", e) diff --git a/wiregui/utils/network.py b/wiregui/utils/network.py index 3bb505b..6580f58 100644 --- a/wiregui/utils/network.py +++ b/wiregui/utils/network.py @@ -1,7 +1,7 @@ """IP address allocation for WireGuard tunnel addresses.""" import random -from ipaddress import IPv4Address, IPv4Network, IPv6Address, IPv6Network +from ipaddress import IPv4Address, IPv4Network, IPv6Address, IPv6Network, ip_network from loguru import logger from sqlalchemy.ext.asyncio import AsyncSession @@ -10,6 +10,26 @@ from wiregui.models.device import Device +def parse_subnet_list(raw: str) -> list[str]: + """Parse a comma-separated string of CIDR subnets into validated, normalized values. + + Each entry must be a valid IPv4/IPv6 network; host bits are masked off + (e.g. ``192.168.1.5/24`` -> ``192.168.1.0/24``). Raises ``ValueError`` on any + invalid entry. Validation is mandatory because these values are interpolated + into ``nft`` rules and ``ip route`` commands — rejecting non-CIDR input both + prevents command injection and stops one bad entry from failing the whole + firewall rebuild. + """ + subnets: list[str] = [] + for part in raw.split(","): + candidate = part.strip() + if not candidate: + continue + # ip_network raises ValueError for anything that isn't a clean CIDR. + subnets.append(str(ip_network(candidate, strict=False))) + return subnets + + async def allocate_ipv4(session: AsyncSession, network_cidr: str) -> str: """Find an available IPv4 address in the given CIDR range.""" network = IPv4Network(network_cidr, strict=False)