Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions infrahub_sdk/store.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from __future__ import annotations

import inspect
import warnings
from typing import TYPE_CHECKING, Literal, overload

from infrahub_sdk.protocols_base import CoreNodeBase

from .exceptions import NodeInvalidError, NodeNotFoundError
from .node.parsers import parse_human_friendly_id

Expand All @@ -16,8 +19,15 @@ def get_schema_name(schema: type[SchemaType | SchemaTypeSync] | str | None = Non
if isinstance(schema, str):
return schema

if hasattr(schema, "_is_runtime_protocol") and schema._is_runtime_protocol: # type: ignore[union-attr]
return schema.__name__ # type: ignore[union-attr]
if schema is None:
return None

if issubclass(schema, CoreNodeBase):
if inspect.iscoroutinefunction(schema.save):
return schema.__name__
if schema.__name__[-4:] == "Sync":
return schema.__name__[:-4]
return schema.__name__

return None

Expand Down
8 changes: 7 additions & 1 deletion tests/unit/sdk/test_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@

from infrahub_sdk.exceptions import NodeInvalidError, NodeNotFoundError
from infrahub_sdk.node import InfrahubNode, InfrahubNodeSync
from infrahub_sdk.store import NodeStore, NodeStoreSync
from infrahub_sdk.protocols import BuiltinIPAddressSync, BuiltinIPPrefix
from infrahub_sdk.store import NodeStore, NodeStoreSync, get_schema_name

if TYPE_CHECKING:
from infrahub_sdk.schema import NodeSchemaAPI
Expand Down Expand Up @@ -157,3 +158,8 @@ def test_node_store_get_with_hfid(
store.get(kind="BuiltinLocation", key="anotherkey")
with pytest.raises(NodeNotFoundError):
store.get(key="anotherkey")


def test_store_get_schema_name() -> None:
assert get_schema_name(schema=BuiltinIPPrefix) == BuiltinIPPrefix.__name__
assert get_schema_name(schema=BuiltinIPAddressSync) == BuiltinIPAddressSync.__name__[:-4]