diff --git a/infrahub_sdk/spec/object.py b/infrahub_sdk/spec/object.py index 0df3f95c..456c9f1e 100644 --- a/infrahub_sdk/spec/object.py +++ b/infrahub_sdk/spec/object.py @@ -7,6 +7,7 @@ from ..exceptions import ObjectValidationError, ValidationError from ..schema import GenericSchemaAPI, RelationshipKind, RelationshipSchema +from ..utils import is_valid_uuid from ..yaml import InfrahubFile, InfrahubFileKind from .models import InfrahubObjectParameters from .processors.factory import DataProcessorFactory @@ -33,6 +34,36 @@ def validate_list_of_objects(value: list[Any]) -> bool: return all(isinstance(item, dict) for item in value) +def normalize_hfid_reference(value: str | list[str]) -> str | list[str]: + """Normalize a reference value to HFID format. + + Only call this function when the peer schema has human_friendly_id defined. + + Args: + value: Either a string (ID or single-component HFID) or a list of strings (multi-component HFID). + + Returns: + - If value is already a list: returns it unchanged as list[str] + - If value is a valid UUID string: returns it unchanged as str (will be treated as an ID) + - If value is a non-UUID string: wraps it in a list as list[str] (single-component HFID) + """ + if isinstance(value, list): + return value + if is_valid_uuid(value): + return value + return [value] + + +def normalize_hfid_references(values: list[str | list[str]]) -> list[str | list[str]]: + """Normalize a list of reference values to HFID format. + + Only call this function when the peer schema has human_friendly_id defined. + + Each string that is not a valid UUID will be wrapped in a list to treat it as a single-component HFID. + """ + return [normalize_hfid_reference(v) for v in values] + + class RelationshipDataFormat(str, Enum): UNKNOWN = "unknown" @@ -51,6 +82,12 @@ class RelationshipInfo(BaseModel): peer_rel: RelationshipSchema | None = None reason_relationship_not_valid: str | None = None format: RelationshipDataFormat = RelationshipDataFormat.UNKNOWN + peer_human_friendly_id: list[str] | None = None + + @property + def peer_has_hfid(self) -> bool: + """Indicate if the peer schema has a human-friendly ID defined.""" + return bool(self.peer_human_friendly_id) @property def is_bidirectional(self) -> bool: @@ -119,6 +156,7 @@ async def get_relationship_info( info.peer_kind = value["kind"] peer_schema = await client.schema.get(kind=info.peer_kind, branch=branch) + info.peer_human_friendly_id = peer_schema.human_friendly_id try: info.peer_rel = peer_schema.get_matching_relationship( @@ -444,10 +482,12 @@ async def create_node( # - if the relationship is bidirectional and is mandatory on the other side, then we need to create this object First # - if the relationship is bidirectional and is not mandatory on the other side, then we need should create the related object First # - if the relationship is not bidirectional, then we need to create the related object First - if rel_info.is_reference and isinstance(value, list): - clean_data[key] = value - elif rel_info.format == RelationshipDataFormat.ONE_REF and isinstance(value, str): - clean_data[key] = [value] + if rel_info.format == RelationshipDataFormat.MANY_REF and isinstance(value, list): + # Cardinality-many reference: normalize string HFIDs to list format if peer has HFID defined + clean_data[key] = normalize_hfid_references(value) if rel_info.peer_has_hfid else value + elif rel_info.format == RelationshipDataFormat.ONE_REF: + # Cardinality-one reference: normalize string to HFID list if peer has HFID, else pass as-is + clean_data[key] = normalize_hfid_reference(value) if rel_info.peer_has_hfid else value elif not rel_info.is_reference and rel_info.is_bidirectional and rel_info.is_mandatory: remaining_rels.append(key) elif not rel_info.is_reference and not rel_info.is_mandatory: diff --git a/pyproject.toml b/pyproject.toml index cab7c7c4..854fc39d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -328,25 +328,6 @@ max-complexity = 17 "ARG002", # Unused method argument ] -################################################################################################## -# ANN001 ignores - broken down for incremental cleanup # -# Remove each section as type annotations are added to that directory # -################################################################################################## - -# tests/unit/sdk/ - 478 errors total -"tests/unit/sdk/test_node.py" = ["ANN001"] # 206 errors -"tests/unit/sdk/test_client.py" = ["ANN001"] # 85 errors -"tests/unit/sdk/test_schema.py" = ["ANN001"] # 36 errors -"tests/unit/sdk/test_artifact.py" = ["ANN001"] # 27 errors -"tests/unit/sdk/test_hierarchical_nodes.py" = ["ANN001"] # 26 errors -"tests/unit/sdk/test_task.py" = ["ANN001"] # 21 errors -"tests/unit/sdk/test_store.py" = ["ANN001"] # 12 errors -"tests/unit/sdk/spec/test_object.py" = ["ANN001"] # 11 errors -"tests/unit/sdk/conftest.py" = ["ANN001"] # 11 errors -"tests/unit/sdk/test_diff_summary.py" = ["ANN001"] # 9 errors -"tests/unit/sdk/test_object_store.py" = ["ANN001"] # 7 errors -"tests/unit/sdk/graphql/test_query.py" = ["ANN001"] # 7 errors - # tests/integration/ "tests/integration/test_infrahub_client.py" = ["PLR0904"] "tests/integration/test_infrahub_client_sync.py" = ["PLR0904"] diff --git a/tests/fixtures/schema_01.json b/tests/fixtures/schema_01.json index 344ebeab..c2fab38a 100644 --- a/tests/fixtures/schema_01.json +++ b/tests/fixtures/schema_01.json @@ -242,7 +242,10 @@ "label": null, "inherit_from": [], "branch": "aware", - "default_filter": "name__value" + "default_filter": "name__value", + "human_friendly_id": [ + "name__value" + ] }, { "name": "Location", diff --git a/tests/unit/sdk/conftest.py b/tests/unit/sdk/conftest.py index 55f2be7f..8fb9ecf2 100644 --- a/tests/unit/sdk/conftest.py +++ b/tests/unit/sdk/conftest.py @@ -103,7 +103,7 @@ def replace_annotation(annotation: str) -> str: @pytest.fixture def replace_async_parameter_annotations( - replace_async_return_annotation, + replace_async_return_annotation: Callable[[str], str], ) -> Callable[[Mapping[str, Parameter]], list[tuple[str, str]]]: """Allows for comparison between sync and async parameter annotations.""" @@ -130,7 +130,7 @@ def replace_annotation(annotation: str) -> str: @pytest.fixture def replace_sync_parameter_annotations( - replace_sync_return_annotation, + replace_sync_return_annotation: Callable[[str], str], ) -> Callable[[Mapping[str, Parameter]], list[tuple[str, str]]]: """Allows for comparison between sync and async parameter annotations.""" @@ -1501,7 +1501,7 @@ async def mock_repositories_query_no_pagination(httpx_mock: HTTPXMock) -> HTTPXM @pytest.fixture async def mock_query_repository_all_01_no_pagination( - httpx_mock: HTTPXMock, client: InfrahubClient, mock_schema_query_01 + httpx_mock: HTTPXMock, client: InfrahubClient, mock_schema_query_01: HTTPXMock ) -> HTTPXMock: response = { "data": { @@ -1601,7 +1601,7 @@ async def mock_repositories_query(httpx_mock: HTTPXMock) -> HTTPXMock: @pytest.fixture async def mock_query_repository_page1_1( - httpx_mock: HTTPXMock, client: InfrahubClient, mock_schema_query_01 + httpx_mock: HTTPXMock, client: InfrahubClient, mock_schema_query_01: HTTPXMock ) -> HTTPXMock: response = { "data": { @@ -1641,7 +1641,9 @@ async def mock_query_repository_page1_1( @pytest.fixture -async def mock_query_corenode_page1_1(httpx_mock: HTTPXMock, client: InfrahubClient, mock_schema_query_02) -> HTTPXMock: +async def mock_query_corenode_page1_1( + httpx_mock: HTTPXMock, client: InfrahubClient, mock_schema_query_02: HTTPXMock +) -> HTTPXMock: response = { "data": { "CoreNode": { @@ -1676,14 +1678,16 @@ async def mock_query_corenode_page1_1(httpx_mock: HTTPXMock, client: InfrahubCli @pytest.fixture -async def mock_query_repository_count(httpx_mock: HTTPXMock, client: InfrahubClient, mock_schema_query_01) -> HTTPXMock: +async def mock_query_repository_count( + httpx_mock: HTTPXMock, client: InfrahubClient, mock_schema_query_01: HTTPXMock +) -> HTTPXMock: httpx_mock.add_response(method="POST", json={"data": {"CoreRepository": {"count": 5}}}, is_reusable=True) return httpx_mock @pytest.fixture async def mock_query_repository_page1_empty( - httpx_mock: HTTPXMock, client: InfrahubClient, mock_schema_query_01 + httpx_mock: HTTPXMock, client: InfrahubClient, mock_schema_query_01: HTTPXMock ) -> HTTPXMock: response: dict = {"data": {"CoreRepository": {"edges": []}}} @@ -1698,7 +1702,7 @@ async def mock_query_repository_page1_empty( @pytest.fixture async def mock_query_repository_page1_2( - httpx_mock: HTTPXMock, client: InfrahubClient, mock_schema_query_01 + httpx_mock: HTTPXMock, client: InfrahubClient, mock_schema_query_01: HTTPXMock ) -> HTTPXMock: response = { "data": { @@ -1748,7 +1752,7 @@ async def mock_query_repository_page1_2( @pytest.fixture async def mock_query_repository_page2_2( - httpx_mock: HTTPXMock, client: InfrahubClient, mock_schema_query_01 + httpx_mock: HTTPXMock, client: InfrahubClient, mock_schema_query_01: HTTPXMock ) -> HTTPXMock: response = { "data": { @@ -2512,7 +2516,7 @@ async def mock_schema_query_ipam(httpx_mock: HTTPXMock) -> HTTPXMock: @pytest.fixture async def mock_query_location_batch_count( - httpx_mock: HTTPXMock, client: InfrahubClient, mock_schema_query_01 + httpx_mock: HTTPXMock, client: InfrahubClient, mock_schema_query_01: HTTPXMock ) -> HTTPXMock: response = {"data": {"BuiltinLocation": {"count": 30}}} httpx_mock.add_response(method="POST", url="http://mock/graphql/main", json=response, is_reusable=True) @@ -2520,7 +2524,9 @@ async def mock_query_location_batch_count( @pytest.fixture -async def mock_query_location_batch(httpx_mock: HTTPXMock, client: InfrahubClient, mock_schema_query_01) -> HTTPXMock: +async def mock_query_location_batch( + httpx_mock: HTTPXMock, client: InfrahubClient, mock_schema_query_01: HTTPXMock +) -> HTTPXMock: for i in range(1, 11): filename = get_fixtures_dir() / "batch" / f"mock_query_location_page{i}.json" response_text = filename.read_text(encoding="UTF-8") diff --git a/tests/unit/sdk/graphql/test_query.py b/tests/unit/sdk/graphql/test_query.py index a01c41d3..0fd4fb72 100644 --- a/tests/unit/sdk/graphql/test_query.py +++ b/tests/unit/sdk/graphql/test_query.py @@ -1,4 +1,5 @@ from enum import Enum +from typing import Any from infrahub_sdk.graphql.query import Mutation, Query @@ -13,7 +14,7 @@ class MyIntEnum(int, Enum): VALUE2 = 24 -def test_query_rendering_no_vars(query_data_no_filter) -> None: +def test_query_rendering_no_vars(query_data_no_filter: dict[str, Any]) -> None: query = Query(query=query_data_no_filter) expected_query = """ @@ -37,7 +38,7 @@ def test_query_rendering_no_vars(query_data_no_filter) -> None: assert query.render() == expected_query -def test_query_rendering_empty_filter(query_data_empty_filter) -> None: +def test_query_rendering_empty_filter(query_data_empty_filter: dict[str, Any]) -> None: query = Query(query=query_data_empty_filter) expected_query = """ @@ -61,7 +62,7 @@ def test_query_rendering_empty_filter(query_data_empty_filter) -> None: assert query.render() == expected_query -def test_query_rendering_with_filters_and_vars(query_data_filters_01) -> None: +def test_query_rendering_with_filters_and_vars(query_data_filters_01: dict[str, Any]) -> None: query = Query(query=query_data_filters_01, variables={"name": str, "enabled": bool}) expected_query = """ @@ -85,7 +86,7 @@ def test_query_rendering_with_filters_and_vars(query_data_filters_01) -> None: assert query.render() == expected_query -def test_query_rendering_with_filters(query_data_filters_02) -> None: +def test_query_rendering_with_filters(query_data_filters_02: dict[str, Any]) -> None: query = Query(query=query_data_filters_02) expected_query = """ @@ -105,7 +106,7 @@ def test_query_rendering_with_filters(query_data_filters_02) -> None: assert query.render() == expected_query -def test_query_rendering_with_filters_convert_enum(query_data_filters_02) -> None: +def test_query_rendering_with_filters_convert_enum(query_data_filters_02: dict[str, Any]) -> None: query = Query(query=query_data_filters_02) expected_query = """ @@ -125,7 +126,7 @@ def test_query_rendering_with_filters_convert_enum(query_data_filters_02) -> Non assert query.render(convert_enum=True) == expected_query -def test_mutation_rendering_no_vars(input_data_01) -> None: +def test_mutation_rendering_no_vars(input_data_01: dict[str, Any]) -> None: query_data = {"ok": None, "object": {"id": None}} query = Mutation(mutation="myobject_create", query=query_data, input_data=input_data_01) @@ -245,7 +246,7 @@ def test_mutation_rendering_enum() -> None: assert query.render() == expected_query -def test_mutation_rendering_with_vars(input_data_01) -> None: +def test_mutation_rendering_with_vars(input_data_01: dict[str, Any]) -> None: query_data = {"ok": None, "object": {"id": None}} variables = {"name": str, "description": str, "number": int} query = Mutation( diff --git a/tests/unit/sdk/spec/test_object.py b/tests/unit/sdk/spec/test_object.py index 1af02ac3..90b248b1 100644 --- a/tests/unit/sdk/spec/test_object.py +++ b/tests/unit/sdk/spec/test_object.py @@ -1,14 +1,23 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any +from unittest.mock import AsyncMock, MagicMock, patch import pytest from infrahub_sdk.exceptions import ValidationError -from infrahub_sdk.spec.object import ObjectFile, RelationshipDataFormat, get_relationship_info +from infrahub_sdk.node.related_node import RelatedNode +from infrahub_sdk.spec.object import ( + ObjectFile, + RelationshipDataFormat, + get_relationship_info, + normalize_hfid_reference, +) if TYPE_CHECKING: from infrahub_sdk.client import InfrahubClient + from infrahub_sdk.node import InfrahubNode @pytest.fixture @@ -118,7 +127,7 @@ def location_with_empty_parameters(root_location: dict) -> dict: return location -async def test_validate_object(client: InfrahubClient, schema_query_01_data: dict, location_mexico_01) -> None: +async def test_validate_object(client: InfrahubClient, schema_query_01_data: dict, location_mexico_01: dict) -> None: client.schema.set_cache(schema=schema_query_01_data, branch="main") obj = ObjectFile(location="some/path", content=location_mexico_01) await obj.validate_format(client=client) @@ -127,7 +136,7 @@ async def test_validate_object(client: InfrahubClient, schema_query_01_data: dic async def test_validate_object_bad_syntax01( - client: InfrahubClient, schema_query_01_data: dict, location_bad_syntax01 + client: InfrahubClient, schema_query_01_data: dict, location_bad_syntax01: dict ) -> None: client.schema.set_cache(schema=schema_query_01_data, branch="main") obj = ObjectFile(location="some/path", content=location_bad_syntax01) @@ -137,7 +146,7 @@ async def test_validate_object_bad_syntax01( assert "name" in str(exc.value) -async def test_validate_object_bad_syntax02(client_with_schema_01: InfrahubClient, location_bad_syntax02) -> None: +async def test_validate_object_bad_syntax02(client_with_schema_01: InfrahubClient, location_bad_syntax02: dict) -> None: obj = ObjectFile(location="some/path", content=location_bad_syntax02) with pytest.raises(ValidationError) as exc: await obj.validate_format(client=client_with_schema_01) @@ -145,7 +154,7 @@ async def test_validate_object_bad_syntax02(client_with_schema_01: InfrahubClien assert "notvalidattribute" in str(exc.value) -async def test_validate_object_expansion(client_with_schema_01: InfrahubClient, location_expansion) -> None: +async def test_validate_object_expansion(client_with_schema_01: InfrahubClient, location_expansion: dict) -> None: obj = ObjectFile(location="some/path", content=location_expansion) await obj.validate_format(client=client_with_schema_01) @@ -155,7 +164,7 @@ async def test_validate_object_expansion(client_with_schema_01: InfrahubClient, assert obj.spec.data[4]["name"] == "AMS5" -async def test_validate_no_object_expansion(client_with_schema_01: InfrahubClient, no_location_expansion) -> None: +async def test_validate_no_object_expansion(client_with_schema_01: InfrahubClient, no_location_expansion: dict) -> None: obj = ObjectFile(location="some/path", content=no_location_expansion) await obj.validate_format(client=client_with_schema_01) assert obj.spec.kind == "BuiltinLocation" @@ -165,7 +174,7 @@ async def test_validate_no_object_expansion(client_with_schema_01: InfrahubClien async def test_validate_object_expansion_multiple_ranges( - client_with_schema_01: InfrahubClient, location_expansion_multiple_ranges + client_with_schema_01: InfrahubClient, location_expansion_multiple_ranges: dict ) -> None: obj = ObjectFile(location="some/path", content=location_expansion_multiple_ranges) await obj.validate_format(client=client_with_schema_01) @@ -179,7 +188,7 @@ async def test_validate_object_expansion_multiple_ranges( async def test_validate_object_expansion_multiple_ranges_bad_syntax( - client_with_schema_01: InfrahubClient, location_expansion_multiple_ranges_bad_syntax + client_with_schema_01: InfrahubClient, location_expansion_multiple_ranges_bad_syntax: dict ) -> None: obj = ObjectFile(location="some/path", content=location_expansion_multiple_ranges_bad_syntax) with pytest.raises(ValidationError) as exc: @@ -241,25 +250,209 @@ async def test_get_relationship_info_tags( assert rel_info.format == format -async def test_parameters_top_level(client_with_schema_01: InfrahubClient, location_expansion) -> None: +async def test_parameters_top_level(client_with_schema_01: InfrahubClient, location_expansion: dict) -> None: obj = ObjectFile(location="some/path", content=location_expansion) await obj.validate_format(client=client_with_schema_01) assert obj.spec.parameters.expand_range is True -async def test_parameters_missing(client_with_schema_01: InfrahubClient, location_mexico_01) -> None: +async def test_parameters_missing(client_with_schema_01: InfrahubClient, location_mexico_01: dict) -> None: obj = ObjectFile(location="some/path", content=location_mexico_01) await obj.validate_format(client=client_with_schema_01) assert hasattr(obj.spec.parameters, "expand_range") -async def test_parameters_empty_dict(client_with_schema_01: InfrahubClient, location_with_empty_parameters) -> None: +async def test_parameters_empty_dict( + client_with_schema_01: InfrahubClient, location_with_empty_parameters: dict +) -> None: obj = ObjectFile(location="some/path", content=location_with_empty_parameters) await obj.validate_format(client=client_with_schema_01) assert hasattr(obj.spec.parameters, "expand_range") -async def test_parameters_non_dict(client_with_schema_01: InfrahubClient, location_with_non_dict_parameters) -> None: +async def test_parameters_non_dict( + client_with_schema_01: InfrahubClient, location_with_non_dict_parameters: dict +) -> None: obj = ObjectFile(location="some/path", content=location_with_non_dict_parameters) with pytest.raises(ValidationError): await obj.validate_format(client=client_with_schema_01) + + +@dataclass +class HfidLoadTestCase: + """Test case for HFID normalization in object loading.""" + + name: str + data: list[dict[str, Any]] + expected_primary_tag: str | list[str] | None + expected_tags: list[str] | list[list[str]] | None + + +HFID_NORMALIZATION_TEST_CASES = [ + HfidLoadTestCase( + name="cardinality_one_string_hfid_normalized", + data=[{"name": "Mexico", "type": "Country", "primary_tag": "Important"}], + expected_primary_tag=["Important"], + expected_tags=None, + ), + HfidLoadTestCase( + name="cardinality_one_list_hfid_unchanged", + data=[{"name": "Mexico", "type": "Country", "primary_tag": ["Important"]}], + expected_primary_tag=["Important"], + expected_tags=None, + ), + HfidLoadTestCase( + name="cardinality_one_uuid_unchanged", + data=[{"name": "Mexico", "type": "Country", "primary_tag": "550e8400-e29b-41d4-a716-446655440000"}], + expected_primary_tag="550e8400-e29b-41d4-a716-446655440000", + expected_tags=None, + ), + HfidLoadTestCase( + name="cardinality_many_string_hfids_normalized", + data=[{"name": "Mexico", "type": "Country", "tags": ["Important", "Active"]}], + expected_primary_tag=None, + expected_tags=[["Important"], ["Active"]], + ), + HfidLoadTestCase( + name="cardinality_many_list_hfids_unchanged", + data=[{"name": "Mexico", "type": "Country", "tags": [["Important"], ["Active"]]}], + expected_primary_tag=None, + expected_tags=[["Important"], ["Active"]], + ), + HfidLoadTestCase( + name="cardinality_many_mixed_hfids_normalized", + data=[{"name": "Mexico", "type": "Country", "tags": ["Important", ["namespace", "name"]]}], + expected_primary_tag=None, + expected_tags=[["Important"], ["namespace", "name"]], + ), + HfidLoadTestCase( + name="cardinality_many_uuids_unchanged", + data=[ + { + "name": "Mexico", + "type": "Country", + "tags": ["550e8400-e29b-41d4-a716-446655440000", "6ba7b810-9dad-11d1-80b4-00c04fd430c8"], + } + ], + expected_primary_tag=None, + expected_tags=["550e8400-e29b-41d4-a716-446655440000", "6ba7b810-9dad-11d1-80b4-00c04fd430c8"], + ), +] + + +@pytest.mark.parametrize("test_case", HFID_NORMALIZATION_TEST_CASES, ids=lambda tc: tc.name) +async def test_hfid_normalization_in_object_loading( + client_with_schema_01: InfrahubClient, test_case: HfidLoadTestCase +) -> None: + """Test that HFIDs are normalized correctly based on cardinality and format.""" + + root_location = {"apiVersion": "infrahub.app/v1", "kind": "Object", "spec": {"kind": "BuiltinLocation", "data": []}} + location = { + "apiVersion": root_location["apiVersion"], + "kind": root_location["kind"], + "spec": {"kind": root_location["spec"]["kind"], "data": test_case.data}, + } + + obj = ObjectFile(location="some/path", content=location) + await obj.validate_format(client=client_with_schema_01) + + create_calls: list[dict[str, Any]] = [] + + async def mock_create( + kind: str, + branch: str | None = None, + data: dict | None = None, + **kwargs: Any, # noqa: ANN401 + ) -> InfrahubNode: + create_calls.append({"kind": kind, "data": data}) + original_create = client_with_schema_01.__class__.create + return await original_create(client_with_schema_01, kind=kind, branch=branch, data=data, **kwargs) + + client_with_schema_01.create = mock_create + + with patch("infrahub_sdk.node.InfrahubNode.save", new_callable=AsyncMock): + await obj.process(client=client_with_schema_01) + + assert len(create_calls) == 1 + if test_case.expected_primary_tag is not None: + assert create_calls[0]["data"]["primary_tag"] == test_case.expected_primary_tag + if test_case.expected_tags is not None: + assert create_calls[0]["data"]["tags"] == test_case.expected_tags + + +def test_normalize_hfid_reference_function() -> None: + """Test the normalize_hfid_reference function directly. + + This tests the normalization logic in isolation: + - Non-UUID strings get wrapped in a list (for HFID lookup) + - UUID strings stay as strings (for ID lookup) + - Lists stay unchanged + """ + # Non-UUID string becomes list + assert normalize_hfid_reference("Important") == ["Important"] + + # UUID string stays as string + uuid_value = "550e8400-e29b-41d4-a716-446655440000" + assert normalize_hfid_reference(uuid_value) == uuid_value + + # List stays unchanged + assert normalize_hfid_reference(["namespace", "name"]) == ["namespace", "name"] + assert normalize_hfid_reference(["single"]) == ["single"] + + +@dataclass +class RelatedNodePayloadTestCase: + """Test case for verifying the actual GraphQL payload structure from RelatedNode.""" + + name: str + input_data: str | list[str] + expected_payload: dict[str, Any] + + +RELATED_NODE_PAYLOAD_TEST_CASES = [ + # String (UUID) → {"id": "uuid"} + RelatedNodePayloadTestCase( + name="uuid_string_becomes_id_payload", + input_data="550e8400-e29b-41d4-a716-446655440000", + expected_payload={"id": "550e8400-e29b-41d4-a716-446655440000"}, + ), + # List (HFID) → {"hfid": [...]} + RelatedNodePayloadTestCase( + name="list_becomes_hfid_payload", + input_data=["Important"], + expected_payload={"hfid": ["Important"]}, + ), + # Multi-component HFID list → {"hfid": [...]} + RelatedNodePayloadTestCase( + name="multi_component_hfid_payload", + input_data=["namespace", "name"], + expected_payload={"hfid": ["namespace", "name"]}, + ), +] + + +@pytest.mark.parametrize("test_case", RELATED_NODE_PAYLOAD_TEST_CASES, ids=lambda tc: tc.name) +def test_related_node_graphql_payload(test_case: RelatedNodePayloadTestCase) -> None: + """Test that RelatedNode produces the correct GraphQL payload structure. + + This test verifies the actual {"id": ...} or {"hfid": ...} payload + that gets sent in GraphQL mutations. + """ + # Create mock dependencies + mock_client = MagicMock() + mock_schema = MagicMock() + + # Create RelatedNode with the input data + related_node = RelatedNode( + schema=mock_schema, + name="test_rel", + branch="main", + client=mock_client, + data=test_case.input_data, + ) + + # Generate the input data that would go into GraphQL mutation + payload = related_node._generate_input_data() + + # Verify the payload structure + assert payload == test_case.expected_payload, f"Expected payload {test_case.expected_payload}, got {payload}" diff --git a/tests/unit/sdk/test_artifact.py b/tests/unit/sdk/test_artifact.py index b5d4e656..4ed526d5 100644 --- a/tests/unit/sdk/test_artifact.py +++ b/tests/unit/sdk/test_artifact.py @@ -1,14 +1,26 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + import pytest from infrahub_sdk.exceptions import FeatureNotSupportedError from infrahub_sdk.node import InfrahubNode, InfrahubNodeSync +if TYPE_CHECKING: + from pytest_httpx import HTTPXMock + + from infrahub_sdk import InfrahubClient + from infrahub_sdk.schema import NodeSchemaAPI + + from .conftest import BothClients + client_types = ["standard", "sync"] @pytest.mark.parametrize("client_type", client_types) async def test_node_artifact_generate_raise_featurenotsupported( - client, client_type, location_schema, location_data01 + client: InfrahubClient, client_type: str, location_schema: NodeSchemaAPI, location_data01: dict[str, Any] ) -> None: # node does not inherit from CoreArtifactTarget if client_type == "standard": @@ -23,7 +35,7 @@ async def test_node_artifact_generate_raise_featurenotsupported( @pytest.mark.parametrize("client_type", client_types) async def test_node_artifact_fetch_raise_featurenotsupported( - client, client_type, location_schema, location_data01 + client: InfrahubClient, client_type: str, location_schema: NodeSchemaAPI, location_data01: dict[str, Any] ) -> None: # node does not inherit from CoreArtifactTarget if client_type == "standard": @@ -37,7 +49,9 @@ async def test_node_artifact_fetch_raise_featurenotsupported( @pytest.mark.parametrize("client_type", client_types) -async def test_node_generate_raise_featurenotsupported(client, client_type, location_schema, location_data01) -> None: +async def test_node_generate_raise_featurenotsupported( + client: InfrahubClient, client_type: str, location_schema: NodeSchemaAPI, location_data01: dict[str, Any] +) -> None: # node not of kind CoreArtifactDefinition if client_type == "standard": node = InfrahubNode(client=client, schema=location_schema, data=location_data01) @@ -51,11 +65,11 @@ async def test_node_generate_raise_featurenotsupported(client, client_type, loca @pytest.mark.parametrize("client_type", client_types) async def test_node_artifact_definition_generate( - clients, - client_type, - mock_rest_api_artifact_definition_generate, - artifact_definition_schema, - artifact_definition_data, + clients: BothClients, + client_type: str, + mock_rest_api_artifact_definition_generate: HTTPXMock, + artifact_definition_schema: NodeSchemaAPI, + artifact_definition_data: dict[str, Any], ) -> None: if client_type == "standard": node = InfrahubNode(client=clients.standard, schema=artifact_definition_schema, data=artifact_definition_data) @@ -67,7 +81,11 @@ async def test_node_artifact_definition_generate( @pytest.mark.parametrize("client_type", client_types) async def test_node_artifact_fetch( - clients, client_type, mock_rest_api_artifact_fetch, device_schema, device_data + clients: BothClients, + client_type: str, + mock_rest_api_artifact_fetch: HTTPXMock, + device_schema: NodeSchemaAPI, + device_data: dict[str, Any], ) -> None: if client_type == "standard": node = InfrahubNode(client=clients.standard, schema=device_schema, data=device_data) @@ -86,7 +104,11 @@ async def test_node_artifact_fetch( @pytest.mark.parametrize("client_type", client_types) async def test_node_artifact_generate( - clients, client_type, mock_rest_api_artifact_generate, device_schema, device_data + clients: BothClients, + client_type: str, + mock_rest_api_artifact_generate: HTTPXMock, + device_schema: NodeSchemaAPI, + device_data: dict[str, Any], ) -> None: if client_type == "standard": node = InfrahubNode(client=clients.standard, schema=device_schema, data=device_data) diff --git a/tests/unit/sdk/test_client.py b/tests/unit/sdk/test_client.py index 2d5de0ad..e9cce23e 100644 --- a/tests/unit/sdk/test_client.py +++ b/tests/unit/sdk/test_client.py @@ -1,14 +1,25 @@ +from __future__ import annotations + import inspect import ssl from pathlib import Path +from typing import TYPE_CHECKING import pytest -from pytest_httpx import HTTPXMock from infrahub_sdk import Config, InfrahubClient, InfrahubClientSync from infrahub_sdk.exceptions import NodeNotFoundError from infrahub_sdk.node import InfrahubNode, InfrahubNodeSync -from tests.unit.sdk.conftest import BothClients + +if TYPE_CHECKING: + from collections.abc import Callable, Mapping + from inspect import Parameter + from typing import Any + + from pytest_httpx import HTTPXMock + + from infrahub_sdk.schema import NodeSchemaAPI + from tests.unit.sdk.conftest import BothClients pytestmark = pytest.mark.httpx_mock(can_send_already_matched_responses=True) @@ -33,7 +44,7 @@ CURRENT_DIRECTORY = Path(__file__).parent -async def test_verify_config_caches_default_ssl_context(monkeypatch) -> None: +async def test_verify_config_caches_default_ssl_context(monkeypatch: pytest.MonkeyPatch) -> None: contexts: list[tuple[str | None, object]] = [] def fake_create_default_context(*args: object, **kwargs: object) -> object: @@ -52,7 +63,7 @@ def fake_create_default_context(*args: object, **kwargs: object) -> object: assert contexts == [(None, first)] -async def test_verify_config_caches_tls_ca_file_context(monkeypatch) -> None: +async def test_verify_config_caches_tls_ca_file_context(monkeypatch: pytest.MonkeyPatch) -> None: contexts: list[tuple[str | None, object]] = [] def fake_create_default_context(*args: object, **kwargs: object) -> object: @@ -81,7 +92,7 @@ def fake_create_default_context(*args: object, **kwargs: object) -> object: ] -async def test_verify_config_respects_tls_insecure(monkeypatch) -> None: +async def test_verify_config_respects_tls_insecure(monkeypatch: pytest.MonkeyPatch) -> None: def fake_create_default_context(*args: object, **kwargs: object) -> object: raise AssertionError("create_default_context should not be called when TLS is insecure") @@ -95,7 +106,7 @@ def fake_create_default_context(*args: object, **kwargs: object) -> object: assert verify_value.verify_mode == ssl.CERT_NONE -async def test_verify_config_uses_custom_tls_context(monkeypatch) -> None: +async def test_verify_config_uses_custom_tls_context(monkeypatch: pytest.MonkeyPatch) -> None: def fake_create_default_context(*args: object, **kwargs: object) -> object: raise AssertionError("create_default_context should not be called when custom context is provided") @@ -121,11 +132,11 @@ async def test_method_sanity() -> None: @pytest.mark.parametrize("method", async_client_methods) async def test_validate_method_signature( - method, - replace_async_return_annotation, - replace_sync_return_annotation, - replace_async_parameter_annotations, - replace_sync_parameter_annotations, + method: str, + replace_async_return_annotation: Callable[[str], str], + replace_sync_return_annotation: Callable[[str], str], + replace_async_parameter_annotations: Callable[[Mapping[str, Parameter]], list[tuple[str, str]]], + replace_sync_parameter_annotations: Callable[[Mapping[str, Parameter]], list[tuple[str, str]]], ) -> None: async_method = getattr(InfrahubClient, method) sync_method = getattr(InfrahubClientSync, method) @@ -150,7 +161,10 @@ def test_init_with_invalid_address() -> None: async def test_get_repositories( - client: InfrahubClient, mock_branches_list_query, mock_schema_query_02, mock_repositories_query + client: InfrahubClient, + mock_branches_list_query: HTTPXMock, + mock_schema_query_02: HTTPXMock, + mock_repositories_query: HTTPXMock, ) -> None: repos = await client.get_list_repositories() @@ -167,7 +181,7 @@ async def test_get_repositories( @pytest.mark.parametrize("client_type", client_types) -async def test_method_count(clients, mock_query_repository_count, client_type) -> None: +async def test_method_count(clients: BothClients, mock_query_repository_count: HTTPXMock, client_type: str) -> None: if client_type == "standard": count = await clients.standard.count(kind="CoreRepository") else: @@ -177,7 +191,9 @@ async def test_method_count(clients, mock_query_repository_count, client_type) - @pytest.mark.parametrize("client_type", client_types) -async def test_method_count_with_filter(clients, mock_query_repository_count, client_type) -> None: +async def test_method_count_with_filter( + clients: BothClients, mock_query_repository_count: HTTPXMock, client_type: str +) -> None: if client_type == "standard": count = await clients.standard.count(kind="CoreRepository", name__value="test") else: @@ -187,7 +203,9 @@ async def test_method_count_with_filter(clients, mock_query_repository_count, cl @pytest.mark.parametrize("client_type", client_types) -async def test_method_get_version(clients, mock_query_infrahub_version, client_type) -> None: +async def test_method_get_version( + clients: BothClients, mock_query_infrahub_version: HTTPXMock, client_type: str +) -> None: if client_type == "standard": version = await clients.standard.get_version() else: @@ -197,7 +215,7 @@ async def test_method_get_version(clients, mock_query_infrahub_version, client_t @pytest.mark.parametrize("client_type", client_types) -async def test_method_get_user(clients, mock_query_infrahub_user, client_type) -> None: +async def test_method_get_user(clients: BothClients, mock_query_infrahub_user: HTTPXMock, client_type: str) -> None: if client_type == "standard": user = await clients.standard.get_user() else: @@ -208,7 +226,9 @@ async def test_method_get_user(clients, mock_query_infrahub_user, client_type) - @pytest.mark.parametrize("client_type", client_types) -async def test_method_get_user_permissions(clients, mock_query_infrahub_user, client_type) -> None: +async def test_method_get_user_permissions( + clients: BothClients, mock_query_infrahub_user: HTTPXMock, client_type: str +) -> None: if client_type == "standard": groups = await clients.standard.get_user_permissions() else: @@ -219,7 +239,9 @@ async def test_method_get_user_permissions(clients, mock_query_infrahub_user, cl @pytest.mark.parametrize("client_type", client_types) -async def test_method_all_with_limit(clients, mock_query_repository_page1_2, client_type) -> None: +async def test_method_all_with_limit( + clients: BothClients, mock_query_repository_page1_2: HTTPXMock, client_type: str +) -> None: if client_type == "standard": repos = await clients.standard.all(kind="CoreRepository", populate_store=False, limit=3) assert clients.standard.store.count() == 0 @@ -237,7 +259,10 @@ async def test_method_all_with_limit(clients, mock_query_repository_page1_2, cli @pytest.mark.parametrize("client_type", client_types) async def test_method_all_multiple_pages( - clients, mock_query_repository_page1_2, mock_query_repository_page2_2, client_type + clients: BothClients, + mock_query_repository_page1_2: HTTPXMock, + mock_query_repository_page2_2: HTTPXMock, + client_type: str, ) -> None: if client_type == "standard": repos = await clients.standard.all(kind="CoreRepository", populate_store=False) @@ -257,7 +282,11 @@ async def test_method_all_multiple_pages( @pytest.mark.parametrize("client_type, use_parallel", batch_client_types) async def test_method_all_batching( - clients, mock_query_location_batch_count, mock_query_location_batch, client_type, use_parallel + clients: BothClients, + mock_query_location_batch_count: HTTPXMock, + mock_query_location_batch: HTTPXMock, + client_type: str, + use_parallel: bool, ) -> None: if client_type == "standard": locations = await clients.standard.all(kind="BuiltinLocation", populate_store=False, parallel=use_parallel) @@ -276,7 +305,9 @@ async def test_method_all_batching( @pytest.mark.parametrize("client_type", client_types) -async def test_method_all_single_page(clients, mock_query_repository_page1_1, client_type) -> None: +async def test_method_all_single_page( + clients: BothClients, mock_query_repository_page1_1: HTTPXMock, client_type: str +) -> None: if client_type == "standard": repos = await clients.standard.all(kind="CoreRepository", populate_store=False) assert clients.standard.store.count() == 0 @@ -294,7 +325,9 @@ async def test_method_all_single_page(clients, mock_query_repository_page1_1, cl @pytest.mark.parametrize("client_type", client_types) -async def test_method_all_generic(clients, mock_query_corenode_page1_1, client_type) -> None: +async def test_method_all_generic( + clients: BothClients, mock_query_corenode_page1_1: HTTPXMock, client_type: str +) -> None: if client_type == "standard": nodes = await clients.standard.all(kind="CoreNode") else: @@ -306,7 +339,9 @@ async def test_method_all_generic(clients, mock_query_corenode_page1_1, client_t @pytest.mark.parametrize("client_type", client_types) -async def test_method_get_by_id(httpx_mock: HTTPXMock, clients, mock_schema_query_01, client_type) -> None: +async def test_method_get_by_id( + httpx_mock: HTTPXMock, clients: BothClients, mock_schema_query_01: HTTPXMock, client_type: str +) -> None: response = { "data": { "CoreRepository": { @@ -354,7 +389,9 @@ async def test_method_get_by_id(httpx_mock: HTTPXMock, clients, mock_schema_quer @pytest.mark.parametrize("client_type", client_types) -async def test_method_get_by_hfid(httpx_mock: HTTPXMock, clients, mock_schema_query_01, client_type) -> None: +async def test_method_get_by_hfid( + httpx_mock: HTTPXMock, clients: BothClients, mock_schema_query_01: HTTPXMock, client_type: str +) -> None: response = { "data": { "CoreRepository": { @@ -403,7 +440,9 @@ async def test_method_get_by_hfid(httpx_mock: HTTPXMock, clients, mock_schema_qu @pytest.mark.parametrize("client_type", client_types) -async def test_method_get_by_default_filter(httpx_mock: HTTPXMock, clients, mock_schema_query_01, client_type) -> None: +async def test_method_get_by_default_filter( + httpx_mock: HTTPXMock, clients: BothClients, mock_schema_query_01: HTTPXMock, client_type: str +) -> None: response = { "data": { "CoreRepository": { @@ -449,7 +488,9 @@ async def test_method_get_by_default_filter(httpx_mock: HTTPXMock, clients, mock @pytest.mark.parametrize("client_type", client_types) -async def test_method_get_by_name(httpx_mock: HTTPXMock, clients, mock_schema_query_01, client_type) -> None: +async def test_method_get_by_name( + httpx_mock: HTTPXMock, clients: BothClients, mock_schema_query_01: HTTPXMock, client_type: str +) -> None: response = { "data": { "CoreRepository": { @@ -486,7 +527,7 @@ async def test_method_get_by_name(httpx_mock: HTTPXMock, clients, mock_schema_qu @pytest.mark.parametrize("client_type", client_types) async def test_method_get_not_found( - httpx_mock: HTTPXMock, clients, mock_query_repository_page1_empty, client_type + httpx_mock: HTTPXMock, clients: BothClients, mock_query_repository_page1_empty: HTTPXMock, client_type: str ) -> None: with pytest.raises(NodeNotFoundError): if client_type == "standard": @@ -497,7 +538,7 @@ async def test_method_get_not_found( @pytest.mark.parametrize("client_type", client_types) async def test_method_get_not_found_none( - httpx_mock: HTTPXMock, clients, mock_query_repository_page1_empty, client_type + httpx_mock: HTTPXMock, clients: BothClients, mock_query_repository_page1_empty: HTTPXMock, client_type: str ) -> None: if client_type == "standard": response = await clients.standard.get( @@ -512,10 +553,10 @@ async def test_method_get_not_found_none( @pytest.mark.parametrize("client_type", client_types) async def test_method_get_found_many( httpx_mock: HTTPXMock, - clients, - mock_schema_query_01, - mock_query_repository_page1_1, - client_type, + clients: BothClients, + mock_schema_query_01: HTTPXMock, + mock_query_repository_page1_1: HTTPXMock, + client_type: str, ) -> None: with pytest.raises(IndexError): if client_type == "standard": @@ -525,7 +566,9 @@ async def test_method_get_found_many( @pytest.mark.parametrize("client_type", client_types) -async def test_method_filters_many(httpx_mock: HTTPXMock, clients, mock_query_repository_page1_1, client_type) -> None: +async def test_method_filters_many( + httpx_mock: HTTPXMock, clients: BothClients, mock_query_repository_page1_1: HTTPXMock, client_type: str +) -> None: if client_type == "standard": repos = await clients.standard.filters( kind="CoreRepository", @@ -572,7 +615,7 @@ async def test_method_filters_many(httpx_mock: HTTPXMock, clients, mock_query_re @pytest.mark.parametrize("client_type", client_types) async def test_method_filters_empty( - httpx_mock: HTTPXMock, clients, mock_query_repository_page1_empty, client_type + httpx_mock: HTTPXMock, clients: BothClients, mock_query_repository_page1_empty: HTTPXMock, client_type: str ) -> None: if client_type == "standard": repos = await clients.standard.filters( @@ -597,11 +640,11 @@ async def test_method_filters_empty( async def test_allocate_next_ip_address( httpx_mock: HTTPXMock, mock_schema_query_ipam: HTTPXMock, - clients, - ipaddress_pool_schema, - ipam_ipprefix_schema, - ipam_ipprefix_data, - client_type, + clients: BothClients, + ipaddress_pool_schema: NodeSchemaAPI, + ipam_ipprefix_schema: NodeSchemaAPI, + ipam_ipprefix_data: dict[str, Any], + client_type: str, ) -> None: httpx_mock.add_response( method="POST", @@ -698,11 +741,11 @@ async def test_allocate_next_ip_address( async def test_allocate_next_ip_prefix( httpx_mock: HTTPXMock, mock_schema_query_ipam: HTTPXMock, - clients, - ipprefix_pool_schema, - ipam_ipprefix_schema, - ipam_ipprefix_data, - client_type, + clients: BothClients, + ipprefix_pool_schema: NodeSchemaAPI, + ipam_ipprefix_schema: NodeSchemaAPI, + ipam_ipprefix_data: dict[str, Any], + client_type: str, ) -> None: httpx_mock.add_response( method="POST", @@ -818,7 +861,7 @@ async def test_allocate_next_ip_prefix( @pytest.mark.parametrize("client_type", client_types) -async def test_query_echo(httpx_mock: HTTPXMock, echo_clients, client_type) -> None: +async def test_query_echo(httpx_mock: HTTPXMock, echo_clients: BothClients, client_type: str) -> None: httpx_mock.add_response( method="POST", json={"data": {"BuiltinTag": {"edges": []}}}, diff --git a/tests/unit/sdk/test_diff_summary.py b/tests/unit/sdk/test_diff_summary.py index 49a6d128..fddd388b 100644 --- a/tests/unit/sdk/test_diff_summary.py +++ b/tests/unit/sdk/test_diff_summary.py @@ -95,7 +95,7 @@ async def mock_diff_tree_query(httpx_mock: HTTPXMock, client: InfrahubClient) -> @pytest.mark.parametrize("client_type", client_types) -async def test_diffsummary(clients: BothClients, mock_diff_tree_query, client_type) -> None: +async def test_diffsummary(clients: BothClients, mock_diff_tree_query: HTTPXMock, client_type: str) -> None: if client_type == "standard": node_diffs = await clients.standard.get_diff_summary( branch="branch2", @@ -241,7 +241,7 @@ async def mock_diff_tree_with_metadata(httpx_mock: HTTPXMock, client: InfrahubCl @pytest.mark.parametrize("client_type", client_types) -async def test_get_diff_tree(clients: BothClients, mock_diff_tree_with_metadata, client_type) -> None: +async def test_get_diff_tree(clients: BothClients, mock_diff_tree_with_metadata: HTTPXMock, client_type: str) -> None: """Test get_diff_tree returns complete DiffTreeData with metadata.""" if client_type == "standard": diff_tree = await clients.standard.get_diff_tree( @@ -298,7 +298,7 @@ async def mock_diff_tree_none(httpx_mock: HTTPXMock, client: InfrahubClient) -> @pytest.mark.parametrize("client_type", client_types) -async def test_get_diff_tree_none(clients: BothClients, mock_diff_tree_none, client_type) -> None: +async def test_get_diff_tree_none(clients: BothClients, mock_diff_tree_none: HTTPXMock, client_type: str) -> None: """Test get_diff_tree returns None when no diff exists.""" if client_type == "standard": diff_tree = await clients.standard.get_diff_tree( @@ -343,7 +343,9 @@ async def mock_diff_tree_with_params(httpx_mock: HTTPXMock, client: InfrahubClie @pytest.mark.parametrize("client_type", client_types) -async def test_get_diff_tree_with_parameters(clients: BothClients, mock_diff_tree_with_params, client_type) -> None: +async def test_get_diff_tree_with_parameters( + clients: BothClients, mock_diff_tree_with_params: HTTPXMock, client_type: str +) -> None: """Test get_diff_tree with name and time range parameters.""" from_time = datetime(2025, 11, 14, 12, 0, 0, tzinfo=timezone.utc) to_time = datetime(2025, 11, 14, 18, 0, 0, tzinfo=timezone.utc) @@ -373,7 +375,7 @@ async def test_get_diff_tree_with_parameters(clients: BothClients, mock_diff_tre @pytest.mark.parametrize("client_type", client_types) -async def test_get_diff_tree_time_validation(clients: BothClients, client_type) -> None: +async def test_get_diff_tree_time_validation(clients: BothClients, client_type: str) -> None: """Test get_diff_tree raises error when from_time > to_time.""" from_time = datetime(2025, 11, 14, 18, 0, 0, tzinfo=timezone.utc) to_time = datetime(2025, 11, 14, 12, 0, 0, tzinfo=timezone.utc) # Earlier than from_time diff --git a/tests/unit/sdk/test_hierarchical_nodes.py b/tests/unit/sdk/test_hierarchical_nodes.py index c8a8ed79..3165effe 100644 --- a/tests/unit/sdk/test_hierarchical_nodes.py +++ b/tests/unit/sdk/test_hierarchical_nodes.py @@ -52,7 +52,7 @@ async def hierarchical_schema() -> NodeSchemaAPI: @pytest.mark.parametrize("client_type", ["standard", "sync"]) async def test_hierarchical_node_has_hierarchy_support( - client: InfrahubClient, client_sync: InfrahubClientSync, hierarchical_schema, client_type + client: InfrahubClient, client_sync: InfrahubClientSync, hierarchical_schema: NodeSchemaAPI, client_type: str ) -> None: """Test that hierarchical nodes are properly detected and support parent/children/ancestors/descendants.""" if client_type == "standard": @@ -67,7 +67,7 @@ async def test_hierarchical_node_has_hierarchy_support( @pytest.mark.parametrize("client_type", ["standard", "sync"]) async def test_hierarchical_node_has_all_hierarchical_fields( - client: InfrahubClient, client_sync: InfrahubClientSync, hierarchical_schema, client_type + client: InfrahubClient, client_sync: InfrahubClientSync, hierarchical_schema: NodeSchemaAPI, client_type: str ) -> None: """Test that hierarchical nodes have parent, children, ancestors and descendants attributes.""" if client_type == "standard": @@ -110,7 +110,7 @@ async def test_hierarchical_node_has_all_hierarchical_fields( @pytest.mark.parametrize("client_type", ["standard", "sync"]) async def test_hierarchical_node_with_parent_data( - client: InfrahubClient, client_sync: InfrahubClientSync, hierarchical_schema, client_type + client: InfrahubClient, client_sync: InfrahubClientSync, hierarchical_schema: NodeSchemaAPI, client_type: str ) -> None: """Test that hierarchical nodes can be initialized with parent data.""" data = { @@ -133,7 +133,7 @@ async def test_hierarchical_node_with_parent_data( @pytest.mark.parametrize("client_type", ["standard", "sync"]) async def test_hierarchical_node_with_children_data( - client: InfrahubClient, client_sync: InfrahubClientSync, hierarchical_schema, client_type + client: InfrahubClient, client_sync: InfrahubClientSync, hierarchical_schema: NodeSchemaAPI, client_type: str ) -> None: """Test that hierarchical nodes can be initialized with children data.""" data = { @@ -164,7 +164,7 @@ async def test_hierarchical_node_with_children_data( @pytest.mark.parametrize("client_type", ["standard", "sync"]) async def test_hierarchical_node_with_ancestors_data( - client: InfrahubClient, client_sync: InfrahubClientSync, hierarchical_schema, client_type + client: InfrahubClient, client_sync: InfrahubClientSync, hierarchical_schema: NodeSchemaAPI, client_type: str ) -> None: """Test that hierarchical nodes can be initialized with ancestors data.""" data = { @@ -195,7 +195,7 @@ async def test_hierarchical_node_with_ancestors_data( @pytest.mark.parametrize("client_type", ["standard", "sync"]) async def test_hierarchical_node_with_descendants_data( - client: InfrahubClient, client_sync: InfrahubClientSync, hierarchical_schema, client_type + client: InfrahubClient, client_sync: InfrahubClientSync, hierarchical_schema: NodeSchemaAPI, client_type: str ) -> None: """Test that hierarchical nodes can be initialized with descendants data.""" data = { @@ -229,7 +229,7 @@ async def test_hierarchical_node_with_descendants_data( @pytest.mark.parametrize("client_type", ["standard", "sync"]) async def test_non_hierarchical_node_no_hierarchical_fields( - client: InfrahubClient, client_sync: InfrahubClientSync, location_schema, client_type + client: InfrahubClient, client_sync: InfrahubClientSync, location_schema: NodeSchemaAPI, client_type: str ) -> None: """Test that non-hierarchical nodes don't have parent/children/ancestors/descendants.""" if client_type == "standard": @@ -254,7 +254,9 @@ async def test_non_hierarchical_node_no_hierarchical_fields( _ = node.descendants -async def test_hierarchical_node_query_generation_includes_parent(client: InfrahubClient, hierarchical_schema) -> None: +async def test_hierarchical_node_query_generation_includes_parent( + client: InfrahubClient, hierarchical_schema: NodeSchemaAPI +) -> None: """Test that query generation includes parent when requested.""" # Pre-populate schema cache to avoid fetching from server cache_data = { @@ -277,7 +279,7 @@ async def test_hierarchical_node_query_generation_includes_parent(client: Infrah async def test_hierarchical_node_query_generation_includes_children( - client: InfrahubClient, hierarchical_schema + client: InfrahubClient, hierarchical_schema: NodeSchemaAPI ) -> None: """Test that query generation includes children when requested.""" # Pre-populate schema cache to avoid fetching from server @@ -302,7 +304,7 @@ async def test_hierarchical_node_query_generation_includes_children( async def test_hierarchical_node_query_generation_includes_ancestors( - client: InfrahubClient, hierarchical_schema + client: InfrahubClient, hierarchical_schema: NodeSchemaAPI ) -> None: """Test that query generation includes ancestors when requested.""" # Pre-populate schema cache to avoid fetching from server @@ -327,7 +329,7 @@ async def test_hierarchical_node_query_generation_includes_ancestors( async def test_hierarchical_node_query_generation_includes_descendants( - client: InfrahubClient, hierarchical_schema + client: InfrahubClient, hierarchical_schema: NodeSchemaAPI ) -> None: """Test that query generation includes descendants when requested.""" # Pre-populate schema cache to avoid fetching from server @@ -352,7 +354,7 @@ async def test_hierarchical_node_query_generation_includes_descendants( async def test_hierarchical_node_query_generation_prefetch_relationships( - client: InfrahubClient, hierarchical_schema + client: InfrahubClient, hierarchical_schema: NodeSchemaAPI ) -> None: """Test that query generation includes all hierarchical fields with prefetch_relationships=True.""" # Pre-populate schema cache to avoid fetching from server @@ -374,7 +376,9 @@ async def test_hierarchical_node_query_generation_prefetch_relationships( assert "descendants" in query_data -async def test_hierarchical_node_query_generation_exclude(client: InfrahubClient, hierarchical_schema) -> None: +async def test_hierarchical_node_query_generation_exclude( + client: InfrahubClient, hierarchical_schema: NodeSchemaAPI +) -> None: """Test that query generation respects exclude for hierarchical fields.""" # Pre-populate schema cache to avoid fetching from server cache_data = { @@ -396,7 +400,7 @@ async def test_hierarchical_node_query_generation_exclude(client: InfrahubClient def test_hierarchical_node_sync_query_generation_includes_parent( - client_sync: InfrahubClientSync, hierarchical_schema + client_sync: InfrahubClientSync, hierarchical_schema: NodeSchemaAPI ) -> None: """Test that sync query generation includes parent when requested.""" # Set schema in cache to avoid HTTP request @@ -420,7 +424,7 @@ def test_hierarchical_node_sync_query_generation_includes_parent( def test_hierarchical_node_sync_query_generation_includes_children( - client_sync: InfrahubClientSync, hierarchical_schema + client_sync: InfrahubClientSync, hierarchical_schema: NodeSchemaAPI ) -> None: """Test that sync query generation includes children when requested.""" # Set schema in cache to avoid HTTP request @@ -445,7 +449,7 @@ def test_hierarchical_node_sync_query_generation_includes_children( def test_hierarchical_node_sync_query_generation_includes_ancestors( - client_sync: InfrahubClientSync, hierarchical_schema + client_sync: InfrahubClientSync, hierarchical_schema: NodeSchemaAPI ) -> None: """Test that sync query generation includes ancestors when requested.""" # Set schema in cache to avoid HTTP request @@ -470,7 +474,7 @@ def test_hierarchical_node_sync_query_generation_includes_ancestors( def test_hierarchical_node_sync_query_generation_includes_descendants( - client_sync: InfrahubClientSync, hierarchical_schema + client_sync: InfrahubClientSync, hierarchical_schema: NodeSchemaAPI ) -> None: """Test that sync query generation includes descendants when requested.""" # Set schema in cache to avoid HTTP request @@ -495,7 +499,7 @@ def test_hierarchical_node_sync_query_generation_includes_descendants( async def test_hierarchical_node_no_infinite_recursion_with_children( - client: InfrahubClient, hierarchical_schema + client: InfrahubClient, hierarchical_schema: NodeSchemaAPI ) -> None: """Test that including children does not cause infinite recursion.""" # Pre-populate schema cache to avoid fetching from server @@ -521,7 +525,7 @@ async def test_hierarchical_node_no_infinite_recursion_with_children( def test_hierarchical_node_sync_no_infinite_recursion_with_children( - client_sync: InfrahubClientSync, hierarchical_schema + client_sync: InfrahubClientSync, hierarchical_schema: NodeSchemaAPI ) -> None: """Test that including children does not cause infinite recursion in sync mode.""" # Set schema in cache to avoid HTTP request diff --git a/tests/unit/sdk/test_node.py b/tests/unit/sdk/test_node.py index 635abc7e..74434a92 100644 --- a/tests/unit/sdk/test_node.py +++ b/tests/unit/sdk/test_node.py @@ -22,8 +22,13 @@ from infrahub_sdk.node.related_node import RelatedNode, RelatedNodeSync if TYPE_CHECKING: + from collections.abc import Callable, Mapping + from inspect import Parameter + from typing import Any + from pytest_httpx import HTTPXMock + from infrahub_sdk import InfrahubClient, InfrahubClientSync from infrahub_sdk.schema import GenericSchema, NodeSchemaAPI from tests.unit.sdk.conftest import BothClients @@ -58,7 +63,7 @@ ] -async def set_builtin_tag_schema_cache(client) -> None: +async def set_builtin_tag_schema_cache(client: InfrahubClient | InfrahubClientSync) -> None: # Set tag schema in cache to avoid needed to request the server. builtin_tag_schema = { "version": "1.0", @@ -95,11 +100,11 @@ def test_identify_unsafe_graphql_value(value: str) -> None: @pytest.mark.parametrize("method", async_node_methods) async def test_validate_method_signature( - method, - replace_async_parameter_annotations, - replace_sync_parameter_annotations, - replace_async_return_annotation, - replace_sync_return_annotation, + method: str, + replace_async_parameter_annotations: Callable[[Mapping[str, Parameter]], list[tuple[str, str]]], + replace_sync_parameter_annotations: Callable[[Mapping[str, Parameter]], list[tuple[str, str]]], + replace_async_return_annotation: Callable[[str], str], + replace_sync_return_annotation: Callable[[str], str], ) -> None: EXCLUDE_PARAMETERS = ["client"] async_method = getattr(InfrahubNode, method) @@ -132,7 +137,7 @@ def test_parse_human_friendly_id(hfid: str, expected_kind: str, expected_hfid: l @pytest.mark.parametrize("client_type", client_types) -async def test_init_node_no_data(client, location_schema, client_type: str) -> None: +async def test_init_node_no_data(client: InfrahubClient, location_schema: NodeSchemaAPI, client_type: str) -> None: if client_type == "standard": node = InfrahubNode(client=client, schema=location_schema) else: @@ -145,7 +150,7 @@ async def test_init_node_no_data(client, location_schema, client_type: str) -> N @pytest.mark.parametrize("client_type", client_types) -async def test_node_hfid(client, schema_with_hfid, client_type: str) -> None: +async def test_node_hfid(client: InfrahubClient, schema_with_hfid: dict[str, NodeSchemaAPI], client_type: str) -> None: location_data = {"name": {"value": "JFK1"}, "description": {"value": "JFK Airport"}, "type": {"value": "SITE"}} if client_type == "standard": location = InfrahubNode(client=client, schema=schema_with_hfid["location"], data=location_data) @@ -168,7 +173,7 @@ async def test_node_hfid(client, schema_with_hfid, client_type: str) -> None: @pytest.mark.parametrize("client_type", client_types) -async def test_init_node_data_user(client, location_schema: NodeSchemaAPI, client_type: str) -> None: +async def test_init_node_data_user(client: InfrahubClient, location_schema: NodeSchemaAPI, client_type: str) -> None: data = { "name": {"value": "JFK1"}, "description": {"value": "JFK Airport"}, @@ -186,7 +191,9 @@ async def test_init_node_data_user(client, location_schema: NodeSchemaAPI, clien @pytest.mark.parametrize("client_type", client_types) -async def test_init_node_data_user_with_relationships(client, location_schema: NodeSchemaAPI, client_type: str) -> None: +async def test_init_node_data_user_with_relationships( + client: InfrahubClient, location_schema: NodeSchemaAPI, client_type: str +) -> None: data = { "name": {"value": "JFK1"}, "description": {"value": "JFK Airport"}, @@ -226,7 +233,7 @@ async def test_init_node_data_user_with_relationships(client, location_schema: N ], ) async def test_init_node_data_user_with_relationships_using_related_node( - client, location_schema: NodeSchemaAPI, client_type: str, rel_data + client: InfrahubClient, location_schema: NodeSchemaAPI, client_type: str, rel_data: dict[str, Any] ) -> None: rel_schema = location_schema.get_relationship(name="primary_tag") if client_type == "standard": @@ -273,7 +280,12 @@ async def test_init_node_data_user_with_relationships_using_related_node( @pytest.mark.parametrize("property_test", property_tests) @pytest.mark.parametrize("client_type", client_types) async def test_init_node_data_graphql( - client, location_schema: NodeSchemaAPI, location_data01, location_data01_property, client_type: str, property_test + client: InfrahubClient, + location_schema: NodeSchemaAPI, + location_data01: dict[str, Any], + location_data01_property: dict[str, Any], + client_type: str, + property_test: str, ) -> None: location_data = location_data01 if property_test == WITHOUT_PROPERTY else location_data01_property @@ -531,7 +543,7 @@ async def test_query_data_node(clients: BothClients, location_schema: NodeSchema @pytest.mark.parametrize("client_type", client_types) async def test_query_data_with_prefetch_relationships_property( - clients: BothClients, mock_schema_query_02, client_type: str + clients: BothClients, mock_schema_query_02: HTTPXMock, client_type: str ) -> None: if client_type == "standard": location_schema: GenericSchema = await clients.standard.schema.get(kind="BuiltinLocation") # type: ignore[annotation-unchecked] @@ -667,7 +679,7 @@ async def test_query_data_with_prefetch_relationships_property( @pytest.mark.parametrize("client_type", client_types) async def test_query_data_with_prefetch_relationships( - clients: BothClients, mock_schema_query_02, client_type: str + clients: BothClients, mock_schema_query_02: HTTPXMock, client_type: str ) -> None: if client_type == "standard": location_schema: GenericSchema = await clients.standard.schema.get(kind="BuiltinLocation") # type: ignore[annotation-unchecked] @@ -719,7 +731,7 @@ async def test_query_data_with_prefetch_relationships( @pytest.mark.parametrize("client_type", client_types) async def test_query_data_node_with_prefetch_relationships_property( - clients: BothClients, mock_schema_query_02, client_type: str + clients: BothClients, mock_schema_query_02: HTTPXMock, client_type: str ) -> None: if client_type == "standard": location_schema: GenericSchema = await clients.standard.schema.get(kind="BuiltinLocation") # type: ignore[assignment] @@ -795,7 +807,7 @@ async def test_query_data_node_with_prefetch_relationships_property( @pytest.mark.parametrize("client_type", client_types) async def test_query_data_node_with_prefetch_relationships( - clients: BothClients, mock_schema_query_02, client_type: str + clients: BothClients, mock_schema_query_02: HTTPXMock, client_type: str ) -> None: if client_type == "standard": location_schema: GenericSchema = await clients.standard.schema.get(kind="BuiltinLocation") # type: ignore[assignment] @@ -834,7 +846,9 @@ async def test_query_data_node_with_prefetch_relationships( @pytest.mark.parametrize("client_type", client_types) -async def test_query_data_generic_property(clients: BothClients, mock_schema_query_02, client_type: str) -> None: +async def test_query_data_generic_property( + clients: BothClients, mock_schema_query_02: HTTPXMock, client_type: str +) -> None: if client_type == "standard": corenode_schema: GenericSchema = await clients.standard.schema.get(kind="CoreNode") # type: ignore[assignment] node = InfrahubNode(client=clients.standard, schema=corenode_schema) @@ -862,7 +876,7 @@ async def test_query_data_generic_property(clients: BothClients, mock_schema_que @pytest.mark.parametrize("client_type", client_types) async def test_query_data_generic_fragment_property( - clients: BothClients, mock_schema_query_02, client_type: str + clients: BothClients, mock_schema_query_02: HTTPXMock, client_type: str ) -> None: if client_type == "standard": corenode_schema: GenericSchema = await clients.standard.schema.get(kind="CoreNode") # type: ignore[assignment] @@ -1007,7 +1021,9 @@ async def test_query_data_generic_fragment_property( @pytest.mark.parametrize("client_type", client_types) -async def test_query_data_generic_fragment(clients: BothClients, mock_schema_query_02, client_type: str) -> None: +async def test_query_data_generic_fragment( + clients: BothClients, mock_schema_query_02: HTTPXMock, client_type: str +) -> None: if client_type == "standard": corenode_schema: GenericSchema = await clients.standard.schema.get(kind="CoreNode") # type: ignore[assignment] node = InfrahubNode(client=clients.standard, schema=corenode_schema) @@ -1068,8 +1084,8 @@ async def test_query_data_generic_fragment(clients: BothClients, mock_schema_que @pytest.mark.parametrize("client_type", client_types) async def test_query_data_include_property( - client, - client_sync, + client: InfrahubClient, + client_sync: InfrahubClientSync, location_schema: NodeSchemaAPI, client_type: str, ) -> None: @@ -1198,8 +1214,8 @@ async def test_query_data_include_property( @pytest.mark.parametrize("client_type", client_types) async def test_query_data_include( - client, - client_sync, + client: InfrahubClient, + client_sync: InfrahubClientSync, location_schema: NodeSchemaAPI, client_type: str, ) -> None: @@ -1257,7 +1273,9 @@ async def test_query_data_include( @pytest.mark.parametrize("client_type", client_types) -async def test_query_data_exclude_property(client, location_schema: NodeSchemaAPI, client_type: str) -> None: +async def test_query_data_exclude_property( + client: InfrahubClient, location_schema: NodeSchemaAPI, client_type: str +) -> None: if client_type == "standard": node = InfrahubNode(client=client, schema=location_schema) data = await node.generate_query_data(exclude=["description", "primary_tag"], property=True) @@ -1316,7 +1334,7 @@ async def test_query_data_exclude_property(client, location_schema: NodeSchemaAP @pytest.mark.parametrize("client_type", client_types) -async def test_query_data_exclude(client, location_schema: NodeSchemaAPI, client_type: str) -> None: +async def test_query_data_exclude(client: InfrahubClient, location_schema: NodeSchemaAPI, client_type: str) -> None: if client_type == "standard": node = InfrahubNode(client=client, schema=location_schema) data = await node.generate_query_data(exclude=["description", "primary_tag"]) @@ -1347,7 +1365,7 @@ async def test_query_data_exclude(client, location_schema: NodeSchemaAPI, client @pytest.mark.parametrize("client_type", client_types) -async def test_create_input_data(client, location_schema: NodeSchemaAPI, client_type: str) -> None: +async def test_create_input_data(client: InfrahubClient, location_schema: NodeSchemaAPI, client_type: str) -> None: data = {"name": {"value": "JFK1"}, "description": {"value": "JFK Airport"}, "type": {"value": "SITE"}} if client_type == "standard": @@ -1365,7 +1383,9 @@ async def test_create_input_data(client, location_schema: NodeSchemaAPI, client_ @pytest.mark.parametrize("client_type", client_types) -async def test_create_input_data_with_dropdown(client, location_schema_with_dropdown, client_type: str) -> None: +async def test_create_input_data_with_dropdown( + client: InfrahubClient, location_schema_with_dropdown: NodeSchemaAPI, client_type: str +) -> None: """Validate input data including dropdown field""" data = { "name": {"value": "JFK1"}, @@ -1425,7 +1445,9 @@ async def test_update_input_data_existing_node_with_optional_relationship( @pytest.mark.parametrize("client_type", client_types) -async def test_create_input_data__with_relationships_02(client, location_schema, client_type: str) -> None: +async def test_create_input_data__with_relationships_02( + client: InfrahubClient, location_schema: NodeSchemaAPI, client_type: str +) -> None: """Validate input data with variables that needs replacements""" data = { "name": {"value": "JFK1"}, @@ -1459,7 +1481,9 @@ async def test_create_input_data__with_relationships_02(client, location_schema, @pytest.mark.parametrize("client_type", client_types) -async def test_create_input_data__with_relationships_01(client, location_schema, client_type: str) -> None: +async def test_create_input_data__with_relationships_01( + client: InfrahubClient, location_schema: NodeSchemaAPI, client_type: str +) -> None: data = { "name": {"value": "JFK1"}, "description": {"value": "JFK Airport"}, @@ -1485,7 +1509,9 @@ async def test_create_input_data__with_relationships_01(client, location_schema, @pytest.mark.parametrize("client_type", client_types) -async def test_create_input_data_with_relationships_02(clients: BothClients, rfile_schema, client_type: str) -> None: +async def test_create_input_data_with_relationships_02( + clients: BothClients, rfile_schema: NodeSchemaAPI, client_type: str +) -> None: data = { "name": {"value": "rfile01", "is_protected": True, "source": "ffffffff", "owner": "ffffffff"}, "template_path": {"value": "mytemplate.j2"}, @@ -1524,7 +1550,9 @@ async def test_create_input_data_with_relationships_02(clients: BothClients, rfi @pytest.mark.parametrize("client_type", client_types) -async def test_create_input_data_with_relationships_03(clients: BothClients, rfile_schema, client_type: str) -> None: +async def test_create_input_data_with_relationships_03( + clients: BothClients, rfile_schema: NodeSchemaAPI, client_type: str +) -> None: data = { "name": {"value": "rfile01", "is_protected": True, "source": "ffffffff"}, "template_path": {"value": "mytemplate.j2"}, @@ -1558,11 +1586,11 @@ async def test_create_input_data_with_relationships_03(clients: BothClients, rfi @pytest.mark.parametrize("client_type", client_types) async def test_create_input_data_with_relationships_03_for_update_include_unmodified( clients: BothClients, - rfile_schema, - rfile_userdata01, - rfile_userdata01_property, + rfile_schema: NodeSchemaAPI, + rfile_userdata01: dict[str, Any], + rfile_userdata01_property: dict[str, Any], client_type: str, - property_test, + property_test: str, ) -> None: rfile_userdata = rfile_userdata01 if property_test == WITHOUT_PROPERTY else rfile_userdata01_property @@ -1616,11 +1644,11 @@ async def test_create_input_data_with_relationships_03_for_update_include_unmodi @pytest.mark.parametrize("client_type", client_types) async def test_create_input_data_with_relationships_03_for_update_exclude_unmodified( clients: BothClients, - rfile_schema, - rfile_userdata01, - rfile_userdata01_property, + rfile_schema: NodeSchemaAPI, + rfile_userdata01: dict[str, Any], + rfile_userdata01_property: dict[str, Any], client_type: str, - property_test, + property_test: str, ) -> None: """NOTE: Need to fix this test, the issue is tracked in https://github.com/opsmill/infrahub-sdk-python/issues/214.""" rfile_userdata = rfile_userdata01 if property_test == WITHOUT_PROPERTY else rfile_userdata01_property @@ -1658,7 +1686,9 @@ async def test_create_input_data_with_relationships_03_for_update_exclude_unmodi @pytest.mark.parametrize("client_type", client_types) -async def test_create_input_data_with_IPHost_attribute(client, ipaddress_schema, client_type: str) -> None: +async def test_create_input_data_with_IPHost_attribute( + client: InfrahubClient, ipaddress_schema: NodeSchemaAPI, client_type: str +) -> None: data = {"address": {"value": ipaddress.ip_interface("1.1.1.1/24"), "is_protected": True}} if client_type == "standard": @@ -1672,7 +1702,9 @@ async def test_create_input_data_with_IPHost_attribute(client, ipaddress_schema, @pytest.mark.parametrize("client_type", client_types) -async def test_create_input_data_with_IPNetwork_attribute(client, ipnetwork_schema, client_type: str) -> None: +async def test_create_input_data_with_IPNetwork_attribute( + client: InfrahubClient, ipnetwork_schema: NodeSchemaAPI, client_type: str +) -> None: data = {"network": {"value": ipaddress.ip_network("1.1.1.0/24"), "is_protected": True}} if client_type == "standard": @@ -1688,16 +1720,16 @@ async def test_create_input_data_with_IPNetwork_attribute(client, ipnetwork_sche @pytest.mark.parametrize("property_test", property_tests) @pytest.mark.parametrize("client_type", client_types) async def test_update_input_data__with_relationships_01( - client, - location_schema, - location_data01, - location_data01_property, - tag_schema, - tag_blue_data, - tag_green_data, - tag_red_data, + client: InfrahubClient, + location_schema: NodeSchemaAPI, + location_data01: dict[str, Any], + location_data01_property: dict[str, Any], + tag_schema: NodeSchemaAPI, + tag_blue_data: dict[str, Any], + tag_green_data: dict[str, Any], + tag_red_data: dict[str, Any], client_type: str, - property_test, + property_test: str, ) -> None: location_data = location_data01 if property_test == WITHOUT_PROPERTY else location_data01_property @@ -1752,7 +1784,12 @@ async def test_update_input_data__with_relationships_01( @pytest.mark.parametrize("property_test", property_tests) @pytest.mark.parametrize("client_type", client_types) async def test_update_input_data_with_relationships_02( - client, location_schema, location_data02, location_data02_property, client_type: str, property_test + client: InfrahubClient, + location_schema: NodeSchemaAPI, + location_data02: dict[str, Any], + location_data02_property: dict[str, Any], + client_type: str, + property_test: str, ) -> None: location_data = location_data02 if property_test == WITHOUT_PROPERTY else location_data02_property @@ -1823,7 +1860,12 @@ async def test_update_input_data_with_relationships_02( @pytest.mark.parametrize("property_test", property_tests) @pytest.mark.parametrize("client_type", client_types) async def test_update_input_data_with_relationships_02_exclude_unmodified( - client, location_schema, location_data02, location_data02_property, client_type: str, property_test + client: InfrahubClient, + location_schema: NodeSchemaAPI, + location_data02: dict[str, Any], + location_data02_property: dict[str, Any], + client_type: str, + property_test: str, ) -> None: """NOTE Need to fix this test, issue is tracked in https://github.com/opsmill/infrahub-sdk-python/issues/214.""" location_data = location_data02 if property_test == WITHOUT_PROPERTY else location_data02_property @@ -1861,14 +1903,14 @@ async def test_update_input_data_with_relationships_02_exclude_unmodified( @pytest.mark.parametrize("property_test", property_tests) @pytest.mark.parametrize("client_type", client_types) async def test_update_input_data_empty_relationship( - client, - location_schema, - location_data01, - location_data01_property, - tag_schema, - tag_blue_data, + client: InfrahubClient, + location_schema: NodeSchemaAPI, + location_data01: dict[str, Any], + location_data01_property: dict[str, Any], + tag_schema: NodeSchemaAPI, + tag_blue_data: dict[str, Any], client_type: str, - property_test, + property_test: str, ) -> None: """TODO: investigate why name and type are being returned since they haven't been modified.""" location_data = location_data01 if property_test == WITHOUT_PROPERTY else location_data01_property @@ -1918,12 +1960,12 @@ async def test_update_input_data_empty_relationship( @pytest.mark.parametrize("client_type", client_types) async def test_node_get_relationship_from_store( - client, - location_schema, - location_data01, - tag_schema, - tag_red_data, - tag_blue_data, + client: InfrahubClient, + location_schema: NodeSchemaAPI, + location_data01: dict[str, Any], + tag_schema: NodeSchemaAPI, + tag_red_data: dict[str, Any], + tag_blue_data: dict[str, Any], client_type: str, ) -> None: if client_type == "standard": @@ -1946,7 +1988,9 @@ async def test_node_get_relationship_from_store( @pytest.mark.parametrize("client_type", client_types) -async def test_node_get_relationship_not_in_store(client, location_schema, location_data01, client_type: str) -> None: +async def test_node_get_relationship_not_in_store( + client: InfrahubClient, location_schema: NodeSchemaAPI, location_data01: dict[str, Any], client_type: str +) -> None: if client_type == "standard": node = InfrahubNode(client=client, schema=location_schema, data=location_data01) else: @@ -1962,13 +2006,13 @@ async def test_node_get_relationship_not_in_store(client, location_schema, locat @pytest.mark.parametrize("client_type", client_types) async def test_node_fetch_relationship( httpx_mock: HTTPXMock, - mock_schema_query_01, + mock_schema_query_01: HTTPXMock, clients: BothClients, - location_schema, - location_data01, - tag_schema, - tag_red_data, - tag_blue_data, + location_schema: NodeSchemaAPI, + location_data01: dict[str, Any], + tag_schema: NodeSchemaAPI, + tag_red_data: dict[str, Any], + tag_blue_data: dict[str, Any], client_type: str, ) -> None: response1 = { @@ -2032,7 +2076,9 @@ async def test_node_fetch_relationship( @pytest.mark.parametrize("client_type", client_types) -async def test_node_IPHost_deserialization(client, ipaddress_schema, client_type: str) -> None: +async def test_node_IPHost_deserialization( + client: InfrahubClient, ipaddress_schema: NodeSchemaAPI, client_type: str +) -> None: data = { "id": "aaaaaaaaaaaaaa", "address": { @@ -2049,7 +2095,9 @@ async def test_node_IPHost_deserialization(client, ipaddress_schema, client_type @pytest.mark.parametrize("client_type", client_types) -async def test_node_IPNetwork_deserialization(client, ipnetwork_schema, client_type: str) -> None: +async def test_node_IPNetwork_deserialization( + client: InfrahubClient, ipnetwork_schema: NodeSchemaAPI, client_type: str +) -> None: data = { "id": "aaaaaaaaaaaaaa", "network": { @@ -2068,10 +2116,10 @@ async def test_node_IPNetwork_deserialization(client, ipnetwork_schema, client_t @pytest.mark.parametrize("client_type", client_types) async def test_get_flat_value( httpx_mock: HTTPXMock, - mock_schema_query_01, + mock_schema_query_01: HTTPXMock, clients: BothClients, - location_schema, - location_data01, + location_schema: NodeSchemaAPI, + location_data01: dict[str, Any], client_type: str, ) -> None: httpx_mock.add_response( @@ -2100,7 +2148,9 @@ async def test_get_flat_value( @pytest.mark.parametrize("client_type", client_types) -async def test_node_extract(clients: BothClients, location_schema, location_data01, client_type: str) -> None: +async def test_node_extract( + clients: BothClients, location_schema: NodeSchemaAPI, location_data01: dict[str, Any], client_type: str +) -> None: params = {"identifier": "id", "name": "name__value", "description": "description__value"} if client_type == "standard": node = InfrahubNode(client=clients.standard, schema=location_schema, data=location_data01) @@ -2121,9 +2171,9 @@ async def test_node_extract(clients: BothClients, location_schema, location_data @pytest.mark.parametrize("client_type", client_types) async def test_read_only_attr( - client, - address_schema, - address_data, + client: InfrahubClient, + address_schema: NodeSchemaAPI, + address_data: dict[str, Any], client_type: str, ) -> None: if client_type == "standard": @@ -2143,7 +2193,9 @@ async def test_read_only_attr( @pytest.mark.parametrize("client_type", client_types) -async def test_relationships_excluded_input_data(client, location_schema, client_type: str) -> None: +async def test_relationships_excluded_input_data( + client: InfrahubClient, location_schema: NodeSchemaAPI, client_type: str +) -> None: data = { "name": {"value": "JFK1"}, "description": {"value": "JFK Airport"}, @@ -2161,7 +2213,12 @@ async def test_relationships_excluded_input_data(client, location_schema, client @pytest.mark.parametrize("client_type", client_types) async def test_create_input_data_with_resource_pool_relationship( - client, ipaddress_pool_schema, ipam_ipprefix_schema, simple_device_schema, ipam_ipprefix_data, client_type: str + client: InfrahubClient, + ipaddress_pool_schema: NodeSchemaAPI, + ipam_ipprefix_schema: NodeSchemaAPI, + simple_device_schema: NodeSchemaAPI, + ipam_ipprefix_data: dict[str, Any], + client_type: str, ) -> None: if client_type == "standard": ip_prefix = InfrahubNode(client=client, schema=ipam_ipprefix_schema, data=ipam_ipprefix_data) @@ -2213,7 +2270,12 @@ async def test_create_input_data_with_resource_pool_relationship( @pytest.mark.parametrize("client_type", client_types) async def test_create_mutation_query_with_resource_pool_relationship( - client, ipaddress_pool_schema, ipam_ipprefix_schema, simple_device_schema, ipam_ipprefix_data, client_type: str + client: InfrahubClient, + ipaddress_pool_schema: NodeSchemaAPI, + ipam_ipprefix_schema: NodeSchemaAPI, + simple_device_schema: NodeSchemaAPI, + ipam_ipprefix_data: dict[str, Any], + client_type: str, ) -> None: if client_type == "standard": ip_prefix = InfrahubNode(client=client, schema=ipam_ipprefix_schema, data=ipam_ipprefix_data) @@ -2269,9 +2331,9 @@ async def test_get_pool_allocated_resources( httpx_mock: HTTPXMock, mock_schema_query_ipam: HTTPXMock, clients: BothClients, - ipaddress_pool_schema, - ipam_ipprefix_schema, - ipam_ipprefix_data, + ipaddress_pool_schema: NodeSchemaAPI, + ipam_ipprefix_schema: NodeSchemaAPI, + ipam_ipprefix_data: dict[str, Any], client_type: str, ) -> None: httpx_mock.add_response( @@ -2367,9 +2429,9 @@ async def test_get_pool_allocated_resources( async def test_get_pool_resources_utilization( httpx_mock: HTTPXMock, clients: BothClients, - ipaddress_pool_schema, - ipam_ipprefix_schema, - ipam_ipprefix_data, + ipaddress_pool_schema: NodeSchemaAPI, + ipam_ipprefix_schema: NodeSchemaAPI, + ipam_ipprefix_data: dict[str, Any], client_type: str, ) -> None: httpx_mock.add_response( @@ -2433,7 +2495,9 @@ async def test_get_pool_resources_utilization( @pytest.mark.parametrize("client_type", client_types) -async def test_from_graphql(clients: BothClients, mock_schema_query_01, location_data01, client_type: str) -> None: +async def test_from_graphql( + clients: BothClients, mock_schema_query_01: HTTPXMock, location_data01: dict[str, Any], client_type: str +) -> None: if client_type == "standard": schema = await clients.standard.schema.get(kind="BuiltinLocation", branch="main") node = await InfrahubNode.from_graphql( @@ -2659,7 +2723,7 @@ async def test_process_relationships_recursive_deep_nesting( class TestRelatedNodeIsFromProfile: - def test_is_from_profile_when_source_is_profile(self, location_schema) -> None: + def test_is_from_profile_when_source_is_profile(self, location_schema: NodeSchemaAPI) -> None: data = { "node": {"id": "test-id", "display_label": "test-tag", "__typename": "BuiltinTag"}, "properties": { @@ -2671,7 +2735,7 @@ def test_is_from_profile_when_source_is_profile(self, location_schema) -> None: related_node = RelatedNodeBase(branch="main", schema=location_schema.relationships[0], data=data) assert related_node.is_from_profile - def test_is_from_profile_when_source_is_not_profile(self, location_schema) -> None: + def test_is_from_profile_when_source_is_not_profile(self, location_schema: NodeSchemaAPI) -> None: data = { "node": {"id": "test-id", "display_label": "test-tag", "__typename": "BuiltinTag"}, "properties": { @@ -2683,7 +2747,7 @@ def test_is_from_profile_when_source_is_not_profile(self, location_schema) -> No related_node = RelatedNodeBase(branch="main", schema=location_schema.relationships[0], data=data) assert not related_node.is_from_profile - def test_is_from_profile_when_source_not_queried(self, location_schema) -> None: + def test_is_from_profile_when_source_not_queried(self, location_schema: NodeSchemaAPI) -> None: data = { "node": {"id": "test-id", "display_label": "test-tag", "__typename": "BuiltinTag"}, "properties": {"is_protected": False, "owner": None, "source": None}, @@ -2691,18 +2755,20 @@ def test_is_from_profile_when_source_not_queried(self, location_schema) -> None: related_node = RelatedNodeBase(branch="main", schema=location_schema.relationships[0], data=data) assert not related_node.is_from_profile - def test_is_from_profile_when_no_properties(self, location_schema) -> None: + def test_is_from_profile_when_no_properties(self, location_schema: NodeSchemaAPI) -> None: data = {"node": {"id": "test-id", "display_label": "test-tag", "__typename": "BuiltinTag"}} related_node = RelatedNodeBase(branch="main", schema=location_schema.relationships[0], data=data) assert not related_node.is_from_profile class TestRelationshipManagerIsFromProfile: - def test_is_from_profile_when_no_peers(self, location_schema) -> None: + def test_is_from_profile_when_no_peers(self, location_schema: NodeSchemaAPI) -> None: manager = RelationshipManagerBase(name="tags", branch="main", schema=location_schema.relationships[0]) assert not manager.is_from_profile - def test_is_from_profile_when_all_peers_from_profile(self, client, location_schema) -> None: + def test_is_from_profile_when_all_peers_from_profile( + self, client: InfrahubClient, location_schema: NodeSchemaAPI + ) -> None: data = { "count": 2, "edges": [ @@ -2729,7 +2795,9 @@ def test_is_from_profile_when_all_peers_from_profile(self, client, location_sche ) assert manager.is_from_profile - def test_is_from_profile_when_any_peer_not_from_profile(self, client, location_schema) -> None: + def test_is_from_profile_when_any_peer_not_from_profile( + self, client: InfrahubClient, location_schema: NodeSchemaAPI + ) -> None: data = { "count": 2, "edges": [ @@ -2756,7 +2824,9 @@ def test_is_from_profile_when_any_peer_not_from_profile(self, client, location_s ) assert not manager.is_from_profile - def test_is_from_profile_when_any_peer_has_unknown_source(self, client, location_schema) -> None: + def test_is_from_profile_when_any_peer_has_unknown_source( + self, client: InfrahubClient, location_schema: NodeSchemaAPI + ) -> None: data = { "count": 2, "edges": [ diff --git a/tests/unit/sdk/test_object_store.py b/tests/unit/sdk/test_object_store.py index a43a5cd3..0e981d5c 100644 --- a/tests/unit/sdk/test_object_store.py +++ b/tests/unit/sdk/test_object_store.py @@ -4,6 +4,7 @@ from pytest_httpx import HTTPXMock from infrahub_sdk.object_store import ObjectStore, ObjectStoreSync +from tests.unit.sdk.conftest import BothClients async_methods = [method for method in dir(ObjectStore) if not method.startswith("_")] sync_methods = [method for method in dir(ObjectStoreSync) if not method.startswith("_")] @@ -44,7 +45,7 @@ async def test_method_sanity() -> None: @pytest.mark.parametrize("method", async_methods) -async def test_validate_method_signature(method) -> None: +async def test_validate_method_signature(method: str) -> None: async_method = getattr(ObjectStore, method) sync_method = getattr(ObjectStoreSync, method) async_sig = inspect.signature(async_method) @@ -54,7 +55,7 @@ async def test_validate_method_signature(method) -> None: @pytest.mark.parametrize("client_type", client_types) -async def test_object_store_get(client_type, clients, mock_get_object_store_01) -> None: +async def test_object_store_get(client_type: str, clients: BothClients, mock_get_object_store_01: HTTPXMock) -> None: client = getattr(clients, client_type) if client_type == "standard": @@ -66,7 +67,9 @@ async def test_object_store_get(client_type, clients, mock_get_object_store_01) @pytest.mark.parametrize("client_type", client_types) -async def test_object_store_upload(client_type, clients, mock_upload_object_store_01) -> None: +async def test_object_store_upload( + client_type: str, clients: BothClients, mock_upload_object_store_01: HTTPXMock +) -> None: client = getattr(clients, client_type) if client_type == "standard": diff --git a/tests/unit/sdk/test_schema.py b/tests/unit/sdk/test_schema.py index 54713fa8..05302b11 100644 --- a/tests/unit/sdk/test_schema.py +++ b/tests/unit/sdk/test_schema.py @@ -1,6 +1,7 @@ import inspect from io import StringIO from unittest import mock +from unittest.mock import MagicMock import pytest from pytest_httpx import HTTPXMock @@ -33,7 +34,7 @@ async def test_method_sanity() -> None: @pytest.mark.parametrize("method", async_schema_methods) -async def test_validate_method_signature(method) -> None: +async def test_validate_method_signature(method: str) -> None: async_method = getattr(InfrahubSchema, method) sync_method = getattr(InfrahubSchemaSync, method) async_sig = inspect.signature(async_method) @@ -43,7 +44,7 @@ async def test_validate_method_signature(method) -> None: @pytest.mark.parametrize("client_type", client_types) -async def test_fetch_schema(mock_schema_query_01, client_type) -> None: +async def test_fetch_schema(mock_schema_query_01: HTTPXMock, client_type: str) -> None: if client_type == "standard": client = InfrahubClient(config=Config(address="http://mock", insert_tracker=True)) nodes = await client.schema.fetch(branch="main") @@ -89,7 +90,7 @@ async def test_fetch_schema_conditional_refresh(mock_schema_query_01: HTTPXMock, @pytest.mark.parametrize("client_type", client_types) -async def test_schema_data_validation(rfile_schema, client_type) -> None: +async def test_schema_data_validation(rfile_schema: NodeSchemaAPI, client_type: str) -> None: if client_type == "standard": client = InfrahubClient(config=Config(address="http://mock", insert_tracker=True)) else: @@ -110,7 +111,10 @@ async def test_schema_data_validation(rfile_schema, client_type) -> None: @pytest.mark.parametrize("client_type", client_types) async def test_add_dropdown_option( - clients, client_type, mock_schema_query_01, mock_query_mutation_schema_dropdown_add + clients: BothClients, + client_type: str, + mock_schema_query_01: HTTPXMock, + mock_query_mutation_schema_dropdown_add: None, ) -> None: if client_type == "standard": await clients.standard.schema.add_dropdown_option("BuiltinTag", "status", "something") @@ -120,7 +124,10 @@ async def test_add_dropdown_option( @pytest.mark.parametrize("client_type", client_types) async def test_remove_dropdown_option( - clients, client_type, mock_schema_query_01, mock_query_mutation_schema_dropdown_remove + clients: BothClients, + client_type: str, + mock_schema_query_01: HTTPXMock, + mock_query_mutation_schema_dropdown_remove: None, ) -> None: if client_type == "standard": await clients.standard.schema.remove_dropdown_option("BuiltinTag", "status", "active") @@ -129,7 +136,9 @@ async def test_remove_dropdown_option( @pytest.mark.parametrize("client_type", client_types) -async def test_add_enum_option(clients, client_type, mock_schema_query_01, mock_query_mutation_schema_enum_add) -> None: +async def test_add_enum_option( + clients: BothClients, client_type: str, mock_schema_query_01: HTTPXMock, mock_query_mutation_schema_enum_add: None +) -> None: if client_type == "standard": await clients.standard.schema.add_enum_option("BuiltinTag", "mode", "hard") else: @@ -138,7 +147,10 @@ async def test_add_enum_option(clients, client_type, mock_schema_query_01, mock_ @pytest.mark.parametrize("client_type", client_types) async def test_remove_enum_option( - clients, client_type, mock_schema_query_01, mock_query_mutation_schema_enum_remove + clients: BothClients, + client_type: str, + mock_schema_query_01: HTTPXMock, + mock_query_mutation_schema_enum_remove: None, ) -> None: if client_type == "standard": await clients.standard.schema.remove_enum_option("BuiltinTag", "mode", "easy") @@ -147,7 +159,9 @@ async def test_remove_enum_option( @pytest.mark.parametrize("client_type", client_types) -async def test_add_dropdown_option_raises(clients, client_type, mock_schema_query_01) -> None: +async def test_add_dropdown_option_raises( + clients: BothClients, client_type: str, mock_schema_query_01: HTTPXMock +) -> None: if client_type == "standard": with pytest.raises(SchemaNotFoundError): await clients.standard.schema.add_dropdown_option("DoesNotExist", "atribute", "option") @@ -161,7 +175,7 @@ async def test_add_dropdown_option_raises(clients, client_type, mock_schema_quer @pytest.mark.parametrize("client_type", client_types) -async def test_add_enum_option_raises(clients, client_type, mock_schema_query_01) -> None: +async def test_add_enum_option_raises(clients: BothClients, client_type: str, mock_schema_query_01: HTTPXMock) -> None: if client_type == "standard": with pytest.raises(SchemaNotFoundError): await clients.standard.schema.add_enum_option("DoesNotExist", "atribute", "option") @@ -175,7 +189,9 @@ async def test_add_enum_option_raises(clients, client_type, mock_schema_query_01 @pytest.mark.parametrize("client_type", client_types) -async def test_remove_dropdown_option_raises(clients, client_type, mock_schema_query_01) -> None: +async def test_remove_dropdown_option_raises( + clients: BothClients, client_type: str, mock_schema_query_01: HTTPXMock +) -> None: if client_type == "standard": with pytest.raises(SchemaNotFoundError): await clients.standard.schema.remove_dropdown_option("DoesNotExist", "atribute", "option") @@ -189,7 +205,9 @@ async def test_remove_dropdown_option_raises(clients, client_type, mock_schema_q @pytest.mark.parametrize("client_type", client_types) -async def test_remove_enum_option_raises(clients, client_type, mock_schema_query_01) -> None: +async def test_remove_enum_option_raises( + clients: BothClients, client_type: str, mock_schema_query_01: HTTPXMock +) -> None: if client_type == "standard": with pytest.raises(SchemaNotFoundError): await clients.standard.schema.remove_enum_option("DoesNotExist", "atribute", "option") @@ -337,7 +355,7 @@ async def test_infrahub_repository_config_dups() -> None: "attributes": [{"name": "name", "kind": "Text"}, {"name": "status", "kind": "Dropdown"}], }, ) -async def test_display_schema_load_errors_details_dropdown(mock_get_node) -> None: +async def test_display_schema_load_errors_details_dropdown(mock_get_node: MagicMock) -> None: """Validate error message with details when loading schema.""" error = { "detail": [ @@ -370,7 +388,7 @@ async def test_display_schema_load_errors_details_dropdown(mock_get_node) -> Non "attributes": [{"name": "name", "kind": "Text"}, {"name": "status", "kind": "Dropdown"}], }, ) -async def test_display_schema_load_errors_details_namespace(mock_get_node) -> None: +async def test_display_schema_load_errors_details_namespace(mock_get_node: MagicMock) -> None: """Validate error message with details when loading schema.""" error = { "detail": [ @@ -425,7 +443,9 @@ async def test_display_schema_load_errors_details_namespace(mock_get_node) -> No ], }, ) -async def test_display_schema_load_errors_details_when_error_is_in_attribute_or_relationship(mock_get_node) -> None: +async def test_display_schema_load_errors_details_when_error_is_in_attribute_or_relationship( + mock_get_node: MagicMock, +) -> None: """Validate error message with details when loading schema and errors are in attribute or relationship.""" error = { "detail": [ diff --git a/tests/unit/sdk/test_store.py b/tests/unit/sdk/test_store.py index d8af340b..83644aae 100644 --- a/tests/unit/sdk/test_store.py +++ b/tests/unit/sdk/test_store.py @@ -1,14 +1,23 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + import pytest from infrahub_sdk.exceptions import NodeInvalidError, NodeNotFoundError from infrahub_sdk.node import InfrahubNode, InfrahubNodeSync from infrahub_sdk.store import NodeStore, NodeStoreSync +if TYPE_CHECKING: + from infrahub_sdk.schema import NodeSchemaAPI + + from .conftest import BothClients + client_types = ["standard", "sync"] @pytest.mark.parametrize("client_type", client_types) -def test_node_store_set(client_type, clients, schema_with_hfid) -> None: +def test_node_store_set(client_type: str, clients: BothClients, schema_with_hfid: dict[str, NodeSchemaAPI]) -> None: if client_type == "standard": client = clients.standard store = NodeStore(default_branch="main") @@ -33,7 +42,7 @@ def test_node_store_set(client_type, clients, schema_with_hfid) -> None: @pytest.mark.parametrize("client_type", client_types) -def test_node_store_set_no_hfid(client_type, clients, location_schema) -> None: +def test_node_store_set_no_hfid(client_type: str, clients: BothClients, location_schema: NodeSchemaAPI) -> None: if client_type == "standard": client = clients.standard store = NodeStore(default_branch="main") @@ -67,7 +76,7 @@ def test_node_store_set_no_hfid(client_type, clients, location_schema) -> None: @pytest.mark.parametrize("client_type", client_types) -def test_node_store_get(client_type, clients, location_schema) -> None: +def test_node_store_get(client_type: str, clients: BothClients, location_schema: NodeSchemaAPI) -> None: if client_type == "standard": client = clients.standard store = NodeStore(default_branch="main") @@ -111,7 +120,9 @@ def test_node_store_get(client_type, clients, location_schema) -> None: @pytest.mark.parametrize("client_type", client_types) -def test_node_store_get_with_hfid(client_type, clients, schema_with_hfid) -> None: +def test_node_store_get_with_hfid( + client_type: str, clients: BothClients, schema_with_hfid: dict[str, NodeSchemaAPI] +) -> None: if client_type == "standard": client = clients.standard store = NodeStore(default_branch="main") diff --git a/tests/unit/sdk/test_task.py b/tests/unit/sdk/test_task.py index 5ae73590..dd029003 100644 --- a/tests/unit/sdk/test_task.py +++ b/tests/unit/sdk/test_task.py @@ -1,4 +1,7 @@ +from __future__ import annotations + from datetime import datetime, timezone +from typing import TYPE_CHECKING import pytest @@ -6,11 +9,16 @@ from infrahub_sdk.task.manager import InfraHubTaskManagerBase from infrahub_sdk.task.models import Task, TaskFilter, TaskState +if TYPE_CHECKING: + from pytest_httpx import HTTPXMock + + from tests.unit.sdk.conftest import BothClients + client_types = ["standard", "sync"] @pytest.mark.parametrize("client_type", client_types) -async def test_method_all(clients, mock_query_tasks_01, client_type) -> None: +async def test_method_all(clients: BothClients, mock_query_tasks_01: HTTPXMock, client_type: str) -> None: if client_type == "standard": tasks = await clients.standard.task.all() else: @@ -21,7 +29,7 @@ async def test_method_all(clients, mock_query_tasks_01, client_type) -> None: @pytest.mark.parametrize("client_type", client_types) -async def test_method_all_full(clients, mock_query_tasks_01, client_type) -> None: +async def test_method_all_full(clients: BothClients, mock_query_tasks_01: HTTPXMock, client_type: str) -> None: if client_type == "standard": tasks = await clients.standard.task.all(include_logs=True, include_related_nodes=True) else: @@ -62,7 +70,7 @@ async def test_generate_count_query() -> None: @pytest.mark.parametrize("client_type", client_types) -async def test_method_filters(clients, mock_query_tasks_02_main, client_type) -> None: +async def test_method_filters(clients: BothClients, mock_query_tasks_02_main: HTTPXMock, client_type: str) -> None: if client_type == "standard": tasks = await clients.standard.task.filter(filter=TaskFilter(branch="main")) else: @@ -73,7 +81,7 @@ async def test_method_filters(clients, mock_query_tasks_02_main, client_type) -> @pytest.mark.parametrize("client_type", client_types) -async def test_method_get_too_many(clients, mock_query_tasks_02_main, client_type) -> None: +async def test_method_get_too_many(clients: BothClients, mock_query_tasks_02_main: HTTPXMock, client_type: str) -> None: with pytest.raises(TooManyTasksError): if client_type == "standard": await clients.standard.task.get(id="a60f4431-6a43-451e-8f42-9ec5db9a9370") @@ -82,7 +90,7 @@ async def test_method_get_too_many(clients, mock_query_tasks_02_main, client_typ @pytest.mark.parametrize("client_type", client_types) -async def test_method_get_not_found(clients, mock_query_tasks_empty, client_type) -> None: +async def test_method_get_not_found(clients: BothClients, mock_query_tasks_empty: HTTPXMock, client_type: str) -> None: with pytest.raises(TaskNotFoundError): if client_type == "standard": await clients.standard.task.get(id="a60f4431-6a43-451e-8f42-9ec5db9a9370") @@ -91,7 +99,7 @@ async def test_method_get_not_found(clients, mock_query_tasks_empty, client_type @pytest.mark.parametrize("client_type", client_types) -async def test_method_get(clients, mock_query_tasks_03, client_type) -> None: +async def test_method_get(clients: BothClients, mock_query_tasks_03: HTTPXMock, client_type: str) -> None: if client_type == "standard": task = await clients.standard.task.get(id="a60f4431-6a43-451e-8f42-9ec5db9a9370") else: @@ -102,7 +110,7 @@ async def test_method_get(clients, mock_query_tasks_03, client_type) -> None: @pytest.mark.parametrize("client_type", client_types) -async def test_method_get_full(clients, mock_query_tasks_05, client_type) -> None: +async def test_method_get_full(clients: BothClients, mock_query_tasks_05: HTTPXMock, client_type: str) -> None: if client_type == "standard": task = await clients.standard.task.get(id="32116fcd-9071-43a7-9f14-777901020b5b") else: