Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions alembic/versions/c8d3e5f1a204_add_allowed_subnets_to_devices.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
"""add allowed_subnets to devices

Revision ID: c8d3e5f1a204
Revises: d8b3a1f06e57
Create Date: 2026-06-05 10:30:00.000000

"""
from typing import Sequence, Union

from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql


# revision identifiers, used by Alembic.
revision: str = 'c8d3e5f1a204'
down_revision: Union[str, None] = 'd8b3a1f06e57'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
op.add_column('devices', sa.Column('allowed_subnets', postgresql.JSON(astext_type=sa.Text()), nullable=False, server_default='[]'))


def downgrade() -> None:
op.drop_column('devices', 'allowed_subnets')
58 changes: 57 additions & 1 deletion tests/e2e/test_admin_devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,4 +236,60 @@ async def test_config_dialog_shows_wg_config(page: Page, test_user):
await expect(page.get_by_role("button", name="Download .conf")).to_be_visible()

# QR code should be rendered
await expect(page.locator(".q-dialog img")).to_be_visible(timeout=5_000)
await expect(page.locator(".q-dialog img")).to_be_visible(timeout=5_000)


async def test_create_device_with_relay_subnets(page: Page, test_user):
"""Admin creates a device with relay subnets for site-to-site VPN."""
await _go_to_admin_devices(page)
await page.get_by_role("button", name="Add Device").click()
await expect(page.get_by_text("New Device")).to_be_visible(timeout=5_000)

await page.locator("input[aria-label='Device Name']").fill("site-gateway")
await page.locator("input[aria-label='Description (optional)']").fill("Site-to-site gateway")

# Scroll to relay configuration section
await page.evaluate("window.scrollTo(0, document.body.scrollHeight)")

# Fill in relay subnets
await page.locator(".q-dialog input[aria-label='Routed Subnets (optional)']").fill("192.168.1.0/24, 10.20.0.0/16")

await page.get_by_role("button", name="Create").click()

# Should see config dialog
await expect(page.get_by_text("Config for site-gateway")).to_be_visible(timeout=10_000)
await page.get_by_role("button", name="Close").click()
await page.wait_for_timeout(500)

# Verify device was created with relay subnets in DB
async with async_session() as session:
result = await session.execute(
select(Device).where(Device.name == "site-gateway")
.order_by(Device.inserted_at.desc()).limit(1)
)
device = result.scalar_one()
assert device.allowed_subnets == ["192.168.1.0/24", "10.20.0.0/16"]
assert device.description == "Site-to-site gateway"


async def test_create_device_with_invalid_subnet_rejected(page: Page, test_user):
"""An invalid relay subnet must be rejected — no device is created."""
await _go_to_admin_devices(page)
await page.get_by_role("button", name="Add Device").click()
await expect(page.get_by_text("New Device")).to_be_visible(timeout=5_000)

await page.locator("input[aria-label='Device Name']").fill("bad-subnet-gw")
await page.evaluate("window.scrollTo(0, document.body.scrollHeight)")
# Not a valid CIDR — must be rejected before any nft/route command runs.
await page.locator(".q-dialog input[aria-label='Routed Subnets (optional)']").fill("not-a-subnet")

await page.get_by_role("button", name="Create").click()

# No config dialog appears and the create dialog stays open (validation failed).
await expect(page.get_by_text("Config for bad-subnet-gw")).not_to_be_visible(timeout=3_000)
await expect(page.get_by_text("New Device")).to_be_visible()

# And nothing was persisted.
async with async_session() as session:
result = await session.execute(select(Device).where(Device.name == "bad-subnet-gw"))
assert result.first() is None
17 changes: 17 additions & 0 deletions tests/e2e/test_devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,20 @@ async def test_add_device_requires_name(page: Page, test_user: UserModel):
await expect(page.get_by_text("New Device")).to_be_visible(timeout=5_000)
await page.get_by_role("button", name="Create").click()
await expect(page.get_by_text("Device name is required")).to_be_visible(timeout=5_000)


async def test_user_device_dialog_has_no_relay_field(page: Page, test_user: UserModel):
"""Relay/site-to-site subnets are admin-only — the field must not appear on
the end-user device dialog."""
await login(page)
await expect(page.get_by_text("My Devices")).to_be_visible(timeout=10_000)

await page.get_by_role("button", name="Add Device").click()
await expect(page.get_by_text("New Device")).to_be_visible(timeout=5_000)

# The admin-only relay subnet input must be absent here.
await expect(page.locator("input[aria-label='Routed Subnets (optional)']")).to_have_count(0)


# Note: relay subnet acceptance coverage (create + invalid rejection) lives in
# tests/e2e/test_admin_devices.py, since the capability is admin-only.
87 changes: 87 additions & 0 deletions tests/test_firewall.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,90 @@ def test_build_rule_expr_single_port():
def test_build_rule_expr_no_port():
expr = _build_rule_expr("0.0.0.0/0", "accept", port_type=None, port_range=None)
assert expr == "ip daddr 0.0.0.0/0 accept"


# --- Tests for relay subnet support ---


