diff --git a/anyblok_pyramid/anyblok.py b/anyblok_pyramid/anyblok.py index 7829055..a331101 100644 --- a/anyblok_pyramid/anyblok.py +++ b/anyblok_pyramid/anyblok.py @@ -24,6 +24,7 @@ from transaction.interfaces import IDataManagerSavepoint from zope.interface import implementer + from zope.sqlalchemy.datamanager import ( # noqa isort:skip _SESSION_STATE, NO_SAVEPOINT_SUPPORT, @@ -40,8 +41,8 @@ def __init__( self, session, status, transaction_manager, keep_session=False ): self.transaction_manager = transaction_manager - self.registry = session._query_cls.registry - self.transaction = self.registry.session.transaction + self.session = session + self.transaction = self.session._transaction transaction_manager.get().join(self) _SESSION_STATE[session] = status self.state = "init" @@ -49,14 +50,14 @@ def __init__( def _finish(self, final_state): assert self.transaction is not None - del _SESSION_STATE[self.registry.session] - registry = self.registry - self.transaction = self.registry = None + del _SESSION_STATE[self.session] + session = self.session + self.transaction = self.session = None self.state = final_state if not self.keep_session: - registry.session.close() + session.close() else: - registry.session.expire_all() + session.expire_all() EnvironmentManager.set("_precommit_hook", []) @@ -65,18 +66,18 @@ def abort(self, trans): self._finish("aborted") def tpc_begin(self, trans): - self.registry.session.flush() + self.session.flush() def commit(self, trans): - status = _SESSION_STATE[self.registry.session] + status = _SESSION_STATE[self.session] if status is not STATUS_INVALIDATED: - if self.registry.session.expire_on_commit: - self.registry.session.expire_all() + if self.session.expire_on_commit: + self.session.expire_all() self._finish("no work") def tpc_vote(self, trans): if self.transaction is not None: - self.registry.commit() + self.commit() self._finish("committed") def tpc_finish(self, trans): @@ -98,9 +99,9 @@ def savepoint(self): raise AttributeError("savepoint") return self._savepoint - def _savepoint(self): - self.registry.System.Cache.clear_invalidate_cache() - return AnyBlokSessionSavepoint(self.registry) + # def _savepoint(self): + # self.registry.System.Cache.clear_invalidate_cache() + # return AnyBlokSessionSavepoint(self.registry) def should_retry(self, error): if isinstance(error, ConcurrentModificationError): @@ -204,7 +205,7 @@ def after_bulk_delete(self, session, query, query_context, result): def before_commit(self, session): assert ( - session.transaction.nested + session._transaction.nested or self.transaction_manager.get().status # noqa in (ZopeStatus.COMMITTING, ZopeStatus.ACTIVE) ), ("Transaction must be committed using the transaction manager") diff --git a/anyblok_pyramid/pyramid_config.py b/anyblok_pyramid/pyramid_config.py index 15e5e29..d06aff2 100644 --- a/anyblok_pyramid/pyramid_config.py +++ b/anyblok_pyramid/pyramid_config.py @@ -14,6 +14,7 @@ from pkg_resources import iter_entry_points from pyramid.config import Configurator as PConfigurator +from anyblok.config import get_db_name as gdb from .common import get_registry_for logger = getLogger(__name__) @@ -42,7 +43,11 @@ def registry(self): The db_name must be defined """ - dbname = Configuration.get("get_db_name")(self.request) + dbname = Configuration.get("get_db_name") + if dbname: + dbname=dbname(self.request) + else: + dbname = gdb() if Registry.db_exists(db_name=dbname): return get_registry_for(dbname) else: @@ -88,7 +93,6 @@ def __call__(self, context, request): if not request.anyblok.registry: return False # pragma: no cover - # use this method because she is cached return request.anyblok.registry.System.Blok.is_installed(self.blok_name)