Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ repos:

- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.11.9
rev: v0.14.10
hooks:
# Run the linter.
- id: ruff
Expand Down
2 changes: 1 addition & 1 deletion infrahub_sdk/ctl/branch.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ async def report(
git_files_changed = await check_git_files_changed(client, branch=branch_name)

proposed_changes = await client.filters(
kind=CoreProposedChange, # type: ignore[type-abstract]
kind=CoreProposedChange,
source_branch__value=branch_name,
include=["created_by"],
prefetch_relationships=True,
Expand Down
12 changes: 7 additions & 5 deletions infrahub_sdk/node/related_node.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import re
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, cast

from ..exceptions import Error
from ..protocols_base import CoreNodeBase
Expand All @@ -11,7 +11,7 @@
if TYPE_CHECKING:
from ..client import InfrahubClient, InfrahubClientSync
from ..schema import RelationshipSchemaAPI
from .node import InfrahubNode, InfrahubNodeSync
from .node import InfrahubNode, InfrahubNodeBase, InfrahubNodeSync


class RelatedNodeBase:
Expand All @@ -34,7 +34,7 @@ def __init__(self, branch: str, schema: RelationshipSchemaAPI, data: Any | dict,
self._properties_object = PROPERTIES_OBJECT
self._properties = self._properties_flag + self._properties_object

self._peer = None
self._peer: InfrahubNodeBase | CoreNodeBase | None = None
self._id: str | None = None
self._hfid: list[str] | None = None
self._display_label: str | None = None
Expand All @@ -43,8 +43,10 @@ def __init__(self, branch: str, schema: RelationshipSchemaAPI, data: Any | dict,
self._source_typename: str | None = None
self._relationship_metadata: RelationshipMetadata | None = None

if isinstance(data, (CoreNodeBase)):
self._peer = data
# Check for InfrahubNodeBase instances using duck-typing (_schema attribute)
# to avoid circular imports, or CoreNodeBase instances
if isinstance(data, CoreNodeBase) or hasattr(data, "_schema"):
self._peer = cast("InfrahubNodeBase | CoreNodeBase", data)
for prop in self._properties:
setattr(self, prop, None)
self._relationship_metadata = None
Expand Down
27 changes: 15 additions & 12 deletions infrahub_sdk/protocols_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,7 @@ class AnyAttributeOptional(Attribute):
value: float | None


@runtime_checkable
class CoreNodeBase(Protocol):
class CoreNodeBase:
_schema: MainSchemaTypes
_internal_id: str
id: str # NOTE this is incorrect, should be str | None
Expand All @@ -189,23 +188,28 @@ def get_human_friendly_id(self) -> list[str] | None: ...

def get_human_friendly_id_as_string(self, include_kind: bool = False) -> str | None: ...

def get_kind(self) -> str: ...
def get_kind(self) -> str:
raise NotImplementedError()

def get_all_kinds(self) -> list[str]: ...
def get_all_kinds(self) -> list[str]:
raise NotImplementedError()

def get_branch(self) -> str: ...
def get_branch(self) -> str:
raise NotImplementedError()

def is_ip_prefix(self) -> bool: ...
def is_ip_prefix(self) -> bool:
raise NotImplementedError()

def is_ip_address(self) -> bool: ...
def is_ip_address(self) -> bool:
raise NotImplementedError()

def is_resource_pool(self) -> bool: ...
def is_resource_pool(self) -> bool:
raise NotImplementedError()

def get_raw_graphql_data(self) -> dict | None: ...


@runtime_checkable
class CoreNode(CoreNodeBase, Protocol):
class CoreNode(CoreNodeBase):
async def save(
self,
allow_upsert: bool = False,
Expand All @@ -229,8 +233,7 @@ async def add_relationships(self, relation_to_update: str, related_nodes: list[s
async def remove_relationships(self, relation_to_update: str, related_nodes: list[str]) -> None: ...


@runtime_checkable
class CoreNodeSync(CoreNodeBase, Protocol):
class CoreNodeSync(CoreNodeBase):
def save(
self,
allow_upsert: bool = False,
Expand Down
5 changes: 3 additions & 2 deletions infrahub_sdk/schema/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
ValidationError,
)
from ..graphql import Mutation
from ..protocols_base import CoreNodeBase
from ..queries import SCHEMA_HASH_SYNC_STATUS
from .main import (
AttributeSchema,
Expand Down Expand Up @@ -207,14 +208,14 @@ def _get_schema_name(schema: type[SchemaType | SchemaTypeSync] | str) -> str:
if isinstance(schema, str):
return schema

if hasattr(schema, "_is_runtime_protocol") and getattr(schema, "_is_runtime_protocol", None):
if issubclass(schema, CoreNodeBase):
if inspect.iscoroutinefunction(schema.save):
return schema.__name__
if schema.__name__[-4:] == "Sync":
return schema.__name__[:-4]
return schema.__name__

raise ValueError("schema must be a protocol or a string")
raise ValueError("schema must be a CoreNode subclass or a string")

@staticmethod
def _parse_schema_response(response: httpx.Response, branch: str) -> MutableMapping[str, Any]:
Expand Down
48 changes: 44 additions & 4 deletions infrahub_sdk/spec/object.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from ..exceptions import ObjectValidationError, ValidationError
from ..schema import GenericSchemaAPI, RelationshipKind, RelationshipSchema
from ..utils import is_valid_uuid
from ..yaml import InfrahubFile, InfrahubFileKind
from .models import InfrahubObjectParameters
from .processors.factory import DataProcessorFactory
Expand All @@ -33,6 +34,36 @@ def validate_list_of_objects(value: list[Any]) -> bool:
return all(isinstance(item, dict) for item in value)


def normalize_hfid_reference(value: str | list[str]) -> str | list[str]:
"""Normalize a reference value to HFID format.

Only call this function when the peer schema has human_friendly_id defined.

Args:
value: Either a string (ID or single-component HFID) or a list of strings (multi-component HFID).

Returns:
- If value is already a list: returns it unchanged as list[str]
- If value is a valid UUID string: returns it unchanged as str (will be treated as an ID)
- If value is a non-UUID string: wraps it in a list as list[str] (single-component HFID)
"""
if isinstance(value, list):
return value
if is_valid_uuid(value):
return value
return [value]


def normalize_hfid_references(values: list[str | list[str]]) -> list[str | list[str]]:
"""Normalize a list of reference values to HFID format.

Only call this function when the peer schema has human_friendly_id defined.

Each string that is not a valid UUID will be wrapped in a list to treat it as a single-component HFID.
"""
return [normalize_hfid_reference(v) for v in values]


class RelationshipDataFormat(str, Enum):
UNKNOWN = "unknown"

Expand All @@ -51,6 +82,12 @@ class RelationshipInfo(BaseModel):
peer_rel: RelationshipSchema | None = None
reason_relationship_not_valid: str | None = None
format: RelationshipDataFormat = RelationshipDataFormat.UNKNOWN
peer_human_friendly_id: list[str] | None = None

@property
def peer_has_hfid(self) -> bool:
"""Indicate if the peer schema has a human-friendly ID defined."""
return bool(self.peer_human_friendly_id)

@property
def is_bidirectional(self) -> bool:
Expand Down Expand Up @@ -119,6 +156,7 @@ async def get_relationship_info(
info.peer_kind = value["kind"]

peer_schema = await client.schema.get(kind=info.peer_kind, branch=branch)
info.peer_human_friendly_id = peer_schema.human_friendly_id

try:
info.peer_rel = peer_schema.get_matching_relationship(
Expand Down Expand Up @@ -444,10 +482,12 @@ async def create_node(
# - if the relationship is bidirectional and is mandatory on the other side, then we need to create this object First
# - if the relationship is bidirectional and is not mandatory on the other side, then we need should create the related object First
# - if the relationship is not bidirectional, then we need to create the related object First
if rel_info.is_reference and isinstance(value, list):
clean_data[key] = value
elif rel_info.format == RelationshipDataFormat.ONE_REF and isinstance(value, str):
clean_data[key] = [value]
if rel_info.format == RelationshipDataFormat.MANY_REF and isinstance(value, list):
# Cardinality-many reference: normalize string HFIDs to list format if peer has HFID defined
clean_data[key] = normalize_hfid_references(value) if rel_info.peer_has_hfid else value
elif rel_info.format == RelationshipDataFormat.ONE_REF:
# Cardinality-one reference: normalize string to HFID list if peer has HFID, else pass as-is
clean_data[key] = normalize_hfid_reference(value) if rel_info.peer_has_hfid else value
elif not rel_info.is_reference and rel_info.is_bidirectional and rel_info.is_mandatory:
remaining_rels.append(key)
elif not rel_info.is_reference and not rel_info.is_mandatory:
Expand Down
2 changes: 1 addition & 1 deletion infrahub_sdk/testing/repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ async def wait_for_sync_to_complete(
) -> bool:
for _ in range(retries):
repo = await client.get(
kind=CoreGenericRepository, # type: ignore[type-abstract]
kind=CoreGenericRepository,
name__value=self.name,
branch=branch or self.initial_branch,
)
Expand Down
19 changes: 0 additions & 19 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -328,25 +328,6 @@ max-complexity = 17
"ARG002", # Unused method argument
]

##################################################################################################
# ANN001 ignores - broken down for incremental cleanup #
# Remove each section as type annotations are added to that directory #
##################################################################################################

# tests/unit/sdk/ - 478 errors total
"tests/unit/sdk/test_node.py" = ["ANN001"] # 206 errors
"tests/unit/sdk/test_client.py" = ["ANN001"] # 85 errors
"tests/unit/sdk/test_schema.py" = ["ANN001"] # 36 errors
"tests/unit/sdk/test_artifact.py" = ["ANN001"] # 27 errors
"tests/unit/sdk/test_hierarchical_nodes.py" = ["ANN001"] # 26 errors
"tests/unit/sdk/test_task.py" = ["ANN001"] # 21 errors
"tests/unit/sdk/test_store.py" = ["ANN001"] # 12 errors
"tests/unit/sdk/spec/test_object.py" = ["ANN001"] # 11 errors
"tests/unit/sdk/conftest.py" = ["ANN001"] # 11 errors
"tests/unit/sdk/test_diff_summary.py" = ["ANN001"] # 9 errors
"tests/unit/sdk/test_object_store.py" = ["ANN001"] # 7 errors
"tests/unit/sdk/graphql/test_query.py" = ["ANN001"] # 7 errors

# tests/integration/
"tests/integration/test_infrahub_client.py" = ["PLR0904"]
"tests/integration/test_infrahub_client_sync.py" = ["PLR0904"]
Expand Down
5 changes: 4 additions & 1 deletion tests/fixtures/schema_01.json
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,10 @@
"label": null,
"inherit_from": [],
"branch": "aware",
"default_filter": "name__value"
"default_filter": "name__value",
"human_friendly_id": [
"name__value"
]
},
{
"name": "Location",
Expand Down
28 changes: 17 additions & 11 deletions tests/unit/sdk/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def replace_annotation(annotation: str) -> str:

@pytest.fixture
def replace_async_parameter_annotations(
replace_async_return_annotation,
replace_async_return_annotation: Callable[[str], str],
) -> Callable[[Mapping[str, Parameter]], list[tuple[str, str]]]:
"""Allows for comparison between sync and async parameter annotations."""

Expand All @@ -130,7 +130,7 @@ def replace_annotation(annotation: str) -> str:

@pytest.fixture
def replace_sync_parameter_annotations(
replace_sync_return_annotation,
replace_sync_return_annotation: Callable[[str], str],
) -> Callable[[Mapping[str, Parameter]], list[tuple[str, str]]]:
"""Allows for comparison between sync and async parameter annotations."""

Expand Down Expand Up @@ -1501,7 +1501,7 @@ async def mock_repositories_query_no_pagination(httpx_mock: HTTPXMock) -> HTTPXM

@pytest.fixture
async def mock_query_repository_all_01_no_pagination(
httpx_mock: HTTPXMock, client: InfrahubClient, mock_schema_query_01
httpx_mock: HTTPXMock, client: InfrahubClient, mock_schema_query_01: HTTPXMock
) -> HTTPXMock:
response = {
"data": {
Expand Down Expand Up @@ -1601,7 +1601,7 @@ async def mock_repositories_query(httpx_mock: HTTPXMock) -> HTTPXMock:

@pytest.fixture
async def mock_query_repository_page1_1(
httpx_mock: HTTPXMock, client: InfrahubClient, mock_schema_query_01
httpx_mock: HTTPXMock, client: InfrahubClient, mock_schema_query_01: HTTPXMock
) -> HTTPXMock:
response = {
"data": {
Expand Down Expand Up @@ -1641,7 +1641,9 @@ async def mock_query_repository_page1_1(


@pytest.fixture
async def mock_query_corenode_page1_1(httpx_mock: HTTPXMock, client: InfrahubClient, mock_schema_query_02) -> HTTPXMock:
async def mock_query_corenode_page1_1(
httpx_mock: HTTPXMock, client: InfrahubClient, mock_schema_query_02: HTTPXMock
) -> HTTPXMock:
response = {
"data": {
"CoreNode": {
Expand Down Expand Up @@ -1676,14 +1678,16 @@ async def mock_query_corenode_page1_1(httpx_mock: HTTPXMock, client: InfrahubCli


@pytest.fixture
async def mock_query_repository_count(httpx_mock: HTTPXMock, client: InfrahubClient, mock_schema_query_01) -> HTTPXMock:
async def mock_query_repository_count(
httpx_mock: HTTPXMock, client: InfrahubClient, mock_schema_query_01: HTTPXMock
) -> HTTPXMock:
httpx_mock.add_response(method="POST", json={"data": {"CoreRepository": {"count": 5}}}, is_reusable=True)
return httpx_mock


@pytest.fixture
async def mock_query_repository_page1_empty(
httpx_mock: HTTPXMock, client: InfrahubClient, mock_schema_query_01
httpx_mock: HTTPXMock, client: InfrahubClient, mock_schema_query_01: HTTPXMock
) -> HTTPXMock:
response: dict = {"data": {"CoreRepository": {"edges": []}}}

Expand All @@ -1698,7 +1702,7 @@ async def mock_query_repository_page1_empty(

@pytest.fixture
async def mock_query_repository_page1_2(
httpx_mock: HTTPXMock, client: InfrahubClient, mock_schema_query_01
httpx_mock: HTTPXMock, client: InfrahubClient, mock_schema_query_01: HTTPXMock
) -> HTTPXMock:
response = {
"data": {
Expand Down Expand Up @@ -1748,7 +1752,7 @@ async def mock_query_repository_page1_2(

@pytest.fixture
async def mock_query_repository_page2_2(
httpx_mock: HTTPXMock, client: InfrahubClient, mock_schema_query_01
httpx_mock: HTTPXMock, client: InfrahubClient, mock_schema_query_01: HTTPXMock
) -> HTTPXMock:
response = {
"data": {
Expand Down Expand Up @@ -2512,15 +2516,17 @@ async def mock_schema_query_ipam(httpx_mock: HTTPXMock) -> HTTPXMock:

@pytest.fixture
async def mock_query_location_batch_count(
httpx_mock: HTTPXMock, client: InfrahubClient, mock_schema_query_01
httpx_mock: HTTPXMock, client: InfrahubClient, mock_schema_query_01: HTTPXMock
) -> HTTPXMock:
response = {"data": {"BuiltinLocation": {"count": 30}}}
httpx_mock.add_response(method="POST", url="http://mock/graphql/main", json=response, is_reusable=True)
return httpx_mock


@pytest.fixture
async def mock_query_location_batch(httpx_mock: HTTPXMock, client: InfrahubClient, mock_schema_query_01) -> HTTPXMock:
async def mock_query_location_batch(
httpx_mock: HTTPXMock, client: InfrahubClient, mock_schema_query_01: HTTPXMock
) -> HTTPXMock:
for i in range(1, 11):
filename = get_fixtures_dir() / "batch" / f"mock_query_location_page{i}.json"
response_text = filename.read_text(encoding="UTF-8")
Expand Down
Loading