Skip to content
This repository was archived by the owner on May 25, 2026. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,9 @@ class User(BaseSBModel):
is_active: bool = True

# By default table name is class name in snake_case
# If you want to change it - you should implement _get_table_name method
@classmethod
def _get_table_name(cls) -> str:
return 'db_user'
# you can override it by setting `Meta.table_name` attribute
class Meta:
table_name = 'db_user'

# Save user
active_user = User(name='John Snow')
Expand Down
7 changes: 3 additions & 4 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,9 @@ class User(BaseSBModel):
is_active: bool = True

# By default table name is class name in snake_case
# If you want to change it - you should implement _get_table_name method
@classmethod
def _get_table_name(cls) -> str:
return 'db_user'
# you can override it by setting `Meta.table_name` attribute
class Meta:
table_name = 'db_user'

# Save user
active_user = User(name='John Snow')
Expand Down
3 changes: 2 additions & 1 deletion supadantic/clients/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class BaseClient(ABC, metaclass=BaseClientMeta):
for interacting with a specific database or service.
"""

def __init__(self, table_name: str) -> None:
def __init__(self, table_name: str, schema: str | None = None) -> None:
"""
Initializes the client with the table name.

Expand All @@ -39,6 +39,7 @@ def __init__(self, table_name: str) -> None:
"""

self.table_name = table_name
self.schema = schema

def execute(self, *, query_builder: QueryBuilder) -> list[dict[str, Any]] | int:
"""
Expand Down
4 changes: 2 additions & 2 deletions supadantic/clients/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class CacheClient(BaseClient, metaclass=SingletoneMeta):
It is NOT suitable for production environments.
"""

def __init__(self, table_name: str) -> None:
def __init__(self, table_name: str, schema: str | None = None) -> None:
"""
Initializes the client with the table name and an empty cache.

Expand All @@ -63,7 +63,7 @@ def __init__(self, table_name: str) -> None:
`BaseClient` interface and may be used in future
extensions of this class.
"""
super().__init__(table_name=table_name)
super().__init__(table_name=table_name, schema=schema)

self._cache_data: dict[int, dict[str, Any]] = {}

Expand Down
24 changes: 20 additions & 4 deletions supadantic/clients/supabase.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@


if TYPE_CHECKING:
from postgrest._sync.request_builder import SyncSelectRequestBuilder
from postgrest._sync.request_builder import SyncRequestBuilder, SyncSelectRequestBuilder
from postgrest.base_request_builder import BaseFilterRequestBuilder

from supadantic.query_builder import QueryBuilder
Expand All @@ -25,19 +25,35 @@ class SupabaseClient(BaseClient):
to initialize the Supabase client.
"""

def __init__(self, table_name: str):
def __init__(self, table_name: str, schema: str | None = None) -> None:
"""
Initializes the Supabase client and sets up the query object.

Args:
table_name (str): The name of the table to interact with.
"""

super().__init__(table_name=table_name)
super().__init__(table_name=table_name, schema=schema)
url: str = os.getenv('SUPABASE_URL', default='')
key: str = os.getenv('SUPABASE_KEY', default='')

supabase_client = self._get_supabase_client(url=url, key=key)
self.query = supabase_client

def _get_supabase_client(self, url: str, key: str) -> 'SyncRequestBuilder':
"""
Returns the Supabase client query object.

This method is used to access the underlying Supabase client for executing queries.
It is primarily used internally by other methods in this class.

Returns:
(SyncRequestBuilder): The Supabase client query object.
"""
supabase_client = create_client(url, key)
self.query = supabase_client.table(table_name=self.table_name)
if self.schema:
supabase_client = supabase_client.schema(self.schema)
return supabase_client.table(self.table_name)

