diff --git a/README.md b/README.md index 5200831..8eef57b 100644 --- a/README.md +++ b/README.md @@ -35,10 +35,9 @@ class User(BaseSBModel): is_active: bool = True # By default table name is class name in snake_case - # If you want to change it - you should implement _get_table_name method - @classmethod - def _get_table_name(cls) -> str: - return 'db_user' + # you can override it by setting `Meta.table_name` attribute + class Meta: + table_name = 'db_user' # Save user active_user = User(name='John Snow') diff --git a/docs/index.md b/docs/index.md index b61c45b..86fe247 100644 --- a/docs/index.md +++ b/docs/index.md @@ -31,10 +31,9 @@ class User(BaseSBModel): is_active: bool = True # By default table name is class name in snake_case - # If you want to change it - you should implement _get_table_name method - @classmethod - def _get_table_name(cls) -> str: - return 'db_user' + # you can override it by setting `Meta.table_name` attribute + class Meta: + table_name = 'db_user' # Save user active_user = User(name='John Snow') diff --git a/supadantic/clients/base.py b/supadantic/clients/base.py index 5c7ef66..aa12b57 100644 --- a/supadantic/clients/base.py +++ b/supadantic/clients/base.py @@ -28,7 +28,7 @@ class BaseClient(ABC, metaclass=BaseClientMeta): for interacting with a specific database or service. """ - def __init__(self, table_name: str) -> None: + def __init__(self, table_name: str, schema: str | None = None) -> None: """ Initializes the client with the table name. @@ -39,6 +39,7 @@ def __init__(self, table_name: str) -> None: """ self.table_name = table_name + self.schema = schema def execute(self, *, query_builder: QueryBuilder) -> list[dict[str, Any]] | int: """ diff --git a/supadantic/clients/cache.py b/supadantic/clients/cache.py index f55db39..c178a49 100644 --- a/supadantic/clients/cache.py +++ b/supadantic/clients/cache.py @@ -52,7 +52,7 @@ class CacheClient(BaseClient, metaclass=SingletoneMeta): It is NOT suitable for production environments. """ - def __init__(self, table_name: str) -> None: + def __init__(self, table_name: str, schema: str | None = None) -> None: """ Initializes the client with the table name and an empty cache. @@ -63,7 +63,7 @@ def __init__(self, table_name: str) -> None: `BaseClient` interface and may be used in future extensions of this class. """ - super().__init__(table_name=table_name) + super().__init__(table_name=table_name, schema=schema) self._cache_data: dict[int, dict[str, Any]] = {} diff --git a/supadantic/clients/supabase.py b/supadantic/clients/supabase.py index 5ccc6fb..3e1c301 100644 --- a/supadantic/clients/supabase.py +++ b/supadantic/clients/supabase.py @@ -7,7 +7,7 @@ if TYPE_CHECKING: - from postgrest._sync.request_builder import SyncSelectRequestBuilder + from postgrest._sync.request_builder import SyncRequestBuilder, SyncSelectRequestBuilder from postgrest.base_request_builder import BaseFilterRequestBuilder from supadantic.query_builder import QueryBuilder @@ -25,7 +25,7 @@ class SupabaseClient(BaseClient): to initialize the Supabase client. """ - def __init__(self, table_name: str): + def __init__(self, table_name: str, schema: str | None = None) -> None: """ Initializes the Supabase client and sets up the query object. @@ -33,11 +33,27 @@ def __init__(self, table_name: str): table_name (str): The name of the table to interact with. """ - super().__init__(table_name=table_name) + super().__init__(table_name=table_name, schema=schema) url: str = os.getenv('SUPABASE_URL', default='') key: str = os.getenv('SUPABASE_KEY', default='') + + supabase_client = self._get_supabase_client(url=url, key=key) + self.query = supabase_client + + def _get_supabase_client(self, url: str, key: str) -> 'SyncRequestBuilder': + """ + Returns the Supabase client query object. + + This method is used to access the underlying Supabase client for executing queries. + It is primarily used internally by other methods in this class. + + Returns: + (SyncRequestBuilder): The Supabase client query object. + """ supabase_client = create_client(url, key) - self.query = supabase_client.table(table_name=self.table_name) + if self.schema: + supabase_client = supabase_client.schema(self.schema) + return supabase_client.table(self.table_name) def _delete(self, *, query_builder: 'QueryBuilder') -> list[dict[str, Any]]: """ diff --git a/supadantic/models.py b/supadantic/models.py index a76987d..945777a 100644 --- a/supadantic/models.py +++ b/supadantic/models.py @@ -1,8 +1,7 @@ import ast -import re from abc import ABC from copy import copy -from typing import TYPE_CHECKING, Any, TypeVar +from typing import TYPE_CHECKING, Any, ClassVar, TypeVar from pydantic import BaseModel, model_validator from pydantic._internal._model_construction import ModelMetaclass as PydanticModelMetaclass @@ -10,6 +9,7 @@ from .clients import SupabaseClient from .q_set import QSet from .query_builder import QueryBuilder +from .utils import _to_snake_case if TYPE_CHECKING: @@ -19,35 +19,44 @@ _M = TypeVar('_M', bound='BaseSBModel') -def _to_snake_case(value: str) -> str: +class ModelOptions: + """ + Configuration class to store model metadata and options. """ - Converts a string from camel case or Pascal case to snake case. - - This function uses a regular expression to find uppercase letters within - the string and inserts an underscore before them (except at the beginning - of the string). The entire string is then converted to lowercase. - Args: - value (str): The string to convert. + def __init__( + self, + table_name: str | None = None, + db_client: type['BaseClient'] | None = None, + schema: str | None = None, + ): + self.table_name = table_name + self.db_client = db_client or SupabaseClient + self.schema = schema - Returns: - (str): The snake_case version of the input string. - Example: - >>> _to_snake_case("MyClassName") - 'my_class_name' +class ModelMetaclass(PydanticModelMetaclass): + """ + Metaclass for BaseSBModel, handling Meta class configuration and objects property. """ - return re.sub(r'(? QSet[_M]: # type: ignore @@ -85,6 +94,7 @@ class MultipleObjectsReturned(Exception): pass id: int | None = None + _meta: ClassVar[ModelOptions] def save(self: _M) -> _M: """ @@ -142,35 +152,25 @@ def db_client(cls) -> type['BaseClient']: (BaseClient): The database client class. """ - return SupabaseClient + return cls._meta.db_client @classmethod def _get_table_name(cls) -> str: """ - Gets the table name associated with the model, converting the class name to snake case. - - This method converts the class name to snake_case to determine the corresponding - table name in the database. - - Returns: - (str): The table name in snake_case. + Gets the table name associated with the model. + If no table_name is specified in Meta class, converts class name to snake_case. """ - return _to_snake_case(cls.__name__) + return cls._meta.table_name or _to_snake_case(cls.__name__) @classmethod def _get_db_client(cls) -> 'BaseClient': """ - Retrieves the database client instance for the model, configured with the table name. - - This method creates a database client instance using the `db_client()` method - and initializes it with the appropriate table name. - - Returns: - (BaseClient): An initialized instance of the database client. + Retrieves the database client instance for the model. """ - table_name = cls._get_table_name() - return cls.db_client()(table_name) + schema = cls._meta.schema + client = cls.db_client()(table_name, schema) + return client @model_validator(mode='before') def _validate_data_from_supabase(cls, data: dict[str, Any]) -> dict[str, Any]: diff --git a/supadantic/utils.py b/supadantic/utils.py new file mode 100644 index 0000000..9a77bcd --- /dev/null +++ b/supadantic/utils.py @@ -0,0 +1,22 @@ +import re + + +def _to_snake_case(value: str) -> str: + """ + Converts a string from camel case or Pascal case to snake case. + + This function uses a regular expression to find uppercase letters within + the string and inserts an underscore before them (except at the beginning + of the string). The entire string is then converted to lowercase. + + Args: + value (str): The string to convert. + + Returns: + (str): The snake_case version of the input string. + + Example: + >>> _to_snake_case("MyClassName") + 'my_class_name' + """ + return re.sub(r'(? type['BaseClient']: - return CacheClient +class ModelMockCustomSchema(ModelMock): + class Meta: + db_client = CacheClient + schema = 'custom_schema' @pytest.fixture(scope='function') @@ -29,8 +32,20 @@ def model_mock() -> type[ModelMock]: return ModelMock +@pytest.fixture(scope='function') +def model_mock_custom_schema() -> type[ModelMockCustomSchema]: + return ModelMockCustomSchema + + @pytest.fixture(autouse=True, scope='function') def clean_db_cache(model_mock: type['ModelMock']) -> 'Generator': yield model_mock.objects._cache = [] model_mock.objects.all().delete() + + +@pytest.fixture(autouse=True, scope='function') +def clean_db_cache_custom_schema(model_mock_custom_schema: type['ModelMockCustomSchema']) -> 'Generator': + yield + model_mock_custom_schema.objects._cache = [] + model_mock_custom_schema.objects.all().delete() diff --git a/tests/test_models.py b/tests/test_models.py index 4466596..8407c2b 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -47,3 +47,15 @@ def test_update(self, model_mock: type['ModelMock']): def test_objects(self, model_mock: type['ModelMock']): assert isinstance(model_mock.objects, QSet) + + +class TestBaseSBModelCustomSchema: + def test_db_client_with_custom_schema(self, model_mock_custom_schema: type['ModelMock']): + # Prepare data + test_entity = model_mock_custom_schema(name='test_name') + + # Execution + db_client = test_entity._get_db_client() + + # Testing + assert db_client.schema == 'custom_schema' diff --git a/tests/test_supabase_client.py b/tests/test_supabase_client.py index bdd7152..a17727f 100644 --- a/tests/test_supabase_client.py +++ b/tests/test_supabase_client.py @@ -83,7 +83,6 @@ def test_filter(self, supabase_client: SupabaseClient, httpx_mock: 'HTTPXMock'): ), status_code=200, ) - httpx_mock.add_response(is_optional=True) query_builder = QueryBuilder() query_builder.set_equal(id=1) @@ -132,7 +131,6 @@ def test_order_by(self, supabase_client: SupabaseClient, httpx_mock: 'HTTPXMock' ), status_code=200, ) - httpx_mock.add_response(is_optional=True) query_buider = QueryBuilder() query_buider.set_not_equal(title='test') @@ -144,3 +142,27 @@ def test_order_by(self, supabase_client: SupabaseClient, httpx_mock: 'HTTPXMock' # Assert assert len(httpx_mock.get_requests()) == 1 + + def test_select_with_schema(self, httpx_mock: 'HTTPXMock'): + # Arrange + httpx_mock.add_response( + method='GET', + url=httpx.URL( + 'https://test.supabase.co/rest/v1/table_name', + params={'select': '*', 'id': 'eq.1', 'title': 'neq.test'}, + ), + status_code=200, + ) + + query_builder = QueryBuilder() + query_builder.set_equal(id=1) + query_builder.set_not_equal(title='test') + + # Act + supabase_client = SupabaseClient(table_name='table_name', schema='foo') + supabase_client.execute(query_builder=query_builder) + + # Assert + assert len(httpx_mock.get_requests()) == 1 + request = httpx_mock.get_requests()[0] + assert request.headers['accept-profile'] == 'foo'