diff --git a/alembic/versions/9b7b3c4d8e21_add_protocol_dependencies_table.py b/alembic/versions/9b7b3c4d8e21_add_protocol_dependencies_table.py new file mode 100644 index 0000000..35fb21e --- /dev/null +++ b/alembic/versions/9b7b3c4d8e21_add_protocol_dependencies_table.py @@ -0,0 +1,89 @@ +"""add protocol dependencies table + +Revision ID: 9b7b3c4d8e21 +Revises: f3f7b0a3f1aa +Create Date: 2026-04-22 00:00:00.000000 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = '9b7b3c4d8e21' +down_revision: Union[str, None] = 'f3f7b0a3f1aa' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade(): + op.create_unique_constraint( + 'uq_protocols_projectId_id', + 'protocols', + ['projectId', 'id'], + ) + + op.create_table( + 'protocol_dependencies', + sa.Column('projectId', sa.Integer(), nullable=False), + sa.Column('parentProtocolDbId', sa.Integer(), nullable=False), + sa.Column('childProtocolDbId', sa.Integer(), nullable=False), + sa.Column('createdAt', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False), + sa.CheckConstraint('"parentProtocolDbId" <> "childProtocolDbId"', name='protocol_dependencies_no_self_loop'), + sa.ForeignKeyConstraint(['projectId'], ['projects.id'], name='protocol_dependencies_projectId_fkey', ondelete='CASCADE'), + sa.ForeignKeyConstraint( + ['projectId', 'parentProtocolDbId'], + ['protocols.projectId', 'protocols.id'], + name='protocol_dependencies_parentProtocolDbId_fkey', + ondelete='CASCADE', + ), + sa.ForeignKeyConstraint( + ['projectId', 'childProtocolDbId'], + ['protocols.projectId', 'protocols.id'], + name='protocol_dependencies_childProtocolDbId_fkey', + ondelete='CASCADE', + ), + sa.PrimaryKeyConstraint('projectId', 'parentProtocolDbId', 'childProtocolDbId', name='protocol_dependencies_pkey'), + ) + + op.create_index( + 'idx_protocol_dependencies_parent', + 'protocol_dependencies', + ['projectId', 'parentProtocolDbId'], + unique=False, + ) + op.create_index( + 'idx_protocol_dependencies_child', + 'protocol_dependencies', + ['projectId', 'childProtocolDbId'], + unique=False, + ) + + op.execute( + """ + INSERT INTO protocol_dependencies ( + "projectId", + "parentProtocolDbId", + "childProtocolDbId" + ) + SELECT DISTINCT + child."projectId", + parent.id, + child.id + FROM protocols child + CROSS JOIN LATERAL unnest(COALESCE(child."parentIds", ARRAY[]::integer[])) AS parent_protocol_id + JOIN protocols parent + ON parent."projectId" = child."projectId" + AND parent."protocolId" = parent_protocol_id::text + WHERE parent.id <> child.id + """ + ) + + +def downgrade(): + op.drop_index('idx_protocol_dependencies_child', table_name='protocol_dependencies') + op.drop_index('idx_protocol_dependencies_parent', table_name='protocol_dependencies') + op.drop_table('protocol_dependencies') + op.drop_constraint('uq_protocols_projectId_id', 'protocols', type_='unique') diff --git a/app/backend/api/routers/project_router.py b/app/backend/api/routers/project_router.py index 89a7786..b3fd829 100644 --- a/app/backend/api/routers/project_router.py +++ b/app/backend/api/routers/project_router.py @@ -515,7 +515,6 @@ def renameProtocol( ) try: - # Basic payload validation for semantic HTTP newName = getattr(payload, "name", None) if not newName or not str(newName).strip(): return JSONResponse( @@ -526,6 +525,14 @@ def renameProtocol( ) service.renameProtocol(protocolId, str(newName).strip()) + service.syncProjectGraphAfterMutation( + mapper, + projectId, + actionLabel="rename protocol", + refresh=True, + checkPid=True, + ) + return {"status": 0, "errors": [], "workflow": []} @@ -573,11 +580,12 @@ def duplicateProtocol( "workflow": []}, ) - service.duplicateProtocol(mapper, projectId, items) + result = service.duplicateProtocol(mapper, projectId, items) # Keep 201 on success, but still return unified schema - return {"status": 0, - "errors": [], - "workflow": []} + return {"status": result['status'], + "errors": result['errors'], + "workflow": [], + "duplicated": result['duplicated']} except HTTPException as e: return JSONResponse( @@ -659,13 +667,14 @@ def restartProtocolAll( ) try: - errorList = service.restartProtocolAll(protocolId) - errors = [str(e) for e in (errorList or [])] - - if errors: - return {"status": 1, - "errors": errors, - "workflow": []} + service.restartProtocolAll(protocolId) + service.syncProjectGraphAfterMutation( + mapper, + projectId, + actionLabel="restart protocol subtree", + refresh=True, + checkPid=True, + ) return {"status": 0, "errors": [], @@ -708,6 +717,14 @@ def continueProtocolAll( try: service.continueProtocolAll(mapper, projectId, protocolId, currentUser) + service.syncProjectGraphAfterMutation( + mapper, + projectId, + actionLabel="continue protocol subtree", + refresh=True, + checkPid=True, + ) + return {"status": 0, "errors": [], "workflow": []} except HTTPException as e: @@ -743,6 +760,14 @@ def resetProtocolFrom( try: service.resetProtocolFrom(protocolId) + service.syncProjectGraphAfterMutation( + mapper, + projectId, + actionLabel="reset protocol from node", + refresh=True, + checkPid=True, + ) + return {"status": 0, "errors": [], "workflow": []} except HTTPException as e: @@ -785,6 +810,14 @@ def stopProtocol( ) service.stopProtocol(protocolIds) + service.syncProjectGraphAfterMutation( + mapper, + projectId, + actionLabel="stop protocol", + refresh=True, + checkPid=True, + ) + return {"status": 0, "errors": [], "workflow": []} @@ -2945,12 +2978,19 @@ def getProtocolThumbnail( service.loadProjectForThumbnails(dbProj) - result = service.buildProtocolThumbnail( - protocolId=protocolId, - force=False, - size=size, - outputName=outputName, - ) + if outputName: + result = service.buildProtocolOutputThumbnail( + protocolId=protocolId, + outputName=outputName, + force=False, + size=size, + ) + else: + result = service.buildProtocolThumbnail( + protocolId=protocolId, + force=False, + size=size, + ) thumbPath = result.get("absolutePath") if not thumbPath: @@ -2995,12 +3035,19 @@ def rebuildProtocolThumbnail( service.loadProjectForThumbnails(dbProj) - result = service.buildProtocolThumbnail( - protocolId=protocolId, - force=True, - size=size, - outputName=outputName, - ) + if outputName: + result = service.buildProtocolOutputThumbnail( + protocolId=protocolId, + outputName=outputName, + force=True, + size=size, + ) + else: + result = service.buildProtocolThumbnail( + protocolId=protocolId, + force=True, + size=size, + ) response = JSONResponse( { @@ -3050,7 +3097,7 @@ def listProjectThumbnailItems( ) response = JSONResponse(items) - response.headers["Cache-Control"] = "private, max-age=20, stale-while-revalidate=60" + response.headers["Cache-Control"] = "private, no-store" response.headers["Access-Control-Expose-Headers"] = "Cache-Control" return _attachDebugHeaders(response, currentUser) diff --git a/app/backend/api/services/project_service.py b/app/backend/api/services/project_service.py index b7f75d9..790f224 100644 --- a/app/backend/api/services/project_service.py +++ b/app/backend/api/services/project_service.py @@ -36,8 +36,6 @@ import threading import textwrap import shutil -import inspect -from numbers import Number import numpy as np @@ -85,7 +83,6 @@ from app.backend.utils.file_handlers import FileHandlers from app.utils.scipion_helper import serializeToJson -from contextvars import ContextVar from app.backend.api.services.plugins_revision import getPluginsRevision from app.backend.utils.thumbnail_service import ThumbnailService @@ -101,15 +98,6 @@ _newProtocolCache: Dict[str, Dict[str, Any]] = {} _lastNewProtocolRevision = -1 - -# Per-request context for current project and tomoList -_currentProjectVar: ContextVar[Optional[ScipionProject]] = ContextVar( - "_currentProjectVar", default=None -) -_tomoListVar: ContextVar[Optional[Dict[Any, Any]]] = ContextVar( - "_tomoListVar", default=None -) - # Global lock for metadata / DAO operations (not thread-safe) _metadataLock = threading.Lock() @@ -147,35 +135,17 @@ def __init__(self): # but new HTTP endpoints use a fresh ObjectManager per request. self.objectManager = None + # Real per-instance state + self.currentProject: Optional[ScipionProject] = None + self.tomoList: Dict[Any, Any] = {} + # ------------------------------------------------------------------ # Per-request project / tomogram context # ------------------------------------------------------------------ - @property - def currentProject(self) -> Optional[ScipionProject]: - """Return the current ScipionProject bound to this request context.""" - return _currentProjectVar.get() - - @currentProject.setter - def currentProject(self, value: Optional[ScipionProject]): - _currentProjectVar.set(value) - - @property - def tomoList(self) -> Dict[Any, Any]: - """Return the per-request tomogram cache dictionary.""" - value = _tomoListVar.get() - if value is None: - value = {} - _tomoListVar.set(value) - return value - - @tomoList.setter - def tomoList(self, value: Dict[Any, Any]): - _tomoListVar.set(value) - def clearCurrentProject(self): """Clear per-request project and tomogram cache.""" - _currentProjectVar.set(None) - _tomoListVar.set({}) + self.currentProject = None + self.tomoList = {} def _createObjectManager(self) -> ObjectManager: """Create and configure a fresh ObjectManager instance. @@ -385,6 +355,107 @@ def _validateImportableScipionProject(self, sourcePath: Path) -> Dict[str, Any]: "status": statusValue or "active", } + def syncProjectProtocolsAndDependencies( + self, + mapper: PostgresqlFlatMapper, + projectId: int, + refresh: bool = False, + checkPid: bool = False, + ) -> Dict[str, int]: + if self.currentProject is None: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="No current project loaded", + ) + + runs = self.currentProject.getRunsGraph(refresh=refresh, checkPids=checkPid) + nodesDict = getattr(runs, "_nodesDict", {}) or {} + + protocolDbIdByScipionId: Dict[str, int] = {} + currentProtocolIds: Set[str] = set() + + # 1) Save all protocol nodes that are currently present in the real Scipion graph + for nodeId, nodeObj in nodesDict.items(): + nodeIdText = str(nodeId) + if nodeIdText == "PROJECT": + continue + + protocol = getattr(nodeObj, "run", None) + if protocol is None: + try: + protocol = self.currentProject.getProtocol(int(nodeId)) + except Exception: + protocol = None + + if protocol is None: + continue + + protocolContext = self._buildProtocolContext(projectId, protocol) + protocolDbId = mapper.saveProtocol(protocolContext) + + currentProtocolIds.add(nodeIdText) + protocolDbIdByScipionId[nodeIdText] = int(protocolDbId) + + # 2) Purge stale protocol rows that are no longer present in the real graph + mapper.deleteProjectProtocolsNotInProtocolIds( + projectId, + sorted(currentProtocolIds), + ) + + # 3) Build edges parent -> child using DB ids + edges: List[Tuple[int, int]] = [] + + for nodeId, nodeObj in nodesDict.items(): + childDbId = protocolDbIdByScipionId.get(str(nodeId)) + if not childDbId: + continue + + for parent in getattr(nodeObj, "_parents", []) or []: + parentNodeId = str(parent.getName()) + if parentNodeId == "PROJECT": + continue + + parentDbId = protocolDbIdByScipionId.get(parentNodeId) + if not parentDbId: + continue + + edges.append((parentDbId, childDbId)) + + savedEdges = mapper.replaceProjectProtocolDependencies(projectId, edges) + + return { + "protocols": len(protocolDbIdByScipionId), + "dependencies": int(savedEdges), + } + + def syncProjectGraphAfterMutation( + self, + mapper: PostgresqlFlatMapper, + projectId: int, + actionLabel: str, + refresh: bool = True, + checkPid: bool = True, + ) -> Dict[str, int]: + try: + return self.syncProjectProtocolsAndDependencies( + mapper, + projectId, + refresh=refresh, + checkPid=checkPid, + ) + except HTTPException: + raise + except Exception as e: + logger.exception( + "Failed to sync protocol graph after %s. projectId=%s", + actionLabel, + projectId, + ) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"{actionLabel} succeeded but graph sync to PostgreSQL failed: {e}", + ) + def importProject( self, mapper: PostgresqlFlatMapper, @@ -533,6 +604,20 @@ def importProject( status=statusValue, ) + try: + self.loadProjectForThumbnails({"name": storedProjectPath}) + self.syncProjectProtocolsAndDependencies(mapper, dbProjectId) + except Exception as e: + logger.exception( + "Failed to sync imported project protocols. projectId=%s path=%s", + dbProjectId, + storedProjectPath, + ) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Project was imported but protocols could not be synced to the database: {e}", + ) + sizePath = sourcePath if not copyProject else targetPath try: @@ -971,71 +1056,214 @@ def _buildProjectThumbnailVersion( return f"{projectId}:{updatedText}:{protocolsCount}:{runsMtime}" - def buildProtocolsGraph(self, projectId: int, runs, tags) -> dict: - """Assemble dependency graph of protocols and their status.""" - nodesDict = runs._nodesDict - graphData = {} + def buildProtocolsGraph( + self, + projectId: int, + protocolRows: List[Dict[str, Any]], + tags: Dict[str, List[str]], + dependencyMap: Optional[Dict[str, Dict[str, List[str]]]] = None, + runMap: Optional[Dict[str, Any]] = None, + ) -> dict: + """Assemble protocol graph using PostgreSQL as source of truth for nodes + edges.""" + graphData: Dict[str, Any] = {} + adjacency = dependencyMap or {} + liveRuns = runMap or {} + + def sortKey(row: Dict[str, Any]): + raw = str(row.get("protocolId") or "") + try: + return (0, int(raw)) + except Exception: + return (1, raw) + + orderedRows = sorted(protocolRows or [], key=sortKey) + + protocolIds: List[str] = [] + for row in orderedRows: + rawId = row.get("protocolId") + if rawId is None: + continue + protocolIds.append(str(rawId)) + + # Root node synthesized from DB graph: + # protocols without parents hang directly from PROJECT + rootChildren = [ + pid for pid in protocolIds + if not (adjacency.get(pid, {}).get("parents") or []) + ] + + projectLabel = "PROJECT" + try: + if self.currentProject is not None: + projectLabel = os.path.basename(self.currentProject.getPath()) or "PROJECT" + except Exception: + projectLabel = "PROJECT" + + graphData["PROJECT"] = { + "protocolId": "PROJECT", + "children": rootChildren, + "parents": [], + "label": projectLabel, + "status": "", + "parameter": [], + "inputs": [], + "outputs": [], + "cpuTime": "", + "elapsedTime": "", + "isInteractive": False, + "numberOfSteps": 0, + "stepsDone": 0, + "tags": [], + "thumbnailUrl": None, + "thumbnailRebuildUrl": None, + } + + for row in orderedRows: + rawNodeId = row.get("protocolId") + if rawNodeId is None: + continue + + nodeId = str(rawNodeId) + nodeDeps = adjacency.get(nodeId, {"parents": [], "children": []}) + childrenIds = list(nodeDeps.get("children") or []) + parentIds = list(nodeDeps.get("parents") or []) + + statusValue = row.get("status") + status = str(statusValue) if statusValue is not None else "" + + protocolClassName = str(row.get("protocolClassName") or "") + label = protocolClassName or nodeId - for nodeId, nodeObj in nodesDict.items(): - childrenIds = [child.getName() for child in nodeObj._children] - parentIds = [parent.getName() for parent in nodeObj._parents] - status = nodeObj.run.getStatus() if nodeObj.run else '' inputs = [] outputs = [] - cpuTime = '' - elapsedTime = '' + cpuTime = "" + elapsedTime = "" isinteractive = False numberOfSteps = 0 stepsDone = 0 thumbnailUrl = None thumbnailRebuildUrl = None - if nodeId != 'PROJECT': - protocol = self.currentProject.getProtocol(int(nodeId)) - cpuTime = str(protocol.cpuTime) - elapsedTime = str(protocol.getElapsedTime().total_seconds()).split('.')[0] - isinteractive = protocol.isInteractive() - numberOfSteps = protocol.numberOfSteps - stepsDone = protocol.stepsDone - self.currentProject._fixProtParamsConfiguration(protocol) + # Prefer the live protocol object coming from runs graph + protocol = liveRuns.get(nodeId) - thumbnailUrl = self.buildProtocolThumbnailUrl(projectId, int(nodeId)) - thumbnailRebuildUrl = self.buildProtocolThumbnailRebuildUrl(projectId, int(nodeId)) + if protocol is None: + try: + protocol = self.currentProject.getProtocol(int(nodeId)) + except Exception: + protocol = None - for key, attr in protocol.iterInputAttributes(): - input = {} - try: - input['name'] = key - input['paramClass'] = 'PointerParam' - input['pointerClass'] = attr.get().getClassName() if attr and attr.get() else "" - input['info'] = str(attr.get()) - except Exception: - input['pointerClass'] = "" - input['info'] = "" - parentId = attr.getObjValue().getObjId() - input['value'] = "%s.%s" % (str(parentId), attr.getExtended()) - input['parentId'] = parentId - inputs.append(input) - - for key, attr in protocol.iterOutputAttributes(): - output = {} - output['name'] = key - output['paramClass'] = 'PointerParam' - output['pointerClass'] = attr.__class__.__name__ - try: - output['info'] = attr.__str__() - except Exception: - output['info'] = "" - parentId = protocol.getObjId() - output['value'] = "%s.%s" % (str(parentId), key) - output['parentId'] = parentId - outputs.append(output) + if protocol is not None: + try: + label = str(protocol) or label + except Exception: + pass + + try: + protStatus = protocol.getStatus() + if protStatus: + status = str(protStatus) + except Exception: + pass + + try: + cpuTime = str(protocol.cpuTime) + except Exception: + cpuTime = "" + + try: + elapsedTime = str(protocol.getElapsedTime().total_seconds()).split(".")[0] + except Exception: + elapsedTime = "" + + try: + isinteractive = bool(protocol.isInteractive()) + except Exception: + isinteractive = False + + try: + numberOfSteps = protocol.numberOfSteps + except Exception: + numberOfSteps = 0 + + try: + stepsDone = protocol.stepsDone + except Exception: + stepsDone = 0 + + try: + self.currentProject._fixProtParamsConfiguration(protocol) + except Exception: + pass + + try: + protocolIdInt = int(nodeId) + thumbnailUrl = self.buildProtocolThumbnailUrl(projectId, protocolIdInt) + thumbnailRebuildUrl = self.buildProtocolThumbnailRebuildUrl(projectId, protocolIdInt) + except Exception: + thumbnailUrl = None + thumbnailRebuildUrl = None + + try: + for key, attr in protocol.iterInputAttributes(): + inputItem = {} + try: + inputItem["name"] = key + inputItem["paramClass"] = "PointerParam" + inputItem["pointerClass"] = attr.get().getClassName() if attr and attr.get() else "" + inputItem["info"] = str(attr.get()) + except Exception: + inputItem["pointerClass"] = "" + inputItem["info"] = "" + + try: + parentId = attr.getObjValue().getObjId() + inputItem["value"] = "%s.%s" % (str(parentId), attr.getExtended()) + inputItem["parentId"] = parentId + except Exception: + inputItem["value"] = "" + inputItem["parentId"] = None + + inputs.append(inputItem) + except Exception: + inputs = [] + + try: + for key, attr in protocol.iterOutputAttributes(): + outputItem = {} + outputItem["name"] = key + outputItem["paramClass"] = "PointerParam" + outputItem["pointerClass"] = attr.__class__.__name__ + try: + outputItem["info"] = attr.__str__() + except Exception: + outputItem["info"] = "" + + try: + parentId = protocol.getObjId() + outputItem["value"] = "%s.%s" % (str(parentId), key) + outputItem["parentId"] = parentId + except Exception: + outputItem["value"] = "" + outputItem["parentId"] = None + + outputs.append(outputItem) + except Exception: + outputs = [] + else: + try: + protocolIdInt = int(nodeId) + thumbnailUrl = self.buildProtocolThumbnailUrl(projectId, protocolIdInt) + thumbnailRebuildUrl = self.buildProtocolThumbnailRebuildUrl(projectId, protocolIdInt) + except Exception: + thumbnailUrl = None + thumbnailRebuildUrl = None graphData[nodeId] = { "protocolId": nodeId, "children": childrenIds, "parents": parentIds, - "label": nodeObj.getLabel(), + "label": label, "status": status, "parameter": [], "inputs": inputs, @@ -1045,7 +1273,7 @@ def buildProtocolsGraph(self, projectId: int, runs, tags) -> dict: "isInteractive": isinteractive, "numberOfSteps": numberOfSteps, "stepsDone": stepsDone, - "tags": tags[nodeId] if nodeId in tags else [], + "tags": tags.get(nodeId, []), "thumbnailUrl": thumbnailUrl, "thumbnailRebuildUrl": thumbnailRebuildUrl, } @@ -1056,9 +1284,111 @@ def loadProject(self, dbProj: dict, mapper: PostgresqlFlatMapper = None, refresh projPath = Path(dbProj['name']) self.currentProject = ScipionProject(pyworkflow.Config.getDomain(), str(projPath)) self.currentProject.load(dbPath=self.currentProject.getDbPath()) - runs = self.currentProject.getRunsGraph(refresh=refresh, checkPids=checkPid) - tags = mapper.getProjectProtocolTagIdsByProtocolId(dbProj['id']) - graphData = self.buildProtocolsGraph(dbProj['id'], runs, tags) + + # Refresh Scipion graph and keep a live map of protocol objects + runMap: Dict[str, Any] = {} + scipionProtocolCount = 0 + scipionEdgeCount = 0 + + try: + runs = self.currentProject.getRunsGraph(refresh=refresh, checkPids=checkPid) + nodesDict = getattr(runs, "_nodesDict", {}) or {} + + for nodeId, nodeObj in nodesDict.items(): + if str(nodeId) == "PROJECT": + continue + + scipionProtocolCount += 1 + runMap[str(nodeId)] = getattr(nodeObj, "run", None) + + for parent in getattr(nodeObj, "_parents", []) or []: + parentNodeId = str(parent.getName()) + if parentNodeId != "PROJECT": + scipionEdgeCount += 1 + + except Exception: + logger.exception( + "Failed to refresh Scipion runs graph for project %s", + dbProj['id'], + ) + runMap = {} + scipionProtocolCount = 0 + scipionEdgeCount = 0 + + tags = {} + dependencyMap = {} + protocolRows: List[Dict[str, Any]] = [] + + if mapper is not None: + try: + tags = mapper.getProjectProtocolTagIdsByProtocolId(dbProj['id']) + except Exception: + logger.exception( + "Failed to load protocol tags from PostgreSQL for project %s", + dbProj['id'], + ) + tags = {} + + try: + dependencyMap = mapper.getProjectProtocolAdjacencyMap(dbProj['id']) + except Exception: + logger.exception( + "Failed to load protocol dependencies from PostgreSQL for project %s", + dbProj['id'], + ) + dependencyMap = {} + + try: + protocolRows = mapper.getProtocols(dbProj['id']) + except Exception: + logger.exception( + "Failed to load protocol rows from PostgreSQL for project %s", + dbProj['id'], + ) + protocolRows = [] + + dbProtocolCount = len(protocolRows) + dbEdgeCount = sum(len(v.get("parents") or []) for v in dependencyMap.values()) + + shouldResyncGraph = ( + scipionProtocolCount != dbProtocolCount or + scipionEdgeCount != dbEdgeCount + ) + + if shouldResyncGraph: + try: + logger.info( + "Resyncing protocol graph from Scipion to PostgreSQL. " + "projectId=%s scipionProtocols=%s dbProtocols=%s scipionEdges=%s dbEdges=%s", + dbProj['id'], + scipionProtocolCount, + dbProtocolCount, + scipionEdgeCount, + dbEdgeCount, + ) + + self.syncProjectProtocolsAndDependencies( + mapper, + dbProj['id'], + refresh=False, + checkPid=False, + ) + + dependencyMap = mapper.getProjectProtocolAdjacencyMap(dbProj['id']) + protocolRows = mapper.getProtocols(dbProj['id']) + except Exception: + logger.exception( + "Failed to resync protocol graph during project load for project %s", + dbProj['id'], + ) + + graphData = self.buildProtocolsGraph( + dbProj['id'], + protocolRows, + tags, + dependencyMap=dependencyMap, + runMap=runMap, + ) stats = projPath.stat() updatedAt = datetime.fromtimestamp(stats.st_mtime) @@ -1093,11 +1423,11 @@ def listProjectWorkflows(self): return tempList.sortListByPluginName().templates def applyWorkflowToProject( - self, - mapper: PostgresqlFlatMapper, - projectId: int, - workflowId: Union[int, str], - currentUser: dict, + self, + mapper: PostgresqlFlatMapper, + projectId: int, + workflowId: Union[int, str], + currentUser: dict, ) -> dict: """ Apply a predefined workflow template to an existing project. @@ -1128,7 +1458,7 @@ def applyWorkflowToProject( templateName = getattr(t, "name", None) if (templateId is not None and str(templateId) == workflowIdStr) or ( - templateName and str(templateName) == workflowIdStr + templateName and str(templateName) == workflowIdStr ): selectedTemplate = t break @@ -1168,16 +1498,24 @@ def applyWorkflowToProject( detail=f"Failed to apply workflow '{workflowIdStr}' to project {projectId}: {e}", ) - # 7) Optionally compute how many protocols are present after applying - protocolsCount = None + # 7) Sync protocols + dependencies to PostgreSQL try: - if hasattr(self.currentProject, "getProtocols"): - protocols = self.currentProject.getProtocols() - if protocols is not None: - protocolsCount = len(protocols) - except Exception: - # Ignore errors when computing protocol count - protocolsCount = None + syncInfo = self.syncProjectProtocolsAndDependencies( + mapper, + projectId, + refresh=True, + checkPid=True, + ) + except Exception as e: + logger.exception( + "Failed to sync workflow-applied project graph. projectId=%s workflowId=%s", + projectId, + workflowIdStr, + ) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Workflow was applied but graph sync to PostgreSQL failed: {e}", + ) # 8) Return a compact, useful payload for the frontend return { @@ -1186,20 +1524,11 @@ def applyWorkflowToProject( "workflowId": workflowIdStr, "workflowName": getattr(selectedTemplate, "name", workflowIdStr), "workflowFile": str(workflowFile), - "protocolsCount": protocolsCount, + "protocolsCount": syncInfo.get("protocols"), + "dependenciesCount": syncInfo.get("dependencies"), "loadResult": str(loadResult) if loadResult is not None else None, } - def saveProtocolDependencies(self, mapper: PostgresqlFlatMapper, graphData: dict): - for nodeId, nodeInfo in graphData.items(): - parentIds = [int(pid) for pid in nodeInfo['parents'] if pid != 'PROJECT'] - childIds = [int(cid) for cid in nodeInfo['children']] - mapper.updateProtocolDependencies( - protocolId=nodeId, - parentIds=parentIds, - childIds=childIds - ) - @staticmethod def getProtocolColor(status: str) -> str: """Return hex color based on protocol status.""" @@ -1877,75 +2206,101 @@ def setPointerParam(self, protocol, key, value, parentId): def saveProtocol(self, mapper, projectId, protocolId, protocolClassName, params, setToSave=True): errorList = [] - if not protocolId: # new protocol + + if not protocolId: protClass = self.currentProject.getDomain().getProtocols().get(protocolClassName) + if protClass is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Protocol class not found: {protocolClassName}", + ) protocol = self.currentProject.newProtocol(protClass) - else: # retrieve a protocol by id + else: protocol = self.currentProject.getProtocol(int(protocolId)) + if protocol is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Protocol not found: {protocolId}", + ) - # Set protected parameters protectedParams = ['_objComment', '_useQueue', '_prerequisites', 'gpuList', 'numberOfThreads'] for paramName in protectedParams: protVar = getattr(protocol, paramName, None) - if protVar is not None: - try: - if paramName in params: - value = params[paramName] - protVar.set(value) - except Exception: - setattr(protocol, paramName, value) + if protVar is None or paramName not in params: + continue + + value = params[paramName] + try: + protVar.set(value) + except Exception: + setattr(protocol, paramName, value) - # Set non-pointer parameters for key, value in params.items(): param = protocol.getParam(key) if param is None: - logger.warning(f"[WARN] Param not found: {key}") + logger.warning("[WARN] Param not found: %s", key) continue if isinstance(param, (PointerParam, MultiPointerParam, RelationParam)): continue - rawValue = value try: - castedValue = self.castParamValue(param, rawValue) - errors = param.validate(castedValue) if hasattr(param, 'validate') else [] + castedValue = self.castParamValue(param, value) + errors = param.validate(castedValue) if hasattr(param, "validate") else [] if errors: - errorListAux = ['**' + param.label.get() + '** ' + error for error in errors] - errorList += errorListAux + errorList += ['**' + param.label.get() + '** ' + error for error in errors] + param.set(castedValue) protocol.setAttributeValue(key, castedValue) - if key == 'runName': + if key == "runName": protocol.setObjLabel(castedValue) - logger.info(f"[INFO] Set param {key} = {castedValue}") + logger.info("[INFO] Set param %s = %s", key, castedValue) except Exception as e: - import re - cleaned = re.sub(r'[^A-Za-z0-9\s+\-*/=<>\!&|^%()\[\]{}_,.;:]', '', str(e)) - errorList += ['**' + param.label.get() + '** ' + cleaned] + cleaned = re.sub(r'[^A-Za-z0-9\s+\-*/=<>!&|^%()\[\]{}_,.;:]', '', str(e)) + errorList.append('**' + param.label.get() + '** ' + cleaned) - # Apply pointer parameters errorList += self.applyParamsToProtocol(protocol, params) - # if setToSave: - # protocol.setSaved() - - if protocol.hasObjId(): - self.currentProject._storeProtocol(protocol) - else: - self.currentProject._setupProtocol(protocol) + # Persist protocol in Scipion always. + # The setToSave flag only controls whether we also sync the graph to PostgreSQL now. + try: + if protocol.hasObjId(): + self.currentProject._storeProtocol(protocol) + else: + self.currentProject._setupProtocol(protocol) + except Exception as e: + logger.exception( + "Failed to persist protocol in Scipion. projectId=%s protocolId=%s protocolClassName=%s", + projectId, + protocolId, + protocolClassName, + ) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to persist protocol in Scipion: {e}", + ) - dbProtocol = mapper.getProtocolByProtocolId(protocolId=protocol.getObjId(), projectId=projectId) - if not dbProtocol: - protocolContex = self._buildProtocolContext(projectId, protocol) - mapper.saveProtocol(protocolContex) - pass - else: - # Update parameters and status if exists - pass - # Save dependencies - # graphData = self.currentProject.getRunsGraph(refresh=True, checkPids=True) - # self.saveProtocolDependencies(mapper, graphData._nodesDict) + if setToSave: + try: + self.syncProjectProtocolsAndDependencies( + mapper, + projectId, + refresh=True, + checkPid=True, + ) + except Exception as e: + logger.exception( + "Failed to sync protocol graph after save. projectId=%s protocolId=%s protocolClassName=%s", + projectId, + getattr(protocol, "getObjId", lambda: protocolId)(), + protocolClassName, + ) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Protocol was saved in Scipion but graph sync to PostgreSQL failed: {e}", + ) return protocol, errorList @@ -1964,6 +2319,13 @@ def launchProtocol(self, mapper, projectId, protocolId, protocolClassName, param if executeMode == "stop": try: self.stopProtocol([protocolId]) + + self.syncProjectProtocolsAndDependencies( + mapper, + projectId, + refresh=True, + checkPid=True, + ) return except HTTPException: raise @@ -1983,8 +2345,10 @@ def launchProtocol(self, mapper, projectId, protocolId, protocolClassName, param ) if protocol.useQueue(): - queueParams = [params['_queueName'], params['_queueParams']] - protocol.setQueueParams(queueParams) + queueName = params.get("_queueName") + queueParams = params.get("_queueParams") + protocol.setQueueParams([queueName, queueParams]) + try: validationErrors = protocol._validate() if validationErrors: @@ -1996,23 +2360,56 @@ def launchProtocol(self, mapper, projectId, protocolId, protocolClassName, param ] if errors: + try: + self.syncProjectProtocolsAndDependencies( + mapper, + projectId, + refresh=True, + checkPid=True, + ) + except Exception: + logger.exception( + "Failed to sync protocol graph after validation errors. projectId=%s protocolId=%s", + projectId, + getattr(protocol, "getObjId", lambda: protocolId)(), + ) + raise HTTPException( status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=errors, ) - if executeMode == "schedule": - self.currentProject.scheduleProtocol(protocol) - return - - modeToRunMode = { - "launch": MODE_RESUME, - "restart": MODE_RESTART, - } + try: + if executeMode == "schedule": + self.currentProject.scheduleProtocol(protocol) + else: + modeToRunMode = { + "launch": MODE_RESUME, + "restart": MODE_RESTART, + } + runMode = modeToRunMode[executeMode] + protocol.runMode.set(runMode) + self.currentProject.launchProtocol(protocol) - runMode = modeToRunMode[executeMode] - protocol.runMode.set(runMode) - self.currentProject.launchProtocol(protocol) + self.syncProjectProtocolsAndDependencies( + mapper, + projectId, + refresh=True, + checkPid=True, + ) + except HTTPException: + raise + except Exception as e: + logger.exception( + "Failed to sync protocol graph after execute. projectId=%s protocolId=%s executeMode=%s", + projectId, + getattr(protocol, "getObjId", lambda: protocolId)(), + executeMode, + ) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Protocol execution finished but graph sync to PostgreSQL failed: {e}", + ) def findViewersWeb(self, protocol): # TODO: Find viewers... @@ -2570,67 +2967,322 @@ def getProtocolLogs(self, projectId: int, protocolId: int, "scheduleOffset": newOffsetSchedule, } - def renameProtocol(self, protocolId: int, newName: str): + @staticmethod + def _buildProtocolMutationResult(message: str, **extra) -> Dict[str, Any]: + result = { + "status": 1 if extra['errors'] else 0, + "errors": extra['errors'], + "message": message, + "duplicated": extra['duplicated'] + } + result.update(extra or {}) + return result + + def renameProtocol(self, protocolId, newName): protocol = self.currentProject.getProtocol(int(protocolId)) - protocol.setObjLabel(newName) - self.currentProject._storeProtocol(protocol) - return {"status": "ok", "message": "Protocol renamed successfully"} + if protocol is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Protocol not found: {protocolId}", + ) - def duplicateProtocol(self, mapper, projectId, protocols: Any): try: - protList = [] - for protocol in protocols: - protList.append(self.currentProject.getProtocol(int(protocol.id))) - resultProtList = self.currentProject.copyProtocol(protList) + protocol.setObjLabel(newName) + self.currentProject._storeProtocol(protocol) + except Exception as e: + logger.exception( + "Failed to rename protocol. protocolId=%s newName=%s", + protocolId, + newName, + ) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to rename protocol: {e}", + ) + + return self._buildProtocolMutationResult("Protocol renamed successfully") - for prot in resultProtList: - protocolContex = self._buildProtocolContext(projectId, prot) - mapper.saveProtocol(protocolContex) + def duplicateProtocol(self, mapper, projectId, protocols): + protocolList = [] + sourceIds = [] + duplicated = [] + errors = [] + for item in protocols or []: + protocolId = getattr(item, "id", None) + if protocolId is None: + continue + sourceIds.append(protocolId) + protocol = self.currentProject.getProtocol(int(protocolId)) + if protocol is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Protocol not found: {protocolId}", + ) + + protocolList.append(protocol) + + if not protocolList: + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail="No valid protocols to duplicate", + ) + + try: + protListResult = self.currentProject.copyProtocol(protocolList) + for index, prot in enumerate(protListResult): + protId = str(prot.getObjId()) + duplicated.append({"sourceId": sourceIds[index], "newId": protId}) - return {"status": "ok", "message": "Protocol was duplicated successfully"} except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) + errors.append("Failed to duplicate protocols. projectId=%s protocolIds=%s" %projectId, + [getattr(p, "getObjId", lambda: None)() for p in protocolList]) + + logger.exception( + "Failed to duplicate protocols. projectId=%s protocolIds=%s", + projectId, + [getattr(p, "getObjId", lambda: None)() for p in protocolList], + ) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to duplicate protocols: {e}", + ) + + try: + syncResult = self.syncProjectProtocolsAndDependencies( + mapper, + projectId, + refresh=True, + checkPid=True, + ) + except HTTPException: + raise + except Exception as e: + errors.append("Failed to sync protocol graph after duplication. projectId=%s" %projectId) + logger.exception( + "Failed to sync protocol graph after duplication. projectId=%s", + projectId, + ) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Protocols were duplicated in Scipion but graph sync to PostgreSQL failed: {e}", + ) + + return self._buildProtocolMutationResult( + "Protocol was duplicated successfully", + protocolsCount=int(syncResult.get("protocols", 0)), + dependenciesCount=int(syncResult.get("dependencies", 0)), + duplicated=duplicated, + errors=errors, + ) def deleteProtocol(self, mapper, projectId, protocols: Any): try: protList = [] for protocol in protocols: protList.append(self.currentProject.getProtocol(int(protocol))) + self.currentProject.deleteProtocol(*protList) mapper.deleteProtocol(projectId, protList) + + syncInfo = self.syncProjectProtocolsAndDependencies( + mapper, + projectId, + refresh=True, + checkPid=True, + ) + + return { + "status": "ok", + "message": "Protocol deleted successfully", + "protocolsCount": syncInfo.get("protocols"), + "dependenciesCount": syncInfo.get("dependencies"), + } except Exception as e: raise HTTPException(status_code=500, detail=str(e)) - def restartProtocolAll(self, protocolId: int): + def restartProtocolAll(self, protocolId): + protocol = self.currentProject.getProtocol(int(protocolId)) + if protocol is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Protocol not found: {protocolId}", + ) + + try: + workflowProtocolList, _activeProtocolList = self.currentProject._getSubworkflow(protocol) + except Exception as e: + logger.exception( + "Failed to resolve subworkflow for restart-all. protocolId=%s", + protocolId, + ) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to resolve protocol subworkflow: {e}", + ) + + errorList = [] try: - protocol = self.currentProject.getProtocol(int(protocolId)) - workflowProtocolList, activeProtList = self.currentProject._getSubworkflow(protocol) - errorList = [] self.currentProject._restartWorkflow(errorList, workflowProtocolList) - return errorList except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) + logger.exception( + "Failed to restart workflow subtree. protocolId=%s", + protocolId, + ) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to restart protocol subtree: {e}", + ) - def continueProtocolAll(self, mapper, projectId: int, protocolId: int, currentUser: dict): - raise NotImplementedError + if errorList: + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail=[str(e) for e in errorList], + ) + + return self._buildProtocolMutationResult("Protocol subtree restarted successfully") - def resetProtocolFrom(self, protocolId: int): + def continueProtocolAll(self, mapper, projectId, protocolId, currentUser): protocol = self.currentProject.getProtocol(int(protocolId)) + if protocol is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Protocol not found: {protocolId}", + ) + try: - workflowProtocolList, activeProtList = self.currentProject._getSubworkflow(protocol) - errorProtList = self.currentProject.resetWorkFlow(workflowProtocolList) - if errorProtList: - raise HTTPException(status_code=500, detail=errorProtList) + workflowProtocolList, activeProtocolList = self.currentProject._getSubworkflow(protocol) except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) + logger.exception( + "Failed to resolve subworkflow for continue-all. projectId=%s protocolId=%s", + projectId, + protocolId, + ) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to resolve protocol subworkflow: {e}", + ) + + protocolsToResume = activeProtocolList or workflowProtocolList or [] + if not protocolsToResume: + return self._buildProtocolMutationResult("No protocols to continue") + + for item in protocolsToResume: + protocolToLaunch = item + + if not hasattr(protocolToLaunch, "runMode"): + try: + protocolToLaunch = self.currentProject.getProtocol(int(item)) + except Exception: + logger.exception( + "Failed to resolve protocol to continue. projectId=%s protocolId=%s item=%s", + projectId, + protocolId, + item, + ) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to resolve protocol to continue: {item}", + ) + + try: + protocolToLaunch.runMode.set(MODE_RESUME) + except Exception: + logger.debug( + "Could not set MODE_RESUME before continue-all. projectId=%s protocolId=%s item=%s", + projectId, + protocolId, + getattr(protocolToLaunch, "getObjId", lambda: item)(), + exc_info=True, + ) + + try: + self.currentProject.launchProtocol(protocolToLaunch) + except Exception as e: + logger.exception( + "Failed to continue protocol. projectId=%s protocolId=%s item=%s", + projectId, + protocolId, + getattr(protocolToLaunch, "getObjId", lambda: item)(), + ) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to continue protocol: {e}", + ) + + return self._buildProtocolMutationResult("Protocol subtree continued successfully") + + def resetProtocolFrom(self, protocolId): + protocol = self.currentProject.getProtocol(int(protocolId)) + if protocol is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Protocol not found: {protocolId}", + ) - def stopProtocol(self, protocols: Any): try: - for protocolId in protocols: - protocol = self.currentProject.getProtocol(int(protocolId)) + workflowProtocolList, _activeProtocolList = self.currentProject._getSubworkflow(protocol) + except Exception as e: + logger.exception( + "Failed to resolve subworkflow for reset-from. protocolId=%s", + protocolId, + ) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to resolve protocol subworkflow: {e}", + ) + + try: + resetErrors = self.currentProject.resetWorkFlow(workflowProtocolList) or [] + except Exception as e: + logger.exception( + "Failed to reset workflow subtree. protocolId=%s", + protocolId, + ) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to reset protocol subtree: {e}", + ) + + if resetErrors: + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail=[str(e) for e in resetErrors], + ) + + return self._buildProtocolMutationResult("Protocol subtree reset successfully") + + def stopProtocol(self, protocolIds): + resolvedProtocols = [] + + for protocolId in protocolIds or []: + protocol = self.currentProject.getProtocol(int(protocolId)) + if protocol is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Protocol not found: {protocolId}", + ) + resolvedProtocols.append(protocol) + + if not resolvedProtocols: + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail="No valid protocols to stop", + ) + + try: + for protocol in resolvedProtocols: self.currentProject.stopProtocol(protocol) except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) + logger.exception( + "Failed to stop protocols. protocolIds=%s", + protocolIds, + ) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to stop protocols: {e}", + ) + + return self._buildProtocolMutationResult("Protocol stopped successfully") def _isGlobalFsBrowserMode(self, protocolId: Union[int, str]) -> bool: return str(protocolId).strip() == "-1" diff --git a/app/backend/mapper/postgresql.py b/app/backend/mapper/postgresql.py index b445e52..4e113a4 100644 --- a/app/backend/mapper/postgresql.py +++ b/app/backend/mapper/postgresql.py @@ -6,7 +6,7 @@ import psycopg2 import psycopg2.extras from contextlib import contextmanager -from typing import Optional, List, Dict, Any, Iterator +from typing import Optional, List, Dict, Any, Iterator, Tuple from pyworkflow.mapper.mapper import Mapper # Base class from Scipion @@ -110,26 +110,57 @@ def initTables(self): """ ) - # CreateProtocolsTableLegacy (kept as-is for now) + # CreateProtocolsTableLegacy self.db.execute( """ CREATE TABLE IF NOT EXISTS protocols ( id SERIAL PRIMARY KEY, - CREATE TABLE IF NOT EXISTS protocols ( - id SERIAL PRIMARY KEY, - "projectId" INTEGER NOT NULL REFERENCES projects(id) ON DELETE CASCADE, - "protocolId" TEXT NOT NULL, - "protocolClassName" TEXT NOT NULL, - status TEXT NOT NULL DEFAULT 'pending', - params JSONB, - "parentIds" JSONB NOT NULL DEFAULT '[]'::jsonb, - "childIds" JSONB NOT NULL DEFAULT '[]'::jsonb, - "createdAt" TIMESTAMPTZ NOT NULL DEFAULT NOW(), - "updatedAt" TIMESTAMPTZ + "projectId" INTEGER NOT NULL REFERENCES projects(id) ON DELETE CASCADE, + "protocolId" TEXT NOT NULL, + "protocolClassName" TEXT NOT NULL, + status TEXT NOT NULL DEFAULT 'pending', + params JSONB, + "parentIds" INTEGER[] NOT NULL DEFAULT ARRAY[]::INTEGER[], + "childIds" INTEGER[] NOT NULL DEFAULT ARRAY[]::INTEGER[], + "createdAt" TIMESTAMPTZ NOT NULL DEFAULT NOW(), + "updatedAt" TIMESTAMPTZ ); - + CREATE UNIQUE INDEX IF NOT EXISTS protocols_project_protocol_ux ON protocols("projectId", "protocolId"); + + CREATE UNIQUE INDEX IF NOT EXISTS protocols_project_dbid_ux + ON protocols("projectId", id); + """ + ) + + # CreateProtocolDependenciesTable + self.db.execute( + """ + CREATE TABLE IF NOT EXISTS protocol_dependencies ( + "projectId" INTEGER NOT NULL REFERENCES projects(id) ON DELETE CASCADE, + "parentProtocolDbId" INTEGER NOT NULL, + "childProtocolDbId" INTEGER NOT NULL, + "createdAt" TIMESTAMPTZ NOT NULL DEFAULT NOW(), + + PRIMARY KEY ("projectId", "parentProtocolDbId", "childProtocolDbId"), + + FOREIGN KEY ("projectId", "parentProtocolDbId") + REFERENCES protocols("projectId", id) + ON DELETE CASCADE, + + FOREIGN KEY ("projectId", "childProtocolDbId") + REFERENCES protocols("projectId", id) + ON DELETE CASCADE, + + CHECK ("parentProtocolDbId" <> "childProtocolDbId") + ); + + CREATE INDEX IF NOT EXISTS protocol_dependencies_by_parent + ON protocol_dependencies("projectId", "parentProtocolDbId"); + + CREATE INDEX IF NOT EXISTS protocol_dependencies_by_child + ON protocol_dependencies("projectId", "childProtocolDbId"); """ ) @@ -891,6 +922,137 @@ def saveProtocol(self, protocol: Dict[str, Any]) -> int: ) return cur.fetchone()["id"] + def getProjectProtocolDbIdMap(self, projectId: int) -> Dict[str, int]: + rows = self.db.fetchAll( + """ + SELECT id, "protocolId" + FROM protocols + WHERE "projectId" = %s + """, + (projectId,), + ) + + return { + str(row["protocolId"]): int(row["id"]) + for row in rows + if row.get("protocolId") is not None and row.get("id") is not None + } + + def replaceProjectProtocolDependencies( + self, + projectId: int, + edges: List[Tuple[int, int]], + ) -> int: + self.db.execute( + """ + DELETE FROM protocol_dependencies + WHERE "projectId" = %s + """, + (projectId,), + ) + + cleanEdges: List[Tuple[int, int]] = [] + seen = set() + + for parentDbId, childDbId in edges or []: + try: + parentDbId = int(parentDbId) + childDbId = int(childDbId) + except Exception: + continue + + if parentDbId <= 0 or childDbId <= 0: + continue + if parentDbId == childDbId: + continue + + key = (parentDbId, childDbId) + if key in seen: + continue + + seen.add(key) + cleanEdges.append(key) + + if not cleanEdges: + return 0 + + valuesSql = ",".join(["(%s, %s, %s)"] * len(cleanEdges)) + params: List[Any] = [] + + for parentDbId, childDbId in cleanEdges: + params.extend([projectId, parentDbId, childDbId]) + + self.db.execute( + f""" + INSERT INTO protocol_dependencies ( + "projectId", + "parentProtocolDbId", + "childProtocolDbId" + ) + VALUES {valuesSql} + """, + tuple(params), + ) + + return len(cleanEdges) + + def listProjectProtocolDependencies(self, projectId: int) -> List[Dict[str, Any]]: + return self.db.fetchAll( + """ + SELECT + "projectId", + "parentProtocolDbId", + "childProtocolDbId", + "createdAt" + FROM protocol_dependencies + WHERE "projectId" = %s + ORDER BY "parentProtocolDbId", "childProtocolDbId" + """, + (projectId,), + ) + + def getProjectProtocolAdjacencyMap(self, projectId: int) -> Dict[str, Dict[str, List[str]]]: + rows = self.db.fetchAll( + """ + SELECT + parent."protocolId" AS "parentProtocolId", + child."protocolId" AS "childProtocolId" + FROM protocol_dependencies d + JOIN protocols parent + ON parent.id = d."parentProtocolDbId" + AND parent."projectId" = d."projectId" + JOIN protocols child + ON child.id = d."childProtocolDbId" + AND child."projectId" = d."projectId" + WHERE d."projectId" = %s + ORDER BY child.id, parent.id + """, + (projectId,), + ) + + adjacency: Dict[str, Dict[str, List[str]]] = {} + + for row in rows: + parentProtocolId = row.get("parentProtocolId") + childProtocolId = row.get("childProtocolId") + + if parentProtocolId is None or childProtocolId is None: + continue + + parentProtocolId = str(parentProtocolId) + childProtocolId = str(childProtocolId) + + adjacency.setdefault(parentProtocolId, {"parents": [], "children": []}) + adjacency.setdefault(childProtocolId, {"parents": [], "children": []}) + + if childProtocolId not in adjacency[parentProtocolId]["children"]: + adjacency[parentProtocolId]["children"].append(childProtocolId) + + if parentProtocolId not in adjacency[childProtocolId]["parents"]: + adjacency[childProtocolId]["parents"].append(parentProtocolId) + + return adjacency + def getProtocolByProtocolId(self, protocolId: int, projectId: int) -> Optional[Dict]: """Retrieve a protocol by id.""" return self.db.fetchOne( @@ -950,6 +1112,46 @@ def updateProtocolDependencies(self, protocolId: str, parentIds: list, childIds: query = 'UPDATE protocols SET "parentIds" = %s, "childIds" = %s, "updatedAt" = NOW() WHERE "protocolId" = %s' self.db.execute(query, (parentIds, childIds, protocolId)) + def deleteProjectProtocolsNotInProtocolIds( + self, + projectId: int, + protocolIdsToKeep: List[str], + ) -> int: + keepSet = { + str(protocolId).strip() + for protocolId in (protocolIdsToKeep or []) + if str(protocolId).strip() + } + + rows = self.db.fetchAll( + """ + SELECT id, "protocolId" + FROM protocols + WHERE "projectId" = %s + """, + (projectId,), + ) + + staleDbIds = [ + int(row["id"]) + for row in rows + if str(row.get("protocolId", "")).strip() not in keepSet + ] + + if not staleDbIds: + return 0 + + self.db.execute( + """ + DELETE FROM protocols + WHERE "projectId" = %s + AND id = ANY(%s) + """, + (projectId, staleDbIds), + ) + + return len(staleDbIds) + # ----------------------------- # Settings Methods # ----------------------------- diff --git a/app/backend/utils/thumbnail_service.py b/app/backend/utils/thumbnail_service.py index 300d8ce..7d0d3bc 100644 --- a/app/backend/utils/thumbnail_service.py +++ b/app/backend/utils/thumbnail_service.py @@ -32,6 +32,8 @@ import hashlib from urllib.parse import quote from pathlib import Path +import tempfile +import threading from typing import Any, Dict, Iterable, List, Optional, Sequence, Set, Tuple import numpy as np @@ -62,6 +64,8 @@ from pwem.viewers.viewers_data import RegistryViewerConfig logger = logging.getLogger(__name__) +_thumbnailBuildLocksGuard = threading.Lock() +_thumbnailBuildLocks: Dict[str, threading.Lock] = {} class ThumbnailService: @@ -74,6 +78,31 @@ def __init__(self, currentProject): # ------------------------------------------------------------------ # Public API # ------------------------------------------------------------------ + def _getThumbnailBuildLock(self, cachePath: Path): + key = str(cachePath) + with _thumbnailBuildLocksGuard: + lock = _thumbnailBuildLocks.get(key) + if lock is None: + lock = threading.Lock() + _thumbnailBuildLocks[key] = lock + return lock + + def _isValidCachedImage(self, cachePath: Path) -> bool: + try: + if not cachePath.exists(): + return False + if cachePath.stat().st_size <= 0: + return False + with Image.open(cachePath) as img: + img.verify() + return True + except Exception: + try: + cachePath.unlink(missing_ok=True) + except Exception: + pass + return False + def listUsefulProtocols(self, maxProtocols: int = 12) -> List[Dict[str, Any]]: protocols = self._iterProtocols() candidates: List[Dict[str, Any]] = [] @@ -129,11 +158,28 @@ def buildProtocolThumbnail( size: int = 360, outputName: Optional[str] = None, ) -> Dict[str, Any]: - protocol = self.currentProject.getProtocol(int(protocolId)) + try: + protocol = self.currentProject.getProtocol(int(protocolId)) + except Exception: + protocol = None + if protocol is None: - raise ValueError(f"Protocol {protocolId} not found") + return { + "protocolId": int(protocolId), + "protocolLabel": f"Protocol {int(protocolId)}", + "status": "unknown", + "outputName": outputName, + "outputClassName": None, + "absolutePath": None, + "cached": False, + "exists": False, + } - cachePath = self._getProtocolCachePath(protocolId, size=size, outputName=outputName) + cachePath = self._getProtocolCachePath( + protocolId, + size=size, + outputName=outputName, + ) selectedCandidate: Optional[Dict[str, Any]] = None if outputName: @@ -154,88 +200,93 @@ def buildProtocolThumbnail( "exists": False, } - if cachePath.exists() and not force: - return { - "protocolId": int(protocolId), - "protocolLabel": self._getProtocolLabel(protocol), - "status": self._getProtocolStatus(protocol), - "outputName": selectedCandidate["outputName"] if selectedCandidate else outputName, - "outputClassName": selectedCandidate["outputClassName"] if selectedCandidate else None, - "absolutePath": str(cachePath), - "cached": True, - "exists": True, - } + buildLock = self._getThumbnailBuildLock(cachePath) + with buildLock: + if not force and self._isValidCachedImage(cachePath): + return { + "protocolId": int(protocolId), + "protocolLabel": self._getProtocolLabel(protocol), + "status": self._getProtocolStatus(protocol), + "outputName": selectedCandidate["outputName"] if selectedCandidate else outputName, + "outputClassName": selectedCandidate["outputClassName"] if selectedCandidate else None, + "absolutePath": str(cachePath), + "cached": True, + "exists": True, + } - previewImage: Optional[Image.Image] = None - candidates = [selectedCandidate] if selectedCandidate is not None else self._collectSortedOutputCandidates( - protocol) + previewImage: Optional[Image.Image] = None + candidates = ( + [selectedCandidate] + if selectedCandidate is not None + else self._collectSortedOutputCandidates(protocol) + ) - for candidate in candidates: - if candidate is None: - continue + for candidate in candidates: + if candidate is None: + continue - try: - image = self._renderProtocolPreviewImage( - protocol=protocol, - output=candidate["output"], - outputName=candidate["outputName"], - outputClassName=candidate["outputClassName"], - size=size, - ) - if image is not None: - selectedCandidate = candidate - previewImage = image - break - except Exception: - logger.debug( - "Candidate thumbnail render failed. protocolId=%s output=%s class=%s", - protocolId, - candidate.get("outputName"), - candidate.get("outputClassName"), - exc_info=True, - ) + try: + image = self._renderProtocolPreviewImage( + protocol=protocol, + output=candidate["output"], + outputName=candidate["outputName"], + outputClassName=candidate["outputClassName"], + size=size, + ) + if image is not None: + selectedCandidate = candidate + previewImage = image + break + except Exception: + logger.debug( + "Candidate thumbnail render failed. protocolId=%s output=%s class=%s", + protocolId, + candidate.get("outputName"), + candidate.get("outputClassName"), + exc_info=True, + ) - if previewImage is None and outputName is None: - try: - previewImage = self._renderProtocolFilesystemFallback(protocol, size=size) - except Exception: - logger.debug( - "Filesystem thumbnail fallback failed. protocolId=%s", - protocolId, - exc_info=True, - ) - previewImage = None + if previewImage is None and outputName is None: + try: + previewImage = self._renderProtocolFilesystemFallback(protocol, size=size) + except Exception: + logger.debug( + "Filesystem thumbnail fallback failed. protocolId=%s", + protocolId, + exc_info=True, + ) + previewImage = None + + if previewImage is None: + return { + "protocolId": int(protocolId), + "protocolLabel": self._getProtocolLabel(protocol), + "status": self._getProtocolStatus(protocol), + "outputName": selectedCandidate["outputName"] if selectedCandidate else outputName, + "outputClassName": selectedCandidate["outputClassName"] if selectedCandidate else None, + "absolutePath": None, + "cached": False, + "exists": False, + } + + thumbnail = self._finalizeProtocolThumbnail( + previewImage=previewImage, + size=size, + protocolId=int(protocolId), + ) + self._saveImage(thumbnail, cachePath) - if previewImage is None: return { "protocolId": int(protocolId), "protocolLabel": self._getProtocolLabel(protocol), "status": self._getProtocolStatus(protocol), "outputName": selectedCandidate["outputName"] if selectedCandidate else outputName, "outputClassName": selectedCandidate["outputClassName"] if selectedCandidate else None, - "absolutePath": None, + "absolutePath": str(cachePath), "cached": False, - "exists": False, + "exists": True, } - thumbnail = self._finalizeProtocolThumbnail( - previewImage=previewImage, - size=size, - protocolId=int(protocolId), - ) - self._saveImage(thumbnail, cachePath) - - return { - "protocolId": int(protocolId), - "protocolLabel": self._getProtocolLabel(protocol), - "status": self._getProtocolStatus(protocol), - "outputName": selectedCandidate["outputName"] if selectedCandidate else outputName, - "outputClassName": selectedCandidate["outputClassName"] if selectedCandidate else None, - "absolutePath": str(cachePath), - "cached": False, - "exists": True, - } - def buildProjectThumbnail( self, force: bool = False, @@ -243,65 +294,77 @@ def buildProjectThumbnail( maxProtocols: int = 6, ) -> Dict[str, Any]: cachePath = self._getProjectCachePath(size=size, maxProtocols=maxProtocols) - if cachePath.exists() and not force: - return { - "absolutePath": str(cachePath), - "cached": True, - "items": None, - } - useful = self.listUsefulProtocols(maxProtocols=max(3, int(maxProtocols) * 3)) - renderedItems: List[Dict[str, Any]] = [] - protocolThumbWidth = self._projectProtocolSize(size=int(size), maxProtocols=int(maxProtocols)) + buildLock = self._getThumbnailBuildLock(cachePath) + with buildLock: + if not force and self._isValidCachedImage(cachePath): + return { + "absolutePath": str(cachePath), + "cached": True, + "items": None, + } - for candidate in useful: - try: - built = self.buildProtocolThumbnail( - protocolId=int(candidate["protocolId"]), - force=force, - size=protocolThumbWidth, - ) + useful = self.listUsefulProtocols( + maxProtocols=max(3, int(maxProtocols) * 3), + ) + renderedItems: List[Dict[str, Any]] = [] + protocolThumbWidth = self._projectProtocolSize( + size=int(size), + maxProtocols=int(maxProtocols), + ) - if not built.get("exists") or not built.get("absolutePath"): - continue + for candidate in useful: + try: + built = self.buildProtocolThumbnail( + protocolId=int(candidate["protocolId"]), + force=force, + size=protocolThumbWidth, + ) - renderedItems.append( - { - "protocolId": int(candidate["protocolId"]), - "protocolLabel": candidate.get("protocolLabel"), - "status": candidate.get("status"), - "outputName": candidate.get("outputName"), - "outputClassName": candidate.get("outputClassName"), - "itemsCount": candidate.get("itemsCount", 0), - "absolutePath": built["absolutePath"], - } - ) + if not built.get("exists") or not built.get("absolutePath"): + continue - if len(renderedItems) >= int(maxProtocols): - break + renderedItems.append( + { + "protocolId": int(candidate["protocolId"]), + "protocolLabel": candidate.get("protocolLabel"), + "status": candidate.get("status"), + "outputName": candidate.get("outputName"), + "outputClassName": candidate.get("outputClassName"), + "itemsCount": candidate.get("itemsCount", 0), + "absolutePath": built["absolutePath"], + } + ) - except Exception: - logger.debug( - "Skipping failed protocol thumbnail while building project strip. protocolId=%s", - candidate.get("protocolId"), - exc_info=True, - ) + if len(renderedItems) >= int(maxProtocols): + break + + except Exception: + logger.debug( + "Skipping failed protocol thumbnail while building project strip. protocolId=%s", + candidate.get("protocolId"), + exc_info=True, + ) + + if not renderedItems: + return { + "absolutePath": None, + "cached": False, + "items": 0, + } + + strip = self._composeProjectStrip( + items=renderedItems, + size=int(size), + ) + self._saveImage(strip, cachePath) - if not renderedItems: return { - "absolutePath": None, + "absolutePath": str(cachePath), "cached": False, - "items": 0, + "items": len(renderedItems), } - strip = self._composeProjectStrip(items=renderedItems, size=int(size)) - self._saveImage(strip, cachePath) - return { - "absolutePath": str(cachePath), - "cached": False, - "items": len(renderedItems), - } - def _getBadgeFont(self, badgeH: int): fontSize = max(12, int(round(badgeH * 0.48))) candidates = [ @@ -381,12 +444,34 @@ def buildProtocolOutputThumbnail( force: bool = False, size: int = 320, ) -> Dict[str, Any]: - protocol = self.currentProject.getProtocol(int(protocolId)) + try: + protocol = self.currentProject.getProtocol(int(protocolId)) + except Exception: + protocol = None + if protocol is None: - raise ValueError(f"Protocol {protocolId} not found") + return { + "protocolId": int(protocolId), + "protocolLabel": f"Protocol {int(protocolId)}", + "status": "unknown", + "outputName": outputName, + "outputClassName": None, + "absolutePath": None, + "cached": False, + "exists": False, + } if not hasattr(protocol, outputName): - raise ValueError(f"Output '{outputName}' not found in protocol {protocolId}") + return { + "protocolId": int(protocolId), + "protocolLabel": self._getProtocolLabel(protocol), + "status": self._getProtocolStatus(protocol), + "outputName": outputName, + "outputClassName": None, + "absolutePath": None, + "cached": False, + "exists": False, + } output = getattr(protocol, outputName, None) if output is None: @@ -422,78 +507,80 @@ def buildProtocolOutputThumbnail( size=size, ) - if cachePath.exists() and not force: - return { - "protocolId": int(protocolId), - "protocolLabel": self._getProtocolLabel(protocol), - "status": self._getProtocolStatus(protocol), - "outputName": outputName, - "outputClassName": outputClassName, - "absolutePath": str(cachePath), - "cached": True, - "exists": True, - } - - previewImage: Optional[Image.Image] = None + buildLock = self._getThumbnailBuildLock(cachePath) + with buildLock: + if not force and self._isValidCachedImage(cachePath): + return { + "protocolId": int(protocolId), + "protocolLabel": self._getProtocolLabel(protocol), + "status": self._getProtocolStatus(protocol), + "outputName": outputName, + "outputClassName": outputClassName, + "absolutePath": str(cachePath), + "cached": True, + "exists": True, + } - try: - previewImage = self._renderProtocolPreviewImage( - protocol=protocol, - output=output, - outputName=outputName, - outputClassName=outputClassName, - size=size, - ) - except Exception: - logger.debug( - "Output thumbnail render failed. protocolId=%s output=%s class=%s", - protocolId, - outputName, - outputClassName, - exc_info=True, - ) + previewImage: Optional[Image.Image] = None - if previewImage is None: try: - previewImage = self._renderGenericPreview(protocol, output, size=size) + previewImage = self._renderProtocolPreviewImage( + protocol=protocol, + output=output, + outputName=outputName, + outputClassName=outputClassName, + size=size, + ) except Exception: logger.debug( - "Generic output thumbnail render failed. protocolId=%s output=%s", + "Output thumbnail render failed. protocolId=%s output=%s class=%s", protocolId, outputName, + outputClassName, exc_info=True, ) - if previewImage is None: + if previewImage is None: + try: + previewImage = self._renderGenericPreview(protocol, output, size=size) + except Exception: + logger.debug( + "Generic output thumbnail render failed. protocolId=%s output=%s", + protocolId, + outputName, + exc_info=True, + ) + + if previewImage is None: + return { + "protocolId": int(protocolId), + "protocolLabel": self._getProtocolLabel(protocol), + "status": self._getProtocolStatus(protocol), + "outputName": outputName, + "outputClassName": outputClassName, + "absolutePath": None, + "cached": False, + "exists": False, + } + + thumbnail = self._finalizeProtocolThumbnail( + previewImage=previewImage, + size=size, + protocolId=int(protocolId), + ) + self._saveImage(thumbnail, cachePath) + return { "protocolId": int(protocolId), "protocolLabel": self._getProtocolLabel(protocol), "status": self._getProtocolStatus(protocol), "outputName": outputName, "outputClassName": outputClassName, - "absolutePath": None, + "absolutePath": str(cachePath), "cached": False, - "exists": False, + "exists": True, } - thumbnail = self._finalizeProtocolThumbnail( - previewImage=previewImage, - size=size, - protocolId=int(protocolId), - ) - self._saveImage(thumbnail, cachePath) - - return { - "protocolId": int(protocolId), - "protocolLabel": self._getProtocolLabel(protocol), - "status": self._getProtocolStatus(protocol), - "outputName": outputName, - "outputClassName": outputClassName, - "absolutePath": str(cachePath), - "cached": False, - "exists": True, - } - # ------------------------------------------------------------------ # Candidate selection # ------------------------------------------------------------------ @@ -692,11 +779,11 @@ def listProtocolThumbnailItems( continue try: - built = self.buildProtocolThumbnail( + built = self.buildProtocolOutputThumbnail( protocolId=protocolId, + outputName=outputName, force=force, size=size, - outputName=outputName, ) except Exception: logger.debug( @@ -716,13 +803,9 @@ def listProtocolThumbnailItems( "outputClassName": outputClassName, "exists": True, "thumbnailUrl": ( - f"/projects/{int(projectId)}/protocols/{protocolId}/thumbnail" - f"?outputName={quote(str(outputName))}" - ), - "thumbnailRebuildUrl": ( - f"/projects/{int(projectId)}/protocols/{protocolId}/thumbnail/rebuild" - f"?outputName={quote(str(outputName))}" + f"/projects/{int(projectId)}/protocols/{protocolId}/outputs/{quote(str(outputName))}/thumbnail" ), + "thumbnailRebuildUrl": None, } ) @@ -2168,7 +2251,20 @@ def _getProjectCachePath(self, size: int, maxProtocols: int) -> Path: def _saveImage(self, image: Image.Image, outputPath: Path): outputPath.parent.mkdir(parents=True, exist_ok=True) - image.save(str(outputPath), format="PNG") + + tmpPath = outputPath.with_name( + f".{outputPath.name}.{os.getpid()}.{id(image)}.tmp" + ) + + try: + image.save(str(tmpPath), format="PNG") + os.replace(str(tmpPath), str(outputPath)) + finally: + try: + if tmpPath.exists(): + tmpPath.unlink() + except Exception: + pass def _getProjectPath(self) -> Optional[str]: if self.currentProject is None: diff --git a/tests/conftest.py b/tests/conftest.py index 8044c68..423a0bb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -116,6 +116,7 @@ def makeProjectOut(projectId: int = 1, name: str = "Demo Project", **overrides): payload.update(overrides) return payload + class FakeProjectService: # fakeProjectService def __init__(self): @@ -176,6 +177,9 @@ def __init__(self): self.applyWorkflowError = None self.lastApplyWorkflowCall = None + self.syncProjectGraphAfterMutationError = None + self.lastSyncProjectGraphAfterMutationCall = None + self.protocolParamsResult = { "protocolId": "10", "protocolClassName": "ProtClass", @@ -369,6 +373,18 @@ def listProjectShares(self, mapper, projectId, currentUser): } return self.projectSharesResult + def syncProjectGraphAfterMutation(self, mapper, projectId, actionLabel, refresh=True, checkPid=True): + self.lastSyncProjectGraphAfterMutationCall = { + "mapper": mapper, + "projectId": projectId, + "actionLabel": actionLabel, + "refresh": refresh, + "checkPid": checkPid, + } + if self.syncProjectGraphAfterMutationError is not None: + raise self.syncProjectGraphAfterMutationError + return {"protocols": 1, "dependencies": 0} + def applyWorkflowToProject(self, mapper, projectId, workflowId, currentUser): self.lastApplyWorkflowCall = { "mapper": mapper, diff --git a/tests/integration/api/test_projects_router_protocol_ops.py b/tests/integration/api/test_projects_router_protocol_ops.py index 28ec08f..796498c 100644 --- a/tests/integration/api/test_projects_router_protocol_ops.py +++ b/tests/integration/api/test_projects_router_protocol_ops.py @@ -241,6 +241,14 @@ def test_RenameProtocolDelegatesToService(projectClient, fakeProjectService): "newName": "Renamed protocol", } + assert fakeProjectService.lastSyncProjectGraphAfterMutationCall == { + "mapper": fakeProjectService.lastSyncProjectGraphAfterMutationCall["mapper"], + "projectId": 1, + "actionLabel": "rename protocol", + "refresh": True, + "checkPid": True, + } + def test_DuplicateProtocolRejectsMissingItems(projectClient): response = projectClient.post( @@ -317,12 +325,15 @@ def test_DeleteProtocolDelegatesToService(projectClient, fakeProjectService): } -def test_RestartProtocolAllReturnsErrorsWhenServiceReturnsErrors(projectClient, fakeProjectService): - fakeProjectService.restartProtocolAllResult = ["cannot restart", "blocked"] +def test_RestartProtocolAllReturnsErrorsWhenServiceRaisesHttpException(projectClient, fakeProjectService): + fakeProjectService.restartProtocolAllError = HTTPException( + status_code=422, + detail=["cannot restart", "blocked"], + ) response = projectClient.post("/projects/1/protocols/10/restart-all") - assert response.status_code == 200 + assert response.status_code == 422 assert response.json() == { "status": 1, "errors": ["cannot restart", "blocked"], @@ -343,6 +354,13 @@ def test_RestartProtocolAllReturnsSuccess(projectClient, fakeProjectService): } assert fakeProjectService.lastRestartProtocolAllCall == {"protocolId": 10} + assert fakeProjectService.lastSyncProjectGraphAfterMutationCall == { + "mapper": fakeProjectService.lastSyncProjectGraphAfterMutationCall["mapper"], + "projectId": 1, + "actionLabel": "restart protocol subtree", + "refresh": True, + "checkPid": True, + } def test_ContinueProtocolAllDelegatesToService(projectClient, fakeProjectService): @@ -355,6 +373,14 @@ def test_ContinueProtocolAllDelegatesToService(projectClient, fakeProjectService "workflow": [], } + assert fakeProjectService.lastSyncProjectGraphAfterMutationCall == { + "mapper": fakeProjectService.lastSyncProjectGraphAfterMutationCall["mapper"], + "projectId": 1, + "actionLabel": "continue protocol subtree", + "refresh": True, + "checkPid": True, + } + assert fakeProjectService.lastContinueProtocolAllCall == { "mapper": fakeProjectService.lastContinueProtocolAllCall["mapper"], "projectId": 1, @@ -378,6 +404,123 @@ def test_ResetProtocolFromDelegatesToService(projectClient, fakeProjectService): } assert fakeProjectService.lastResetProtocolFromCall == {"protocolId": 10} + assert fakeProjectService.lastSyncProjectGraphAfterMutationCall == { + "mapper": fakeProjectService.lastSyncProjectGraphAfterMutationCall["mapper"], + "projectId": 1, + "actionLabel": "reset protocol from node", + "refresh": True, + "checkPid": True, + } + +def test_RenameProtocolReturnsErrorWhenGraphSyncFails(projectClient, fakeProjectService): + fakeProjectService.syncProjectGraphAfterMutationError = HTTPException( + status_code=500, + detail="rename protocol succeeded but graph sync to PostgreSQL failed", + ) + + response = projectClient.put( + "/projects/1/protocols/10/rename", + json={"name": "Renamed protocol"}, + ) + + assert response.status_code == 500 + assert response.json() == { + "status": 1, + "errors": ["rename protocol succeeded but graph sync to PostgreSQL failed"], + "workflow": [], + } + + assert fakeProjectService.lastRenameProtocolCall == { + "protocolId": 10, + "newName": "Renamed protocol", + } + + +def test_RestartProtocolAllReturnsErrorWhenGraphSyncFails(projectClient, fakeProjectService): + fakeProjectService.restartProtocolAllResult = [] + fakeProjectService.syncProjectGraphAfterMutationError = HTTPException( + status_code=500, + detail="restart protocol subtree succeeded but graph sync to PostgreSQL failed", + ) + + response = projectClient.post("/projects/1/protocols/10/restart-all") + + assert response.status_code == 500 + assert response.json() == { + "status": 1, + "errors": ["restart protocol subtree succeeded but graph sync to PostgreSQL failed"], + "workflow": [], + } + + assert fakeProjectService.lastRestartProtocolAllCall == {"protocolId": 10} + + +def test_ContinueProtocolAllReturnsErrorWhenGraphSyncFails(projectClient, fakeProjectService): + fakeProjectService.syncProjectGraphAfterMutationError = HTTPException( + status_code=500, + detail="continue protocol subtree succeeded but graph sync to PostgreSQL failed", + ) + + response = projectClient.post("/projects/1/protocols/10/continue-all") + + assert response.status_code == 500 + assert response.json() == { + "status": 1, + "errors": ["continue protocol subtree succeeded but graph sync to PostgreSQL failed"], + "workflow": [], + } + + assert fakeProjectService.lastContinueProtocolAllCall == { + "mapper": fakeProjectService.lastContinueProtocolAllCall["mapper"], + "projectId": 1, + "protocolId": 10, + "currentUser": { + "id": 1, + "email": "user@example.com", + "role": "user", + }, + } + + +def test_ResetProtocolFromReturnsErrorWhenGraphSyncFails(projectClient, fakeProjectService): + fakeProjectService.syncProjectGraphAfterMutationError = HTTPException( + status_code=500, + detail="reset protocol from node succeeded but graph sync to PostgreSQL failed", + ) + + response = projectClient.post("/projects/1/protocols/10/reset-from") + + assert response.status_code == 500 + assert response.json() == { + "status": 1, + "errors": ["reset protocol from node succeeded but graph sync to PostgreSQL failed"], + "workflow": [], + } + + assert fakeProjectService.lastResetProtocolFromCall == {"protocolId": 10} + + +def test_StopProtocolReturnsErrorWhenGraphSyncFails(projectClient, fakeProjectService): + fakeProjectService.syncProjectGraphAfterMutationError = HTTPException( + status_code=500, + detail="stop protocol succeeded but graph sync to PostgreSQL failed", + ) + + response = projectClient.post( + "/projects/1/protocols/stop", + json={"protocolIds": ["10", "11"]}, + ) + + assert response.status_code == 500 + assert response.json() == { + "status": 1, + "errors": ["stop protocol succeeded but graph sync to PostgreSQL failed"], + "workflow": [], + } + + assert fakeProjectService.lastStopProtocolCall == { + "protocolIds": ["10", "11"], + } def test_StopProtocolRejectsMissingProtocolIds(projectClient): @@ -409,4 +552,11 @@ def test_StopProtocolDelegatesToService(projectClient, fakeProjectService): assert fakeProjectService.lastStopProtocolCall == { "protocolIds": ["10", "11"], + } + assert fakeProjectService.lastSyncProjectGraphAfterMutationCall == { + "mapper": fakeProjectService.lastSyncProjectGraphAfterMutationCall["mapper"], + "projectId": 1, + "actionLabel": "stop protocol", + "refresh": True, + "checkPid": True, } \ No newline at end of file diff --git a/tests/unit/backend/api/services/test_project_service_protocols.py b/tests/unit/backend/api/services/test_project_service_protocols.py index ca8878f..600ab31 100644 --- a/tests/unit/backend/api/services/test_project_service_protocols.py +++ b/tests/unit/backend/api/services/test_project_service_protocols.py @@ -245,6 +245,18 @@ def service(projectServiceModule): instance = object.__new__(projectServiceModule.ProjectService) instance.currentProject = FakeCurrentProject() instance.tomoList = {} + instance._buildProtocolContext = lambda projectId, protocol: { + "projectId": projectId, + "protocolId": protocol.getObjId(), + "protocolClassName": getattr(protocol, "_className", "ProtClass"), + "params": {}, + } + instance.syncProjectProtocolsAndDependencies = ( + lambda mapper, projectId, refresh=False, checkPid=False: { + "protocols": 0, + "dependencies": 0, + } + ) return instance @@ -298,6 +310,17 @@ def test_SaveProtocolCreatesNewProtocolAndPersistsContext(projectServiceModule, }, ) + def fakeSyncProjectProtocolsAndDependencies(mapperObj, projectId, refresh=False, checkPid=False): + for protocolObj in service.currentProject.setupProtocols: + mapperObj.saveProtocol(service._buildProtocolContext(projectId, protocolObj)) + return {"protocols": len(service.currentProject.setupProtocols), "dependencies": 0} + + monkeypatch.setattr( + service, + "syncProjectProtocolsAndDependencies", + fakeSyncProjectProtocolsAndDependencies, + ) + def buildProtocol(): protocol = FakeProtocol(objId=None, className="ProtClass") protocol.addParam("runName", FakeStringParam(label="Run name")) @@ -350,6 +373,17 @@ def test_SaveProtocolAggregatesValidationAndPointerErrors(projectServiceModule, monkeypatch.setattr(service, "applyParamsToProtocol", lambda protocolObj, params: ["pointer error"]) + def fakeSyncProjectProtocolsAndDependencies(mapperObj, projectId, refresh=False, checkPid=False): + for protocolObj in service.currentProject.storedProtocols: + mapperObj.saveProtocol(service._buildProtocolContext(projectId, protocolObj)) + return {"protocols": len(service.currentProject.storedProtocols), "dependencies": 0} + + monkeypatch.setattr( + service, + "syncProjectProtocolsAndDependencies", + fakeSyncProjectProtocolsAndDependencies, + ) + _, errors = service.saveProtocol( mapper=mapper, projectId=1, @@ -500,6 +534,20 @@ def test_DuplicateProtocolCopiesAndPersists(service, mapper, monkeypatch): }, ) + def fakeSyncProjectProtocolsAndDependencies(mapperObj, projectId, refresh=False, checkPid=False): + for protocolObj in service.currentProject.copiedProtocolOutputs: + mapperObj.saveProtocol(service._buildProtocolContext(projectId, protocolObj)) + return { + "protocols": len(service.currentProject.copiedProtocolOutputs), + "dependencies": 0, + } + + monkeypatch.setattr( + service, + "syncProjectProtocolsAndDependencies", + fakeSyncProjectProtocolsAndDependencies, + ) + class DuplicateItem: def __init__(self, itemId): self.id = itemId @@ -510,7 +558,12 @@ def __init__(self, itemId): protocols=[DuplicateItem("10"), DuplicateItem("11")], ) - assert result == {"status": "ok", "message": "Protocol was duplicated successfully"} + assert result == { + "status": "ok", + "message": "Protocol was duplicated successfully", + "protocolsCount": 2, + "dependenciesCount": 0, + } assert service.currentProject.copiedProtocolInputs == [[protocolA, protocolB]] assert mapper.savedProtocolContexts == [ {"projectId": 1, "protocolId": 110}, @@ -560,19 +613,32 @@ def test_RestartProtocolAllReturnsCollectedErrors(service): service.currentProject.protocols[10] = protocol service.currentProject.restartWorkflowInjectedErrors = ["cannot restart", "blocked"] - result = service.restartProtocolAll(10) + with pytest.raises(HTTPException) as exc: + service.restartProtocolAll(10) - assert result == ["cannot restart", "blocked"] + assert exc.value.status_code == 422 + assert exc.value.detail == ["cannot restart", "blocked"] -def test_ContinueProtocolAllIsNotImplemented(service, mapper): - with pytest.raises(NotImplementedError): - service.continueProtocolAll( - mapper=mapper, - projectId=1, - protocolId=10, - currentUser={"id": 1}, - ) +def test_ContinueProtocolAllLaunchesActiveProtocolsInResumeMode(projectServiceModule, service, mapper, monkeypatch): + monkeypatch.setattr(projectServiceModule, "MODE_RESUME", "resume-mode") + + protocol = FakeProtocol(objId=10) + activeProtocol = FakeProtocol(objId=20) + + service.currentProject.protocols[10] = protocol + service.currentProject._getSubworkflow = lambda protocolObj: (["wf-a", "wf-b"], [activeProtocol]) + + result = service.continueProtocolAll( + mapper=mapper, + projectId=1, + protocolId=10, + currentUser={"id": 1}, + ) + + assert result == {"status": "ok", "message": "Protocol subtree continued successfully"} + assert activeProtocol.runMode.get() == "resume-mode" + assert service.currentProject.launchedProtocols == [activeProtocol] def test_ResetProtocolFromReturnsSuccessWhenWorkflowResets(service): @@ -582,7 +648,7 @@ def test_ResetProtocolFromReturnsSuccessWhenWorkflowResets(service): result = service.resetProtocolFrom(10) - assert result is None + assert result == {"status": "ok", "message": "Protocol subtree reset successfully"} def test_StopProtocolStopsEachProtocol(service): @@ -593,5 +659,5 @@ def test_StopProtocolStopsEachProtocol(service): result = service.stopProtocol(["10", "11"]) - assert result is None + assert result == {"status": "ok", "message": "Protocol stopped successfully"} assert service.currentProject.stoppedProtocols == [protocolA, protocolB] \ No newline at end of file diff --git a/tests/unit/backend/utils/test_thumbnail_service_core.py b/tests/unit/backend/utils/test_thumbnail_service_core.py index 86fc87e..25811dc 100644 --- a/tests/unit/backend/utils/test_thumbnail_service_core.py +++ b/tests/unit/backend/utils/test_thumbnail_service_core.py @@ -248,6 +248,7 @@ def test_BuildProtocolThumbnailReturnsCachedEntry(service, monkeypatch, tmp_path cachePath.write_text("cached", encoding="utf-8") monkeypatch.setattr(service, "_getProtocolCachePath", lambda protocolId, size, outputName=None: cachePath) + monkeypatch.setattr(service, "_isValidCachedImage", lambda path: True) result = service.buildProtocolThumbnail(protocolId=10, force=False, size=320) @@ -293,6 +294,7 @@ def test_BuildProjectThumbnailReturnsCachedStrip(service, monkeypatch, tmp_path) cachePath.write_text("cached", encoding="utf-8") monkeypatch.setattr(service, "_getProjectCachePath", lambda size, maxProtocols: cachePath) + monkeypatch.setattr(service, "_isValidCachedImage", lambda path: True) result = service.buildProjectThumbnail(force=False, size=720, maxProtocols=6) @@ -327,8 +329,8 @@ def test_ListProtocolThumbnailItemsBuildsGroups(service, monkeypatch): ) monkeypatch.setattr( service, - "buildProtocolThumbnail", - lambda protocolId, force, size, outputName=None: { + "buildProtocolOutputThumbnail", + lambda protocolId, outputName, force, size: { "exists": True, "absolutePath": f"/tmp/{protocolId}_{outputName}.png", }, @@ -352,15 +354,15 @@ def test_ListProtocolThumbnailItemsBuildsGroups(service, monkeypatch): "outputName": "outputVol", "outputClassName": "SetOfVolumes", "exists": True, - "thumbnailUrl": "/projects/7/protocols/11/thumbnail?outputName=outputVol", - "thumbnailRebuildUrl": "/projects/7/protocols/11/thumbnail/rebuild?outputName=outputVol", + "thumbnailUrl": "/projects/7/protocols/11/outputs/outputVol/thumbnail", + "thumbnailRebuildUrl": None, }, { "outputName": "outputParticles", "outputClassName": "SetOfParticles", "exists": True, - "thumbnailUrl": "/projects/7/protocols/11/thumbnail?outputName=outputParticles", - "thumbnailRebuildUrl": "/projects/7/protocols/11/thumbnail/rebuild?outputName=outputParticles", + "thumbnailUrl": "/projects/7/protocols/11/outputs/outputParticles/thumbnail", + "thumbnailRebuildUrl": None, }, ], }