async def test_add_device_jump_rule_with_allowed_subnets():
"""Test that add_device_jump_rule creates rules for tunnel IPs and relay subnets."""
from unittest.mock import AsyncMock, patch
from wiregui.services.firewall import add_device_jump_rule

with patch("wiregui.services.firewall._nft_batch") as mock_nft:
mock_nft.return_value = None

await add_device_jump_rule(
user_id="test-user-id",
device_ipv4="10.3.2.5",
device_ipv6="fd00::3:2:5",
allowed_subnets=["192.168.1.0/24", "10.20.0.0/16", "fd00:1::/64"]
)

# Verify nft_batch was called with correct commands
mock_nft.assert_called_once()
commands = mock_nft.call_args[0][0]

# Should have 5 rules: 2 tunnel IPs + 3 subnets
assert len(commands) == 5
assert any("ip saddr 10.3.2.5 jump" in cmd for cmd in commands)
assert any("ip6 saddr fd00::3:2:5 jump" in cmd for cmd in commands)
assert any("ip saddr 192.168.1.0/24 jump" in cmd for cmd in commands)
assert any("ip saddr 10.20.0.0/16 jump" in cmd for cmd in commands)
assert any("ip6 saddr fd00:1::/64 jump" in cmd for cmd in commands)


async def test_add_device_jump_rule_ipv4_subnet_only():
"""Test add_device_jump_rule with only IPv4 relay subnet."""
from unittest.mock import AsyncMock, patch
from wiregui.services.firewall import add_device_jump_rule

with patch("wiregui.services.firewall._nft_batch") as mock_nft:
mock_nft.return_value = None

await add_device_jump_rule(
user_id="test-user-id",
device_ipv4="10.3.2.5",
device_ipv6=None,
allowed_subnets=["192.168.1.0/24"]
)

commands = mock_nft.call_args[0][0]
assert len(commands) == 2
assert any("ip saddr 10.3.2.5 jump" in cmd for cmd in commands)
assert any("ip saddr 192.168.1.0/24 jump" in cmd for cmd in commands)


async def test_rebuild_all_rules_with_allowed_subnets():
"""Test that rebuild_all_rules includes relay subnets in jump rules."""
from unittest.mock import patch
from wiregui.services.firewall import rebuild_all_rules

with patch("wiregui.services.firewall._nft_batch") as mock_nft, \
patch("wiregui.services.firewall._list_user_chains") as mock_list:
mock_nft.return_value = None
mock_list.return_value = set()

await rebuild_all_rules([{
"user_id": "user-123",
"devices": [
{
"ipv4": "10.3.2.5",
"ipv6": "fd00::3:2:5",
"allowed_subnets": ["192.168.1.0/24", "10.20.0.0/16"]
}
],
"rules": []
}])

# Verify nft_batch was called
mock_nft.assert_called_once()
commands = mock_nft.call_args[0][0]

# Check that jump rules include both tunnel IPs and relay subnets
forward_rules = [cmd for cmd in commands if "forward" in cmd and "jump" in cmd]
assert len(forward_rules) == 4 # 2 tunnel IPs + 2 subnets
assert any("ip saddr 10.3.2.5 jump" in cmd for cmd in forward_rules)
assert any("ip6 saddr fd00::3:2:5 jump" in cmd for cmd in forward_rules)
assert any("ip saddr 192.168.1.0/24 jump" in cmd for cmd in forward_rules)
assert any("ip saddr 10.20.0.0/16 jump" in cmd for cmd in forward_rules)
104 changes: 103 additions & 1 deletion tests/test_services.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from wiregui.models.device import Device
from wiregui.models.rule import Rule
from wiregui.services.events import on_device_created, on_device_deleted, on_device_updated, on_rule_created
from wiregui.services.events import on_device_created, on_device_deleted, on_device_updated, on_rule_created, _device_allowed_ips


def _make_device(**kwargs) -> Device:
Expand All @@ -20,6 +20,51 @@ def _make_device(**kwargs) -> Device:
return Device(**defaults)


# --- _device_allowed_ips tests ---


def test_device_allowed_ips_basic():
"""Test _device_allowed_ips returns tunnel IPs with /32 and /128."""
device = _make_device()
ips = _device_allowed_ips(device)
assert ips == ["10.3.2.5/32", "fd00::3:2:5/128"]


def test_device_allowed_ips_with_relay_subnets():
"""Test _device_allowed_ips includes relay subnets."""
device = _make_device(allowed_subnets=["192.168.1.0/24", "10.20.0.0/16"])
ips = _device_allowed_ips(device)
assert ips == ["10.3.2.5/32", "fd00::3:2:5/128", "192.168.1.0/24", "10.20.0.0/16"]


def test_device_allowed_ips_ipv4_only():
"""Test _device_allowed_ips with only IPv4."""
device = _make_device(ipv6=None)
ips = _device_allowed_ips(device)
assert ips == ["10.3.2.5/32"]


def test_device_allowed_ips_ipv6_only():
"""Test _device_allowed_ips with only IPv6."""
device = _make_device(ipv4=None)
ips = _device_allowed_ips(device)
assert ips == ["fd00::3:2:5/128"]


