From b14d4d2b8006acfe537a77a7116c396af46769c3 Mon Sep 17 00:00:00 2001 From: dnwpark Date: Thu, 23 Oct 2025 15:22:18 -0700 Subject: [PATCH 1/6] Fix union and coalesce expressions not decoding to the correct type. --- gel/_internal/_codegen/_models/_pydantic.py | 39 ++++++ gel/_internal/_qbmodel/_abstract/__init__.py | 6 + gel/_internal/_qbmodel/_abstract/_methods.py | 139 +++++++++++++++++-- gel/_internal/_typing_dispatch.py | 2 + gel/models/pydantic.py | 4 + 5 files changed, 178 insertions(+), 12 deletions(-) diff --git a/gel/_internal/_codegen/_models/_pydantic.py b/gel/_internal/_codegen/_models/_pydantic.py index 58a69891..b61f4dc2 100644 --- a/gel/_internal/_codegen/_models/_pydantic.py +++ b/gel/_internal/_codegen/_models/_pydantic.py @@ -6170,6 +6170,45 @@ def resolve( f"# type: ignore [assignment, misc, unused-ignore]" ) + if function.schemapath in { + SchemaPath('std', 'UNION'), + SchemaPath('std', 'IF'), + SchemaPath('std', '??'), + }: + # Special case for the UNION, IF and ?? operators + # Produce a union type instead of just taking the first + # valid type. + # + # See gel: edb.compiler.func.compile_operator + create_union = self.import_name( + BASE_IMPL, "create_optional_union" + ) + + tvars: list[str] = [] + for param, path in sources: + if ( + param.name in required_generic_params + or param.name in optional_generic_params + ): + pn = param_vars[param.name] + tvar = f"__t_{pn}__" + + resolve(pn, path, tvar) + tvars.append(tvar) + + self.write( + f"{gtvar} = {tvars[0]} " + f"# type: ignore [assignment, misc, unused-ignore]" + ) + for tvar in tvars[1:]: + self.write( + f"{gtvar} = {create_union}({gtvar}, {tvar}) " + f"# type: ignore [" + f"assignment, misc, unused-ignore]" + ) + + continue + # Try to infer generic type from required params first for param, path in sources: if param.name in required_generic_params: diff --git a/gel/_internal/_qbmodel/_abstract/__init__.py b/gel/_internal/_qbmodel/_abstract/__init__.py index 41fb17a8..4500a0f1 100644 --- a/gel/_internal/_qbmodel/_abstract/__init__.py +++ b/gel/_internal/_qbmodel/_abstract/__init__.py @@ -68,6 +68,9 @@ from ._methods import ( BaseGelModel, BaseGelModelIntersection, + BaseGelModelUnion, + create_optional_union, + create_union, ) @@ -138,6 +141,7 @@ "ArrayMeta", "BaseGelModel", "BaseGelModelIntersection", + "BaseGelModelUnion", "ComputedLinkSet", "ComputedLinkWithPropsSet", "ComputedMultiLinkDescriptor", @@ -181,6 +185,8 @@ "TupleMeta", "UUIDImpl", "copy_or_ref_lprops", + "create_optional_union", + "create_union", "empty_set_if_none", "field_descriptor", "get_base_scalars_backed_by_py_type", diff --git a/gel/_internal/_qbmodel/_abstract/_methods.py b/gel/_internal/_qbmodel/_abstract/_methods.py index bcba7d83..fe413936 100644 --- a/gel/_internal/_qbmodel/_abstract/_methods.py +++ b/gel/_internal/_qbmodel/_abstract/_methods.py @@ -18,8 +18,9 @@ from gel._internal import _qb from gel._internal._schemapath import ( - TypeNameIntersection, TypeNameExpr, + TypeNameIntersection, + TypeNameUnion, ) from gel._internal import _type_expression from gel._internal._xmethod import classonlymethod @@ -270,6 +271,17 @@ class BaseGelModelIntersectionBacklinks( rhs: ClassVar[type[AbstractGelObjectBacklinksModel]] +class BaseGelModelUnion( + BaseGelModel, + _type_expression.Union, + Generic[_T_Lhs, _T_Rhs], +): + __gel_type_class__: ClassVar[type] + + lhs: ClassVar[type[AbstractGelModel]] + rhs: ClassVar[type[AbstractGelModel]] + + T = TypeVar('T') U = TypeVar('U') @@ -318,6 +330,17 @@ def combine_dicts( return result +def _order_base_types(lhs: type, rhs: type) -> tuple[type, ...]: + if lhs == rhs: + return (lhs,) + elif issubclass(lhs, rhs): + return (lhs, rhs) + elif issubclass(rhs, lhs): + return (rhs, lhs) + else: + return (lhs, rhs) + + _type_intersection_cache: weakref.WeakKeyDictionary[ type[AbstractGelModel], weakref.WeakKeyDictionary[ @@ -430,17 +453,6 @@ def object( return result -def _order_base_types(lhs: type, rhs: type) -> tuple[type, ...]: - if lhs == rhs: - return (lhs,) - elif issubclass(lhs, rhs): - return (lhs, rhs) - elif issubclass(rhs, lhs): - return (rhs, lhs) - else: - return (lhs, rhs) - - def create_intersection_backlinks( lhs_backlinks: type[AbstractGelObjectBacklinksModel], rhs_backlinks: type[AbstractGelObjectBacklinksModel], @@ -500,3 +512,106 @@ def create_intersection_backlinks( ) return backlinks + + +_type_union_cache: weakref.WeakKeyDictionary[ + type[AbstractGelModel], + weakref.WeakKeyDictionary[ + type[AbstractGelModel], + type[BaseGelModelUnion[AbstractGelModel, AbstractGelModel]], + ], +] = weakref.WeakKeyDictionary() + + +def create_optional_union( + lhs: type[_T_Lhs] | None, + rhs: type[_T_Rhs] | None, +) -> type[BaseGelModelUnion[_T_Lhs, _T_Rhs] | AbstractGelModel] | None: + if lhs is None: + return rhs + elif rhs is None: + return lhs + else: + return create_union(lhs, rhs) + + +def create_union( + lhs: type[_T_Lhs], + rhs: type[_T_Rhs], +) -> type[BaseGelModelUnion[_T_Lhs, _T_Rhs]]: + """Create a runtime union type which acts like a GelModel.""" + + if (lhs_entry := _type_union_cache.get(lhs)) and ( + rhs_entry := lhs_entry.get(rhs) + ): + return rhs_entry # type: ignore[return-value] + + # Combine pointer reflections from args + ptr_reflections: dict[str, _qb.GelPointerReflection] = { + p_name: p_refl + for p_name, p_refl in lhs.__gel_reflection__.pointers.items() + if p_name in rhs.__gel_reflection__.pointers + } + + # Create type reflection for union type + class __gel_reflection__(_qb.GelObjectTypeExprMetadata.__gel_reflection__): # noqa: N801 + expr_object_types: set[type[AbstractGelModel]] = getattr( + lhs.__gel_reflection__, 'expr_object_types', {lhs} + ) | getattr(rhs.__gel_reflection__, 'expr_object_types', {rhs}) + + type_name = TypeNameUnion( + args=( + lhs.__gel_reflection__.type_name, + rhs.__gel_reflection__.type_name, + ) + ) + + pointers = ptr_reflections + + @classmethod + def object( + cls, + ) -> Any: + raise NotImplementedError( + "Type expressions schema objects are inaccessible" + ) + + # Create the resulting union type + result = type( + f"({lhs.__name__} | {rhs.__name__})", + (BaseGelModelUnion,), + { + 'lhs': lhs, + 'rhs': rhs, + '__gel_reflection__': __gel_reflection__, + "__gel_proxied_dunders__": frozenset( + { + "__backlinks__", + } + ), + }, + ) + + # Generate field descriptors. + descriptors: dict[str, ModelFieldDescriptor] = { + p_name: field_descriptor(result, p_name, l_path_alias.__gel_origin__) + for p_name, p_refl in lhs.__gel_reflection__.pointers.items() + if ( + hasattr(lhs, p_name) + and (l_path_alias := getattr(lhs, p_name, None)) is not None + and isinstance(l_path_alias, _qb.PathAlias) + ) + if ( + hasattr(rhs, p_name) + and (r_path_alias := getattr(rhs, p_name, None)) is not None + and isinstance(r_path_alias, _qb.PathAlias) + ) + } + for p_name, descriptor in descriptors.items(): + setattr(result, p_name, descriptor) + + if lhs not in _type_union_cache: + _type_union_cache[lhs] = weakref.WeakKeyDictionary() + _type_union_cache[lhs][rhs] = result + + return result diff --git a/gel/_internal/_typing_dispatch.py b/gel/_internal/_typing_dispatch.py index d74bc23f..77ee9d22 100644 --- a/gel/_internal/_typing_dispatch.py +++ b/gel/_internal/_typing_dispatch.py @@ -70,6 +70,8 @@ def _issubclass(lhs: Any, tp: Any, fn: Any) -> bool: if issubclass(lhs, _type_expression.Intersection): return any(_issubclass(c, tp, fn) for c in (lhs.lhs, lhs.rhs)) + elif issubclass(lhs, _type_expression.Union): + return all(_issubclass(c, tp, fn) for c in (lhs.lhs, lhs.rhs)) if _typing_inspect.is_generic_alias(tp): origin = typing.get_origin(tp) diff --git a/gel/models/pydantic.py b/gel/models/pydantic.py index 2ce9801a..2d7db0e6 100644 --- a/gel/models/pydantic.py +++ b/gel/models/pydantic.py @@ -76,6 +76,8 @@ PyTypeScalarConstraint, RangeMeta, UUIDImpl, + create_optional_union, + create_union, empty_set_if_none, ) @@ -215,6 +217,8 @@ "classonlymethod", "computed_field", "construct_infix_op_chain", + "create_optional_union", + "create_union", "dispatch_overload", "empty_set_if_none", ) From 33ecd3613ad42cf460537e36283b746e606961ef Mon Sep 17 00:00:00 2001 From: dnwpark Date: Mon, 3 Nov 2025 16:28:32 -0800 Subject: [PATCH 2/6] Add tests. --- tests/test_qb.py | 115 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 115 insertions(+) diff --git a/tests/test_qb.py b/tests/test_qb.py index 9cd5a0aa..eb25008b 100644 --- a/tests/test_qb.py +++ b/tests/test_qb.py @@ -2600,6 +2600,121 @@ def test_qb_backlinks_error_01(self): with self.assertRaisesRegex(ValueError, "unsupported query type"): self.client.query(query) + def test_qb_backlinks_error_02(self): + # Unions don't have backlinks + from models.orm_qb import default, std + + with self.assertRaisesRegex( + AttributeError, "has no attribute '__backlinks__'" + ): + std.union(default.Inh_ABC, default.Inh_AB_AC).__backlinks__ + + def test_qb_std_coalesce_scalar_01(self): + from models.orm_qb import std + + query = std.coalesce(1, 2) + result = self.client.query(query) + + self.assertEqual(result, [1]) + + def test_qb_std_coalesce_scalar_02(self): + from models.orm_qb import std + + query = std.coalesce(None, 2) + result = self.client.query(query) + + self.assertEqual(result, [2]) + + def test_qb_std_coalesce_object_01(self): + from models.orm_qb import default, std + + inh_a_objs = { + obj.a: obj + for obj in self.client.query(default.Inh_A.select(a=True)) + } + + query = std.coalesce(default.Inh_AB_AC, default.Inh_ABC) + result = self.client.query(query) + self._assertListEqualUnordered([inh_a_objs[17]], result) + + def test_qb_std_coalesce_object_02(self): + from models.orm_qb import default, std + + inh_a_objs = { + obj.a: obj + for obj in self.client.query(default.Inh_A.select(a=True)) + } + + query = std.coalesce(None, default.Inh_ABC) + result = self.client.query(query) + + self._assertListEqualUnordered([inh_a_objs[13]], result) + + def test_qb_std_coalesce_object_03(self): + from models.orm_qb import default, std + + inh_a_objs = { + obj.a: obj + for obj in self.client.query(default.Inh_A.select(a=True)) + } + + query = std.coalesce( + default.Inh_AB.is_(default.Inh_AC), default.Inh_ABC + ) + result = self.client.query(query) + self._assertListEqualUnordered([inh_a_objs[17]], result) + + def test_qb_std_union_scalar_01(self): + from models.orm_qb import std + + query = std.union(1, 2) + result = self.client.query(query) + self._assertListEqualUnordered(result, [1, 2]) + + def test_qb_std_union_scalar_02(self): + from models.orm_qb import std + + query = std.union(1, [2, 3]) + result = self.client.query(query) + self._assertListEqualUnordered(result, [1, 2, 3]) + + def test_qb_std_union_scalar_03(self): + from models.orm_qb import std + + query = std.union([1, 2], [2, 3]) + result = self.client.query(query) + self._assertListEqualUnordered(result, [1, 2, 2, 3]) + + def test_qb_std_union_object_01(self): + from models.orm_qb import default, std + + inh_a_objs = { + obj.a: obj + for obj in self.client.query(default.Inh_A.select(a=True)) + } + + query = std.union(default.Inh_ABC, default.Inh_AB_AC).select('*') + result = self.client.query(query) + self._assertListEqualUnordered( + [inh_a_objs[13], inh_a_objs[17]], result + ) + + def test_qb_std_union_object_02(self): + from models.orm_qb import default, std + + inh_a_objs = { + obj.a: obj + for obj in self.client.query(default.Inh_A.select(a=True)) + } + + query = std.union( + default.Inh_ABC, default.Inh_AB.is_(default.Inh_AC) + ).select('*') + result = self.client.query(query) + self._assertListEqualUnordered( + [inh_a_objs[13], inh_a_objs[17]], result + ) + class TestQueryBuilderModify(tb.ModelTestCase): """This test suite is for data manipulation using QB.""" From 9d3e0122e2c5951723a0f551f99847f030949a9f Mon Sep 17 00:00:00 2001 From: dnwpark Date: Wed, 5 Nov 2025 13:50:14 -0800 Subject: [PATCH 3/6] Add TernaryOp. --- gel/_internal/_qb/__init__.py | 2 ++ gel/_internal/_qb/_expressions.py | 51 +++++++++++++++++++++++++++++++ gel/models/pydantic.py | 2 ++ 3 files changed, 55 insertions(+) diff --git a/gel/_internal/_qb/__init__.py b/gel/_internal/_qb/__init__.py index cc5074dd..293d19ee 100644 --- a/gel/_internal/_qb/__init__.py +++ b/gel/_internal/_qb/__init__.py @@ -54,6 +54,7 @@ ShapeOp, Splat, StringLiteral, + TernaryOp, UnaryOp, UpdateStmt, Variable, @@ -171,6 +172,7 @@ "Splat", "Stmt", "StringLiteral", + "TernaryOp", "UnaryOp", "UpdateStmt", "VarAlias", diff --git a/gel/_internal/_qb/_expressions.py b/gel/_internal/_qb/_expressions.py index 6ba61f25..e62de7ad 100644 --- a/gel/_internal/_qb/_expressions.py +++ b/gel/_internal/_qb/_expressions.py @@ -425,6 +425,57 @@ def __edgeql_expr__(self, *, ctx: ScopeContext) -> str: return f"{left}[{right}]" +@dataclass(kw_only=True, frozen=True) +class TernaryOp(TypedExpr): + lexpr: Expr + op_1: _edgeql.Token + mexpr: Expr + op_2: _edgeql.Token + rexpr: Expr + + def __init__( + self, + *, + lexpr: ExprCompatible, + op_1: _edgeql.Token | str, + mexpr: ExprCompatible, + op_2: _edgeql.Token | str, + rexpr: ExprCompatible, + type_: TypeNameExpr, + ) -> None: + object.__setattr__(self, "lexpr", edgeql_qb_expr(lexpr)) + if not isinstance(op_1, _edgeql.Token): + op_1 = _edgeql.Token.from_str(op_1) + object.__setattr__(self, "op_1", op_1) + object.__setattr__(self, "mexpr", edgeql_qb_expr(mexpr)) + if not isinstance(op_2, _edgeql.Token): + op_2 = _edgeql.Token.from_str(op_2) + object.__setattr__(self, "op_2", op_2) + object.__setattr__(self, "rexpr", edgeql_qb_expr(rexpr)) + super().__init__(type_=type_) + + def subnodes(self) -> Iterable[Node]: + return (self.lexpr, self.mexpr, self.rexpr) + + def __edgeql_expr__(self, *, ctx: ScopeContext) -> str: + left = edgeql(self.lexpr, ctx=ctx) + if self.lexpr.precedence <= self.precedence: + left = f"({left})" + middle = edgeql(self.mexpr, ctx=ctx) + if self.mexpr.precedence <= self.precedence: + middle = f"({middle})" + right = edgeql(self.rexpr, ctx=ctx) + if self.mexpr.precedence <= self.precedence: + right = f"({right})" + return f"{left} {self.op_1} {middle} {self.op_2} {right}" + + @property + def precedence(self) -> _edgeql.Precedence: + return max( + _edgeql.PRECEDENCE[self.op_1], _edgeql.PRECEDENCE[self.op_2] + ) + + @dataclass(kw_only=True, frozen=True) class FuncCall(TypedExpr): fname: str diff --git a/gel/models/pydantic.py b/gel/models/pydantic.py index 2d7db0e6..b1cd9d93 100644 --- a/gel/models/pydantic.py +++ b/gel/models/pydantic.py @@ -45,6 +45,7 @@ PathAlias, SetLiteral, StringLiteral, + TernaryOp, UnaryOp, construct_infix_op_chain, ) @@ -205,6 +206,7 @@ "SchemaPath", "SetLiteral", "StringLiteral", + "TernaryOp", "TimeDeltaImpl", "TimeImpl", "Tuple", From 67e28555ce0c659b46488bf5aa61182d436ea3e8 Mon Sep 17 00:00:00 2001 From: dnwpark Date: Wed, 5 Nov 2025 13:57:24 -0800 Subject: [PATCH 4/6] Add codegen for ternary operators. --- gel/_internal/_codegen/_models/_pydantic.py | 73 +++++++++++++++++++++ 1 file changed, 73 insertions(+) diff --git a/gel/_internal/_codegen/_models/_pydantic.py b/gel/_internal/_codegen/_models/_pydantic.py index b61f4dc2..444d89ac 100644 --- a/gel/_internal/_codegen/_models/_pydantic.py +++ b/gel/_internal/_codegen/_models/_pydantic.py @@ -2021,6 +2021,14 @@ def process(self, mod: IntrospectedModule) -> None: if op.schemapath.parent == self.canonical_modpath ] ) + self.write_non_magic_ternary_operators( + [ + op + for op in self._operators.other_ops + if op.schemapath.parent == self.canonical_modpath + if op.operator_kind == reflection.OperatorKind.Ternary + ] + ) self.write_globals(mod["globals"]) def reexport_module(self, mod: GeneratedSchemaModule) -> None: @@ -3439,6 +3447,71 @@ def param_getter( excluded_param_types=excluded_param_types, ) + def write_non_magic_ternary_operators( + self, + ops: list[reflection.Operator], + ) -> bool: + # Filter to unary operators without Python magic method equivalents + ternary_ops = [op for op in ops if op.py_magic is None] + if not ternary_ops: + return False + else: + self._write_callables( + ternary_ops, + style="function", + type_ignore=("override", "unused-ignore"), + node_ctor=self._write_ternary_op_func_node_ctor, + ) + return True + + def _write_ternary_op_func_node_ctor( + self, + op: reflection.Operator, + ) -> None: + """Generate the query node constructor for a ternary operator function. + + Creates the code that builds a TernaryOp query node for ternary + operator functions. Unlike method versions, this takes the operand from + function arguments and applies special type casting for tuple + parameters. + + Args: + op: The operator reflection object containing metadata + """ + node_cls = self.import_name(BASE_IMPL, "TernaryOp") + expr_compat = self.import_name(BASE_IMPL, "ExprCompatible") + cast_ = self.import_name("typing", "cast") + + op_1: str + op_2: str + if op.schemapath == SchemaPath("std", "IF"): + op_1 = op.schemapath.name + op_2 = '"ELSE"' + + else: + raise NotImplementedError(f"Unknown operator {op.schemapath}") + + if_true = "__args__[0]" + condition = "__args__[1]" + if_false = "__args__[2]" + # Tuple parameters need ExprCompatible casting + # due to a possible mypy bug. + if reflection.is_tuple_type(op.params[0].get_type(self._types)): + if_true = f"{cast_}({expr_compat!r}, {if_true})" + if reflection.is_tuple_type(op.params[2].get_type(self._types)): + if_false = f"{cast_}({expr_compat!r}, {if_false})" + + args = [ + f"lexpr={if_true}", + f'op_1="{op_1}"', # Gel operator name (e.g., "IF") + f"mexpr={condition}", + f"op_2={op_2}", + f"rexpr={if_false}", + "type_=__rtype__.__gel_reflection__.type_name", # Result type info + ] + + self.write(self.format_list(f"{node_cls}({{list}}),", args)) + def _partition_nominal_overloads( self, callables: Iterable[_Callable_T], From 1f504d935bcb970d1a79e0b11e3694c6dc43923f Mon Sep 17 00:00:00 2001 From: dnwpark Date: Wed, 5 Nov 2025 17:43:55 -0800 Subject: [PATCH 5/6] Add hack for bool overload. --- gel/_internal/_codegen/_models/_pydantic.py | 56 ++++++++++++++++++++- gel/_internal/_reflection/__init__.py | 2 + 2 files changed, 57 insertions(+), 1 deletion(-) diff --git a/gel/_internal/_codegen/_models/_pydantic.py b/gel/_internal/_codegen/_models/_pydantic.py index 444d89ac..f7baeb9b 100644 --- a/gel/_internal/_codegen/_models/_pydantic.py +++ b/gel/_internal/_codegen/_models/_pydantic.py @@ -3722,9 +3722,52 @@ def _write_potentially_overlapping_overloads( # SEE ABOVE: This is what we actually want. # key=lambda o: (generality_key(o), o.edgeql_signature), # noqa: ERA001, E501 ) + base_generic_overload: dict[_Callable_T, _Callable_T] = {} for overload in overloads: overload_signatures[overload] = {} + + if overload.schemapath == SchemaPath('std', 'IF'): + # HACK: Pretend the base overload of std::IF is generic on + # anyobject. + # + # The base overload of std::IF is + # (anytype, std::bool, anytype) -> anytype + # + # However, this causes an overlap with overloading for bool + # arguments since + # (anytype, builtin.bool, anytype) -> anytype + # overlaps with + # (std::bool, builtin.bool, std::bool) -> std::bool + # + # We resolve this by generating the specializations for anytype + # but using anyobject as the base generic type. + + def anytype_to_anyobject( + refl_type: reflection.Type, + default: reflection.Type | reflection.TypeRef, + ) -> reflection.Type | reflection.TypeRef: + if isinstance(refl_type, reflection.PseudoType): + return self._types_by_name["anyobject"] + return default + + base_generic_overload[overload] = dataclasses.replace( + overload, + params=[ + dataclasses.replace( + param, + type=anytype_to_anyobject( + param.get_type(self._types), param.type + ), + ) + for param in overload.params + ], + return_type=anytype_to_anyobject( + overload.get_return_type(self._types), + overload.return_type, + ), + ) + for param in param_getter(overload): param_overload_map[param.key].add(overload) param_type = param.get_type(self._types) @@ -3732,6 +3775,14 @@ def _write_potentially_overlapping_overloads( if param.kind is reflection.CallableParamKind.Variadic: if reflection.is_array_type(param_type): param_type = param_type.get_element_type(self._types) + + if ( + overload.schemapath == SchemaPath('std', 'IF') + and param_type.is_pseudo + ): + # Also generate the base signature using anyobject + param_type = self._types_by_name["anyobject"] + # Start with the base parameter type overload_signatures[overload][param.key] = [param_type] @@ -3843,7 +3894,10 @@ def specialization_sort_key(t: reflection.Type) -> int: for overload in overloads: if overload_specs := overloads_specializations.get(overload): expanded_overloads.extend(overload_specs) - expanded_overloads.append(overload) + if overload in base_generic_overload: + expanded_overloads.append(base_generic_overload[overload]) + else: + expanded_overloads.append(overload) overloads = expanded_overloads overload_order = {overload: i for i, overload in enumerate(overloads)} diff --git a/gel/_internal/_reflection/__init__.py b/gel/_internal/_reflection/__init__.py index 21f6462c..c63b5baa 100644 --- a/gel/_internal/_reflection/__init__.py +++ b/gel/_internal/_reflection/__init__.py @@ -68,6 +68,7 @@ ScalarType, TupleType, Type, + TypeRef, compare_type_generality, fetch_types, is_abstract_type, @@ -126,6 +127,7 @@ "Type", "TypeKind", "TypeModifier", + "TypeRef", "compare_callable_generality", "compare_type_generality", "fetch_branch_state", From f43250bbf065a143d1b03b41925fa386fecf3887 Mon Sep 17 00:00:00 2001 From: dnwpark Date: Wed, 5 Nov 2025 17:43:59 -0800 Subject: [PATCH 6/6] Add tests. --- tests/test_qb.py | 253 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 253 insertions(+) diff --git a/tests/test_qb.py b/tests/test_qb.py index eb25008b..6d81cfa8 100644 --- a/tests/test_qb.py +++ b/tests/test_qb.py @@ -2715,6 +2715,259 @@ def test_qb_std_union_object_02(self): [inh_a_objs[13], inh_a_objs[17]], result ) + def test_qb_std_if_else_scalar_01a(self): + from models.orm_qb import std + + lhs = [1, 2, 3] + rhs = [4, 5, 6] + + query = std.if_else(lhs, True, rhs) + result = self.client.query(query) + self._assertListEqualUnordered([1, 2, 3], result) + + query = std.if_else(lhs, False, rhs) + result = self.client.query(query) + self._assertListEqualUnordered([4, 5, 6], result) + + query = std.if_else(lhs, std.bool(True), rhs) + result = self.client.query(query) + self._assertListEqualUnordered([1, 2, 3], result) + + query = std.if_else(lhs, std.bool(False), rhs) + result = self.client.query(query) + self._assertListEqualUnordered([4, 5, 6], result) + + def test_qb_std_if_else_scalar_01b(self): + from models.orm_qb import std + + lhs = [1, 2, 3] + rhs = [std.int64(4), std.int64(5), std.int64(6)] + + query = std.if_else(lhs, True, rhs) + result = self.client.query(query) + self._assertListEqualUnordered([1, 2, 3], result) + + query = std.if_else(lhs, False, rhs) + result = self.client.query(query) + self._assertListEqualUnordered([4, 5, 6], result) + + query = std.if_else(lhs, std.bool(True), rhs) + result = self.client.query(query) + self._assertListEqualUnordered([1, 2, 3], result) + + query = std.if_else(lhs, std.bool(False), rhs) + result = self.client.query(query) + self._assertListEqualUnordered([4, 5, 6], result) + + def test_qb_std_if_else_scalar_01c(self): + from models.orm_qb import std + + lhs = [std.int64(1), std.int64(2), std.int64(3)] + rhs = [4, 5, 6] + + query = std.if_else(lhs, True, rhs) + result = self.client.query(query) + self._assertListEqualUnordered([1, 2, 3], result) + + query = std.if_else(lhs, False, rhs) + result = self.client.query(query) + self._assertListEqualUnordered([4, 5, 6], result) + + query = std.if_else(lhs, std.bool(True), rhs) + result = self.client.query(query) + self._assertListEqualUnordered([1, 2, 3], result) + + query = std.if_else(lhs, std.bool(False), rhs) + result = self.client.query(query) + self._assertListEqualUnordered([4, 5, 6], result) + + def test_qb_std_if_else_scalar_01d(self): + from models.orm_qb import std + + lhs = [std.int64(1), std.int64(2), std.int64(3)] + rhs = [std.int64(4), std.int64(5), std.int64(6)] + + query = std.if_else(lhs, True, rhs) + result = self.client.query(query) + self._assertListEqualUnordered([1, 2, 3], result) + + query = std.if_else(lhs, False, rhs) + result = self.client.query(query) + self._assertListEqualUnordered([4, 5, 6], result) + + query = std.if_else(lhs, std.bool(True), rhs) + result = self.client.query(query) + self._assertListEqualUnordered([1, 2, 3], result) + + query = std.if_else(lhs, std.bool(False), rhs) + result = self.client.query(query) + self._assertListEqualUnordered([4, 5, 6], result) + + def test_qb_std_if_else_scalar_02a(self): + from models.orm_qb import std + + lhs = [True, True, True] + rhs = [False, False, False] + + query = std.if_else(lhs, True, rhs) + result = self.client.query(query) + self._assertListEqualUnordered([True, True, True], result) + + query = std.if_else(lhs, False, rhs) + result = self.client.query(query) + self._assertListEqualUnordered([False, False, False], result) + + query = std.if_else(lhs, std.bool(True), rhs) + result = self.client.query(query) + self._assertListEqualUnordered([True, True, True], result) + + query = std.if_else(lhs, std.bool(False), rhs) + result = self.client.query(query) + self._assertListEqualUnordered([False, False, False], result) + + def test_qb_std_if_else_scalar_02b(self): + from models.orm_qb import std + + lhs = [True, True, True] + rhs = [std.bool(False), std.bool(False), std.bool(False)] + + query = std.if_else(lhs, True, rhs) + result = self.client.query(query) + self._assertListEqualUnordered([True, True, True], result) + + query = std.if_else(lhs, False, rhs) + result = self.client.query(query) + self._assertListEqualUnordered([False, False, False], result) + + query = std.if_else(lhs, std.bool(True), rhs) + result = self.client.query(query) + self._assertListEqualUnordered([True, True, True], result) + + query = std.if_else(lhs, std.bool(False), rhs) + result = self.client.query(query) + self._assertListEqualUnordered([False, False, False], result) + + def test_qb_std_if_else_scalar_02c(self): + from models.orm_qb import std + + lhs = [std.bool(True), std.bool(True), std.bool(True)] + rhs = [False, False, False] + + query = std.if_else(lhs, True, rhs) + result = self.client.query(query) + self._assertListEqualUnordered([True, True, True], result) + + query = std.if_else(lhs, False, rhs) + result = self.client.query(query) + self._assertListEqualUnordered([False, False, False], result) + + query = std.if_else(lhs, std.bool(True), rhs) + result = self.client.query(query) + self._assertListEqualUnordered([True, True, True], result) + + query = std.if_else(lhs, std.bool(False), rhs) + result = self.client.query(query) + self._assertListEqualUnordered([False, False, False], result) + + def test_qb_std_if_else_scalar_02d(self): + from models.orm_qb import std + + lhs = [std.bool(True), std.bool(True), std.bool(True)] + rhs = [std.bool(False), std.bool(False), std.bool(False)] + + query = std.if_else(lhs, True, rhs) + result = self.client.query(query) + self._assertListEqualUnordered([True, True, True], result) + + query = std.if_else(lhs, False, rhs) + result = self.client.query(query) + self._assertListEqualUnordered([False, False, False], result) + + query = std.if_else(lhs, std.bool(True), rhs) + result = self.client.query(query) + self._assertListEqualUnordered([True, True, True], result) + + query = std.if_else(lhs, std.bool(False), rhs) + result = self.client.query(query) + self._assertListEqualUnordered([False, False, False], result) + + def test_qb_std_if_else_object_01(self): + from models.orm_qb import default, std + + inh_a_objs = { + obj.a: obj + for obj in self.client.query(default.Inh_A.select(a=True)) + } + + query = std.if_else(default.Inh_ABC, True, default.Inh_AB_AC) + result = self.client.query(query) + self._assertListEqualUnordered([inh_a_objs[13]], result) + + query = std.if_else(default.Inh_ABC, False, default.Inh_AB_AC) + result = self.client.query(query) + self._assertListEqualUnordered([inh_a_objs[17]], result) + + query = std.if_else(default.Inh_ABC, std.bool(True), default.Inh_AB_AC) + result = self.client.query(query) + self._assertListEqualUnordered([inh_a_objs[13]], result) + + query = std.if_else( + default.Inh_ABC, std.bool(False), default.Inh_AB_AC + ) + result = self.client.query(query) + self._assertListEqualUnordered([inh_a_objs[17]], result) + + def test_qb_std_if_else_object_02(self): + from models.orm_qb import default, std + + query = default.Inh_A.select( + a=lambda x: std.if_else(x.a, x.a < 10, x.a + 100) + ) + result = self.client.query(query) + + self._assertObjectsWithFields( + result, + "a", + [ + ( + default.Inh_A, + { + "a": 1, + }, + ), + ( + default.Inh_AB, + { + "a": 4, + }, + ), + ( + default.Inh_AC, + { + "a": 7, + }, + ), + ( + default.Inh_ABC, + { + "a": 113, + }, + ), + ( + default.Inh_AB_AC, + { + "a": 117, + }, + ), + ( + default.Inh_AXA, + { + "a": 1101, + }, + ), + ], + ) + class TestQueryBuilderModify(tb.ModelTestCase): """This test suite is for data manipulation using QB."""