From 18967905c8509d34e8c22b1ebecd84b957f223da Mon Sep 17 00:00:00 2001 From: Domenico Andreoli Date: Thu, 10 Nov 2022 11:14:59 +0100 Subject: [PATCH] Add IP version constraint --- geneve/constraints.py | 16 ++++++++++++++-- geneve/events_emitter_eql.py | 2 ++ 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/geneve/constraints.py b/geneve/constraints.py index 3b9ab3ac..4afb3c87 100644 --- a/geneve/constraints.py +++ b/geneve/constraints.py @@ -302,11 +302,12 @@ def solve_date_constraints(self, field, value, constraints, left_attempts): left_attempts -= 1 return {"value": value, "left_attempts": left_attempts} - @solver("ip", "==", "!=", "in", "not in") + @solver("ip", "==", "!=", "in", "not in", "version") def solve_ip_constraints(self, field, value, constraints, left_attempts): include_nets = set() exclude_nets = set() exclude_addrs = set() + ip_versions = [] for k, v, *_ in constraints: if k == "==": @@ -352,6 +353,16 @@ def solve_ip_constraints(self, field, value, constraints, left_attempts): exclude_nets.add(ipaddress.ip_network(str(v))) except ValueError: raise ValueError(f"Not an IP network: {str(v)}") + elif k == "version": + if type(v) is tuple: + if len(v) > 1: + raise ValueError(f"Too many arguments for version of '{field}': {v}") + v = v[0] + v = int(v) + if v not in (4, 6): + raise ValueError(f"Not an valid IP version: {v}") + if v not in ip_versions: + ip_versions = sorted(ip_versions + [v]) if include_nets & exclude_nets: intersecting_nets = ", ".join(str(net) for net in sorted(include_nets & exclude_nets)) @@ -368,7 +379,8 @@ def solve_ip_constraints(self, field, value, constraints, left_attempts): else: exclude_nets = ", ".join(str(v) for v in sorted(exclude_nets)) raise ConflictError(f"cannot be in any of nets ({exclude_nets})", field) - ip_versions = sorted(ip.version for ip in include_nets | exclude_nets | exclude_addrs) or [4] + if not ip_versions: + ip_versions = sorted(ip.version for ip in include_nets | exclude_nets | exclude_addrs) or [4] include_nets = sorted(include_nets, key=lambda x: (x.version, x)) while left_attempts and ( value in (None, []) diff --git a/geneve/events_emitter_eql.py b/geneve/events_emitter_eql.py index 6bbace36..a4421f8b 100644 --- a/geneve/events_emitter_eql.py +++ b/geneve/events_emitter_eql.py @@ -201,6 +201,8 @@ def cc_function_call(node: eql.ast.FunctionCall, negate: bool) -> Root: return cc_function(node, negate, "in") elif fn_name == "_cardinality": return cc_function(node, negate, "cardinality") + elif fn_name == "_version": + return cc_function(node, negate, "version") else: raise NotImplementedError(f"Unsupported function: {node.name}")