def _delete(self, *, query_builder: 'QueryBuilder') -> list[dict[str, Any]]:
"""
Expand Down
82 changes: 41 additions & 41 deletions supadantic/models.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import ast
import re
from abc import ABC
from copy import copy
from typing import TYPE_CHECKING, Any, TypeVar
from typing import TYPE_CHECKING, Any, ClassVar, TypeVar

from pydantic import BaseModel, model_validator
from pydantic._internal._model_construction import ModelMetaclass as PydanticModelMetaclass

from .clients import SupabaseClient
from .q_set import QSet
from .query_builder import QueryBuilder
from .utils import _to_snake_case


if TYPE_CHECKING:
Expand All @@ -19,35 +19,44 @@
_M = TypeVar('_M', bound='BaseSBModel')


def _to_snake_case(value: str) -> str:
class ModelOptions:
"""
Configuration class to store model metadata and options.
"""
Converts a string from camel case or Pascal case to snake case.

This function uses a regular expression to find uppercase letters within
the string and inserts an underscore before them (except at the beginning
of the string). The entire string is then converted to lowercase.

Args:
value (str): The string to convert.
def __init__(
self,
table_name: str | None = None,
db_client: type['BaseClient'] | None = None,
schema: str | None = None,
):
self.table_name = table_name
self.db_client = db_client or SupabaseClient
self.schema = schema

Returns:
(str): The snake_case version of the input string.

Example:
>>> _to_snake_case("MyClassName")
'my_class_name'
class ModelMetaclass(PydanticModelMetaclass):
"""
Metaclass for BaseSBModel, handling Meta class configuration and objects property.
"""
return re.sub(r'(?<!^)(?=[A-Z])', '_', value).lower()

def __new__(mcs, name, bases, namespace):
cls = super().__new__(mcs, name, bases, namespace)
meta = namespace.get('Meta')
options = ModelOptions()

class ModelMetaclass(PydanticModelMetaclass):
"""
Metaclass for BaseSBModel, adding a custom `objects` property.
if meta is not None:
if hasattr(meta, 'table_name'):
options.table_name = meta.table_name

This metaclass extends Pydantic's ModelMetaclass to provide a custom `objects`
property on each class that uses it. The `objects` property returns a `QSet`
instance, which is used for performing database queries related to the model.
"""
if hasattr(meta, 'db_client'):
options.db_client = meta.db_client

if hasattr(meta, 'schema'):
options.schema = meta.schema

cls._meta = options
return cls

@property
def objects(cls: type[_M]) -> QSet[_M]: # type: ignore
Expand Down Expand Up @@ -85,6 +94,7 @@ class MultipleObjectsReturned(Exception):
pass

id: int | None = None
_meta: ClassVar[ModelOptions]

def save(self: _M) -> _M:
"""
Expand Down Expand Up @@ -142,35 +152,25 @@ def db_client(cls) -> type['BaseClient']:
(BaseClient): The database client class.
"""

return SupabaseClient
return cls._meta.db_client

@classmethod
def _get_table_name(cls) -> str:
"""
Gets the table name associated with the model, converting the class name to snake case.

This method converts the class name to snake_case to determine the corresponding
table name in the database.

Returns:
(str): The table name in snake_case.
Gets the table name associated with the model.
If no table_name is specified in Meta class, converts class name to snake_case.
"""
return _to_snake_case(cls.__name__)
return cls._meta.table_name or _to_snake_case(cls.__name__)

@classmethod
def _get_db_client(cls) -> 'BaseClient':
"""
Retrieves the database client instance for the model, configured with the table name.

This method creates a database client instance using the `db_client()` method
and initializes it with the appropriate table name.

Returns:
(BaseClient): An initialized instance of the database client.
Retrieves the database client instance for the model.
"""

table_name = cls._get_table_name()
return cls.db_client()(table_name)
schema = cls._meta.schema
client = cls.db_client()(table_name, schema)
return client

@model_validator(mode='before')
def _validate_data_from_supabase(cls, data: dict[str, Any]) -> dict[str, Any]:
Expand Down
22 changes: 22 additions & 0 deletions supadantic/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import re


def _to_snake_case(value: str) -> str:
"""
Converts a string from camel case or Pascal case to snake case.

