From fb75e0622708393f19ee172be2654ee2cbf8d7bb Mon Sep 17 00:00:00 2001 From: Alexey Makridenko Date: Sun, 6 Jul 2025 14:00:10 +0200 Subject: [PATCH 1/4] refactor model metaclass and options for improved configuration handling (#75) --- supadantic/models.py | 63 +++++++++++++++++++++++++++----------------- 1 file changed, 39 insertions(+), 24 deletions(-) diff --git a/supadantic/models.py b/supadantic/models.py index a76987d..058026a 100644 --- a/supadantic/models.py +++ b/supadantic/models.py @@ -2,7 +2,7 @@ import re from abc import ABC from copy import copy -from typing import TYPE_CHECKING, Any, TypeVar +from typing import TYPE_CHECKING, Any, ClassVar, Optional, TypeVar from pydantic import BaseModel, model_validator from pydantic._internal._model_construction import ModelMetaclass as PydanticModelMetaclass @@ -40,14 +40,39 @@ def _to_snake_case(value: str) -> str: return re.sub(r'(? QSet[_M]: # type: ignore @@ -85,6 +110,7 @@ class MultipleObjectsReturned(Exception): pass id: int | None = None + _meta: ClassVar[ModelOptions] def save(self: _M) -> _M: """ @@ -142,35 +168,24 @@ def db_client(cls) -> type['BaseClient']: (BaseClient): The database client class. """ - return SupabaseClient + return cls._meta.db_client @classmethod def _get_table_name(cls) -> str: """ - Gets the table name associated with the model, converting the class name to snake case. - - This method converts the class name to snake_case to determine the corresponding - table name in the database. - - Returns: - (str): The table name in snake_case. + Gets the table name associated with the model. + If no table_name is specified in Meta class, converts class name to snake_case. """ - return _to_snake_case(cls.__name__) + return cls._meta.table_name or _to_snake_case(cls.__name__) @classmethod def _get_db_client(cls) -> 'BaseClient': """ - Retrieves the database client instance for the model, configured with the table name. - - This method creates a database client instance using the `db_client()` method - and initializes it with the appropriate table name. - - Returns: - (BaseClient): An initialized instance of the database client. + Retrieves the database client instance for the model. """ - table_name = cls._get_table_name() - return cls.db_client()(table_name) + client = cls.db_client()(table_name) + return client @model_validator(mode='before') def _validate_data_from_supabase(cls, data: dict[str, Any]) -> dict[str, Any]: From 1b2a9b759e5bb9a49970d9d08d17b7995080bd31 Mon Sep 17 00:00:00 2001 From: Alexey Makridenko Date: Sun, 6 Jul 2025 14:02:07 +0200 Subject: [PATCH 2/4] move _to_snake_case function to utils.py (#75) --- supadantic/models.py | 23 +---------------------- supadantic/utils.py | 22 ++++++++++++++++++++++ 2 files changed, 23 insertions(+), 22 deletions(-) create mode 100644 supadantic/utils.py diff --git a/supadantic/models.py b/supadantic/models.py index 058026a..db38587 100644 --- a/supadantic/models.py +++ b/supadantic/models.py @@ -1,5 +1,4 @@ import ast -import re from abc import ABC from copy import copy from typing import TYPE_CHECKING, Any, ClassVar, Optional, TypeVar @@ -10,6 +9,7 @@ from .clients import SupabaseClient from .q_set import QSet from .query_builder import QueryBuilder +from .utils import _to_snake_case if TYPE_CHECKING: @@ -19,27 +19,6 @@ _M = TypeVar('_M', bound='BaseSBModel') -def _to_snake_case(value: str) -> str: - """ - Converts a string from camel case or Pascal case to snake case. - - This function uses a regular expression to find uppercase letters within - the string and inserts an underscore before them (except at the beginning - of the string). The entire string is then converted to lowercase. - - Args: - value (str): The string to convert. - - Returns: - (str): The snake_case version of the input string. - - Example: - >>> _to_snake_case("MyClassName") - 'my_class_name' - """ - return re.sub(r'(? str: + """ + Converts a string from camel case or Pascal case to snake case. + + This function uses a regular expression to find uppercase letters within + the string and inserts an underscore before them (except at the beginning + of the string). The entire string is then converted to lowercase. + + Args: + value (str): The string to convert. + + Returns: + (str): The snake_case version of the input string. + + Example: + >>> _to_snake_case("MyClassName") + 'my_class_name' + """ + return re.sub(r'(? Date: Sun, 6 Jul 2025 14:03:51 +0200 Subject: [PATCH 3/4] update documentation to reflect table name customization via Meta class (#75) --- README.md | 7 +++---- docs/index.md | 7 +++---- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index 5200831..8eef57b 100644 --- a/README.md +++ b/README.md @@ -35,10 +35,9 @@ class User(BaseSBModel): is_active: bool = True # By default table name is class name in snake_case - # If you want to change it - you should implement _get_table_name method - @classmethod - def _get_table_name(cls) -> str: - return 'db_user' + # you can override it by setting `Meta.table_name` attribute + class Meta: + table_name = 'db_user' # Save user active_user = User(name='John Snow') diff --git a/docs/index.md b/docs/index.md index b61c45b..86fe247 100644 --- a/docs/index.md +++ b/docs/index.md @@ -31,10 +31,9 @@ class User(BaseSBModel): is_active: bool = True # By default table name is class name in snake_case - # If you want to change it - you should implement _get_table_name method - @classmethod - def _get_table_name(cls) -> str: - return 'db_user' + # you can override it by setting `Meta.table_name` attribute + class Meta: + table_name = 'db_user' # Save user active_user = User(name='John Snow') From 4ffa021c59e9332714c1d2b084ee065d2250aebd Mon Sep 17 00:00:00 2001 From: Alexey Makridenko Date: Sun, 6 Jul 2025 14:56:25 +0200 Subject: [PATCH 4/4] add schema support (#75) --- supadantic/clients/base.py | 3 ++- supadantic/clients/cache.py | 4 ++-- supadantic/clients/supabase.py | 24 ++++++++++++++++++++---- supadantic/models.py | 14 ++++++++++---- tests/fixtures/model.py | 27 +++++++++++++++++++++------ tests/test_models.py | 12 ++++++++++++ tests/test_supabase_client.py | 26 ++++++++++++++++++++++++-- 7 files changed, 91 insertions(+), 19 deletions(-) diff --git a/supadantic/clients/base.py b/supadantic/clients/base.py index 5c7ef66..aa12b57 100644 --- a/supadantic/clients/base.py +++ b/supadantic/clients/base.py @@ -28,7 +28,7 @@ class BaseClient(ABC, metaclass=BaseClientMeta): for interacting with a specific database or service. """ - def __init__(self, table_name: str) -> None: + def __init__(self, table_name: str, schema: str | None = None) -> None: """ Initializes the client with the table name. @@ -39,6 +39,7 @@ def __init__(self, table_name: str) -> None: """ self.table_name = table_name + self.schema = schema def execute(self, *, query_builder: QueryBuilder) -> list[dict[str, Any]] | int: """ diff --git a/supadantic/clients/cache.py b/supadantic/clients/cache.py index f55db39..c178a49 100644 --- a/supadantic/clients/cache.py +++ b/supadantic/clients/cache.py @@ -52,7 +52,7 @@ class CacheClient(BaseClient, metaclass=SingletoneMeta): It is NOT suitable for production environments. """ - def __init__(self, table_name: str) -> None: + def __init__(self, table_name: str, schema: str | None = None) -> None: """ Initializes the client with the table name and an empty cache. @@ -63,7 +63,7 @@ def __init__(self, table_name: str) -> None: `BaseClient` interface and may be used in future extensions of this class. """ - super().__init__(table_name=table_name) + super().__init__(table_name=table_name, schema=schema) self._cache_data: dict[int, dict[str, Any]] = {} diff --git a/supadantic/clients/supabase.py b/supadantic/clients/supabase.py index 5ccc6fb..3e1c301 100644 --- a/supadantic/clients/supabase.py +++ b/supadantic/clients/supabase.py @@ -7,7 +7,7 @@ if TYPE_CHECKING: - from postgrest._sync.request_builder import SyncSelectRequestBuilder + from postgrest._sync.request_builder import SyncRequestBuilder, SyncSelectRequestBuilder from postgrest.base_request_builder import BaseFilterRequestBuilder from supadantic.query_builder import QueryBuilder @@ -25,7 +25,7 @@ class SupabaseClient(BaseClient): to initialize the Supabase client. """ - def __init__(self, table_name: str): + def __init__(self, table_name: str, schema: str | None = None) -> None: """ Initializes the Supabase client and sets up the query object. @@ -33,11 +33,27 @@ def __init__(self, table_name: str): table_name (str): The name of the table to interact with. """ - super().__init__(table_name=table_name) + super().__init__(table_name=table_name, schema=schema) url: str = os.getenv('SUPABASE_URL', default='') key: str = os.getenv('SUPABASE_KEY', default='') + + supabase_client = self._get_supabase_client(url=url, key=key) + self.query = supabase_client + + def _get_supabase_client(self, url: str, key: str) -> 'SyncRequestBuilder': + """ + Returns the Supabase client query object. + + This method is used to access the underlying Supabase client for executing queries. + It is primarily used internally by other methods in this class. + + Returns: + (SyncRequestBuilder): The Supabase client query object. + """ supabase_client = create_client(url, key) - self.query = supabase_client.table(table_name=self.table_name) + if self.schema: + supabase_client = supabase_client.schema(self.schema) + return supabase_client.table(self.table_name) def _delete(self, *, query_builder: 'QueryBuilder') -> list[dict[str, Any]]: """ diff --git a/supadantic/models.py b/supadantic/models.py index db38587..945777a 100644 --- a/supadantic/models.py +++ b/supadantic/models.py @@ -1,7 +1,7 @@ import ast from abc import ABC from copy import copy -from typing import TYPE_CHECKING, Any, ClassVar, Optional, TypeVar +from typing import TYPE_CHECKING, Any, ClassVar, TypeVar from pydantic import BaseModel, model_validator from pydantic._internal._model_construction import ModelMetaclass as PydanticModelMetaclass @@ -26,11 +26,13 @@ class ModelOptions: def __init__( self, - table_name: Optional[str] = None, - db_client: Optional[type['BaseClient']] = None, + table_name: str | None = None, + db_client: type['BaseClient'] | None = None, + schema: str | None = None, ): self.table_name = table_name self.db_client = db_client or SupabaseClient + self.schema = schema class ModelMetaclass(PydanticModelMetaclass): @@ -50,6 +52,9 @@ def __new__(mcs, name, bases, namespace): if hasattr(meta, 'db_client'): options.db_client = meta.db_client + if hasattr(meta, 'schema'): + options.schema = meta.schema + cls._meta = options return cls @@ -163,7 +168,8 @@ def _get_db_client(cls) -> 'BaseClient': Retrieves the database client instance for the model. """ table_name = cls._get_table_name() - client = cls.db_client()(table_name) + schema = cls._meta.schema + client = cls.db_client()(table_name, schema) return client @model_validator(mode='before') diff --git a/tests/fixtures/model.py b/tests/fixtures/model.py index b9cdf88..c360db7 100644 --- a/tests/fixtures/model.py +++ b/tests/fixtures/model.py @@ -9,19 +9,22 @@ if TYPE_CHECKING: from collections.abc import Generator - from supadantic.clients.base import BaseClient - class ModelMock(BaseSBModel): id: int | None = None name: str age: int | None = None some_optional_list: list[str] | None = None - some_optional_tuple: tuple[str, ...] | None = None + some_optional_tuple: tuple[str, ...] | None = None # noqa: CCE001 + + class Meta: + db_client = CacheClient + - @classmethod - def db_client(cls) -> type['BaseClient']: - return CacheClient +class ModelMockCustomSchema(ModelMock): + class Meta: + db_client = CacheClient + schema = 'custom_schema' @pytest.fixture(scope='function') @@ -29,8 +32,20 @@ def model_mock() -> type[ModelMock]: return ModelMock +@pytest.fixture(scope='function') +def model_mock_custom_schema() -> type[ModelMockCustomSchema]: + return ModelMockCustomSchema + + @pytest.fixture(autouse=True, scope='function') def clean_db_cache(model_mock: type['ModelMock']) -> 'Generator': yield model_mock.objects._cache = [] model_mock.objects.all().delete() + + +@pytest.fixture(autouse=True, scope='function') +def clean_db_cache_custom_schema(model_mock_custom_schema: type['ModelMockCustomSchema']) -> 'Generator': + yield + model_mock_custom_schema.objects._cache = [] + model_mock_custom_schema.objects.all().delete() diff --git a/tests/test_models.py b/tests/test_models.py index 4466596..8407c2b 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -47,3 +47,15 @@ def test_update(self, model_mock: type['ModelMock']): def test_objects(self, model_mock: type['ModelMock']): assert isinstance(model_mock.objects, QSet) + + +class TestBaseSBModelCustomSchema: + def test_db_client_with_custom_schema(self, model_mock_custom_schema: type['ModelMock']): + # Prepare data + test_entity = model_mock_custom_schema(name='test_name') + + # Execution + db_client = test_entity._get_db_client() + + # Testing + assert db_client.schema == 'custom_schema' diff --git a/tests/test_supabase_client.py b/tests/test_supabase_client.py index bdd7152..a17727f 100644 --- a/tests/test_supabase_client.py +++ b/tests/test_supabase_client.py @@ -83,7 +83,6 @@ def test_filter(self, supabase_client: SupabaseClient, httpx_mock: 'HTTPXMock'): ), status_code=200, ) - httpx_mock.add_response(is_optional=True) query_builder = QueryBuilder() query_builder.set_equal(id=1) @@ -132,7 +131,6 @@ def test_order_by(self, supabase_client: SupabaseClient, httpx_mock: 'HTTPXMock' ), status_code=200, ) - httpx_mock.add_response(is_optional=True) query_buider = QueryBuilder() query_buider.set_not_equal(title='test') @@ -144,3 +142,27 @@ def test_order_by(self, supabase_client: SupabaseClient, httpx_mock: 'HTTPXMock' # Assert assert len(httpx_mock.get_requests()) == 1 + + def test_select_with_schema(self, httpx_mock: 'HTTPXMock'): + # Arrange + httpx_mock.add_response( + method='GET', + url=httpx.URL( + 'https://test.supabase.co/rest/v1/table_name', + params={'select': '*', 'id': 'eq.1', 'title': 'neq.test'}, + ), + status_code=200, + ) + + query_builder = QueryBuilder() + query_builder.set_equal(id=1) + query_builder.set_not_equal(title='test') + + # Act + supabase_client = SupabaseClient(table_name='table_name', schema='foo') + supabase_client.execute(query_builder=query_builder) + + # Assert + assert len(httpx_mock.get_requests()) == 1 + request = httpx_mock.get_requests()[0] + assert request.headers['accept-profile'] == 'foo'