Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 27 additions & 2 deletions lightdb/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,18 @@ def __init__(self, location: str) -> None:

LightDB._current_db = self

def __enter__(self) -> "LightDB":
"""Support usage as a context manager

Returns:
``LightDB``: The database instance
"""
return self

def __exit__(self, exc_type, exc_val, exc_tb) -> None:
"""Save the database when exiting the context manager"""
self.save()

@classmethod
def current(cls) -> "LightDB":
"""Returns the current instance of the LightDB class
Expand Down Expand Up @@ -65,6 +77,7 @@ def _load(self) -> Dict[str, Any]:

def save(self) -> None:
"""Save the current state of the database to a JSON file"""
self.location.parent.mkdir(parents=True, exist_ok=True)
with self.location.open("w", encoding="utf-8") as file:
json.dump(self, file, ensure_ascii=False, indent=4)

Expand Down Expand Up @@ -99,16 +112,28 @@ def get(self, key: str, default: Union[_VT, _T] = None) -> Union[_VT, _T]:
"""
return super().get(key, default)

@overload
def pop(self, key: str) -> Any:
...

@overload
def pop(self, key: str, default: Any) -> Any:
...

def pop(self, key: str, *args) -> Any:
"""Remove a key-value pair from the database

Params:
key (``str``): The key to remove

default (``Any``, optional): The value to return if the key doesn`t exist.
If not provided and the key is missing, a ``KeyError`` is raised.

Returns:
``Any``: The removed key-value pair
``Any``: The value associated with the removed key, or ``default`` if the key
doesn`t exist and a default was provided
"""
return super().pop(key)
return super().pop(key, *args)

def reset(self) -> None:
"""Reset the database"""
Expand Down
32 changes: 29 additions & 3 deletions lightdb/fields.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""A file containing the implementation of the Field class for data validation and storage"""

from typing import Any, List, Dict, Optional, get_origin, get_args, TYPE_CHECKING
from typing import Any, List, Dict, Optional, Union, get_origin, get_args, TYPE_CHECKING

from .exceptions import ValidationError
from .query import Condition
Expand All @@ -22,7 +22,7 @@ def __init__(
"""Initializes a new instance of the field with the provided arguments

Params:
annotation (``str``, optional): The name of the field
name (``str``, optional): The name of the field

annotation (``Any``, optional): The type of the field

Expand Down Expand Up @@ -55,7 +55,33 @@ def validate(self, value: Any = None) -> None:
origin = get_origin(expected_type)
args = get_args(expected_type)

if origin is None:
# Handle Union types (including Optional[X] which is Union[X, None])
if origin is Union:
non_none_args = [a for a in args if a is not type(None)]
if value is None and type(None) in args:
return
for allowed_type in non_none_args:
allowed_origin = get_origin(allowed_type)
if allowed_origin is None:
if isinstance(value, allowed_type):
return
else:
# Recurse into nested generic type inside the Union
inner_field = Field(name=self.name, annotation=allowed_type)
try:
inner_field.validate(value)
return
except ValidationError:
pass
allowed_names = " | ".join(
getattr(a, "__name__", str(a)) for a in non_none_args
)
raise ValidationError(
f"Expected value of type `{allowed_names}` for field `{self.name}`, "
f"got `{type(value).__name__}`"
)

elif origin is None:
if not isinstance(value, expected_type):
raise ValidationError(f"Expected value of type `{expected_type.__name__}` for field `{self.name}`, got `{type(value).__name__}`")

Expand Down
10 changes: 5 additions & 5 deletions lightdb/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def __new__(mcs, name: str, bases: Tuple[type, ...], attrs: Dict[str, Any], **kw

attrs["__table__"] = table

if not attrs.get("__db__"):
if attrs.get("__db__") is None:
attrs["__db__"] = LightDB.current()

annotations: Dict[str, Any] = attrs.get("__annotations__", {})
Expand Down Expand Up @@ -156,11 +156,11 @@ def save(self) -> None:
def delete(self) -> None:
"""Deletes the current instance of the model from the database"""
rows = self.__db__.get(self.__table__, [])
updated_rows = [item for item in rows if item["_id"] != self._fields_map["_id"].value]

for item in rows:
if item["_id"] == self._fields_map["_id"].value:
rows.remove(item)
self.__db__.save()
if len(updated_rows) != len(rows):
self.__db__[self.__table__] = updated_rows
self.__db__.save()

@classmethod
def filter(cls: Type[MODEL], *args, **kwargs) -> List[MODEL]:
Expand Down
22 changes: 21 additions & 1 deletion tests/test_fields.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import pytest
from typing import List
from typing import List, Optional

from lightdb.fields import Field
from lightdb.exceptions import ValidationError
Expand Down Expand Up @@ -27,3 +27,23 @@ def test_field_validation_list():

with pytest.raises(ValidationError):
field.validate([1, "string", 3])


def test_field_validation_optional_allows_none():
field = Field(name="test", annotation=Optional[str])
field.validate(None)
field.validate("hello")


def test_field_validation_optional_rejects_wrong_type():
field = Field(name="test", annotation=Optional[str])
with pytest.raises(ValidationError):
field.validate(42)


def test_field_validation_optional_int():
field = Field(name="count", annotation=Optional[int])
field.validate(None)
field.validate(5)
with pytest.raises(ValidationError):
field.validate("not_an_int")
35 changes: 35 additions & 0 deletions tests/test_lightdb.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import pytest
import tempfile

from pathlib import Path

Expand Down Expand Up @@ -37,3 +38,37 @@ def test_lightdb_reset(db: LightDB):
db.set("key", "value")
db.reset()
assert db.get("key") is None


def test_lightdb_pop_with_default(db: LightDB):
db.set("key", "value")
assert db.pop("key") == "value"
assert db.pop("missing", None) is None
assert db.pop("missing", "fallback") == "fallback"


def test_lightdb_pop_raises_without_default(db: LightDB):
with pytest.raises(KeyError):
db.pop("nonexistent")


def test_lightdb_save_creates_parent_dirs():
with tempfile.TemporaryDirectory() as tmpdir:
nested_path = os.path.join(tmpdir, "a", "b", "c", "db.json")
db = LightDB(nested_path)
db.set("key", "value")
db.save()
assert os.path.exists(nested_path)

db2 = LightDB(nested_path)
assert db2.get("key") == "value"


def test_lightdb_context_manager():
with tempfile.TemporaryDirectory() as tmpdir:
path = os.path.join(tmpdir, "ctx_db.json")
with LightDB(path) as db:
db.set("key", "value")

db2 = LightDB(path)
assert db2.get("key") == "value"
34 changes: 34 additions & 0 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,37 @@ def test_model_all(user_model: MODEL):

results = user_model.all()
assert len(results) == 2


def test_model_delete_correct_entry(user_model: MODEL):
john = user_model.create(name="John", age=30)
jane = user_model.create(name="Jane", age=25)
john.delete()

remaining = user_model.all()
assert len(remaining) == 1
assert remaining[0].name == "Jane"


def test_model_optional_field():
import os
from typing import Optional

test_db_location = "test_optional_db.json"
from lightdb.core import LightDB

db = LightDB(test_db_location)

class Profile(Model, table="profiles"):
username: str
bio: Optional[str] = None

try:
p = Profile.create(username="Alice")
assert p.bio is None

p2 = Profile.create(username="Bob", bio="Hello!")
assert p2.bio == "Hello!"
finally:
if os.path.exists(test_db_location):
os.remove(test_db_location)