diff --git a/net_watch_plus.py b/net_watch_plus.py index 838146d..64322b4 100644 --- a/net_watch_plus.py +++ b/net_watch_plus.py @@ -49,7 +49,7 @@ import time import urllib.error import urllib.request -from dataclasses import asdict, dataclass, field +from dataclasses import asdict, dataclass from datetime import datetime, timezone from pathlib import Path from typing import Optional @@ -174,6 +174,18 @@ def _is_private_or_special(ip: str) -> bool: ) +def _is_unspecified_ip(ip: Optional[str]) -> bool: + """True when ip is an unspecified wildcard address such as 0.0.0.0 or ::.""" + if not ip or ip == "*": + return True + try: + return ipaddress.ip_address(ip).is_unspecified + except ValueError: + # Do not classify malformed host strings as unspecified; they should simply + # avoid network enrichment elsewhere via _is_private_or_special(). + return False + + def _utc_now_iso() -> str: return datetime.now(timezone.utc).isoformat(timespec="seconds") @@ -208,10 +220,13 @@ def _collect_via_ss() -> list[Conn]: conns: list[Conn] = [] for line in out.splitlines(): parts = line.split() - if len(parts) < 5: + # ss -tunHp columns are: + # Netid State Recv-Q Send-Q Local Address:Port Peer Address:Port Process + # The local/peer addresses therefore live at indexes 4 and 5. + if len(parts) < 6: continue - proto, state, local, remote = parts[0], parts[1], parts[3], parts[4] - pid, proc = _parse_ss_users("".join(parts[5:]) if len(parts) > 5 else "") + proto, state, local, remote = parts[0], parts[1], parts[4], parts[5] + pid, proc = _parse_ss_users(" ".join(parts[6:]) if len(parts) > 6 else "") conns.append(Conn(proto=proto, local=local, remote=remote, state=state, pid=pid, proc=proc)) return conns @@ -278,12 +293,12 @@ def _collect_via_lsof() -> list[Conn]: def classify_basic(conn: Conn) -> list[str]: """Port- and state-based reasons (no network calls).""" reasons: list[str] = [] - _, port = conn.remote_ip_and_port() + ip, port = conn.remote_ip_and_port() if port and port in RISKY_PORTS: reasons.append(RISKY_PORTS[port]) if conn.state in {"SYN-SENT", "SYN-RECV", "SYN_SENT", "SYN_RECV"}: reasons.append("handshake") - if conn.remote.startswith("0.0.0.0") or conn.remote.startswith("[::]"): + if _is_unspecified_ip(ip): reasons.append("unspecified-remote") return reasons diff --git a/test_net_watch_plus.py b/test_net_watch_plus.py new file mode 100644 index 0000000..e3d1368 --- /dev/null +++ b/test_net_watch_plus.py @@ -0,0 +1,39 @@ +import types +import unittest +from unittest import mock + +import net_watch_plus + + +class NetWatchPlusParserTests(unittest.TestCase): + def test_collect_via_ss_uses_peer_address_column(self) -> None: + ss_output = ( + 'tcp ESTAB 0 0 192.168.1.2:52000 ' + '203.0.113.10:3389 users:(("python3",pid=42,fd=7))\n' + ) + + with mock.patch( + "net_watch_plus.subprocess.run", + return_value=types.SimpleNamespace(stdout=ss_output), + ): + conns = net_watch_plus._collect_via_ss() + + self.assertEqual(len(conns), 1) + conn = conns[0] + self.assertEqual(conn.local, "192.168.1.2:52000") + self.assertEqual(conn.remote, "203.0.113.10:3389") + self.assertEqual(conn.pid, 42) + self.assertEqual(conn.proc, "python3") + self.assertIn("rdp", net_watch_plus.classify_basic(conn)) + + def test_unspecified_remote_detection_handles_wildcards(self) -> None: + self.assertIn( + "unspecified-remote", + net_watch_plus.classify_basic( + net_watch_plus.Conn("udp", "127.0.0.1:1", "*:*", "UNCONN") + ), + ) + + +if __name__ == "__main__": + unittest.main()