This function uses a regular expression to find uppercase letters within
the string and inserts an underscore before them (except at the beginning
of the string). The entire string is then converted to lowercase.

Args:
value (str): The string to convert.

Returns:
(str): The snake_case version of the input string.

Example:
>>> _to_snake_case("MyClassName")
'my_class_name'
"""
return re.sub(r'(?<!^)(?=[A-Z])', '_', value).lower()
27 changes: 21 additions & 6 deletions tests/fixtures/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,28 +9,43 @@
if TYPE_CHECKING:
from collections.abc import Generator

from supadantic.clients.base import BaseClient


class ModelMock(BaseSBModel):
id: int | None = None
name: str
age: int | None = None
some_optional_list: list[str] | None = None
some_optional_tuple: tuple[str, ...] | None = None
some_optional_tuple: tuple[str, ...] | None = None # noqa: CCE001

class Meta:
db_client = CacheClient


@classmethod
def db_client(cls) -> type['BaseClient']:
return CacheClient
class ModelMockCustomSchema(ModelMock):
class Meta:
db_client = CacheClient
schema = 'custom_schema'


@pytest.fixture(scope='function')
def model_mock() -> type[ModelMock]:
return ModelMock


@pytest.fixture(scope='function')
def model_mock_custom_schema() -> type[ModelMockCustomSchema]:
return ModelMockCustomSchema


@pytest.fixture(autouse=True, scope='function')
def clean_db_cache(model_mock: type['ModelMock']) -> 'Generator':
yield
model_mock.objects._cache = []
model_mock.objects.all().delete()


@pytest.fixture(autouse=True, scope='function')
def clean_db_cache_custom_schema(model_mock_custom_schema: type['ModelMockCustomSchema']) -> 'Generator':
yield
model_mock_custom_schema.objects._cache = []
model_mock_custom_schema.objects.all().delete()
12 changes: 12 additions & 0 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,15 @@ def test_update(self, model_mock: type['ModelMock']):

def test_objects(self, model_mock: type['ModelMock']):
assert isinstance(model_mock.objects, QSet)


class TestBaseSBModelCustomSchema:
def test_db_client_with_custom_schema(self, model_mock_custom_schema: type['ModelMock']):
# Prepare data
test_entity = model_mock_custom_schema(name='test_name')

# Execution
db_client = test_entity._get_db_client()

# Testing
assert db_client.schema == 'custom_schema'
26 changes: 24 additions & 2 deletions tests/test_supabase_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,6 @@ def test_filter(self, supabase_client: SupabaseClient, httpx_mock: 'HTTPXMock'):
),
status_code=200,
)
httpx_mock.add_response(is_optional=True)

query_builder = QueryBuilder()
query_builder.set_equal(id=1)
Expand Down Expand Up @@ -132,7 +131,6 @@ def test_order_by(self, supabase_client: SupabaseClient, httpx_mock: 'HTTPXMock'
),
status_code=200,
)
httpx_mock.add_response(is_optional=True)

query_buider = QueryBuilder()
query_buider.set_not_equal(title='test')
Expand All @@ -144,3 +142,27 @@ def test_order_by(self, supabase_client: SupabaseClient, httpx_mock: 'HTTPXMock'

# Assert
assert len(httpx_mock.get_requests()) == 1

def test_select_with_schema(self, httpx_mock: 'HTTPXMock'):
# Arrange
httpx_mock.add_response(
method='GET',
url=httpx.URL(
'https://test.supabase.co/rest/v1/table_name',
params={'select': '*', 'id': 'eq.1', 'title': 'neq.test'},
),
status_code=200,
)

query_builder = QueryBuilder()
query_builder.set_equal(id=1)
query_builder.set_not_equal(title='test')

# Act
supabase_client = SupabaseClient(table_name='table_name', schema='foo')
supabase_client.execute(query_builder=query_builder)

# Assert
assert len(httpx_mock.get_requests()) == 1
request = httpx_mock.get_requests()[0]
assert request.headers['accept-profile'] == 'foo'