def test_device_allowed_ips_relay_only():
"""Test _device_allowed_ips with only relay subnets (no tunnel IPs)."""
device = _make_device(ipv4=None, ipv6=None, allowed_subnets=["192.168.1.0/24"])
ips = _device_allowed_ips(device)
assert ips == ["192.168.1.0/24"]


def test_device_allowed_ips_empty():
"""Test _device_allowed_ips with no IPs or subnets."""
device = _make_device(ipv4=None, ipv6=None, allowed_subnets=[])
ips = _device_allowed_ips(device)
assert ips == []


# --- Events (WG disabled) ---


Expand Down Expand Up @@ -55,6 +100,63 @@ async def test_on_device_created_handles_wg_error(mock_wg, mock_fw, mock_setting
await on_device_created(device)


@patch("wiregui.services.events.get_settings")
@patch("wiregui.services.events.firewall")
@patch("wiregui.services.events.wireguard")
async def test_on_device_created_with_relay_subnets(mock_wg, mock_fw, mock_settings):
"""Test that device creation with relay subnets passes correct allowed_ips to WireGuard, adds routes, and configures firewall."""
mock_settings.return_value.wg_enabled = True
mock_wg.add_peer = AsyncMock()
mock_wg.add_routes = AsyncMock()
mock_fw.add_user_chain = AsyncMock()
mock_fw.add_device_jump_rule = AsyncMock()

device = _make_device(allowed_subnets=["192.168.1.0/24", "10.20.0.0/16"])
await on_device_created(device)

# Verify WireGuard peer was added with tunnel IPs + relay subnets
mock_wg.add_peer.assert_awaited_once_with(
public_key="pk-test",
allowed_ips=["10.3.2.5/32", "fd00::3:2:5/128", "192.168.1.0/24", "10.20.0.0/16"],
preshared_key="psk-test",
)

# Verify routes were added for relay subnets
mock_wg.add_routes.assert_awaited_once_with(["192.168.1.0/24", "10.20.0.0/16"])

# Verify firewall jump rule was added with relay subnets
mock_fw.add_device_jump_rule.assert_awaited_once_with(
"00000000-0000-0000-0000-000000000000",
"10.3.2.5",
"fd00::3:2:5",
["192.168.1.0/24", "10.20.0.0/16"],
)


@patch("wiregui.services.events.get_settings")
@patch("wiregui.services.events.wireguard")
async def test_on_device_deleted_prunes_orphaned_routes(mock_wg, mock_settings, monkeypatch):
"""Deleting a device reconciles routes against the remaining DB devices, so its
orphaned subnets are pruned while subnets still used elsewhere are kept."""
mock_settings.return_value.wg_enabled = True
mock_wg.remove_peer = AsyncMock()
mock_wg.sync_routes = AsyncMock()

# Remaining devices (after this delete) still route 10.20.0.0/16 but not the
# deleted device's 192.168.1.0/24.
async def fake_remaining():
return {"10.20.0.0/16"}

monkeypatch.setattr("wiregui.services.events._all_relay_subnets", fake_remaining)

device = _make_device(allowed_subnets=["192.168.1.0/24", "10.20.0.0/16"])
await on_device_deleted(device)

# sync_routes is called with the *remaining* expected set — 192.168.1.0/24 is
# therefore pruned, 10.20.0.0/16 is preserved.
mock_wg.sync_routes.assert_awaited_once_with({"10.20.0.0/16"})


# --- Rule events ---


Expand Down
36 changes: 35 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,45 @@
from wiregui.models.device import Device
from wiregui.models.rule import Rule
from wiregui.models.user import User
from wiregui.utils.network import allocate_ipv4, allocate_ipv6
from wiregui.utils.network import allocate_ipv4, allocate_ipv6, parse_subnet_list
from wiregui.utils.ordering import assign_priorities
from wiregui.utils.wg_conf import build_client_config


# --- Relay subnet parsing/validation ---


def test_parse_subnet_list_valid_v4_and_v6():
assert parse_subnet_list("192.168.1.0/24, 10.20.0.0/16, fd00:1::/64") == [
"192.168.1.0/24", "10.20.0.0/16", "fd00:1::/64",
]


def test_parse_subnet_list_empty_and_whitespace():
assert parse_subnet_list("") == []
assert parse_subnet_list(" , ,") == []


def test_parse_subnet_list_normalizes_host_bits():
# Host bits are masked off so the value is a clean network.
assert parse_subnet_list("192.168.1.5/24") == ["192.168.1.0/24"]


def test_parse_subnet_list_bare_ip_becomes_host_route():
assert parse_subnet_list("10.0.0.1") == ["10.0.0.1/32"]


def test_parse_subnet_list_rejects_invalid_cidr():
with pytest.raises(ValueError):
parse_subnet_list("not-a-subnet")


def test_parse_subnet_list_rejects_injection_attempt():
# A value crafted to inject an extra nft command must be rejected, not passed through.
with pytest.raises(ValueError):
parse_subnet_list("10.0.0.0/24 jump evil")


# --- Rule priority ordering ---


Expand Down
Loading
Loading