Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 17 additions & 16 deletions anyblok_pyramid/anyblok.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from transaction.interfaces import IDataManagerSavepoint
from zope.interface import implementer


from zope.sqlalchemy.datamanager import ( # noqa isort:skip
_SESSION_STATE,
NO_SAVEPOINT_SUPPORT,
Expand All @@ -40,23 +41,23 @@ def __init__(
self, session, status, transaction_manager, keep_session=False
):
self.transaction_manager = transaction_manager
self.registry = session._query_cls.registry
self.transaction = self.registry.session.transaction
self.session = session
self.transaction = self.session._transaction
transaction_manager.get().join(self)
_SESSION_STATE[session] = status
self.state = "init"
self.keep_session = keep_session

def _finish(self, final_state):
assert self.transaction is not None
del _SESSION_STATE[self.registry.session]
registry = self.registry
self.transaction = self.registry = None
del _SESSION_STATE[self.session]
session = self.session
self.transaction = self.session = None
self.state = final_state
if not self.keep_session:
registry.session.close()
session.close()
else:
registry.session.expire_all()
session.expire_all()

EnvironmentManager.set("_precommit_hook", [])

Expand All @@ -65,18 +66,18 @@ def abort(self, trans):
self._finish("aborted")

def tpc_begin(self, trans):
self.registry.session.flush()
self.session.flush()

def commit(self, trans):
status = _SESSION_STATE[self.registry.session]
status = _SESSION_STATE[self.session]
if status is not STATUS_INVALIDATED:
if self.registry.session.expire_on_commit:
self.registry.session.expire_all()
if self.session.expire_on_commit:
self.session.expire_all()
self._finish("no work")

def tpc_vote(self, trans):
if self.transaction is not None:
self.registry.commit()
self.commit()
self._finish("committed")

def tpc_finish(self, trans):
Expand All @@ -98,9 +99,9 @@ def savepoint(self):
raise AttributeError("savepoint")
return self._savepoint

def _savepoint(self):
self.registry.System.Cache.clear_invalidate_cache()
return AnyBlokSessionSavepoint(self.registry)
# def _savepoint(self):
# self.registry.System.Cache.clear_invalidate_cache()
# return AnyBlokSessionSavepoint(self.registry)
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure how to get proper anyblok registry here

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


def should_retry(self, error):
if isinstance(error, ConcurrentModificationError):
Expand Down Expand Up @@ -204,7 +205,7 @@ def after_bulk_delete(self, session, query, query_context, result):

def before_commit(self, session):
assert (
session.transaction.nested
session._transaction.nested
or self.transaction_manager.get().status # noqa
in (ZopeStatus.COMMITTING, ZopeStatus.ACTIVE)
), ("Transaction must be committed using the transaction manager")
Expand Down
8 changes: 6 additions & 2 deletions anyblok_pyramid/pyramid_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from pkg_resources import iter_entry_points
from pyramid.config import Configurator as PConfigurator

from anyblok.config import get_db_name as gdb
from .common import get_registry_for

logger = getLogger(__name__)
Expand Down Expand Up @@ -42,7 +43,11 @@ def registry(self):
The db_name must be defined

"""
dbname = Configuration.get("get_db_name")(self.request)
dbname = Configuration.get("get_db_name")
if dbname:
dbname=dbname(self.request)
else:
dbname = gdb()
if Registry.db_exists(db_name=dbname):
return get_registry_for(dbname)
else:
Expand Down Expand Up @@ -88,7 +93,6 @@ def __call__(self, context, request):

if not request.anyblok.registry:
return False # pragma: no cover

# use this method because she is cached
return request.anyblok.registry.System.Blok.is_installed(self.blok_name)

Expand Down