From edb7c3ffff38a04790213343361dda12ec526c77 Mon Sep 17 00:00:00 2001 From: "Yunior C. Fonseca Reyna" Date: Wed, 22 Apr 2026 10:39:42 +0200 Subject: [PATCH 01/14] fix: make thumbnail cache writes atomic and validate cached images --- app/backend/api/routers/project_router.py | 2 +- app/backend/utils/thumbnail_service.py | 407 +++++++++++++--------- 2 files changed, 237 insertions(+), 172 deletions(-) diff --git a/app/backend/api/routers/project_router.py b/app/backend/api/routers/project_router.py index 89a7786..dc4957d 100644 --- a/app/backend/api/routers/project_router.py +++ b/app/backend/api/routers/project_router.py @@ -3050,7 +3050,7 @@ def listProjectThumbnailItems( ) response = JSONResponse(items) - response.headers["Cache-Control"] = "private, max-age=20, stale-while-revalidate=60" + response.headers["Cache-Control"] = "private, no-store" response.headers["Access-Control-Expose-Headers"] = "Cache-Control" return _attachDebugHeaders(response, currentUser) diff --git a/app/backend/utils/thumbnail_service.py b/app/backend/utils/thumbnail_service.py index 300d8ce..6920f6b 100644 --- a/app/backend/utils/thumbnail_service.py +++ b/app/backend/utils/thumbnail_service.py @@ -32,6 +32,8 @@ import hashlib from urllib.parse import quote from pathlib import Path +import tempfile +import threading from typing import Any, Dict, Iterable, List, Optional, Sequence, Set, Tuple import numpy as np @@ -62,6 +64,8 @@ from pwem.viewers.viewers_data import RegistryViewerConfig logger = logging.getLogger(__name__) +_thumbnailBuildLocksGuard = threading.Lock() +_thumbnailBuildLocks: Dict[str, threading.Lock] = {} class ThumbnailService: @@ -74,6 +78,31 @@ def __init__(self, currentProject): # ------------------------------------------------------------------ # Public API # ------------------------------------------------------------------ + def _getThumbnailBuildLock(self, cachePath: Path): + key = str(cachePath) + with _thumbnailBuildLocksGuard: + lock = _thumbnailBuildLocks.get(key) + if lock is None: + lock = threading.Lock() + _thumbnailBuildLocks[key] = lock + return lock + + def _isValidCachedImage(self, cachePath: Path) -> bool: + try: + if not cachePath.exists(): + return False + if cachePath.stat().st_size <= 0: + return False + with Image.open(cachePath) as img: + img.verify() + return True + except Exception: + try: + cachePath.unlink(missing_ok=True) + except Exception: + pass + return False + def listUsefulProtocols(self, maxProtocols: int = 12) -> List[Dict[str, Any]]: protocols = self._iterProtocols() candidates: List[Dict[str, Any]] = [] @@ -133,7 +162,11 @@ def buildProtocolThumbnail( if protocol is None: raise ValueError(f"Protocol {protocolId} not found") - cachePath = self._getProtocolCachePath(protocolId, size=size, outputName=outputName) + cachePath = self._getProtocolCachePath( + protocolId, + size=size, + outputName=outputName, + ) selectedCandidate: Optional[Dict[str, Any]] = None if outputName: @@ -154,88 +187,93 @@ def buildProtocolThumbnail( "exists": False, } - if cachePath.exists() and not force: - return { - "protocolId": int(protocolId), - "protocolLabel": self._getProtocolLabel(protocol), - "status": self._getProtocolStatus(protocol), - "outputName": selectedCandidate["outputName"] if selectedCandidate else outputName, - "outputClassName": selectedCandidate["outputClassName"] if selectedCandidate else None, - "absolutePath": str(cachePath), - "cached": True, - "exists": True, - } + buildLock = self._getThumbnailBuildLock(cachePath) + with buildLock: + if not force and self._isValidCachedImage(cachePath): + return { + "protocolId": int(protocolId), + "protocolLabel": self._getProtocolLabel(protocol), + "status": self._getProtocolStatus(protocol), + "outputName": selectedCandidate["outputName"] if selectedCandidate else outputName, + "outputClassName": selectedCandidate["outputClassName"] if selectedCandidate else None, + "absolutePath": str(cachePath), + "cached": True, + "exists": True, + } - previewImage: Optional[Image.Image] = None - candidates = [selectedCandidate] if selectedCandidate is not None else self._collectSortedOutputCandidates( - protocol) + previewImage: Optional[Image.Image] = None + candidates = ( + [selectedCandidate] + if selectedCandidate is not None + else self._collectSortedOutputCandidates(protocol) + ) - for candidate in candidates: - if candidate is None: - continue + for candidate in candidates: + if candidate is None: + continue - try: - image = self._renderProtocolPreviewImage( - protocol=protocol, - output=candidate["output"], - outputName=candidate["outputName"], - outputClassName=candidate["outputClassName"], - size=size, - ) - if image is not None: - selectedCandidate = candidate - previewImage = image - break - except Exception: - logger.debug( - "Candidate thumbnail render failed. protocolId=%s output=%s class=%s", - protocolId, - candidate.get("outputName"), - candidate.get("outputClassName"), - exc_info=True, - ) + try: + image = self._renderProtocolPreviewImage( + protocol=protocol, + output=candidate["output"], + outputName=candidate["outputName"], + outputClassName=candidate["outputClassName"], + size=size, + ) + if image is not None: + selectedCandidate = candidate + previewImage = image + break + except Exception: + logger.debug( + "Candidate thumbnail render failed. protocolId=%s output=%s class=%s", + protocolId, + candidate.get("outputName"), + candidate.get("outputClassName"), + exc_info=True, + ) - if previewImage is None and outputName is None: - try: - previewImage = self._renderProtocolFilesystemFallback(protocol, size=size) - except Exception: - logger.debug( - "Filesystem thumbnail fallback failed. protocolId=%s", - protocolId, - exc_info=True, - ) - previewImage = None + if previewImage is None and outputName is None: + try: + previewImage = self._renderProtocolFilesystemFallback(protocol, size=size) + except Exception: + logger.debug( + "Filesystem thumbnail fallback failed. protocolId=%s", + protocolId, + exc_info=True, + ) + previewImage = None + + if previewImage is None: + return { + "protocolId": int(protocolId), + "protocolLabel": self._getProtocolLabel(protocol), + "status": self._getProtocolStatus(protocol), + "outputName": selectedCandidate["outputName"] if selectedCandidate else outputName, + "outputClassName": selectedCandidate["outputClassName"] if selectedCandidate else None, + "absolutePath": None, + "cached": False, + "exists": False, + } + + thumbnail = self._finalizeProtocolThumbnail( + previewImage=previewImage, + size=size, + protocolId=int(protocolId), + ) + self._saveImage(thumbnail, cachePath) - if previewImage is None: return { "protocolId": int(protocolId), "protocolLabel": self._getProtocolLabel(protocol), "status": self._getProtocolStatus(protocol), "outputName": selectedCandidate["outputName"] if selectedCandidate else outputName, "outputClassName": selectedCandidate["outputClassName"] if selectedCandidate else None, - "absolutePath": None, + "absolutePath": str(cachePath), "cached": False, - "exists": False, + "exists": True, } - thumbnail = self._finalizeProtocolThumbnail( - previewImage=previewImage, - size=size, - protocolId=int(protocolId), - ) - self._saveImage(thumbnail, cachePath) - - return { - "protocolId": int(protocolId), - "protocolLabel": self._getProtocolLabel(protocol), - "status": self._getProtocolStatus(protocol), - "outputName": selectedCandidate["outputName"] if selectedCandidate else outputName, - "outputClassName": selectedCandidate["outputClassName"] if selectedCandidate else None, - "absolutePath": str(cachePath), - "cached": False, - "exists": True, - } - def buildProjectThumbnail( self, force: bool = False, @@ -243,65 +281,77 @@ def buildProjectThumbnail( maxProtocols: int = 6, ) -> Dict[str, Any]: cachePath = self._getProjectCachePath(size=size, maxProtocols=maxProtocols) - if cachePath.exists() and not force: - return { - "absolutePath": str(cachePath), - "cached": True, - "items": None, - } - useful = self.listUsefulProtocols(maxProtocols=max(3, int(maxProtocols) * 3)) - renderedItems: List[Dict[str, Any]] = [] - protocolThumbWidth = self._projectProtocolSize(size=int(size), maxProtocols=int(maxProtocols)) + buildLock = self._getThumbnailBuildLock(cachePath) + with buildLock: + if not force and self._isValidCachedImage(cachePath): + return { + "absolutePath": str(cachePath), + "cached": True, + "items": None, + } - for candidate in useful: - try: - built = self.buildProtocolThumbnail( - protocolId=int(candidate["protocolId"]), - force=force, - size=protocolThumbWidth, - ) + useful = self.listUsefulProtocols( + maxProtocols=max(3, int(maxProtocols) * 3), + ) + renderedItems: List[Dict[str, Any]] = [] + protocolThumbWidth = self._projectProtocolSize( + size=int(size), + maxProtocols=int(maxProtocols), + ) - if not built.get("exists") or not built.get("absolutePath"): - continue + for candidate in useful: + try: + built = self.buildProtocolThumbnail( + protocolId=int(candidate["protocolId"]), + force=force, + size=protocolThumbWidth, + ) - renderedItems.append( - { - "protocolId": int(candidate["protocolId"]), - "protocolLabel": candidate.get("protocolLabel"), - "status": candidate.get("status"), - "outputName": candidate.get("outputName"), - "outputClassName": candidate.get("outputClassName"), - "itemsCount": candidate.get("itemsCount", 0), - "absolutePath": built["absolutePath"], - } - ) + if not built.get("exists") or not built.get("absolutePath"): + continue - if len(renderedItems) >= int(maxProtocols): - break + renderedItems.append( + { + "protocolId": int(candidate["protocolId"]), + "protocolLabel": candidate.get("protocolLabel"), + "status": candidate.get("status"), + "outputName": candidate.get("outputName"), + "outputClassName": candidate.get("outputClassName"), + "itemsCount": candidate.get("itemsCount", 0), + "absolutePath": built["absolutePath"], + } + ) - except Exception: - logger.debug( - "Skipping failed protocol thumbnail while building project strip. protocolId=%s", - candidate.get("protocolId"), - exc_info=True, - ) + if len(renderedItems) >= int(maxProtocols): + break + + except Exception: + logger.debug( + "Skipping failed protocol thumbnail while building project strip. protocolId=%s", + candidate.get("protocolId"), + exc_info=True, + ) + + if not renderedItems: + return { + "absolutePath": None, + "cached": False, + "items": 0, + } + + strip = self._composeProjectStrip( + items=renderedItems, + size=int(size), + ) + self._saveImage(strip, cachePath) - if not renderedItems: return { - "absolutePath": None, + "absolutePath": str(cachePath), "cached": False, - "items": 0, + "items": len(renderedItems), } - strip = self._composeProjectStrip(items=renderedItems, size=int(size)) - self._saveImage(strip, cachePath) - return { - "absolutePath": str(cachePath), - "cached": False, - "items": len(renderedItems), - } - def _getBadgeFont(self, badgeH: int): fontSize = max(12, int(round(badgeH * 0.48))) candidates = [ @@ -422,78 +472,80 @@ def buildProtocolOutputThumbnail( size=size, ) - if cachePath.exists() and not force: - return { - "protocolId": int(protocolId), - "protocolLabel": self._getProtocolLabel(protocol), - "status": self._getProtocolStatus(protocol), - "outputName": outputName, - "outputClassName": outputClassName, - "absolutePath": str(cachePath), - "cached": True, - "exists": True, - } + buildLock = self._getThumbnailBuildLock(cachePath) + with buildLock: + if not force and self._isValidCachedImage(cachePath): + return { + "protocolId": int(protocolId), + "protocolLabel": self._getProtocolLabel(protocol), + "status": self._getProtocolStatus(protocol), + "outputName": outputName, + "outputClassName": outputClassName, + "absolutePath": str(cachePath), + "cached": True, + "exists": True, + } - previewImage: Optional[Image.Image] = None + previewImage: Optional[Image.Image] = None - try: - previewImage = self._renderProtocolPreviewImage( - protocol=protocol, - output=output, - outputName=outputName, - outputClassName=outputClassName, - size=size, - ) - except Exception: - logger.debug( - "Output thumbnail render failed. protocolId=%s output=%s class=%s", - protocolId, - outputName, - outputClassName, - exc_info=True, - ) - - if previewImage is None: try: - previewImage = self._renderGenericPreview(protocol, output, size=size) + previewImage = self._renderProtocolPreviewImage( + protocol=protocol, + output=output, + outputName=outputName, + outputClassName=outputClassName, + size=size, + ) except Exception: logger.debug( - "Generic output thumbnail render failed. protocolId=%s output=%s", + "Output thumbnail render failed. protocolId=%s output=%s class=%s", protocolId, outputName, + outputClassName, exc_info=True, ) - if previewImage is None: + if previewImage is None: + try: + previewImage = self._renderGenericPreview(protocol, output, size=size) + except Exception: + logger.debug( + "Generic output thumbnail render failed. protocolId=%s output=%s", + protocolId, + outputName, + exc_info=True, + ) + + if previewImage is None: + return { + "protocolId": int(protocolId), + "protocolLabel": self._getProtocolLabel(protocol), + "status": self._getProtocolStatus(protocol), + "outputName": outputName, + "outputClassName": outputClassName, + "absolutePath": None, + "cached": False, + "exists": False, + } + + thumbnail = self._finalizeProtocolThumbnail( + previewImage=previewImage, + size=size, + protocolId=int(protocolId), + ) + self._saveImage(thumbnail, cachePath) + return { "protocolId": int(protocolId), "protocolLabel": self._getProtocolLabel(protocol), "status": self._getProtocolStatus(protocol), "outputName": outputName, "outputClassName": outputClassName, - "absolutePath": None, + "absolutePath": str(cachePath), "cached": False, - "exists": False, + "exists": True, } - thumbnail = self._finalizeProtocolThumbnail( - previewImage=previewImage, - size=size, - protocolId=int(protocolId), - ) - self._saveImage(thumbnail, cachePath) - - return { - "protocolId": int(protocolId), - "protocolLabel": self._getProtocolLabel(protocol), - "status": self._getProtocolStatus(protocol), - "outputName": outputName, - "outputClassName": outputClassName, - "absolutePath": str(cachePath), - "cached": False, - "exists": True, - } - # ------------------------------------------------------------------ # Candidate selection # ------------------------------------------------------------------ @@ -2168,7 +2220,20 @@ def _getProjectCachePath(self, size: int, maxProtocols: int) -> Path: def _saveImage(self, image: Image.Image, outputPath: Path): outputPath.parent.mkdir(parents=True, exist_ok=True) - image.save(str(outputPath), format="PNG") + + tmpPath = outputPath.with_name( + f".{outputPath.name}.{os.getpid()}.{id(image)}.tmp" + ) + + try: + image.save(str(tmpPath), format="PNG") + os.replace(str(tmpPath), str(outputPath)) + finally: + try: + if tmpPath.exists(): + tmpPath.unlink() + except Exception: + pass def _getProjectPath(self) -> Optional[str]: if self.currentProject is None: From 8183941c8535ececb83b646c05c0edee6dc8147f Mon Sep 17 00:00:00 2001 From: "Yunior C. Fonseca Reyna" <34970661+fonsecareyna82@users.noreply.github.com> Date: Wed, 22 Apr 2026 13:18:35 +0200 Subject: [PATCH 02/14] add protocol dependencies migration --- ...c4d8e21_add_protocol_dependencies_table.py | 89 +++++++++++++++++++ 1 file changed, 89 insertions(+) create mode 100644 alembic/versions/9b7b3c4d8e21_add_protocol_dependencies_table.py diff --git a/alembic/versions/9b7b3c4d8e21_add_protocol_dependencies_table.py b/alembic/versions/9b7b3c4d8e21_add_protocol_dependencies_table.py new file mode 100644 index 0000000..35fb21e --- /dev/null +++ b/alembic/versions/9b7b3c4d8e21_add_protocol_dependencies_table.py @@ -0,0 +1,89 @@ +"""add protocol dependencies table + +Revision ID: 9b7b3c4d8e21 +Revises: f3f7b0a3f1aa +Create Date: 2026-04-22 00:00:00.000000 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = '9b7b3c4d8e21' +down_revision: Union[str, None] = 'f3f7b0a3f1aa' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade(): + op.create_unique_constraint( + 'uq_protocols_projectId_id', + 'protocols', + ['projectId', 'id'], + ) + + op.create_table( + 'protocol_dependencies', + sa.Column('projectId', sa.Integer(), nullable=False), + sa.Column('parentProtocolDbId', sa.Integer(), nullable=False), + sa.Column('childProtocolDbId', sa.Integer(), nullable=False), + sa.Column('createdAt', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False), + sa.CheckConstraint('"parentProtocolDbId" <> "childProtocolDbId"', name='protocol_dependencies_no_self_loop'), + sa.ForeignKeyConstraint(['projectId'], ['projects.id'], name='protocol_dependencies_projectId_fkey', ondelete='CASCADE'), + sa.ForeignKeyConstraint( + ['projectId', 'parentProtocolDbId'], + ['protocols.projectId', 'protocols.id'], + name='protocol_dependencies_parentProtocolDbId_fkey', + ondelete='CASCADE', + ), + sa.ForeignKeyConstraint( + ['projectId', 'childProtocolDbId'], + ['protocols.projectId', 'protocols.id'], + name='protocol_dependencies_childProtocolDbId_fkey', + ondelete='CASCADE', + ), + sa.PrimaryKeyConstraint('projectId', 'parentProtocolDbId', 'childProtocolDbId', name='protocol_dependencies_pkey'), + ) + + op.create_index( + 'idx_protocol_dependencies_parent', + 'protocol_dependencies', + ['projectId', 'parentProtocolDbId'], + unique=False, + ) + op.create_index( + 'idx_protocol_dependencies_child', + 'protocol_dependencies', + ['projectId', 'childProtocolDbId'], + unique=False, + ) + + op.execute( + """ + INSERT INTO protocol_dependencies ( + "projectId", + "parentProtocolDbId", + "childProtocolDbId" + ) + SELECT DISTINCT + child."projectId", + parent.id, + child.id + FROM protocols child + CROSS JOIN LATERAL unnest(COALESCE(child."parentIds", ARRAY[]::integer[])) AS parent_protocol_id + JOIN protocols parent + ON parent."projectId" = child."projectId" + AND parent."protocolId" = parent_protocol_id::text + WHERE parent.id <> child.id + """ + ) + + +def downgrade(): + op.drop_index('idx_protocol_dependencies_child', table_name='protocol_dependencies') + op.drop_index('idx_protocol_dependencies_parent', table_name='protocol_dependencies') + op.drop_table('protocol_dependencies') + op.drop_constraint('uq_protocols_projectId_id', 'protocols', type_='unique') From 7fece9ba24eebaa2d01abeaac2b84549e3343a46 Mon Sep 17 00:00:00 2001 From: "Yunior C. Fonseca Reyna" Date: Wed, 22 Apr 2026 13:27:41 +0200 Subject: [PATCH 03/14] feat: add protocol dependencies table and backfill existing graph relations --- app/backend/api/services/project_service.py | 77 +++++++++++++ app/backend/mapper/postgresql.py | 119 ++++++++++++++++++++ 2 files changed, 196 insertions(+) diff --git a/app/backend/api/services/project_service.py b/app/backend/api/services/project_service.py index b7f75d9..926a71a 100644 --- a/app/backend/api/services/project_service.py +++ b/app/backend/api/services/project_service.py @@ -385,6 +385,69 @@ def _validateImportableScipionProject(self, sourcePath: Path) -> Dict[str, Any]: "status": statusValue or "active", } + def syncProjectProtocolsAndDependencies( + self, + mapper: PostgresqlFlatMapper, + projectId: int, + refresh: bool = False, + checkPid: bool = False, + ) -> Dict[str, int]: + if self.currentProject is None: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="No current project loaded", + ) + + runs = self.currentProject.getRunsGraph(refresh=refresh, checkPids=checkPid) + nodesDict = getattr(runs, "_nodesDict", {}) or {} + + protocolDbIdByScipionId: Dict[str, int] = {} + + # 1) Save all protocol nodes + for nodeId, nodeObj in nodesDict.items(): + if str(nodeId) == "PROJECT": + continue + + protocol = getattr(nodeObj, "run", None) + if protocol is None: + try: + protocol = self.currentProject.getProtocol(int(nodeId)) + except Exception: + protocol = None + + if protocol is None: + continue + + protocolContext = self._buildProtocolContext(projectId, protocol) + protocolDbId = mapper.saveProtocol(protocolContext) + protocolDbIdByScipionId[str(nodeId)] = int(protocolDbId) + + # 2) Build edges parent -> child using DB ids + edges: List[tuple[int, int]] = [] + + for nodeId, nodeObj in nodesDict.items(): + childDbId = protocolDbIdByScipionId.get(str(nodeId)) + if not childDbId: + continue + + for parent in getattr(nodeObj, "_parents", []) or []: + parentNodeId = str(parent.getName()) + if parentNodeId == "PROJECT": + continue + + parentDbId = protocolDbIdByScipionId.get(parentNodeId) + if not parentDbId: + continue + + edges.append((parentDbId, childDbId)) + + savedEdges = mapper.replaceProjectProtocolDependencies(projectId, edges) + + return { + "protocols": len(protocolDbIdByScipionId), + "dependencies": int(savedEdges), + } + def importProject( self, mapper: PostgresqlFlatMapper, @@ -533,6 +596,20 @@ def importProject( status=statusValue, ) + try: + self.loadProjectForThumbnails({"name": storedProjectPath}) + self.syncProjectProtocolsAndDependencies(mapper, dbProjectId) + except Exception as e: + logger.exception( + "Failed to sync imported project protocols. projectId=%s path=%s", + dbProjectId, + storedProjectPath, + ) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Project was imported but protocols could not be synced to the database: {e}", + ) + sizePath = sourcePath if not copyProject else targetPath try: diff --git a/app/backend/mapper/postgresql.py b/app/backend/mapper/postgresql.py index b445e52..7caa132 100644 --- a/app/backend/mapper/postgresql.py +++ b/app/backend/mapper/postgresql.py @@ -133,6 +133,36 @@ def initTables(self): """ ) + # CreateProtocolDependenciesTable + self.db.execute( + """ + CREATE TABLE IF NOT EXISTS protocol_dependencies ( + "projectId" INTEGER NOT NULL REFERENCES projects(id) ON DELETE CASCADE, + "parentProtocolDbId" INTEGER NOT NULL, + "childProtocolDbId" INTEGER NOT NULL, + "createdAt" TIMESTAMPTZ NOT NULL DEFAULT NOW(), + + PRIMARY KEY ("projectId", "parentProtocolDbId", "childProtocolDbId"), + + FOREIGN KEY ("projectId", "parentProtocolDbId") + REFERENCES protocols("projectId", id) + ON DELETE CASCADE, + + FOREIGN KEY ("projectId", "childProtocolDbId") + REFERENCES protocols("projectId", id) + ON DELETE CASCADE, + + CHECK ("parentProtocolDbId" <> "childProtocolDbId") + ); + + CREATE INDEX IF NOT EXISTS protocol_dependencies_by_parent + ON protocol_dependencies("projectId", "parentProtocolDbId"); + + CREATE INDEX IF NOT EXISTS protocol_dependencies_by_child + ON protocol_dependencies("projectId", "childProtocolDbId"); + """ + ) + # CreateProjectSharesTable (requires users and projects) self.db.execute( """ @@ -891,6 +921,95 @@ def saveProtocol(self, protocol: Dict[str, Any]) -> int: ) return cur.fetchone()["id"] + def getProjectProtocolDbIdMap(self, projectId: int) -> Dict[str, int]: + rows = self.db.fetchAll( + """ + SELECT id, "protocolId" + FROM protocols + WHERE "projectId" = %s + """, + (projectId,), + ) + + return { + str(row["protocolId"]): int(row["id"]) + for row in rows + if row.get("protocolId") is not None and row.get("id") is not None + } + + def replaceProjectProtocolDependencies( + self, + projectId: int, + edges: List[tuple[int, int]], + ) -> int: + self.db.execute( + """ + DELETE FROM protocol_dependencies + WHERE "projectId" = %s + """, + (projectId,), + ) + + cleanEdges: List[tuple[int, int]] = [] + seen = set() + + for parentDbId, childDbId in edges or []: + try: + parentDbId = int(parentDbId) + childDbId = int(childDbId) + except Exception: + continue + + if parentDbId <= 0 or childDbId <= 0: + continue + if parentDbId == childDbId: + continue + + key = (parentDbId, childDbId) + if key in seen: + continue + + seen.add(key) + cleanEdges.append(key) + + if not cleanEdges: + return 0 + + valuesSql = ",".join(["(%s, %s, %s)"] * len(cleanEdges)) + params: List[Any] = [] + + for parentDbId, childDbId in cleanEdges: + params.extend([projectId, parentDbId, childDbId]) + + self.db.execute( + f""" + INSERT INTO protocol_dependencies ( + "projectId", + "parentProtocolDbId", + "childProtocolDbId" + ) + VALUES {valuesSql} + """, + tuple(params), + ) + + return len(cleanEdges) + + def listProjectProtocolDependencies(self, projectId: int) -> List[Dict[str, Any]]: + return self.db.fetchAll( + """ + SELECT + "projectId", + "parentProtocolDbId", + "childProtocolDbId", + "createdAt" + FROM protocol_dependencies + WHERE "projectId" = %s + ORDER BY "parentProtocolDbId", "childProtocolDbId" + """, + (projectId,), + ) + def getProtocolByProtocolId(self, protocolId: int, projectId: int) -> Optional[Dict]: """Retrieve a protocol by id.""" return self.db.fetchOne( From b2ad86e3ea544027eb643e4bfd8a908325a1a924 Mon Sep 17 00:00:00 2001 From: "Yunior C. Fonseca Reyna" Date: Wed, 22 Apr 2026 13:59:16 +0200 Subject: [PATCH 04/14] feat: add protocol_dependencies and sync workflow graph to postgres --- app/backend/api/services/project_service.py | 75 +++++++++++++++------ app/backend/mapper/postgresql.py | 37 +++++----- 2 files changed, 73 insertions(+), 39 deletions(-) diff --git a/app/backend/api/services/project_service.py b/app/backend/api/services/project_service.py index 926a71a..a321e79 100644 --- a/app/backend/api/services/project_service.py +++ b/app/backend/api/services/project_service.py @@ -1170,11 +1170,11 @@ def listProjectWorkflows(self): return tempList.sortListByPluginName().templates def applyWorkflowToProject( - self, - mapper: PostgresqlFlatMapper, - projectId: int, - workflowId: Union[int, str], - currentUser: dict, + self, + mapper: PostgresqlFlatMapper, + projectId: int, + workflowId: Union[int, str], + currentUser: dict, ) -> dict: """ Apply a predefined workflow template to an existing project. @@ -1205,7 +1205,7 @@ def applyWorkflowToProject( templateName = getattr(t, "name", None) if (templateId is not None and str(templateId) == workflowIdStr) or ( - templateName and str(templateName) == workflowIdStr + templateName and str(templateName) == workflowIdStr ): selectedTemplate = t break @@ -1245,16 +1245,24 @@ def applyWorkflowToProject( detail=f"Failed to apply workflow '{workflowIdStr}' to project {projectId}: {e}", ) - # 7) Optionally compute how many protocols are present after applying - protocolsCount = None + # 7) Sync protocols + dependencies to PostgreSQL try: - if hasattr(self.currentProject, "getProtocols"): - protocols = self.currentProject.getProtocols() - if protocols is not None: - protocolsCount = len(protocols) - except Exception: - # Ignore errors when computing protocol count - protocolsCount = None + syncInfo = self.syncProjectProtocolsAndDependencies( + mapper, + projectId, + refresh=True, + checkPid=True, + ) + except Exception as e: + logger.exception( + "Failed to sync workflow-applied project graph. projectId=%s workflowId=%s", + projectId, + workflowIdStr, + ) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Workflow was applied but graph sync to PostgreSQL failed: {e}", + ) # 8) Return a compact, useful payload for the frontend return { @@ -1263,7 +1271,8 @@ def applyWorkflowToProject( "workflowId": workflowIdStr, "workflowName": getattr(selectedTemplate, "name", workflowIdStr), "workflowFile": str(workflowFile), - "protocolsCount": protocolsCount, + "protocolsCount": syncInfo.get("protocols"), + "dependenciesCount": syncInfo.get("dependencies"), "loadResult": str(loadResult) if loadResult is not None else None, } @@ -2658,13 +2667,22 @@ def duplicateProtocol(self, mapper, projectId, protocols: Any): protList = [] for protocol in protocols: protList.append(self.currentProject.getProtocol(int(protocol.id))) - resultProtList = self.currentProject.copyProtocol(protList) - for prot in resultProtList: - protocolContex = self._buildProtocolContext(projectId, prot) - mapper.saveProtocol(protocolContex) + self.currentProject.copyProtocol(protList) - return {"status": "ok", "message": "Protocol was duplicated successfully"} + syncInfo = self.syncProjectProtocolsAndDependencies( + mapper, + projectId, + refresh=True, + checkPid=True, + ) + + return { + "status": "ok", + "message": "Protocol was duplicated successfully", + "protocolsCount": syncInfo.get("protocols"), + "dependenciesCount": syncInfo.get("dependencies"), + } except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @@ -2673,8 +2691,23 @@ def deleteProtocol(self, mapper, projectId, protocols: Any): protList = [] for protocol in protocols: protList.append(self.currentProject.getProtocol(int(protocol))) + self.currentProject.deleteProtocol(*protList) mapper.deleteProtocol(projectId, protList) + + syncInfo = self.syncProjectProtocolsAndDependencies( + mapper, + projectId, + refresh=True, + checkPid=True, + ) + + return { + "status": "ok", + "message": "Protocol deleted successfully", + "protocolsCount": syncInfo.get("protocols"), + "dependenciesCount": syncInfo.get("dependencies"), + } except Exception as e: raise HTTPException(status_code=500, detail=str(e)) diff --git a/app/backend/mapper/postgresql.py b/app/backend/mapper/postgresql.py index 7caa132..ad5939d 100644 --- a/app/backend/mapper/postgresql.py +++ b/app/backend/mapper/postgresql.py @@ -6,7 +6,7 @@ import psycopg2 import psycopg2.extras from contextlib import contextmanager -from typing import Optional, List, Dict, Any, Iterator +from typing import Optional, List, Dict, Any, Iterator, Tuple from pyworkflow.mapper.mapper import Mapper # Base class from Scipion @@ -110,26 +110,27 @@ def initTables(self): """ ) - # CreateProtocolsTableLegacy (kept as-is for now) + # CreateProtocolsTableLegacy self.db.execute( """ CREATE TABLE IF NOT EXISTS protocols ( id SERIAL PRIMARY KEY, - CREATE TABLE IF NOT EXISTS protocols ( - id SERIAL PRIMARY KEY, - "projectId" INTEGER NOT NULL REFERENCES projects(id) ON DELETE CASCADE, - "protocolId" TEXT NOT NULL, - "protocolClassName" TEXT NOT NULL, - status TEXT NOT NULL DEFAULT 'pending', - params JSONB, - "parentIds" JSONB NOT NULL DEFAULT '[]'::jsonb, - "childIds" JSONB NOT NULL DEFAULT '[]'::jsonb, - "createdAt" TIMESTAMPTZ NOT NULL DEFAULT NOW(), - "updatedAt" TIMESTAMPTZ + "projectId" INTEGER NOT NULL REFERENCES projects(id) ON DELETE CASCADE, + "protocolId" TEXT NOT NULL, + "protocolClassName" TEXT NOT NULL, + status TEXT NOT NULL DEFAULT 'pending', + params JSONB, + "parentIds" INTEGER[] NOT NULL DEFAULT ARRAY[]::INTEGER[], + "childIds" INTEGER[] NOT NULL DEFAULT ARRAY[]::INTEGER[], + "createdAt" TIMESTAMPTZ NOT NULL DEFAULT NOW(), + "updatedAt" TIMESTAMPTZ ); - + CREATE UNIQUE INDEX IF NOT EXISTS protocols_project_protocol_ux ON protocols("projectId", "protocolId"); + + CREATE UNIQUE INDEX IF NOT EXISTS protocols_project_dbid_ux + ON protocols("projectId", id); """ ) @@ -938,9 +939,9 @@ def getProjectProtocolDbIdMap(self, projectId: int) -> Dict[str, int]: } def replaceProjectProtocolDependencies( - self, - projectId: int, - edges: List[tuple[int, int]], + self, + projectId: int, + edges: List[Tuple[int, int]], ) -> int: self.db.execute( """ @@ -950,7 +951,7 @@ def replaceProjectProtocolDependencies( (projectId,), ) - cleanEdges: List[tuple[int, int]] = [] + cleanEdges: List[Tuple[int, int]] = [] seen = set() for parentDbId, childDbId in edges or []: From 07ca9ebaf6141029146ba1d3dfcbf419ea41a937 Mon Sep 17 00:00:00 2001 From: "Yunior C. Fonseca Reyna" Date: Wed, 22 Apr 2026 14:25:57 +0200 Subject: [PATCH 05/14] feat: read protocol graph dependencies from postgres --- app/backend/api/services/project_service.py | 38 +++++++++++++++++-- app/backend/mapper/postgresql.py | 42 +++++++++++++++++++++ 2 files changed, 76 insertions(+), 4 deletions(-) diff --git a/app/backend/api/services/project_service.py b/app/backend/api/services/project_service.py index a321e79..8782d3d 100644 --- a/app/backend/api/services/project_service.py +++ b/app/backend/api/services/project_service.py @@ -1048,14 +1048,27 @@ def _buildProjectThumbnailVersion( return f"{projectId}:{updatedText}:{protocolsCount}:{runsMtime}" - def buildProtocolsGraph(self, projectId: int, runs, tags) -> dict: + def buildProtocolsGraph( + self, + projectId, + runs, + tags, + dependencyMap: Optional[Dict[str, Dict[str, List[str]]]] = None, + ) -> dict: """Assemble dependency graph of protocols and their status.""" nodesDict = runs._nodesDict graphData = {} + usePostgresDependencies = dependencyMap is not None for nodeId, nodeObj in nodesDict.items(): - childrenIds = [child.getName() for child in nodeObj._children] - parentIds = [parent.getName() for parent in nodeObj._parents] + if nodeId != 'PROJECT' and usePostgresDependencies: + nodeDeps = dependencyMap.get(str(nodeId), {"parents": [], "children": []}) + childrenIds = list(nodeDeps.get("children") or []) + parentIds = list(nodeDeps.get("parents") or []) + else: + childrenIds = [child.getName() for child in nodeObj._children] + parentIds = [parent.getName() for parent in nodeObj._parents] + status = nodeObj.run.getStatus() if nodeObj.run else '' inputs = [] outputs = [] @@ -1135,7 +1148,24 @@ def loadProject(self, dbProj: dict, mapper: PostgresqlFlatMapper = None, refresh self.currentProject.load(dbPath=self.currentProject.getDbPath()) runs = self.currentProject.getRunsGraph(refresh=refresh, checkPids=checkPid) tags = mapper.getProjectProtocolTagIdsByProtocolId(dbProj['id']) - graphData = self.buildProtocolsGraph(dbProj['id'], runs, tags) + + dependencyMap = None + if mapper is not None: + try: + dependencyMap = mapper.getProjectProtocolAdjacencyMap(dbProj['id']) + except Exception: + logger.exception( + "Failed to load protocol dependencies from PostgreSQL for project %s", + dbProj['id'], + ) + dependencyMap = None + + graphData = self.buildProtocolsGraph( + dbProj['id'], + runs, + tags, + dependencyMap=dependencyMap, + ) stats = projPath.stat() updatedAt = datetime.fromtimestamp(stats.st_mtime) diff --git a/app/backend/mapper/postgresql.py b/app/backend/mapper/postgresql.py index ad5939d..56924c1 100644 --- a/app/backend/mapper/postgresql.py +++ b/app/backend/mapper/postgresql.py @@ -1011,6 +1011,48 @@ def listProjectProtocolDependencies(self, projectId: int) -> List[Dict[str, Any] (projectId,), ) + def getProjectProtocolAdjacencyMap(self, projectId: int) -> Dict[str, Dict[str, List[str]]]: + rows = self.db.fetchAll( + """ + SELECT + parent."protocolId" AS "parentProtocolId", + child."protocolId" AS "childProtocolId" + FROM protocol_dependencies d + JOIN protocols parent + ON parent.id = d."parentProtocolDbId" + AND parent."projectId" = d."projectId" + JOIN protocols child + ON child.id = d."childProtocolDbId" + AND child."projectId" = d."projectId" + WHERE d."projectId" = %s + ORDER BY child.id, parent.id + """, + (projectId,), + ) + + adjacency: Dict[str, Dict[str, List[str]]] = {} + + for row in rows: + parentProtocolId = row.get("parentProtocolId") + childProtocolId = row.get("childProtocolId") + + if parentProtocolId is None or childProtocolId is None: + continue + + parentProtocolId = str(parentProtocolId) + childProtocolId = str(childProtocolId) + + adjacency.setdefault(parentProtocolId, {"parents": [], "children": []}) + adjacency.setdefault(childProtocolId, {"parents": [], "children": []}) + + if childProtocolId not in adjacency[parentProtocolId]["children"]: + adjacency[parentProtocolId]["children"].append(childProtocolId) + + if parentProtocolId not in adjacency[childProtocolId]["parents"]: + adjacency[childProtocolId]["parents"].append(parentProtocolId) + + return adjacency + def getProtocolByProtocolId(self, protocolId: int, projectId: int) -> Optional[Dict]: """Retrieve a protocol by id.""" return self.db.fetchOne( From acf08b38ed4b209482738908eb3cd0aa501607a1 Mon Sep 17 00:00:00 2001 From: "Yunior C. Fonseca Reyna" Date: Wed, 22 Apr 2026 16:23:52 +0200 Subject: [PATCH 06/14] feat: persist protocol graph in postgres and stabilize project thumbnails --- app/backend/api/routers/project_router.py | 38 ++- app/backend/api/services/project_service.py | 286 +++++++++++++++----- app/backend/utils/thumbnail_service.py | 57 +++- 3 files changed, 290 insertions(+), 91 deletions(-) diff --git a/app/backend/api/routers/project_router.py b/app/backend/api/routers/project_router.py index dc4957d..428757f 100644 --- a/app/backend/api/routers/project_router.py +++ b/app/backend/api/routers/project_router.py @@ -2945,12 +2945,19 @@ def getProtocolThumbnail( service.loadProjectForThumbnails(dbProj) - result = service.buildProtocolThumbnail( - protocolId=protocolId, - force=False, - size=size, - outputName=outputName, - ) + if outputName: + result = service.buildProtocolOutputThumbnail( + protocolId=protocolId, + outputName=outputName, + force=False, + size=size, + ) + else: + result = service.buildProtocolThumbnail( + protocolId=protocolId, + force=False, + size=size, + ) thumbPath = result.get("absolutePath") if not thumbPath: @@ -2995,12 +3002,19 @@ def rebuildProtocolThumbnail( service.loadProjectForThumbnails(dbProj) - result = service.buildProtocolThumbnail( - protocolId=protocolId, - force=True, - size=size, - outputName=outputName, - ) + if outputName: + result = service.buildProtocolOutputThumbnail( + protocolId=protocolId, + outputName=outputName, + force=True, + size=size, + ) + else: + result = service.buildProtocolThumbnail( + protocolId=protocolId, + force=True, + size=size, + ) response = JSONResponse( { diff --git a/app/backend/api/services/project_service.py b/app/backend/api/services/project_service.py index 8782d3d..94ba92e 100644 --- a/app/backend/api/services/project_service.py +++ b/app/backend/api/services/project_service.py @@ -1049,83 +1049,208 @@ def _buildProjectThumbnailVersion( return f"{projectId}:{updatedText}:{protocolsCount}:{runsMtime}" def buildProtocolsGraph( - self, - projectId, - runs, - tags, - dependencyMap: Optional[Dict[str, Dict[str, List[str]]]] = None, + self, + projectId: int, + protocolRows: List[Dict[str, Any]], + tags: Dict[str, List[str]], + dependencyMap: Optional[Dict[str, Dict[str, List[str]]]] = None, ) -> dict: - """Assemble dependency graph of protocols and their status.""" - nodesDict = runs._nodesDict - graphData = {} - usePostgresDependencies = dependencyMap is not None + """Assemble protocol graph using PostgreSQL as source of truth for nodes + edges.""" + graphData: Dict[str, Any] = {} + adjacency = dependencyMap or {} - for nodeId, nodeObj in nodesDict.items(): - if nodeId != 'PROJECT' and usePostgresDependencies: - nodeDeps = dependencyMap.get(str(nodeId), {"parents": [], "children": []}) - childrenIds = list(nodeDeps.get("children") or []) - parentIds = list(nodeDeps.get("parents") or []) - else: - childrenIds = [child.getName() for child in nodeObj._children] - parentIds = [parent.getName() for parent in nodeObj._parents] + def sortKey(row: Dict[str, Any]): + raw = str(row.get("protocolId") or "") + try: + return (0, int(raw)) + except Exception: + return (1, raw) + + orderedRows = sorted(protocolRows or [], key=sortKey) + + protocolIds: List[str] = [] + for row in orderedRows: + rawId = row.get("protocolId") + if rawId is None: + continue + protocolIds.append(str(rawId)) + + # Root node synthesized from DB graph: + # protocols without parents hang directly from PROJECT + rootChildren = [ + pid for pid in protocolIds + if not (adjacency.get(pid, {}).get("parents") or []) + ] + + projectLabel = "PROJECT" + try: + if self.currentProject is not None: + projectLabel = os.path.basename(self.currentProject.getPath()) or "PROJECT" + except Exception: + projectLabel = "PROJECT" + + graphData["PROJECT"] = { + "protocolId": "PROJECT", + "children": rootChildren, + "parents": [], + "label": projectLabel, + "status": "", + "parameter": [], + "inputs": [], + "outputs": [], + "cpuTime": "", + "elapsedTime": "", + "isInteractive": False, + "numberOfSteps": 0, + "stepsDone": 0, + "tags": [], + "thumbnailUrl": None, + "thumbnailRebuildUrl": None, + } + + for row in orderedRows: + rawNodeId = row.get("protocolId") + if rawNodeId is None: + continue + + nodeId = str(rawNodeId) + nodeDeps = adjacency.get(nodeId, {"parents": [], "children": []}) + childrenIds = list(nodeDeps.get("children") or []) + parentIds = list(nodeDeps.get("parents") or []) + + statusValue = row.get("status") + status = str(statusValue) if statusValue is not None else "" + + protocolClassName = str(row.get("protocolClassName") or "") + label = protocolClassName or nodeId - status = nodeObj.run.getStatus() if nodeObj.run else '' inputs = [] outputs = [] - cpuTime = '' - elapsedTime = '' + cpuTime = "" + elapsedTime = "" isinteractive = False numberOfSteps = 0 stepsDone = 0 thumbnailUrl = None thumbnailRebuildUrl = None - if nodeId != 'PROJECT': + protocol = None + try: protocol = self.currentProject.getProtocol(int(nodeId)) - cpuTime = str(protocol.cpuTime) - elapsedTime = str(protocol.getElapsedTime().total_seconds()).split('.')[0] - isinteractive = protocol.isInteractive() - numberOfSteps = protocol.numberOfSteps - stepsDone = protocol.stepsDone - self.currentProject._fixProtParamsConfiguration(protocol) - - thumbnailUrl = self.buildProtocolThumbnailUrl(projectId, int(nodeId)) - thumbnailRebuildUrl = self.buildProtocolThumbnailRebuildUrl(projectId, int(nodeId)) - - for key, attr in protocol.iterInputAttributes(): - input = {} - try: - input['name'] = key - input['paramClass'] = 'PointerParam' - input['pointerClass'] = attr.get().getClassName() if attr and attr.get() else "" - input['info'] = str(attr.get()) - except Exception: - input['pointerClass'] = "" - input['info'] = "" - parentId = attr.getObjValue().getObjId() - input['value'] = "%s.%s" % (str(parentId), attr.getExtended()) - input['parentId'] = parentId - inputs.append(input) - - for key, attr in protocol.iterOutputAttributes(): - output = {} - output['name'] = key - output['paramClass'] = 'PointerParam' - output['pointerClass'] = attr.__class__.__name__ - try: - output['info'] = attr.__str__() - except Exception: - output['info'] = "" - parentId = protocol.getObjId() - output['value'] = "%s.%s" % (str(parentId), key) - output['parentId'] = parentId - outputs.append(output) + except Exception: + protocol = None + + if protocol is not None: + try: + label = str(protocol) or label + except Exception: + pass + + try: + protStatus = protocol.getStatus() + if protStatus: + status = str(protStatus) + except Exception: + pass + + try: + cpuTime = str(protocol.cpuTime) + except Exception: + cpuTime = "" + + try: + elapsedTime = str(protocol.getElapsedTime().total_seconds()).split(".")[0] + except Exception: + elapsedTime = "" + + try: + isinteractive = bool(protocol.isInteractive()) + except Exception: + isinteractive = False + + try: + numberOfSteps = protocol.numberOfSteps + except Exception: + numberOfSteps = 0 + + try: + stepsDone = protocol.stepsDone + except Exception: + stepsDone = 0 + + try: + self.currentProject._fixProtParamsConfiguration(protocol) + except Exception: + pass + + try: + protocolIdInt = int(nodeId) + thumbnailUrl = self.buildProtocolThumbnailUrl(projectId, protocolIdInt) + thumbnailRebuildUrl = self.buildProtocolThumbnailRebuildUrl(projectId, protocolIdInt) + except Exception: + thumbnailUrl = None + thumbnailRebuildUrl = None + + try: + for key, attr in protocol.iterInputAttributes(): + inputItem = {} + try: + inputItem["name"] = key + inputItem["paramClass"] = "PointerParam" + inputItem["pointerClass"] = attr.get().getClassName() if attr and attr.get() else "" + inputItem["info"] = str(attr.get()) + except Exception: + inputItem["pointerClass"] = "" + inputItem["info"] = "" + + try: + parentId = attr.getObjValue().getObjId() + inputItem["value"] = "%s.%s" % (str(parentId), attr.getExtended()) + inputItem["parentId"] = parentId + except Exception: + inputItem["value"] = "" + inputItem["parentId"] = None + + inputs.append(inputItem) + except Exception: + inputs = [] + + try: + for key, attr in protocol.iterOutputAttributes(): + outputItem = {} + outputItem["name"] = key + outputItem["paramClass"] = "PointerParam" + outputItem["pointerClass"] = attr.__class__.__name__ + try: + outputItem["info"] = attr.__str__() + except Exception: + outputItem["info"] = "" + + try: + parentId = protocol.getObjId() + outputItem["value"] = "%s.%s" % (str(parentId), key) + outputItem["parentId"] = parentId + except Exception: + outputItem["value"] = "" + outputItem["parentId"] = None + + outputs.append(outputItem) + except Exception: + outputs = [] + else: + try: + protocolIdInt = int(nodeId) + thumbnailUrl = self.buildProtocolThumbnailUrl(projectId, protocolIdInt) + thumbnailRebuildUrl = self.buildProtocolThumbnailRebuildUrl(projectId, protocolIdInt) + except Exception: + thumbnailUrl = None + thumbnailRebuildUrl = None graphData[nodeId] = { "protocolId": nodeId, "children": childrenIds, "parents": parentIds, - "label": nodeObj.getLabel(), + "label": label, "status": status, "parameter": [], "inputs": inputs, @@ -1135,7 +1260,7 @@ def buildProtocolsGraph( "isInteractive": isinteractive, "numberOfSteps": numberOfSteps, "stepsDone": stepsDone, - "tags": tags[nodeId] if nodeId in tags else [], + "tags": tags.get(nodeId, []), "thumbnailUrl": thumbnailUrl, "thumbnailRebuildUrl": thumbnailRebuildUrl, } @@ -1146,11 +1271,31 @@ def loadProject(self, dbProj: dict, mapper: PostgresqlFlatMapper = None, refresh projPath = Path(dbProj['name']) self.currentProject = ScipionProject(pyworkflow.Config.getDomain(), str(projPath)) self.currentProject.load(dbPath=self.currentProject.getDbPath()) - runs = self.currentProject.getRunsGraph(refresh=refresh, checkPids=checkPid) - tags = mapper.getProjectProtocolTagIdsByProtocolId(dbProj['id']) - dependencyMap = None + # Keep Scipion refreshed because we still use protocol objects + # to enrich the nodes (inputs, outputs, timings, thumbnails, etc.) + try: + self.currentProject.getRunsGraph(refresh=refresh, checkPids=checkPid) + except Exception: + logger.exception( + "Failed to refresh Scipion runs graph for project %s", + dbProj['id'], + ) + + tags = {} + dependencyMap = {} + protocolRows: List[Dict[str, Any]] = [] + if mapper is not None: + try: + tags = mapper.getProjectProtocolTagIdsByProtocolId(dbProj['id']) + except Exception: + logger.exception( + "Failed to load protocol tags from PostgreSQL for project %s", + dbProj['id'], + ) + tags = {} + try: dependencyMap = mapper.getProjectProtocolAdjacencyMap(dbProj['id']) except Exception: @@ -1158,11 +1303,20 @@ def loadProject(self, dbProj: dict, mapper: PostgresqlFlatMapper = None, refresh "Failed to load protocol dependencies from PostgreSQL for project %s", dbProj['id'], ) - dependencyMap = None + dependencyMap = {} + + try: + protocolRows = mapper.getProtocols(dbProj['id']) + except Exception: + logger.exception( + "Failed to load protocol rows from PostgreSQL for project %s", + dbProj['id'], + ) + protocolRows = [] graphData = self.buildProtocolsGraph( dbProj['id'], - runs, + protocolRows, tags, dependencyMap=dependencyMap, ) diff --git a/app/backend/utils/thumbnail_service.py b/app/backend/utils/thumbnail_service.py index 6920f6b..7d0d3bc 100644 --- a/app/backend/utils/thumbnail_service.py +++ b/app/backend/utils/thumbnail_service.py @@ -158,9 +158,22 @@ def buildProtocolThumbnail( size: int = 360, outputName: Optional[str] = None, ) -> Dict[str, Any]: - protocol = self.currentProject.getProtocol(int(protocolId)) + try: + protocol = self.currentProject.getProtocol(int(protocolId)) + except Exception: + protocol = None + if protocol is None: - raise ValueError(f"Protocol {protocolId} not found") + return { + "protocolId": int(protocolId), + "protocolLabel": f"Protocol {int(protocolId)}", + "status": "unknown", + "outputName": outputName, + "outputClassName": None, + "absolutePath": None, + "cached": False, + "exists": False, + } cachePath = self._getProtocolCachePath( protocolId, @@ -431,12 +444,34 @@ def buildProtocolOutputThumbnail( force: bool = False, size: int = 320, ) -> Dict[str, Any]: - protocol = self.currentProject.getProtocol(int(protocolId)) + try: + protocol = self.currentProject.getProtocol(int(protocolId)) + except Exception: + protocol = None + if protocol is None: - raise ValueError(f"Protocol {protocolId} not found") + return { + "protocolId": int(protocolId), + "protocolLabel": f"Protocol {int(protocolId)}", + "status": "unknown", + "outputName": outputName, + "outputClassName": None, + "absolutePath": None, + "cached": False, + "exists": False, + } if not hasattr(protocol, outputName): - raise ValueError(f"Output '{outputName}' not found in protocol {protocolId}") + return { + "protocolId": int(protocolId), + "protocolLabel": self._getProtocolLabel(protocol), + "status": self._getProtocolStatus(protocol), + "outputName": outputName, + "outputClassName": None, + "absolutePath": None, + "cached": False, + "exists": False, + } output = getattr(protocol, outputName, None) if output is None: @@ -744,11 +779,11 @@ def listProtocolThumbnailItems( continue try: - built = self.buildProtocolThumbnail( + built = self.buildProtocolOutputThumbnail( protocolId=protocolId, + outputName=outputName, force=force, size=size, - outputName=outputName, ) except Exception: logger.debug( @@ -768,13 +803,9 @@ def listProtocolThumbnailItems( "outputClassName": outputClassName, "exists": True, "thumbnailUrl": ( - f"/projects/{int(projectId)}/protocols/{protocolId}/thumbnail" - f"?outputName={quote(str(outputName))}" - ), - "thumbnailRebuildUrl": ( - f"/projects/{int(projectId)}/protocols/{protocolId}/thumbnail/rebuild" - f"?outputName={quote(str(outputName))}" + f"/projects/{int(projectId)}/protocols/{protocolId}/outputs/{quote(str(outputName))}/thumbnail" ), + "thumbnailRebuildUrl": None, } ) From 56676a28c560182c96e737f2c871151d49f51e79 Mon Sep 17 00:00:00 2001 From: "Yunior C. Fonseca Reyna" Date: Wed, 22 Apr 2026 16:59:42 +0200 Subject: [PATCH 07/14] feat: persist and serve protocol graph from postgres --- app/backend/api/services/project_service.py | 120 ++++++++++++++------ 1 file changed, 86 insertions(+), 34 deletions(-) diff --git a/app/backend/api/services/project_service.py b/app/backend/api/services/project_service.py index 94ba92e..4f3e617 100644 --- a/app/backend/api/services/project_service.py +++ b/app/backend/api/services/project_service.py @@ -1054,10 +1054,12 @@ def buildProtocolsGraph( protocolRows: List[Dict[str, Any]], tags: Dict[str, List[str]], dependencyMap: Optional[Dict[str, Dict[str, List[str]]]] = None, + runMap: Optional[Dict[str, Any]] = None, ) -> dict: """Assemble protocol graph using PostgreSQL as source of truth for nodes + edges.""" graphData: Dict[str, Any] = {} adjacency = dependencyMap or {} + liveRuns = runMap or {} def sortKey(row: Dict[str, Any]): raw = str(row.get("protocolId") or "") @@ -1134,11 +1136,14 @@ def sortKey(row: Dict[str, Any]): thumbnailUrl = None thumbnailRebuildUrl = None - protocol = None - try: - protocol = self.currentProject.getProtocol(int(nodeId)) - except Exception: - protocol = None + # Prefer the live protocol object coming from runs graph + protocol = liveRuns.get(nodeId) + + if protocol is None: + try: + protocol = self.currentProject.getProtocol(int(nodeId)) + except Exception: + protocol = None if protocol is not None: try: @@ -1272,15 +1277,23 @@ def loadProject(self, dbProj: dict, mapper: PostgresqlFlatMapper = None, refresh self.currentProject = ScipionProject(pyworkflow.Config.getDomain(), str(projPath)) self.currentProject.load(dbPath=self.currentProject.getDbPath()) - # Keep Scipion refreshed because we still use protocol objects - # to enrich the nodes (inputs, outputs, timings, thumbnails, etc.) + # Refresh Scipion graph and keep a live map of protocol objects + runs = None + runMap: Dict[str, Any] = {} + try: - self.currentProject.getRunsGraph(refresh=refresh, checkPids=checkPid) + runs = self.currentProject.getRunsGraph(refresh=refresh, checkPids=checkPid) + nodesDict = getattr(runs, "_nodesDict", {}) or {} + for nodeId, nodeObj in nodesDict.items(): + if str(nodeId) == "PROJECT": + continue + runMap[str(nodeId)] = getattr(nodeObj, "run", None) except Exception: logger.exception( "Failed to refresh Scipion runs graph for project %s", dbProj['id'], ) + runMap = {} tags = {} dependencyMap = {} @@ -1319,6 +1332,7 @@ def loadProject(self, dbProj: dict, mapper: PostgresqlFlatMapper = None, refresh protocolRows, tags, dependencyMap=dependencyMap, + runMap=runMap, ) stats = projPath.stat() @@ -2147,6 +2161,7 @@ def setPointerParam(self, protocol, key, value, parentId): def saveProtocol(self, mapper, projectId, protocolId, protocolClassName, params, setToSave=True): errorList = [] + if not protocolId: # new protocol protClass = self.currentProject.getDomain().getProtocols().get(protocolClassName) protocol = self.currentProject.newProtocol(protClass) @@ -2182,6 +2197,7 @@ def saveProtocol(self, mapper, projectId, protocolId, protocolClassName, params, if errors: errorListAux = ['**' + param.label.get() + '** ' + error for error in errors] errorList += errorListAux + param.set(castedValue) protocol.setAttributeValue(key, castedValue) @@ -2191,31 +2207,40 @@ def saveProtocol(self, mapper, projectId, protocolId, protocolClassName, params, logger.info(f"[INFO] Set param {key} = {castedValue}") except Exception as e: import re - cleaned = re.sub(r'[^A-Za-z0-9\s+\-*/=<>\!&|^%()\[\]{}_,.;:]', '', str(e)) + cleaned = re.sub(r'[^A-Za-z0-9\s+\-*/=<>!&|^%()\[\]{}_,.;:]', '', str(e)) errorList += ['**' + param.label.get() + '** ' + cleaned] # Apply pointer parameters errorList += self.applyParamsToProtocol(protocol, params) - # if setToSave: - # protocol.setSaved() - + # Persist protocol in Scipion if protocol.hasObjId(): self.currentProject._storeProtocol(protocol) else: self.currentProject._setupProtocol(protocol) - dbProtocol = mapper.getProtocolByProtocolId(protocolId=protocol.getObjId(), projectId=projectId) - if not dbProtocol: - protocolContex = self._buildProtocolContext(projectId, protocol) - mapper.saveProtocol(protocolContex) - pass - else: - # Update parameters and status if exists - pass - # Save dependencies - # graphData = self.currentProject.getRunsGraph(refresh=True, checkPids=True) - # self.saveProtocolDependencies(mapper, graphData._nodesDict) + # Persist protocol in PostgreSQL and resync graph + try: + protocolContext = self._buildProtocolContext(projectId, protocol) + mapper.saveProtocol(protocolContext) + + self.syncProjectProtocolsAndDependencies( + mapper, + projectId, + refresh=True, + checkPid=True, + ) + except Exception as e: + logger.exception( + "Failed to sync protocol graph after save. projectId=%s protocolId=%s protocolClassName=%s", + projectId, + getattr(protocol, "getObjId", lambda: protocolId)(), + protocolClassName, + ) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Protocol was saved in Scipion but graph sync to PostgreSQL failed: {e}", + ) return protocol, errorList @@ -2234,6 +2259,13 @@ def launchProtocol(self, mapper, projectId, protocolId, protocolClassName, param if executeMode == "stop": try: self.stopProtocol([protocolId]) + + self.syncProjectProtocolsAndDependencies( + mapper, + projectId, + refresh=True, + checkPid=True, + ) return except HTTPException: raise @@ -2255,6 +2287,7 @@ def launchProtocol(self, mapper, projectId, protocolId, protocolClassName, param if protocol.useQueue(): queueParams = [params['_queueName'], params['_queueParams']] protocol.setQueueParams(queueParams) + try: validationErrors = protocol._validate() if validationErrors: @@ -2271,18 +2304,37 @@ def launchProtocol(self, mapper, projectId, protocolId, protocolClassName, param detail=errors, ) - if executeMode == "schedule": - self.currentProject.scheduleProtocol(protocol) - return - - modeToRunMode = { - "launch": MODE_RESUME, - "restart": MODE_RESTART, - } + try: + if executeMode == "schedule": + self.currentProject.scheduleProtocol(protocol) + else: + modeToRunMode = { + "launch": MODE_RESUME, + "restart": MODE_RESTART, + } + runMode = modeToRunMode[executeMode] + protocol.runMode.set(runMode) + self.currentProject.launchProtocol(protocol) - runMode = modeToRunMode[executeMode] - protocol.runMode.set(runMode) - self.currentProject.launchProtocol(protocol) + self.syncProjectProtocolsAndDependencies( + mapper, + projectId, + refresh=True, + checkPid=True, + ) + except HTTPException: + raise + except Exception as e: + logger.exception( + "Failed to sync protocol graph after execute. projectId=%s protocolId=%s executeMode=%s", + projectId, + getattr(protocol, "getObjId", lambda: protocolId)(), + executeMode, + ) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Protocol execution finished but graph sync to PostgreSQL failed: {e}", + ) def findViewersWeb(self, protocol): # TODO: Find viewers... From a0c29dc27c612775fd16138535be1a24cff3430c Mon Sep 17 00:00:00 2001 From: "Yunior C. Fonseca Reyna" Date: Wed, 22 Apr 2026 22:41:50 +0200 Subject: [PATCH 08/14] fix: resync protocol graph from scipion when loading projects --- app/backend/api/services/project_service.py | 59 +++++++++++++++++---- 1 file changed, 48 insertions(+), 11 deletions(-) diff --git a/app/backend/api/services/project_service.py b/app/backend/api/services/project_service.py index 4f3e617..3341c06 100644 --- a/app/backend/api/services/project_service.py +++ b/app/backend/api/services/project_service.py @@ -1278,22 +1278,34 @@ def loadProject(self, dbProj: dict, mapper: PostgresqlFlatMapper = None, refresh self.currentProject.load(dbPath=self.currentProject.getDbPath()) # Refresh Scipion graph and keep a live map of protocol objects - runs = None runMap: Dict[str, Any] = {} + scipionProtocolCount = 0 + scipionEdgeCount = 0 try: runs = self.currentProject.getRunsGraph(refresh=refresh, checkPids=checkPid) nodesDict = getattr(runs, "_nodesDict", {}) or {} + for nodeId, nodeObj in nodesDict.items(): if str(nodeId) == "PROJECT": continue + + scipionProtocolCount += 1 runMap[str(nodeId)] = getattr(nodeObj, "run", None) + + for parent in getattr(nodeObj, "_parents", []) or []: + parentNodeId = str(parent.getName()) + if parentNodeId != "PROJECT": + scipionEdgeCount += 1 + except Exception: logger.exception( "Failed to refresh Scipion runs graph for project %s", dbProj['id'], ) runMap = {} + scipionProtocolCount = 0 + scipionEdgeCount = 0 tags = {} dependencyMap = {} @@ -1327,6 +1339,41 @@ def loadProject(self, dbProj: dict, mapper: PostgresqlFlatMapper = None, refresh ) protocolRows = [] + dbProtocolCount = len(protocolRows) + dbEdgeCount = sum(len(v.get("parents") or []) for v in dependencyMap.values()) + + shouldResyncGraph = ( + scipionProtocolCount != dbProtocolCount or + scipionEdgeCount != dbEdgeCount + ) + + if shouldResyncGraph: + try: + logger.info( + "Resyncing protocol graph from Scipion to PostgreSQL. " + "projectId=%s scipionProtocols=%s dbProtocols=%s scipionEdges=%s dbEdges=%s", + dbProj['id'], + scipionProtocolCount, + dbProtocolCount, + scipionEdgeCount, + dbEdgeCount, + ) + + self.syncProjectProtocolsAndDependencies( + mapper, + dbProj['id'], + refresh=False, + checkPid=False, + ) + + dependencyMap = mapper.getProjectProtocolAdjacencyMap(dbProj['id']) + protocolRows = mapper.getProtocols(dbProj['id']) + except Exception: + logger.exception( + "Failed to resync protocol graph during project load for project %s", + dbProj['id'], + ) + graphData = self.buildProtocolsGraph( dbProj['id'], protocolRows, @@ -1474,16 +1521,6 @@ def applyWorkflowToProject( "loadResult": str(loadResult) if loadResult is not None else None, } - def saveProtocolDependencies(self, mapper: PostgresqlFlatMapper, graphData: dict): - for nodeId, nodeInfo in graphData.items(): - parentIds = [int(pid) for pid in nodeInfo['parents'] if pid != 'PROJECT'] - childIds = [int(cid) for cid in nodeInfo['children']] - mapper.updateProtocolDependencies( - protocolId=nodeId, - parentIds=parentIds, - childIds=childIds - ) - @staticmethod def getProtocolColor(status: str) -> str: """Return hex color based on protocol status.""" From 844ce164f9563abb7ceccf52da3be50ca1dbb4f7 Mon Sep 17 00:00:00 2001 From: "Yunior C. Fonseca Reyna" Date: Wed, 22 Apr 2026 22:58:06 +0200 Subject: [PATCH 09/14] fix: sync protocol graph after rename stop restart and reset operations --- app/backend/api/routers/project_router.py | 33 ++++- app/backend/api/services/project_service.py | 144 ++++++++++++++------ 2 files changed, 133 insertions(+), 44 deletions(-) diff --git a/app/backend/api/routers/project_router.py b/app/backend/api/routers/project_router.py index 428757f..c92c87e 100644 --- a/app/backend/api/routers/project_router.py +++ b/app/backend/api/routers/project_router.py @@ -515,7 +515,6 @@ def renameProtocol( ) try: - # Basic payload validation for semantic HTTP newName = getattr(payload, "name", None) if not newName or not str(newName).strip(): return JSONResponse( @@ -526,6 +525,14 @@ def renameProtocol( ) service.renameProtocol(protocolId, str(newName).strip()) + service.syncProjectGraphAfterMutation( + mapper, + projectId, + actionLabel="rename protocol", + refresh=True, + checkPid=True, + ) + return {"status": 0, "errors": [], "workflow": []} @@ -667,6 +674,14 @@ def restartProtocolAll( "errors": errors, "workflow": []} + service.syncProjectGraphAfterMutation( + mapper, + projectId, + actionLabel="restart protocol subtree", + refresh=True, + checkPid=True, + ) + return {"status": 0, "errors": [], "workflow": []} @@ -743,6 +758,14 @@ def resetProtocolFrom( try: service.resetProtocolFrom(protocolId) + service.syncProjectGraphAfterMutation( + mapper, + projectId, + actionLabel="reset protocol from node", + refresh=True, + checkPid=True, + ) + return {"status": 0, "errors": [], "workflow": []} except HTTPException as e: @@ -785,6 +808,14 @@ def stopProtocol( ) service.stopProtocol(protocolIds) + service.syncProjectGraphAfterMutation( + mapper, + projectId, + actionLabel="stop protocol", + refresh=True, + checkPid=True, + ) + return {"status": 0, "errors": [], "workflow": []} diff --git a/app/backend/api/services/project_service.py b/app/backend/api/services/project_service.py index 3341c06..6590145 100644 --- a/app/backend/api/services/project_service.py +++ b/app/backend/api/services/project_service.py @@ -448,6 +448,34 @@ def syncProjectProtocolsAndDependencies( "dependencies": int(savedEdges), } + def syncProjectGraphAfterMutation( + self, + mapper: PostgresqlFlatMapper, + projectId: int, + actionLabel: str, + refresh: bool = True, + checkPid: bool = True, + ) -> Dict[str, int]: + try: + return self.syncProjectProtocolsAndDependencies( + mapper, + projectId, + refresh=refresh, + checkPid=checkPid, + ) + except HTTPException: + raise + except Exception as e: + logger.exception( + "Failed to sync protocol graph after %s. projectId=%s", + actionLabel, + projectId, + ) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"{actionLabel} succeeded but graph sync to PostgreSQL failed: {e}", + ) + def importProject( self, mapper: PostgresqlFlatMapper, @@ -2199,86 +2227,101 @@ def setPointerParam(self, protocol, key, value, parentId): def saveProtocol(self, mapper, projectId, protocolId, protocolClassName, params, setToSave=True): errorList = [] - if not protocolId: # new protocol + if not protocolId: protClass = self.currentProject.getDomain().getProtocols().get(protocolClassName) + if protClass is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Protocol class not found: {protocolClassName}", + ) protocol = self.currentProject.newProtocol(protClass) - else: # retrieve a protocol by id + else: protocol = self.currentProject.getProtocol(int(protocolId)) + if protocol is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Protocol not found: {protocolId}", + ) - # Set protected parameters protectedParams = ['_objComment', '_useQueue', '_prerequisites', 'gpuList', 'numberOfThreads'] for paramName in protectedParams: protVar = getattr(protocol, paramName, None) - if protVar is not None: - try: - if paramName in params: - value = params[paramName] - protVar.set(value) - except Exception: - setattr(protocol, paramName, value) + if protVar is None or paramName not in params: + continue + + value = params[paramName] + try: + protVar.set(value) + except Exception: + setattr(protocol, paramName, value) - # Set non-pointer parameters for key, value in params.items(): param = protocol.getParam(key) if param is None: - logger.warning(f"[WARN] Param not found: {key}") + logger.warning("[WARN] Param not found: %s", key) continue if isinstance(param, (PointerParam, MultiPointerParam, RelationParam)): continue - rawValue = value try: - castedValue = self.castParamValue(param, rawValue) - errors = param.validate(castedValue) if hasattr(param, 'validate') else [] + castedValue = self.castParamValue(param, value) + errors = param.validate(castedValue) if hasattr(param, "validate") else [] if errors: - errorListAux = ['**' + param.label.get() + '** ' + error for error in errors] - errorList += errorListAux + errorList += ['**' + param.label.get() + '** ' + error for error in errors] param.set(castedValue) protocol.setAttributeValue(key, castedValue) - if key == 'runName': + if key == "runName": protocol.setObjLabel(castedValue) - logger.info(f"[INFO] Set param {key} = {castedValue}") + logger.info("[INFO] Set param %s = %s", key, castedValue) except Exception as e: - import re cleaned = re.sub(r'[^A-Za-z0-9\s+\-*/=<>!&|^%()\[\]{}_,.;:]', '', str(e)) - errorList += ['**' + param.label.get() + '** ' + cleaned] + errorList.append('**' + param.label.get() + '** ' + cleaned) - # Apply pointer parameters errorList += self.applyParamsToProtocol(protocol, params) - # Persist protocol in Scipion - if protocol.hasObjId(): - self.currentProject._storeProtocol(protocol) - else: - self.currentProject._setupProtocol(protocol) - - # Persist protocol in PostgreSQL and resync graph + # Persist protocol in Scipion always. + # The setToSave flag only controls whether we also sync the graph to PostgreSQL now. try: - protocolContext = self._buildProtocolContext(projectId, protocol) - mapper.saveProtocol(protocolContext) - - self.syncProjectProtocolsAndDependencies( - mapper, - projectId, - refresh=True, - checkPid=True, - ) + if protocol.hasObjId(): + self.currentProject._storeProtocol(protocol) + else: + self.currentProject._setupProtocol(protocol) except Exception as e: logger.exception( - "Failed to sync protocol graph after save. projectId=%s protocolId=%s protocolClassName=%s", + "Failed to persist protocol in Scipion. projectId=%s protocolId=%s protocolClassName=%s", projectId, - getattr(protocol, "getObjId", lambda: protocolId)(), + protocolId, protocolClassName, ) raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Protocol was saved in Scipion but graph sync to PostgreSQL failed: {e}", + detail=f"Failed to persist protocol in Scipion: {e}", ) + if setToSave: + try: + self.syncProjectProtocolsAndDependencies( + mapper, + projectId, + refresh=True, + checkPid=True, + ) + except Exception as e: + logger.exception( + "Failed to sync protocol graph after save. projectId=%s protocolId=%s protocolClassName=%s", + projectId, + getattr(protocol, "getObjId", lambda: protocolId)(), + protocolClassName, + ) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Protocol was saved in Scipion but graph sync to PostgreSQL failed: {e}", + ) + return protocol, errorList def launchProtocol(self, mapper, projectId, protocolId, protocolClassName, params, executeMode): @@ -2322,8 +2365,9 @@ def launchProtocol(self, mapper, projectId, protocolId, protocolClassName, param ) if protocol.useQueue(): - queueParams = [params['_queueName'], params['_queueParams']] - protocol.setQueueParams(queueParams) + queueName = params.get("_queueName") + queueParams = params.get("_queueParams") + protocol.setQueueParams([queueName, queueParams]) try: validationErrors = protocol._validate() @@ -2336,6 +2380,20 @@ def launchProtocol(self, mapper, projectId, protocolId, protocolClassName, param ] if errors: + try: + self.syncProjectProtocolsAndDependencies( + mapper, + projectId, + refresh=True, + checkPid=True, + ) + except Exception: + logger.exception( + "Failed to sync protocol graph after validation errors. projectId=%s protocolId=%s", + projectId, + getattr(protocol, "getObjId", lambda: protocolId)(), + ) + raise HTTPException( status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=errors, From 9cdb3e66517a8d610d686ee98d311d4b95cd8be1 Mon Sep 17 00:00:00 2001 From: "Yunior C. Fonseca Reyna" Date: Thu, 23 Apr 2026 06:47:03 +0200 Subject: [PATCH 10/14] test: update protocol operation tests for graph sync and thumbnail changes --- app/backend/api/routers/project_router.py | 8 ++ app/backend/api/services/project_service.py | 71 +++++++++++++++- tests/conftest.py | 16 ++++ .../api/test_projects_router_protocol_ops.py | 27 +++++++ .../test_project_service_protocols.py | 81 ++++++++++++++++--- .../utils/test_thumbnail_service_core.py | 14 ++-- 6 files changed, 196 insertions(+), 21 deletions(-) diff --git a/app/backend/api/routers/project_router.py b/app/backend/api/routers/project_router.py index c92c87e..847fcbd 100644 --- a/app/backend/api/routers/project_router.py +++ b/app/backend/api/routers/project_router.py @@ -723,6 +723,14 @@ def continueProtocolAll( try: service.continueProtocolAll(mapper, projectId, protocolId, currentUser) + service.syncProjectGraphAfterMutation( + mapper, + projectId, + actionLabel="continue protocol subtree", + refresh=True, + checkPid=True, + ) + return {"status": 0, "errors": [], "workflow": []} except HTTPException as e: diff --git a/app/backend/api/services/project_service.py b/app/backend/api/services/project_service.py index 6590145..1e98f6a 100644 --- a/app/backend/api/services/project_service.py +++ b/app/backend/api/services/project_service.py @@ -3052,8 +3052,75 @@ def restartProtocolAll(self, protocolId: int): except Exception as e: raise HTTPException(status_code=500, detail=str(e)) - def continueProtocolAll(self, mapper, projectId: int, protocolId: int, currentUser: dict): - raise NotImplementedError + def continueProtocolAll(self, mapper, projectId, protocolId, currentUser): + protocol = self.currentProject.getProtocol(int(protocolId)) + if protocol is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Protocol not found: {protocolId}", + ) + + try: + workflowProtocolList, activeProtocolList = self.currentProject._getSubworkflow(protocol) + except Exception as e: + logger.exception( + "Failed to resolve subworkflow for continue-all. projectId=%s protocolId=%s", + projectId, + protocolId, + ) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to resolve protocol subworkflow: {e}", + ) + + protocolsToResume = activeProtocolList or workflowProtocolList or [] + if not protocolsToResume: + return {"status": "ok", "message": "No protocols to continue"} + + for item in protocolsToResume: + protocolToLaunch = item + + if not hasattr(protocolToLaunch, "runMode"): + try: + protocolToLaunch = self.currentProject.getProtocol(int(item)) + except Exception: + logger.exception( + "Failed to resolve protocol to continue. projectId=%s protocolId=%s item=%s", + projectId, + protocolId, + item, + ) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to resolve protocol to continue: {item}", + ) + + try: + protocolToLaunch.runMode.set(MODE_RESUME) + except Exception: + logger.debug( + "Could not set MODE_RESUME before continue-all. projectId=%s protocolId=%s item=%s", + projectId, + protocolId, + getattr(protocolToLaunch, "getObjId", lambda: item)(), + exc_info=True, + ) + + try: + self.currentProject.launchProtocol(protocolToLaunch) + except Exception as e: + logger.exception( + "Failed to continue protocol. projectId=%s protocolId=%s item=%s", + projectId, + protocolId, + getattr(protocolToLaunch, "getObjId", lambda: item)(), + ) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to continue protocol: {e}", + ) + + return {"status": "ok", "message": "Protocol subtree continued successfully"} def resetProtocolFrom(self, protocolId: int): protocol = self.currentProject.getProtocol(int(protocolId)) diff --git a/tests/conftest.py b/tests/conftest.py index 8044c68..423a0bb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -116,6 +116,7 @@ def makeProjectOut(projectId: int = 1, name: str = "Demo Project", **overrides): payload.update(overrides) return payload + class FakeProjectService: # fakeProjectService def __init__(self): @@ -176,6 +177,9 @@ def __init__(self): self.applyWorkflowError = None self.lastApplyWorkflowCall = None + self.syncProjectGraphAfterMutationError = None + self.lastSyncProjectGraphAfterMutationCall = None + self.protocolParamsResult = { "protocolId": "10", "protocolClassName": "ProtClass", @@ -369,6 +373,18 @@ def listProjectShares(self, mapper, projectId, currentUser): } return self.projectSharesResult + def syncProjectGraphAfterMutation(self, mapper, projectId, actionLabel, refresh=True, checkPid=True): + self.lastSyncProjectGraphAfterMutationCall = { + "mapper": mapper, + "projectId": projectId, + "actionLabel": actionLabel, + "refresh": refresh, + "checkPid": checkPid, + } + if self.syncProjectGraphAfterMutationError is not None: + raise self.syncProjectGraphAfterMutationError + return {"protocols": 1, "dependencies": 0} + def applyWorkflowToProject(self, mapper, projectId, workflowId, currentUser): self.lastApplyWorkflowCall = { "mapper": mapper, diff --git a/tests/integration/api/test_projects_router_protocol_ops.py b/tests/integration/api/test_projects_router_protocol_ops.py index 28ec08f..deeb2af 100644 --- a/tests/integration/api/test_projects_router_protocol_ops.py +++ b/tests/integration/api/test_projects_router_protocol_ops.py @@ -241,6 +241,14 @@ def test_RenameProtocolDelegatesToService(projectClient, fakeProjectService): "newName": "Renamed protocol", } + assert fakeProjectService.lastSyncProjectGraphAfterMutationCall == { + "mapper": fakeProjectService.lastSyncProjectGraphAfterMutationCall["mapper"], + "projectId": 1, + "actionLabel": "rename protocol", + "refresh": True, + "checkPid": True, + } + def test_DuplicateProtocolRejectsMissingItems(projectClient): response = projectClient.post( @@ -366,6 +374,25 @@ def test_ContinueProtocolAllDelegatesToService(projectClient, fakeProjectService }, } + assert fakeProjectService.lastSyncProjectGraphAfterMutationCall == { + "mapper": fakeProjectService.lastSyncProjectGraphAfterMutationCall["mapper"], + "projectId": 1, + "actionLabel": "continue protocol subtree", + "refresh": True, + "checkPid": True, + } + + assert fakeProjectService.lastContinueProtocolAllCall == { + "mapper": fakeProjectService.lastContinueProtocolAllCall["mapper"], + "projectId": 1, + "protocolId": 10, + "currentUser": { + "id": 1, + "email": "user@example.com", + "role": "user", + }, + } + def test_ResetProtocolFromDelegatesToService(projectClient, fakeProjectService): response = projectClient.post("/projects/1/protocols/10/reset-from") diff --git a/tests/unit/backend/api/services/test_project_service_protocols.py b/tests/unit/backend/api/services/test_project_service_protocols.py index ca8878f..40cdb08 100644 --- a/tests/unit/backend/api/services/test_project_service_protocols.py +++ b/tests/unit/backend/api/services/test_project_service_protocols.py @@ -245,6 +245,18 @@ def service(projectServiceModule): instance = object.__new__(projectServiceModule.ProjectService) instance.currentProject = FakeCurrentProject() instance.tomoList = {} + instance._buildProtocolContext = lambda projectId, protocol: { + "projectId": projectId, + "protocolId": protocol.getObjId(), + "protocolClassName": getattr(protocol, "_className", "ProtClass"), + "params": {}, + } + instance.syncProjectProtocolsAndDependencies = ( + lambda mapper, projectId, refresh=False, checkPid=False: { + "protocols": 0, + "dependencies": 0, + } + ) return instance @@ -298,6 +310,17 @@ def test_SaveProtocolCreatesNewProtocolAndPersistsContext(projectServiceModule, }, ) + def fakeSyncProjectProtocolsAndDependencies(mapperObj, projectId, refresh=False, checkPid=False): + for protocolObj in service.currentProject.setupProtocols: + mapperObj.saveProtocol(service._buildProtocolContext(projectId, protocolObj)) + return {"protocols": len(service.currentProject.setupProtocols), "dependencies": 0} + + monkeypatch.setattr( + service, + "syncProjectProtocolsAndDependencies", + fakeSyncProjectProtocolsAndDependencies, + ) + def buildProtocol(): protocol = FakeProtocol(objId=None, className="ProtClass") protocol.addParam("runName", FakeStringParam(label="Run name")) @@ -350,6 +373,17 @@ def test_SaveProtocolAggregatesValidationAndPointerErrors(projectServiceModule, monkeypatch.setattr(service, "applyParamsToProtocol", lambda protocolObj, params: ["pointer error"]) + def fakeSyncProjectProtocolsAndDependencies(mapperObj, projectId, refresh=False, checkPid=False): + for protocolObj in service.currentProject.storedProtocols: + mapperObj.saveProtocol(service._buildProtocolContext(projectId, protocolObj)) + return {"protocols": len(service.currentProject.storedProtocols), "dependencies": 0} + + monkeypatch.setattr( + service, + "syncProjectProtocolsAndDependencies", + fakeSyncProjectProtocolsAndDependencies, + ) + _, errors = service.saveProtocol( mapper=mapper, projectId=1, @@ -510,12 +544,22 @@ def __init__(self, itemId): protocols=[DuplicateItem("10"), DuplicateItem("11")], ) - assert result == {"status": "ok", "message": "Protocol was duplicated successfully"} + def fakeSyncProjectProtocolsAndDependencies(mapperObj, projectId, refresh=False, checkPid=False): + for protocolObj in service.currentProject.copiedProtocolOutputs: + mapperObj.saveProtocol(service._buildProtocolContext(projectId, protocolObj)) + return {"protocols": len(service.currentProject.copiedProtocolOutputs), "dependencies": 0} + + monkeypatch.setattr( + service, + "syncProjectProtocolsAndDependencies", + fakeSyncProjectProtocolsAndDependencies, + ) + + assert result["status"] == "ok" + assert result["message"] == "Protocol was duplicated successfully" + assert "protocolsCount" in result + assert "dependenciesCount" in result assert service.currentProject.copiedProtocolInputs == [[protocolA, protocolB]] - assert mapper.savedProtocolContexts == [ - {"projectId": 1, "protocolId": 110}, - {"projectId": 1, "protocolId": 111}, - ] def test_DeleteProtocolDelegatesToCurrentProjectAndMapper(service, mapper): @@ -565,14 +609,25 @@ def test_RestartProtocolAllReturnsCollectedErrors(service): assert result == ["cannot restart", "blocked"] -def test_ContinueProtocolAllIsNotImplemented(service, mapper): - with pytest.raises(NotImplementedError): - service.continueProtocolAll( - mapper=mapper, - projectId=1, - protocolId=10, - currentUser={"id": 1}, - ) +def test_ContinueProtocolAllLaunchesActiveProtocolsInResumeMode(projectServiceModule, service, mapper, monkeypatch): + monkeypatch.setattr(projectServiceModule, "MODE_RESUME", "resume-mode") + + protocol = FakeProtocol(objId=10) + activeProtocol = FakeProtocol(objId=20) + + service.currentProject.protocols[10] = protocol + service.currentProject._getSubworkflow = lambda protocolObj: (["wf-a", "wf-b"], [activeProtocol]) + + result = service.continueProtocolAll( + mapper=mapper, + projectId=1, + protocolId=10, + currentUser={"id": 1}, + ) + + assert result == {"status": "ok", "message": "Protocol subtree continued successfully"} + assert activeProtocol.runMode.get() == "resume-mode" + assert service.currentProject.launchedProtocols == [activeProtocol] def test_ResetProtocolFromReturnsSuccessWhenWorkflowResets(service): diff --git a/tests/unit/backend/utils/test_thumbnail_service_core.py b/tests/unit/backend/utils/test_thumbnail_service_core.py index 86fc87e..25811dc 100644 --- a/tests/unit/backend/utils/test_thumbnail_service_core.py +++ b/tests/unit/backend/utils/test_thumbnail_service_core.py @@ -248,6 +248,7 @@ def test_BuildProtocolThumbnailReturnsCachedEntry(service, monkeypatch, tmp_path cachePath.write_text("cached", encoding="utf-8") monkeypatch.setattr(service, "_getProtocolCachePath", lambda protocolId, size, outputName=None: cachePath) + monkeypatch.setattr(service, "_isValidCachedImage", lambda path: True) result = service.buildProtocolThumbnail(protocolId=10, force=False, size=320) @@ -293,6 +294,7 @@ def test_BuildProjectThumbnailReturnsCachedStrip(service, monkeypatch, tmp_path) cachePath.write_text("cached", encoding="utf-8") monkeypatch.setattr(service, "_getProjectCachePath", lambda size, maxProtocols: cachePath) + monkeypatch.setattr(service, "_isValidCachedImage", lambda path: True) result = service.buildProjectThumbnail(force=False, size=720, maxProtocols=6) @@ -327,8 +329,8 @@ def test_ListProtocolThumbnailItemsBuildsGroups(service, monkeypatch): ) monkeypatch.setattr( service, - "buildProtocolThumbnail", - lambda protocolId, force, size, outputName=None: { + "buildProtocolOutputThumbnail", + lambda protocolId, outputName, force, size: { "exists": True, "absolutePath": f"/tmp/{protocolId}_{outputName}.png", }, @@ -352,15 +354,15 @@ def test_ListProtocolThumbnailItemsBuildsGroups(service, monkeypatch): "outputName": "outputVol", "outputClassName": "SetOfVolumes", "exists": True, - "thumbnailUrl": "/projects/7/protocols/11/thumbnail?outputName=outputVol", - "thumbnailRebuildUrl": "/projects/7/protocols/11/thumbnail/rebuild?outputName=outputVol", + "thumbnailUrl": "/projects/7/protocols/11/outputs/outputVol/thumbnail", + "thumbnailRebuildUrl": None, }, { "outputName": "outputParticles", "outputClassName": "SetOfParticles", "exists": True, - "thumbnailUrl": "/projects/7/protocols/11/thumbnail?outputName=outputParticles", - "thumbnailRebuildUrl": "/projects/7/protocols/11/thumbnail/rebuild?outputName=outputParticles", + "thumbnailUrl": "/projects/7/protocols/11/outputs/outputParticles/thumbnail", + "thumbnailRebuildUrl": None, }, ], } From d9e86ac619c2652129b57135b2604199e6c14711 Mon Sep 17 00:00:00 2001 From: "Yunior C. Fonseca Reyna" Date: Thu, 23 Apr 2026 07:02:17 +0200 Subject: [PATCH 11/14] test: cover graph sync failures after protocol mutations --- app/backend/api/services/project_service.py | 68 +++++++-- .../api/test_projects_router_protocol_ops.py | 130 +++++++++++++++++- .../test_project_service_protocols.py | 39 ++++-- 3 files changed, 202 insertions(+), 35 deletions(-) diff --git a/app/backend/api/services/project_service.py b/app/backend/api/services/project_service.py index 1e98f6a..73297a6 100644 --- a/app/backend/api/services/project_service.py +++ b/app/backend/api/services/project_service.py @@ -2993,29 +2993,67 @@ def renameProtocol(self, protocolId: int, newName: str): self.currentProject._storeProtocol(protocol) return {"status": "ok", "message": "Protocol renamed successfully"} - def duplicateProtocol(self, mapper, projectId, protocols: Any): - try: - protList = [] - for protocol in protocols: - protList.append(self.currentProject.getProtocol(int(protocol.id))) + def duplicateProtocol(self, mapper, projectId, protocols): + protocolList = [] - self.currentProject.copyProtocol(protList) + for item in protocols or []: + protocolId = getattr(item, "id", None) + if protocolId is None: + continue - syncInfo = self.syncProjectProtocolsAndDependencies( + protocol = self.currentProject.getProtocol(int(protocolId)) + if protocol is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Protocol not found: {protocolId}", + ) + + protocolList.append(protocol) + + if not protocolList: + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail="No valid protocols to duplicate", + ) + + try: + copiedProtocols = self.currentProject.copyProtocol(protocolList) or [] + except Exception as e: + logger.exception( + "Failed to duplicate protocols. projectId=%s protocolIds=%s", + projectId, + [getattr(p, "getObjId", lambda: None)() for p in protocolList], + ) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to duplicate protocols: {e}", + ) + + try: + syncResult = self.syncProjectProtocolsAndDependencies( mapper, projectId, refresh=True, checkPid=True, ) - - return { - "status": "ok", - "message": "Protocol was duplicated successfully", - "protocolsCount": syncInfo.get("protocols"), - "dependenciesCount": syncInfo.get("dependencies"), - } + except HTTPException: + raise except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) + logger.exception( + "Failed to sync protocol graph after duplication. projectId=%s", + projectId, + ) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Protocols were duplicated in Scipion but graph sync to PostgreSQL failed: {e}", + ) + + return { + "status": "ok", + "message": "Protocol was duplicated successfully", + "protocolsCount": int(syncResult.get("protocols", 0)), + "dependenciesCount": int(syncResult.get("dependencies", 0)), + } def deleteProtocol(self, mapper, projectId, protocols: Any): try: diff --git a/tests/integration/api/test_projects_router_protocol_ops.py b/tests/integration/api/test_projects_router_protocol_ops.py index deeb2af..c86e562 100644 --- a/tests/integration/api/test_projects_router_protocol_ops.py +++ b/tests/integration/api/test_projects_router_protocol_ops.py @@ -351,6 +351,13 @@ def test_RestartProtocolAllReturnsSuccess(projectClient, fakeProjectService): } assert fakeProjectService.lastRestartProtocolAllCall == {"protocolId": 10} + assert fakeProjectService.lastSyncProjectGraphAfterMutationCall == { + "mapper": fakeProjectService.lastSyncProjectGraphAfterMutationCall["mapper"], + "projectId": 1, + "actionLabel": "restart protocol subtree", + "refresh": True, + "checkPid": True, + } def test_ContinueProtocolAllDelegatesToService(projectClient, fakeProjectService): @@ -363,6 +370,14 @@ def test_ContinueProtocolAllDelegatesToService(projectClient, fakeProjectService "workflow": [], } + assert fakeProjectService.lastSyncProjectGraphAfterMutationCall == { + "mapper": fakeProjectService.lastSyncProjectGraphAfterMutationCall["mapper"], + "projectId": 1, + "actionLabel": "continue protocol subtree", + "refresh": True, + "checkPid": True, + } + assert fakeProjectService.lastContinueProtocolAllCall == { "mapper": fakeProjectService.lastContinueProtocolAllCall["mapper"], "projectId": 1, @@ -374,14 +389,84 @@ def test_ContinueProtocolAllDelegatesToService(projectClient, fakeProjectService }, } + +def test_ResetProtocolFromDelegatesToService(projectClient, fakeProjectService): + response = projectClient.post("/projects/1/protocols/10/reset-from") + + assert response.status_code == 200 + assert response.json() == { + "status": 0, + "errors": [], + "workflow": [], + } + + assert fakeProjectService.lastResetProtocolFromCall == {"protocolId": 10} assert fakeProjectService.lastSyncProjectGraphAfterMutationCall == { "mapper": fakeProjectService.lastSyncProjectGraphAfterMutationCall["mapper"], "projectId": 1, - "actionLabel": "continue protocol subtree", + "actionLabel": "reset protocol from node", "refresh": True, "checkPid": True, } +def test_RenameProtocolReturnsErrorWhenGraphSyncFails(projectClient, fakeProjectService): + fakeProjectService.syncProjectGraphAfterMutationError = HTTPException( + status_code=500, + detail="rename protocol succeeded but graph sync to PostgreSQL failed", + ) + + response = projectClient.put( + "/projects/1/protocols/10/rename", + json={"name": "Renamed protocol"}, + ) + + assert response.status_code == 500 + assert response.json() == { + "status": 1, + "errors": ["rename protocol succeeded but graph sync to PostgreSQL failed"], + "workflow": [], + } + + assert fakeProjectService.lastRenameProtocolCall == { + "protocolId": 10, + "newName": "Renamed protocol", + } + + +def test_RestartProtocolAllReturnsErrorWhenGraphSyncFails(projectClient, fakeProjectService): + fakeProjectService.restartProtocolAllResult = [] + fakeProjectService.syncProjectGraphAfterMutationError = HTTPException( + status_code=500, + detail="restart protocol subtree succeeded but graph sync to PostgreSQL failed", + ) + + response = projectClient.post("/projects/1/protocols/10/restart-all") + + assert response.status_code == 500 + assert response.json() == { + "status": 1, + "errors": ["restart protocol subtree succeeded but graph sync to PostgreSQL failed"], + "workflow": [], + } + + assert fakeProjectService.lastRestartProtocolAllCall == {"protocolId": 10} + + +def test_ContinueProtocolAllReturnsErrorWhenGraphSyncFails(projectClient, fakeProjectService): + fakeProjectService.syncProjectGraphAfterMutationError = HTTPException( + status_code=500, + detail="continue protocol subtree succeeded but graph sync to PostgreSQL failed", + ) + + response = projectClient.post("/projects/1/protocols/10/continue-all") + + assert response.status_code == 500 + assert response.json() == { + "status": 1, + "errors": ["continue protocol subtree succeeded but graph sync to PostgreSQL failed"], + "workflow": [], + } + assert fakeProjectService.lastContinueProtocolAllCall == { "mapper": fakeProjectService.lastContinueProtocolAllCall["mapper"], "projectId": 1, @@ -394,19 +479,47 @@ def test_ContinueProtocolAllDelegatesToService(projectClient, fakeProjectService } -def test_ResetProtocolFromDelegatesToService(projectClient, fakeProjectService): +def test_ResetProtocolFromReturnsErrorWhenGraphSyncFails(projectClient, fakeProjectService): + fakeProjectService.syncProjectGraphAfterMutationError = HTTPException( + status_code=500, + detail="reset protocol from node succeeded but graph sync to PostgreSQL failed", + ) + response = projectClient.post("/projects/1/protocols/10/reset-from") - assert response.status_code == 200 + assert response.status_code == 500 assert response.json() == { - "status": 0, - "errors": [], + "status": 1, + "errors": ["reset protocol from node succeeded but graph sync to PostgreSQL failed"], "workflow": [], } assert fakeProjectService.lastResetProtocolFromCall == {"protocolId": 10} +def test_StopProtocolReturnsErrorWhenGraphSyncFails(projectClient, fakeProjectService): + fakeProjectService.syncProjectGraphAfterMutationError = HTTPException( + status_code=500, + detail="stop protocol succeeded but graph sync to PostgreSQL failed", + ) + + response = projectClient.post( + "/projects/1/protocols/stop", + json={"protocolIds": ["10", "11"]}, + ) + + assert response.status_code == 500 + assert response.json() == { + "status": 1, + "errors": ["stop protocol succeeded but graph sync to PostgreSQL failed"], + "workflow": [], + } + + assert fakeProjectService.lastStopProtocolCall == { + "protocolIds": ["10", "11"], + } + + def test_StopProtocolRejectsMissingProtocolIds(projectClient): response = projectClient.post( "/projects/1/protocols/stop", @@ -436,4 +549,11 @@ def test_StopProtocolDelegatesToService(projectClient, fakeProjectService): assert fakeProjectService.lastStopProtocolCall == { "protocolIds": ["10", "11"], + } + assert fakeProjectService.lastSyncProjectGraphAfterMutationCall == { + "mapper": fakeProjectService.lastSyncProjectGraphAfterMutationCall["mapper"], + "projectId": 1, + "actionLabel": "stop protocol", + "refresh": True, + "checkPid": True, } \ No newline at end of file diff --git a/tests/unit/backend/api/services/test_project_service_protocols.py b/tests/unit/backend/api/services/test_project_service_protocols.py index 40cdb08..abc6c06 100644 --- a/tests/unit/backend/api/services/test_project_service_protocols.py +++ b/tests/unit/backend/api/services/test_project_service_protocols.py @@ -534,20 +534,13 @@ def test_DuplicateProtocolCopiesAndPersists(service, mapper, monkeypatch): }, ) - class DuplicateItem: - def __init__(self, itemId): - self.id = itemId - - result = service.duplicateProtocol( - mapper=mapper, - projectId=1, - protocols=[DuplicateItem("10"), DuplicateItem("11")], - ) - def fakeSyncProjectProtocolsAndDependencies(mapperObj, projectId, refresh=False, checkPid=False): for protocolObj in service.currentProject.copiedProtocolOutputs: mapperObj.saveProtocol(service._buildProtocolContext(projectId, protocolObj)) - return {"protocols": len(service.currentProject.copiedProtocolOutputs), "dependencies": 0} + return { + "protocols": len(service.currentProject.copiedProtocolOutputs), + "dependencies": 0, + } monkeypatch.setattr( service, @@ -555,11 +548,27 @@ def fakeSyncProjectProtocolsAndDependencies(mapperObj, projectId, refresh=False, fakeSyncProjectProtocolsAndDependencies, ) - assert result["status"] == "ok" - assert result["message"] == "Protocol was duplicated successfully" - assert "protocolsCount" in result - assert "dependenciesCount" in result + class DuplicateItem: + def __init__(self, itemId): + self.id = itemId + + result = service.duplicateProtocol( + mapper=mapper, + projectId=1, + protocols=[DuplicateItem("10"), DuplicateItem("11")], + ) + + assert result == { + "status": "ok", + "message": "Protocol was duplicated successfully", + "protocolsCount": 2, + "dependenciesCount": 0, + } assert service.currentProject.copiedProtocolInputs == [[protocolA, protocolB]] + assert mapper.savedProtocolContexts == [ + {"projectId": 1, "protocolId": 110}, + {"projectId": 1, "protocolId": 111}, + ] def test_DeleteProtocolDelegatesToCurrentProjectAndMapper(service, mapper): From 5a574de99e8c3efa45be3d689ed2fda259399752 Mon Sep 17 00:00:00 2001 From: "Yunior C. Fonseca Reyna" Date: Thu, 23 Apr 2026 10:27:06 +0200 Subject: [PATCH 12/14] refactor: unify protocol mutation return contracts across ProjectService --- app/backend/api/routers/project_router.py | 9 +- app/backend/api/services/project_service.py | 170 +++++++++++++++--- .../api/test_projects_router_protocol_ops.py | 9 +- .../test_project_service_protocols.py | 10 +- 4 files changed, 154 insertions(+), 44 deletions(-) diff --git a/app/backend/api/routers/project_router.py b/app/backend/api/routers/project_router.py index 847fcbd..09718d9 100644 --- a/app/backend/api/routers/project_router.py +++ b/app/backend/api/routers/project_router.py @@ -666,14 +666,7 @@ def restartProtocolAll( ) try: - errorList = service.restartProtocolAll(protocolId) - errors = [str(e) for e in (errorList or [])] - - if errors: - return {"status": 1, - "errors": errors, - "workflow": []} - + service.restartProtocolAll(protocolId) service.syncProjectGraphAfterMutation( mapper, projectId, diff --git a/app/backend/api/services/project_service.py b/app/backend/api/services/project_service.py index 73297a6..18ff7e4 100644 --- a/app/backend/api/services/project_service.py +++ b/app/backend/api/services/project_service.py @@ -2987,11 +2987,38 @@ def getProtocolLogs(self, projectId: int, protocolId: int, "scheduleOffset": newOffsetSchedule, } - def renameProtocol(self, protocolId: int, newName: str): + @staticmethod + def _buildProtocolMutationResult(message: str, **extra) -> Dict[str, Any]: + result = { + "status": "ok", + "message": message, + } + result.update(extra or {}) + return result + + def renameProtocol(self, protocolId, newName): protocol = self.currentProject.getProtocol(int(protocolId)) - protocol.setObjLabel(newName) - self.currentProject._storeProtocol(protocol) - return {"status": "ok", "message": "Protocol renamed successfully"} + if protocol is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Protocol not found: {protocolId}", + ) + + try: + protocol.setObjLabel(newName) + self.currentProject._storeProtocol(protocol) + except Exception as e: + logger.exception( + "Failed to rename protocol. protocolId=%s newName=%s", + protocolId, + newName, + ) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to rename protocol: {e}", + ) + + return self._buildProtocolMutationResult("Protocol renamed successfully") def duplicateProtocol(self, mapper, projectId, protocols): protocolList = [] @@ -3017,7 +3044,7 @@ def duplicateProtocol(self, mapper, projectId, protocols): ) try: - copiedProtocols = self.currentProject.copyProtocol(protocolList) or [] + self.currentProject.copyProtocol(protocolList) except Exception as e: logger.exception( "Failed to duplicate protocols. projectId=%s protocolIds=%s", @@ -3048,12 +3075,11 @@ def duplicateProtocol(self, mapper, projectId, protocols): detail=f"Protocols were duplicated in Scipion but graph sync to PostgreSQL failed: {e}", ) - return { - "status": "ok", - "message": "Protocol was duplicated successfully", - "protocolsCount": int(syncResult.get("protocols", 0)), - "dependenciesCount": int(syncResult.get("dependencies", 0)), - } + return self._buildProtocolMutationResult( + "Protocol was duplicated successfully", + protocolsCount=int(syncResult.get("protocols", 0)), + dependenciesCount=int(syncResult.get("dependencies", 0)), + ) def deleteProtocol(self, mapper, projectId, protocols: Any): try: @@ -3080,15 +3106,46 @@ def deleteProtocol(self, mapper, projectId, protocols: Any): except Exception as e: raise HTTPException(status_code=500, detail=str(e)) - def restartProtocolAll(self, protocolId: int): + def restartProtocolAll(self, protocolId): + protocol = self.currentProject.getProtocol(int(protocolId)) + if protocol is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Protocol not found: {protocolId}", + ) + + try: + workflowProtocolList, _activeProtocolList = self.currentProject._getSubworkflow(protocol) + except Exception as e: + logger.exception( + "Failed to resolve subworkflow for restart-all. protocolId=%s", + protocolId, + ) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to resolve protocol subworkflow: {e}", + ) + + errorList = [] try: - protocol = self.currentProject.getProtocol(int(protocolId)) - workflowProtocolList, activeProtList = self.currentProject._getSubworkflow(protocol) - errorList = [] self.currentProject._restartWorkflow(errorList, workflowProtocolList) - return errorList except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) + logger.exception( + "Failed to restart workflow subtree. protocolId=%s", + protocolId, + ) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to restart protocol subtree: {e}", + ) + + if errorList: + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail=[str(e) for e in errorList], + ) + + return self._buildProtocolMutationResult("Protocol subtree restarted successfully") def continueProtocolAll(self, mapper, projectId, protocolId, currentUser): protocol = self.currentProject.getProtocol(int(protocolId)) @@ -3113,7 +3170,7 @@ def continueProtocolAll(self, mapper, projectId, protocolId, currentUser): protocolsToResume = activeProtocolList or workflowProtocolList or [] if not protocolsToResume: - return {"status": "ok", "message": "No protocols to continue"} + return self._buildProtocolMutationResult("No protocols to continue") for item in protocolsToResume: protocolToLaunch = item @@ -3158,25 +3215,80 @@ def continueProtocolAll(self, mapper, projectId, protocolId, currentUser): detail=f"Failed to continue protocol: {e}", ) - return {"status": "ok", "message": "Protocol subtree continued successfully"} + return self._buildProtocolMutationResult("Protocol subtree continued successfully") - def resetProtocolFrom(self, protocolId: int): + def resetProtocolFrom(self, protocolId): protocol = self.currentProject.getProtocol(int(protocolId)) + if protocol is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Protocol not found: {protocolId}", + ) + try: - workflowProtocolList, activeProtList = self.currentProject._getSubworkflow(protocol) - errorProtList = self.currentProject.resetWorkFlow(workflowProtocolList) - if errorProtList: - raise HTTPException(status_code=500, detail=errorProtList) + workflowProtocolList, _activeProtocolList = self.currentProject._getSubworkflow(protocol) except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) + logger.exception( + "Failed to resolve subworkflow for reset-from. protocolId=%s", + protocolId, + ) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to resolve protocol subworkflow: {e}", + ) - def stopProtocol(self, protocols: Any): try: - for protocolId in protocols: - protocol = self.currentProject.getProtocol(int(protocolId)) + resetErrors = self.currentProject.resetWorkFlow(workflowProtocolList) or [] + except Exception as e: + logger.exception( + "Failed to reset workflow subtree. protocolId=%s", + protocolId, + ) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to reset protocol subtree: {e}", + ) + + if resetErrors: + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail=[str(e) for e in resetErrors], + ) + + return self._buildProtocolMutationResult("Protocol subtree reset successfully") + + def stopProtocol(self, protocolIds): + resolvedProtocols = [] + + for protocolId in protocolIds or []: + protocol = self.currentProject.getProtocol(int(protocolId)) + if protocol is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Protocol not found: {protocolId}", + ) + resolvedProtocols.append(protocol) + + if not resolvedProtocols: + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail="No valid protocols to stop", + ) + + try: + for protocol in resolvedProtocols: self.currentProject.stopProtocol(protocol) except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) + logger.exception( + "Failed to stop protocols. protocolIds=%s", + protocolIds, + ) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to stop protocols: {e}", + ) + + return self._buildProtocolMutationResult("Protocol stopped successfully") def _isGlobalFsBrowserMode(self, protocolId: Union[int, str]) -> bool: return str(protocolId).strip() == "-1" diff --git a/tests/integration/api/test_projects_router_protocol_ops.py b/tests/integration/api/test_projects_router_protocol_ops.py index c86e562..796498c 100644 --- a/tests/integration/api/test_projects_router_protocol_ops.py +++ b/tests/integration/api/test_projects_router_protocol_ops.py @@ -325,12 +325,15 @@ def test_DeleteProtocolDelegatesToService(projectClient, fakeProjectService): } -def test_RestartProtocolAllReturnsErrorsWhenServiceReturnsErrors(projectClient, fakeProjectService): - fakeProjectService.restartProtocolAllResult = ["cannot restart", "blocked"] +def test_RestartProtocolAllReturnsErrorsWhenServiceRaisesHttpException(projectClient, fakeProjectService): + fakeProjectService.restartProtocolAllError = HTTPException( + status_code=422, + detail=["cannot restart", "blocked"], + ) response = projectClient.post("/projects/1/protocols/10/restart-all") - assert response.status_code == 200 + assert response.status_code == 422 assert response.json() == { "status": 1, "errors": ["cannot restart", "blocked"], diff --git a/tests/unit/backend/api/services/test_project_service_protocols.py b/tests/unit/backend/api/services/test_project_service_protocols.py index abc6c06..600ab31 100644 --- a/tests/unit/backend/api/services/test_project_service_protocols.py +++ b/tests/unit/backend/api/services/test_project_service_protocols.py @@ -613,9 +613,11 @@ def test_RestartProtocolAllReturnsCollectedErrors(service): service.currentProject.protocols[10] = protocol service.currentProject.restartWorkflowInjectedErrors = ["cannot restart", "blocked"] - result = service.restartProtocolAll(10) + with pytest.raises(HTTPException) as exc: + service.restartProtocolAll(10) - assert result == ["cannot restart", "blocked"] + assert exc.value.status_code == 422 + assert exc.value.detail == ["cannot restart", "blocked"] def test_ContinueProtocolAllLaunchesActiveProtocolsInResumeMode(projectServiceModule, service, mapper, monkeypatch): @@ -646,7 +648,7 @@ def test_ResetProtocolFromReturnsSuccessWhenWorkflowResets(service): result = service.resetProtocolFrom(10) - assert result is None + assert result == {"status": "ok", "message": "Protocol subtree reset successfully"} def test_StopProtocolStopsEachProtocol(service): @@ -657,5 +659,5 @@ def test_StopProtocolStopsEachProtocol(service): result = service.stopProtocol(["10", "11"]) - assert result is None + assert result == {"status": "ok", "message": "Protocol stopped successfully"} assert service.currentProject.stoppedProtocols == [protocolA, protocolB] \ No newline at end of file From e9c23ac01cf97c736d605131937cdbec18509728 Mon Sep 17 00:00:00 2001 From: "Yunior C. Fonseca Reyna" Date: Thu, 23 Apr 2026 10:58:22 +0200 Subject: [PATCH 13/14] fix: isolate project service state and purge stale protocol rows on sync --- app/backend/api/services/project_service.py | 71 ++++++++------------- app/backend/mapper/postgresql.py | 40 ++++++++++++ 2 files changed, 66 insertions(+), 45 deletions(-) diff --git a/app/backend/api/services/project_service.py b/app/backend/api/services/project_service.py index 18ff7e4..d17cf7a 100644 --- a/app/backend/api/services/project_service.py +++ b/app/backend/api/services/project_service.py @@ -36,8 +36,6 @@ import threading import textwrap import shutil -import inspect -from numbers import Number import numpy as np @@ -85,7 +83,6 @@ from app.backend.utils.file_handlers import FileHandlers from app.utils.scipion_helper import serializeToJson -from contextvars import ContextVar from app.backend.api.services.plugins_revision import getPluginsRevision from app.backend.utils.thumbnail_service import ThumbnailService @@ -101,15 +98,6 @@ _newProtocolCache: Dict[str, Dict[str, Any]] = {} _lastNewProtocolRevision = -1 - -# Per-request context for current project and tomoList -_currentProjectVar: ContextVar[Optional[ScipionProject]] = ContextVar( - "_currentProjectVar", default=None -) -_tomoListVar: ContextVar[Optional[Dict[Any, Any]]] = ContextVar( - "_tomoListVar", default=None -) - # Global lock for metadata / DAO operations (not thread-safe) _metadataLock = threading.Lock() @@ -142,40 +130,23 @@ def _invalidateNewProtocolCacheIfNeeded() -> int: class ProjectService: def __init__(self): - self.manager = Manager() - # Keep objectManager attribute for backward compatibility, - # but new HTTP endpoints use a fresh ObjectManager per request. - self.objectManager = None + def __init__(self): + self.manager = Manager() + # Keep objectManager attribute for backward compatibility, + # but new HTTP endpoints use a fresh ObjectManager per request. + self.objectManager = None + + # Real per-instance state + self.currentProject: Optional[ScipionProject] = None + self.tomoList: Dict[Any, Any] = {} # ------------------------------------------------------------------ # Per-request project / tomogram context # ------------------------------------------------------------------ - @property - def currentProject(self) -> Optional[ScipionProject]: - """Return the current ScipionProject bound to this request context.""" - return _currentProjectVar.get() - - @currentProject.setter - def currentProject(self, value: Optional[ScipionProject]): - _currentProjectVar.set(value) - - @property - def tomoList(self) -> Dict[Any, Any]: - """Return the per-request tomogram cache dictionary.""" - value = _tomoListVar.get() - if value is None: - value = {} - _tomoListVar.set(value) - return value - - @tomoList.setter - def tomoList(self, value: Dict[Any, Any]): - _tomoListVar.set(value) - def clearCurrentProject(self): """Clear per-request project and tomogram cache.""" - _currentProjectVar.set(None) - _tomoListVar.set({}) + self.currentProject = None + self.tomoList = {} def _createObjectManager(self) -> ObjectManager: """Create and configure a fresh ObjectManager instance. @@ -402,10 +373,12 @@ def syncProjectProtocolsAndDependencies( nodesDict = getattr(runs, "_nodesDict", {}) or {} protocolDbIdByScipionId: Dict[str, int] = {} + currentProtocolIds: Set[str] = set() - # 1) Save all protocol nodes + # 1) Save all protocol nodes that are currently present in the real Scipion graph for nodeId, nodeObj in nodesDict.items(): - if str(nodeId) == "PROJECT": + nodeIdText = str(nodeId) + if nodeIdText == "PROJECT": continue protocol = getattr(nodeObj, "run", None) @@ -420,10 +393,18 @@ def syncProjectProtocolsAndDependencies( protocolContext = self._buildProtocolContext(projectId, protocol) protocolDbId = mapper.saveProtocol(protocolContext) - protocolDbIdByScipionId[str(nodeId)] = int(protocolDbId) - # 2) Build edges parent -> child using DB ids - edges: List[tuple[int, int]] = [] + currentProtocolIds.add(nodeIdText) + protocolDbIdByScipionId[nodeIdText] = int(protocolDbId) + + # 2) Purge stale protocol rows that are no longer present in the real graph + mapper.deleteProjectProtocolsNotInProtocolIds( + projectId, + sorted(currentProtocolIds), + ) + + # 3) Build edges parent -> child using DB ids + edges: List[Tuple[int, int]] = [] for nodeId, nodeObj in nodesDict.items(): childDbId = protocolDbIdByScipionId.get(str(nodeId)) diff --git a/app/backend/mapper/postgresql.py b/app/backend/mapper/postgresql.py index 56924c1..4e113a4 100644 --- a/app/backend/mapper/postgresql.py +++ b/app/backend/mapper/postgresql.py @@ -1112,6 +1112,46 @@ def updateProtocolDependencies(self, protocolId: str, parentIds: list, childIds: query = 'UPDATE protocols SET "parentIds" = %s, "childIds" = %s, "updatedAt" = NOW() WHERE "protocolId" = %s' self.db.execute(query, (parentIds, childIds, protocolId)) + def deleteProjectProtocolsNotInProtocolIds( + self, + projectId: int, + protocolIdsToKeep: List[str], + ) -> int: + keepSet = { + str(protocolId).strip() + for protocolId in (protocolIdsToKeep or []) + if str(protocolId).strip() + } + + rows = self.db.fetchAll( + """ + SELECT id, "protocolId" + FROM protocols + WHERE "projectId" = %s + """, + (projectId,), + ) + + staleDbIds = [ + int(row["id"]) + for row in rows + if str(row.get("protocolId", "")).strip() not in keepSet + ] + + if not staleDbIds: + return 0 + + self.db.execute( + """ + DELETE FROM protocols + WHERE "projectId" = %s + AND id = ANY(%s) + """, + (projectId, staleDbIds), + ) + + return len(staleDbIds) + # ----------------------------- # Settings Methods # ----------------------------- From d6494ac31abe10d28a0de5c52f9d58f8aec42460 Mon Sep 17 00:00:00 2001 From: "Yunior C. Fonseca Reyna" Date: Sat, 25 Apr 2026 00:21:33 +0200 Subject: [PATCH 14/14] Return duplicated protocol id mapping --- app/backend/api/routers/project_router.py | 9 ++--- app/backend/api/services/project_service.py | 37 ++++++++++++++------- 2 files changed, 30 insertions(+), 16 deletions(-) diff --git a/app/backend/api/routers/project_router.py b/app/backend/api/routers/project_router.py index 09718d9..b3fd829 100644 --- a/app/backend/api/routers/project_router.py +++ b/app/backend/api/routers/project_router.py @@ -580,11 +580,12 @@ def duplicateProtocol( "workflow": []}, ) - service.duplicateProtocol(mapper, projectId, items) + result = service.duplicateProtocol(mapper, projectId, items) # Keep 201 on success, but still return unified schema - return {"status": 0, - "errors": [], - "workflow": []} + return {"status": result['status'], + "errors": result['errors'], + "workflow": [], + "duplicated": result['duplicated']} except HTTPException as e: return JSONResponse( diff --git a/app/backend/api/services/project_service.py b/app/backend/api/services/project_service.py index d17cf7a..790f224 100644 --- a/app/backend/api/services/project_service.py +++ b/app/backend/api/services/project_service.py @@ -130,15 +130,14 @@ def _invalidateNewProtocolCacheIfNeeded() -> int: class ProjectService: def __init__(self): - def __init__(self): - self.manager = Manager() - # Keep objectManager attribute for backward compatibility, - # but new HTTP endpoints use a fresh ObjectManager per request. - self.objectManager = None + self.manager = Manager() + # Keep objectManager attribute for backward compatibility, + # but new HTTP endpoints use a fresh ObjectManager per request. + self.objectManager = None - # Real per-instance state - self.currentProject: Optional[ScipionProject] = None - self.tomoList: Dict[Any, Any] = {} + # Real per-instance state + self.currentProject: Optional[ScipionProject] = None + self.tomoList: Dict[Any, Any] = {} # ------------------------------------------------------------------ # Per-request project / tomogram context @@ -2971,8 +2970,10 @@ def getProtocolLogs(self, projectId: int, protocolId: int, @staticmethod def _buildProtocolMutationResult(message: str, **extra) -> Dict[str, Any]: result = { - "status": "ok", + "status": 1 if extra['errors'] else 0, + "errors": extra['errors'], "message": message, + "duplicated": extra['duplicated'] } result.update(extra or {}) return result @@ -3003,12 +3004,14 @@ def renameProtocol(self, protocolId, newName): def duplicateProtocol(self, mapper, projectId, protocols): protocolList = [] - + sourceIds = [] + duplicated = [] + errors = [] for item in protocols or []: protocolId = getattr(item, "id", None) if protocolId is None: continue - + sourceIds.append(protocolId) protocol = self.currentProject.getProtocol(int(protocolId)) if protocol is None: raise HTTPException( @@ -3025,8 +3028,15 @@ def duplicateProtocol(self, mapper, projectId, protocols): ) try: - self.currentProject.copyProtocol(protocolList) + protListResult = self.currentProject.copyProtocol(protocolList) + for index, prot in enumerate(protListResult): + protId = str(prot.getObjId()) + duplicated.append({"sourceId": sourceIds[index], "newId": protId}) + except Exception as e: + errors.append("Failed to duplicate protocols. projectId=%s protocolIds=%s" %projectId, + [getattr(p, "getObjId", lambda: None)() for p in protocolList]) + logger.exception( "Failed to duplicate protocols. projectId=%s protocolIds=%s", projectId, @@ -3047,6 +3057,7 @@ def duplicateProtocol(self, mapper, projectId, protocols): except HTTPException: raise except Exception as e: + errors.append("Failed to sync protocol graph after duplication. projectId=%s" %projectId) logger.exception( "Failed to sync protocol graph after duplication. projectId=%s", projectId, @@ -3060,6 +3071,8 @@ def duplicateProtocol(self, mapper, projectId, protocols): "Protocol was duplicated successfully", protocolsCount=int(syncResult.get("protocols", 0)), dependenciesCount=int(syncResult.get("dependencies", 0)), + duplicated=duplicated, + errors=errors, ) def deleteProtocol(self, mapper, projectId, protocols: Any):