diff --git a/lightdb/core.py b/lightdb/core.py index ffdd7bc..1a89d89 100644 --- a/lightdb/core.py +++ b/lightdb/core.py @@ -34,6 +34,18 @@ def __init__(self, location: str) -> None: LightDB._current_db = self + def __enter__(self) -> "LightDB": + """Support usage as a context manager + + Returns: + ``LightDB``: The database instance + """ + return self + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + """Save the database when exiting the context manager""" + self.save() + @classmethod def current(cls) -> "LightDB": """Returns the current instance of the LightDB class @@ -65,6 +77,7 @@ def _load(self) -> Dict[str, Any]: def save(self) -> None: """Save the current state of the database to a JSON file""" + self.location.parent.mkdir(parents=True, exist_ok=True) with self.location.open("w", encoding="utf-8") as file: json.dump(self, file, ensure_ascii=False, indent=4) @@ -99,16 +112,28 @@ def get(self, key: str, default: Union[_VT, _T] = None) -> Union[_VT, _T]: """ return super().get(key, default) + @overload def pop(self, key: str) -> Any: + ... + + @overload + def pop(self, key: str, default: Any) -> Any: + ... + + def pop(self, key: str, *args) -> Any: """Remove a key-value pair from the database Params: key (``str``): The key to remove + default (``Any``, optional): The value to return if the key doesn`t exist. + If not provided and the key is missing, a ``KeyError`` is raised. + Returns: - ``Any``: The removed key-value pair + ``Any``: The value associated with the removed key, or ``default`` if the key + doesn`t exist and a default was provided """ - return super().pop(key) + return super().pop(key, *args) def reset(self) -> None: """Reset the database""" diff --git a/lightdb/fields.py b/lightdb/fields.py index 52f2120..c59c059 100644 --- a/lightdb/fields.py +++ b/lightdb/fields.py @@ -1,6 +1,6 @@ """A file containing the implementation of the Field class for data validation and storage""" -from typing import Any, List, Dict, Optional, get_origin, get_args, TYPE_CHECKING +from typing import Any, List, Dict, Optional, Union, get_origin, get_args, TYPE_CHECKING from .exceptions import ValidationError from .query import Condition @@ -22,7 +22,7 @@ def __init__( """Initializes a new instance of the field with the provided arguments Params: - annotation (``str``, optional): The name of the field + name (``str``, optional): The name of the field annotation (``Any``, optional): The type of the field @@ -55,7 +55,33 @@ def validate(self, value: Any = None) -> None: origin = get_origin(expected_type) args = get_args(expected_type) - if origin is None: + # Handle Union types (including Optional[X] which is Union[X, None]) + if origin is Union: + non_none_args = [a for a in args if a is not type(None)] + if value is None and type(None) in args: + return + for allowed_type in non_none_args: + allowed_origin = get_origin(allowed_type) + if allowed_origin is None: + if isinstance(value, allowed_type): + return + else: + # Recurse into nested generic type inside the Union + inner_field = Field(name=self.name, annotation=allowed_type) + try: + inner_field.validate(value) + return + except ValidationError: + pass + allowed_names = " | ".join( + getattr(a, "__name__", str(a)) for a in non_none_args + ) + raise ValidationError( + f"Expected value of type `{allowed_names}` for field `{self.name}`, " + f"got `{type(value).__name__}`" + ) + + elif origin is None: if not isinstance(value, expected_type): raise ValidationError(f"Expected value of type `{expected_type.__name__}` for field `{self.name}`, got `{type(value).__name__}`") diff --git a/lightdb/models.py b/lightdb/models.py index b4e3396..7e40fed 100644 --- a/lightdb/models.py +++ b/lightdb/models.py @@ -27,7 +27,7 @@ def __new__(mcs, name: str, bases: Tuple[type, ...], attrs: Dict[str, Any], **kw attrs["__table__"] = table - if not attrs.get("__db__"): + if attrs.get("__db__") is None: attrs["__db__"] = LightDB.current() annotations: Dict[str, Any] = attrs.get("__annotations__", {}) @@ -156,11 +156,11 @@ def save(self) -> None: def delete(self) -> None: """Deletes the current instance of the model from the database""" rows = self.__db__.get(self.__table__, []) + updated_rows = [item for item in rows if item["_id"] != self._fields_map["_id"].value] - for item in rows: - if item["_id"] == self._fields_map["_id"].value: - rows.remove(item) - self.__db__.save() + if len(updated_rows) != len(rows): + self.__db__[self.__table__] = updated_rows + self.__db__.save() @classmethod def filter(cls: Type[MODEL], *args, **kwargs) -> List[MODEL]: diff --git a/tests/test_fields.py b/tests/test_fields.py index e40aa7d..ef59ae3 100644 --- a/tests/test_fields.py +++ b/tests/test_fields.py @@ -1,5 +1,5 @@ import pytest -from typing import List +from typing import List, Optional from lightdb.fields import Field from lightdb.exceptions import ValidationError @@ -27,3 +27,23 @@ def test_field_validation_list(): with pytest.raises(ValidationError): field.validate([1, "string", 3]) + + +def test_field_validation_optional_allows_none(): + field = Field(name="test", annotation=Optional[str]) + field.validate(None) + field.validate("hello") + + +def test_field_validation_optional_rejects_wrong_type(): + field = Field(name="test", annotation=Optional[str]) + with pytest.raises(ValidationError): + field.validate(42) + + +def test_field_validation_optional_int(): + field = Field(name="count", annotation=Optional[int]) + field.validate(None) + field.validate(5) + with pytest.raises(ValidationError): + field.validate("not_an_int") diff --git a/tests/test_lightdb.py b/tests/test_lightdb.py index d7a1a08..d7b0dd1 100644 --- a/tests/test_lightdb.py +++ b/tests/test_lightdb.py @@ -1,5 +1,6 @@ import os import pytest +import tempfile from pathlib import Path @@ -37,3 +38,37 @@ def test_lightdb_reset(db: LightDB): db.set("key", "value") db.reset() assert db.get("key") is None + + +def test_lightdb_pop_with_default(db: LightDB): + db.set("key", "value") + assert db.pop("key") == "value" + assert db.pop("missing", None) is None + assert db.pop("missing", "fallback") == "fallback" + + +def test_lightdb_pop_raises_without_default(db: LightDB): + with pytest.raises(KeyError): + db.pop("nonexistent") + + +def test_lightdb_save_creates_parent_dirs(): + with tempfile.TemporaryDirectory() as tmpdir: + nested_path = os.path.join(tmpdir, "a", "b", "c", "db.json") + db = LightDB(nested_path) + db.set("key", "value") + db.save() + assert os.path.exists(nested_path) + + db2 = LightDB(nested_path) + assert db2.get("key") == "value" + + +def test_lightdb_context_manager(): + with tempfile.TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, "ctx_db.json") + with LightDB(path) as db: + db.set("key", "value") + + db2 = LightDB(path) + assert db2.get("key") == "value" diff --git a/tests/test_models.py b/tests/test_models.py index cab07e3..e087aa4 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -68,3 +68,37 @@ def test_model_all(user_model: MODEL): results = user_model.all() assert len(results) == 2 + + +def test_model_delete_correct_entry(user_model: MODEL): + john = user_model.create(name="John", age=30) + jane = user_model.create(name="Jane", age=25) + john.delete() + + remaining = user_model.all() + assert len(remaining) == 1 + assert remaining[0].name == "Jane" + + +def test_model_optional_field(): + import os + from typing import Optional + + test_db_location = "test_optional_db.json" + from lightdb.core import LightDB + + db = LightDB(test_db_location) + + class Profile(Model, table="profiles"): + username: str + bio: Optional[str] = None + + try: + p = Profile.create(username="Alice") + assert p.bio is None + + p2 = Profile.create(username="Bob", bio="Hello!") + assert p2.bio == "Hello!" + finally: + if os.path.exists(test_db_location): + os.remove(test_db_location)