diff --git a/.gitignore b/.gitignore index 686575d8..9ec274ff 100755 --- a/.gitignore +++ b/.gitignore @@ -219,4 +219,4 @@ evaluation/locomo_evaluation/results_ref/demo/results/ # i18n translation progress .translation_progress.json -.review_progress.json \ No newline at end of file +.review_progress.jsontree_less.txt diff --git a/src/agentic_layer/memory_manager.py b/src/agentic_layer/memory_manager.py index df42b05a..80d2334c 100644 --- a/src/agentic_layer/memory_manager.py +++ b/src/agentic_layer/memory_manager.py @@ -298,8 +298,8 @@ async def retrieve_mem_keyword( ) -> RetrieveMemResponse: """Keyword-based memory retrieval""" start_time = time.perf_counter() - memory_type = ( - retrieve_mem_request.memory_types[0].value + memory_type_label = ( + ','.join([mt.value for mt in retrieve_mem_request.memory_types]) if retrieve_mem_request.memory_types else 'unknown' ) @@ -312,7 +312,7 @@ async def retrieve_mem_keyword( status = 'success' if hits else 'empty_result' record_retrieve_request( - memory_type=memory_type, + memory_type=memory_type_label, retrieve_method=RetrieveMethod.KEYWORD.value, status=status, duration_seconds=duration, @@ -323,7 +323,7 @@ async def retrieve_mem_keyword( except Exception as e: duration = time.perf_counter() - start_time record_retrieve_request( - memory_type=memory_type, + memory_type=memory_type_label, retrieve_method=RetrieveMethod.KEYWORD.value, status='error', duration_seconds=duration, @@ -337,11 +337,12 @@ async def get_keyword_search_results( retrieve_mem_request: 'RetrieveMemRequest', retrieve_method: str = RetrieveMethod.KEYWORD.value, ) -> List[Dict[str, Any]]: - """Keyword search with stage-level metrics""" + """Keyword search with stage-level metrics, supports multiple memory_types""" stage_start = time.perf_counter() - memory_type = ( - retrieve_mem_request.memory_types[0].value - if retrieve_mem_request.memory_types + memory_types = retrieve_mem_request.memory_types or [] + memory_type_label = ( + ','.join([mt.value for mt in memory_types]) + if memory_types else 'unknown' ) @@ -356,7 +357,6 @@ async def get_keyword_search_results( group_id = retrieve_mem_request.group_id start_time = retrieve_mem_request.start_time end_time = retrieve_mem_request.end_time - memory_types = retrieve_mem_request.memory_types # Convert query string to search word list # Use jieba for search mode word segmentation, then filter stopwords @@ -375,47 +375,49 @@ async def get_keyword_search_results( if end_time is not None: date_range["lte"] = end_time - mem_type = memory_types[0] - - repo_class = ES_REPO_MAP.get(mem_type) - if not repo_class: - logger.warning(f"Unsupported memory_type: {mem_type}") - return [] + # Search across ALL memory_types, not just the first one + all_results = [] + for mem_type in memory_types: + repo_class = ES_REPO_MAP.get(mem_type) + if not repo_class: + logger.warning(f"Unsupported memory_type: {mem_type}, skipping") + continue - es_repo = get_bean_by_type(repo_class) - logger.debug(f"Using {repo_class.__name__} for {mem_type}") + es_repo = get_bean_by_type(repo_class) + logger.debug(f"Using {repo_class.__name__} for {mem_type}") - results = await es_repo.multi_search( - query=query_words, - user_id=user_id, - group_id=group_id, - size=top_k, - from_=0, - date_range=date_range, - ) + results = await es_repo.multi_search( + query=query_words, + user_id=user_id, + group_id=group_id, + size=top_k, + from_=0, + date_range=date_range, + ) - # Mark memory_type, search_source, and unified score - if results: - for r in results: - r['memory_type'] = mem_type.value - r['_search_source'] = RetrieveMethod.KEYWORD.value - r['id'] = r.get('_id', '') # Unify ES '_id' to 'id' - r['score'] = r.get('_score', 0.0) # Unified score field + # Mark memory_type, search_source, and unified score + if results: + for r in results: + r['memory_type'] = mem_type.value + r['_search_source'] = RetrieveMethod.KEYWORD.value + r['id'] = r.get('_id', '') # Unify ES '_id' to 'id' + r['score'] = r.get('_score', 0.0) # Unified score field + all_results.extend(results) # Record stage metrics record_retrieve_stage( retrieve_method=retrieve_method, stage=RetrieveMethod.KEYWORD.value, - memory_type=memory_type, + memory_type=memory_type_label, duration_seconds=time.perf_counter() - stage_start, ) - return results or [] + return all_results except Exception as e: record_retrieve_stage( retrieve_method=retrieve_method, stage=RetrieveMethod.KEYWORD.value, - memory_type=memory_type, + memory_type=memory_type_label, duration_seconds=time.perf_counter() - stage_start, ) record_retrieve_error( @@ -433,8 +435,8 @@ async def retrieve_mem_vector( ) -> RetrieveMemResponse: """Vector-based memory retrieval""" start_time = time.perf_counter() - memory_type = ( - retrieve_mem_request.memory_types[0].value + memory_type_label = ( + ','.join([mt.value for mt in retrieve_mem_request.memory_types]) if retrieve_mem_request.memory_types else 'unknown' ) @@ -447,7 +449,7 @@ async def retrieve_mem_vector( status = 'success' if hits else 'empty_result' record_retrieve_request( - memory_type=memory_type, + memory_type=memory_type_label, retrieve_method=RetrieveMethod.VECTOR.value, status=status, duration_seconds=duration, @@ -458,7 +460,7 @@ async def retrieve_mem_vector( except Exception as e: duration = time.perf_counter() - start_time record_retrieve_request( - memory_type=memory_type, + memory_type=memory_type_label, retrieve_method=RetrieveMethod.VECTOR.value, status='error', duration_seconds=duration, @@ -472,12 +474,14 @@ async def get_vector_search_results( retrieve_mem_request: 'RetrieveMemRequest', retrieve_method: str = RetrieveMethod.VECTOR.value, ) -> List[Dict[str, Any]]: - """Vector search with stage-level metrics (embedding + milvus_search)""" - memory_type = ( - retrieve_mem_request.memory_types[0].value - if retrieve_mem_request.memory_types + """Vector search with stage-level metrics (embedding + milvus_search), supports multiple memory_types""" + memory_types = retrieve_mem_request.memory_types or [] + memory_type_label = ( + ','.join([mt.value for mt in memory_types]) + if memory_types else 'unknown' ) + stage_start = time.perf_counter() try: # Get parameters from Request @@ -497,7 +501,6 @@ async def get_vector_search_results( top_k = retrieve_mem_request.top_k start_time = retrieve_mem_request.start_time end_time = retrieve_mem_request.end_time - mem_type = retrieve_mem_request.memory_types[0] logger.debug( f"retrieve_mem_vector called with query: {query}, user_id: {user_id}, group_id: {group_id}, top_k: {top_k}" @@ -506,7 +509,7 @@ async def get_vector_search_results( # Get vectorization service vectorize_service = get_vectorize_service() - # Convert query text to vector (embedding stage) + # Convert query text to vector (embedding stage) - only once for all types logger.debug(f"Starting to vectorize query text: {query}") embedding_start = time.perf_counter() query_vector = await vectorize_service.get_embedding(query) @@ -514,100 +517,108 @@ async def get_vector_search_results( record_retrieve_stage( retrieve_method=retrieve_method, stage='embedding', - memory_type=memory_type, + memory_type=memory_type_label, duration_seconds=time.perf_counter() - embedding_start, ) logger.debug( f"Query text vectorization completed, vector dimension: {len(query_vector_list)}" ) - # Select Milvus repository based on memory type - match mem_type: - case MemoryType.FORESIGHT: - milvus_repo = get_bean_by_type(ForesightMilvusRepository) - case MemoryType.EVENT_LOG: - milvus_repo = get_bean_by_type(EventLogMilvusRepository) - case MemoryType.EPISODIC_MEMORY: - milvus_repo = get_bean_by_type(EpisodicMemoryMilvusRepository) - case _: - raise ValueError(f"Unsupported memory type: {mem_type}") - - # Handle time range filter conditions - start_time_dt = None - end_time_dt = None - current_time_dt = None - - if start_time is not None: - start_time_dt = ( - from_iso_format(start_time) - if isinstance(start_time, str) - else start_time - ) + # Map memory type to Milvus repository + MILVUS_REPO_MAP = { + MemoryType.FORESIGHT: ForesightMilvusRepository, + MemoryType.EVENT_LOG: EventLogMilvusRepository, + MemoryType.EPISODIC_MEMORY: EpisodicMemoryMilvusRepository, + } + + # Search across ALL memory_types, not just the first one + all_search_results = [] + for mem_type in memory_types: + milvus_repo_class = MILVUS_REPO_MAP.get(mem_type) + if not milvus_repo_class: + logger.warning(f"Unsupported memory type for vector search: {mem_type}, skipping") + continue + + milvus_repo = get_bean_by_type(milvus_repo_class) + + # Handle time range filter conditions + start_time_dt = None + end_time_dt = None + current_time_dt = None + + if start_time is not None: + start_time_dt = ( + from_iso_format(start_time) + if isinstance(start_time, str) + else start_time + ) - if end_time is not None: - if isinstance(end_time, str): - end_time_dt = from_iso_format(end_time) - # If date only format, set to end of day - if len(end_time) == 10: - end_time_dt = end_time_dt.replace(hour=23, minute=59, second=59) + if end_time is not None: + if isinstance(end_time, str): + end_time_dt = from_iso_format(end_time) + # If date only format, set to end of day + if len(end_time) == 10: + end_time_dt = end_time_dt.replace(hour=23, minute=59, second=59) + else: + end_time_dt = end_time + + # Handle foresight time range (only valid for foresight) + if mem_type == MemoryType.FORESIGHT: + if retrieve_mem_request.start_time: + start_time_dt = from_iso_format(retrieve_mem_request.start_time) + if retrieve_mem_request.end_time: + end_time_dt = from_iso_format(retrieve_mem_request.end_time) + if retrieve_mem_request.current_time: + current_time_dt = from_iso_format(retrieve_mem_request.current_time) + + # Call Milvus vector search (pass different parameters based on memory type) + milvus_start = time.perf_counter() + if mem_type == MemoryType.FORESIGHT: + # Foresight: supports time range and validity filtering, supports radius parameter + search_results = await milvus_repo.vector_search( + query_vector=query_vector_list, + user_id=user_id, + group_id=group_id, + start_time=start_time_dt, + end_time=end_time_dt, + current_time=current_time_dt, + limit=top_k, + score_threshold=0.0, + radius=retrieve_mem_request.radius, + ) else: - end_time_dt = end_time - - # Handle foresight time range (only valid for foresight) - if mem_type == MemoryType.FORESIGHT: - if retrieve_mem_request.start_time: - start_time_dt = from_iso_format(retrieve_mem_request.start_time) - if retrieve_mem_request.end_time: - end_time_dt = from_iso_format(retrieve_mem_request.end_time) - if retrieve_mem_request.current_time: - current_time_dt = from_iso_format(retrieve_mem_request.current_time) - - # Call Milvus vector search (pass different parameters based on memory type) - milvus_start = time.perf_counter() - if mem_type == MemoryType.FORESIGHT: - # Foresight: supports time range and validity filtering, supports radius parameter - search_results = await milvus_repo.vector_search( - query_vector=query_vector_list, - user_id=user_id, - group_id=group_id, - start_time=start_time_dt, - end_time=end_time_dt, - current_time=current_time_dt, - limit=top_k, - score_threshold=0.0, - radius=retrieve_mem_request.radius, - ) - else: - # Episodic memory and event log: use timestamp filtering, supports radius parameter - search_results = await milvus_repo.vector_search( - query_vector=query_vector_list, - user_id=user_id, - group_id=group_id, - start_time=start_time_dt, - end_time=end_time_dt, - limit=top_k, - score_threshold=0.0, - radius=retrieve_mem_request.radius, + # Episodic memory and event log: use timestamp filtering, supports radius parameter + search_results = await milvus_repo.vector_search( + query_vector=query_vector_list, + user_id=user_id, + group_id=group_id, + start_time=start_time_dt, + end_time=end_time_dt, + limit=top_k, + score_threshold=0.0, + radius=retrieve_mem_request.radius, + ) + record_retrieve_stage( + retrieve_method=retrieve_method, + stage='milvus_search', + memory_type=mem_type.value, + duration_seconds=time.perf_counter() - milvus_start, ) - record_retrieve_stage( - retrieve_method=retrieve_method, - stage='milvus_search', - memory_type=memory_type, - duration_seconds=time.perf_counter() - milvus_start, - ) - for r in search_results: - r['memory_type'] = mem_type.value - r['_search_source'] = RetrieveMethod.VECTOR.value - # Milvus already uses 'score', no need to rename + for r in search_results: + r['memory_type'] = mem_type.value + r['_search_source'] = RetrieveMethod.VECTOR.value + # Milvus already uses 'score', no need to rename - return search_results + all_search_results.extend(search_results) + + return all_search_results except Exception as e: record_retrieve_stage( retrieve_method=retrieve_method, stage=RetrieveMethod.VECTOR.value, - memory_type=memory_type, - duration_seconds=time.perf_counter() - milvus_start, + memory_type=memory_type_label, + duration_seconds=time.perf_counter() - stage_start, ) record_retrieve_error( retrieve_method=retrieve_method, @@ -615,6 +626,7 @@ async def get_vector_search_results( error_type=self._classify_retrieve_error(e), ) logger.error(f"Error in get_vector_search_results: {e}") + return [] raise # Hybrid memory retrieval @@ -624,8 +636,8 @@ async def retrieve_mem_hybrid( ) -> RetrieveMemResponse: """Hybrid memory retrieval: keyword + vector + rerank""" start_time = time.perf_counter() - memory_type = ( - retrieve_mem_request.memory_types[0].value + memory_type_label = ( + ','.join([mt.value for mt in retrieve_mem_request.memory_types]) if retrieve_mem_request.memory_types else 'unknown' ) @@ -638,7 +650,7 @@ async def retrieve_mem_hybrid( status = 'success' if hits else 'empty_result' record_retrieve_request( - memory_type=memory_type, + memory_type=memory_type_label, retrieve_method=RetrieveMethod.HYBRID.value, status=status, duration_seconds=duration, @@ -649,7 +661,7 @@ async def retrieve_mem_hybrid( except Exception as e: duration = time.perf_counter() - start_time record_retrieve_request( - memory_type=memory_type, + memory_type=memory_type_label, retrieve_method=RetrieveMethod.HYBRID.value, status='error', duration_seconds=duration, @@ -699,8 +711,10 @@ async def _search_hybrid( retrieve_method: str = RetrieveMethod.HYBRID.value, ) -> List[Dict]: """Core hybrid search: keyword + vector + rerank, returns flat list""" - memory_type = ( - request.memory_types[0].value if request.memory_types else 'unknown' + memory_type_label = ( + ','.join([mt.value for mt in request.memory_types]) + if request.memory_types + else 'unknown' ) # Run keyword and vector search concurrently kw_results, vec_results = await asyncio.gather( @@ -713,7 +727,7 @@ async def _search_hybrid( h for h in vec_results if h.get('id') not in seen_ids ] return await self._rerank( - request.query, merged_results, request.top_k, memory_type, retrieve_method + request.query, merged_results, request.top_k, memory_type_label, retrieve_method ) async def _search_rrf( @@ -722,8 +736,10 @@ async def _search_rrf( retrieve_method: str = RetrieveMethod.RRF.value, ) -> List[Dict]: """Core RRF search: keyword + vector + RRF fusion, returns flat list""" - memory_type = ( - request.memory_types[0].value if request.memory_types else 'unknown' + memory_type_label = ( + ','.join([mt.value for mt in request.memory_types]) + if request.memory_types + else 'unknown' ) # Run keyword and vector search concurrently @@ -740,7 +756,7 @@ async def _search_rrf( record_retrieve_stage( retrieve_method=retrieve_method, stage='rrf_fusion', - memory_type=memory_type, + memory_type=memory_type_label, duration_seconds=time.perf_counter() - rrf_start, ) @@ -766,7 +782,11 @@ async def _to_response( """Convert flat hits list to grouped RetrieveMemResponse""" user_id = req.user_id if req else "" source_type = req.retrieve_method.value - memory_type = req.memory_types[0].value + memory_type_label = ( + ','.join([mt.value for mt in req.memory_types]) + if req.memory_types + else 'unknown' + ) if not hits: return RetrieveMemResponse( @@ -777,10 +797,10 @@ async def _to_response( total_count=0, has_more=False, query_metadata=Metadata( - source=source_type, user_id=user_id or "", memory_type=memory_type + source=source_type, user_id=user_id or "", memory_type=memory_type_label ), metadata=Metadata( - source=source_type, user_id=user_id or "", memory_type=memory_type + source=source_type, user_id=user_id or "", memory_type=memory_type_label ), ) memories, scores, importance_scores, original_data, total_count = ( @@ -794,10 +814,10 @@ async def _to_response( total_count=total_count, has_more=False, query_metadata=Metadata( - source=source_type, user_id=user_id or "", memory_type=memory_type + source=source_type, user_id=user_id or "", memory_type=memory_type_label ), metadata=Metadata( - source=source_type, user_id=user_id or "", memory_type=memory_type + source=source_type, user_id=user_id or "", memory_type=memory_type_label ), ) @@ -808,8 +828,8 @@ async def retrieve_mem_rrf( ) -> RetrieveMemResponse: """RRF-based memory retrieval: keyword + vector + RRF fusion""" start_time = time.perf_counter() - memory_type = ( - retrieve_mem_request.memory_types[0].value + memory_type_label = ( + ','.join([mt.value for mt in retrieve_mem_request.memory_types]) if retrieve_mem_request.memory_types else 'unknown' ) @@ -822,7 +842,7 @@ async def retrieve_mem_rrf( status = 'success' if hits else 'empty_result' record_retrieve_request( - memory_type=memory_type, + memory_type=memory_type_label, retrieve_method=RetrieveMethod.RRF.value, status=status, duration_seconds=duration, @@ -833,7 +853,7 @@ async def retrieve_mem_rrf( except Exception as e: duration = time.perf_counter() - start_time record_retrieve_request( - memory_type=memory_type, + memory_type=memory_type_label, retrieve_method=RetrieveMethod.RRF.value, status='error', duration_seconds=duration, @@ -855,7 +875,11 @@ async def retrieve_mem_agentic( req = retrieve_mem_request # alias top_k = req.top_k config = AgenticConfig() - memory_type = req.memory_types[0].value if req.memory_types else 'unknown' + memory_type_label = ( + ','.join([mt.value for mt in req.memory_types]) + if req.memory_types + else 'unknown' + ) try: llm_provider = LLMProvider( @@ -883,7 +907,7 @@ async def retrieve_mem_agentic( if not round1: duration = time.perf_counter() - start_time record_retrieve_request( - memory_type=memory_type, + memory_type=memory_type_label, retrieve_method=RetrieveMethod.AGENTIC.value, status='empty_result', duration_seconds=duration, @@ -894,7 +918,7 @@ async def retrieve_mem_agentic( # ========== Rerank → max(5, top_k) for LLM & return ========== rerank_n = max(config.round1_rerank_top_n, top_k) reranked = await self._rerank( - req.query, round1, rerank_n, memory_type, 'agentic', + req.query, round1, rerank_n, memory_type_label, 'agentic', instruction=config.reranker_instruction, ) # Use top 5 for sufficiency check @@ -917,7 +941,7 @@ async def retrieve_mem_agentic( final_results = reranked[:top_k] duration = time.perf_counter() - start_time record_retrieve_request( - memory_type=memory_type, + memory_type=memory_type_label, retrieve_method=RetrieveMethod.AGENTIC.value, status='success', duration_seconds=duration, @@ -964,13 +988,13 @@ async def do_search(q: str) -> List[Dict]: # ========== Final Rerank ========== final = await self._rerank( - req.query, combined, top_k, memory_type, 'agentic', + req.query, combined, top_k, memory_type_label, 'agentic', instruction=config.reranker_instruction, ) duration = time.perf_counter() - start_time record_retrieve_request( - memory_type=memory_type, + memory_type=memory_type_label, retrieve_method=RetrieveMethod.AGENTIC.value, status='success', duration_seconds=duration, @@ -982,7 +1006,7 @@ async def do_search(q: str) -> List[Dict]: except Exception as e: duration = time.perf_counter() - start_time record_retrieve_request( - memory_type=memory_type, + memory_type=memory_type_label, retrieve_method=RetrieveMethod.AGENTIC.value, status='error', duration_seconds=duration, diff --git a/src/api_specs/dtos/memory.py b/src/api_specs/dtos/memory.py index 4b5a6f57..81b26f90 100644 --- a/src/api_specs/dtos/memory.py +++ b/src/api_specs/dtos/memory.py @@ -922,3 +922,48 @@ class DeleteMemoriesResponse(BaseApiResponse[DeleteMemoriesResult]): ] } } + + +class RestoreMemoriesResult(BaseModel): + """Restore soft-deleted memories result data""" + + filters: List[str] = Field( + default_factory=list, + description="List of filter types used for restoration", + examples=[["event_id"], ["user_id"]], + ) + count: int = Field( + default=0, description="Number of memories restored", examples=[1, 25] + ) + + +class RestoreMemoriesResponse(BaseApiResponse[RestoreMemoriesResult]): + """Restore soft-deleted memories API response + + Response for POST /api/v1/memories/restore endpoint. + """ + + result: RestoreMemoriesResult = Field(description="Restore operation result") + + model_config = { + "json_schema_extra": { + "examples": [ + { + "summary": "Restore by event_id", + "value": { + "status": "ok", + "message": "Successfully restored 1 memory", + "result": {"filters": ["event_id"], "count": 1}, + }, + }, + { + "summary": "Restore by user_id", + "value": { + "status": "ok", + "message": "Successfully restored 25 memories", + "result": {"filters": ["user_id"], "count": 25}, + }, + }, + ] + } + } diff --git a/src/infra_layer/adapters/input/api/memory/memory_controller.py b/src/infra_layer/adapters/input/api/memory/memory_controller.py index a4a4d57e..36da4cfb 100644 --- a/src/infra_layer/adapters/input/api/memory/memory_controller.py +++ b/src/infra_layer/adapters/input/api/memory/memory_controller.py @@ -50,12 +50,14 @@ SaveConversationMetaResponse, PatchConversationMetaResponse, DeleteMemoriesResponse, + RestoreMemoriesResponse, ) from core.request.timeout_background import timeout_to_background from core.request import log_request from core.component.redis_provider import RedisProvider from service.memory_request_log_service import MemoryRequestLogService from service.memcell_delete_service import MemCellDeleteService +from service.memcell_restore_service import MemCellRestoreService from service.conversation_meta_service import ConversationMetaService from api_specs.memory_types import RawDataType from agentic_layer.metrics.memorize_metrics import ( @@ -1120,3 +1122,121 @@ async def delete_memories( status_code=500, detail="Failed to delete memories, please try again later", ) from e + + @post( + "/restore", + response_model=RestoreMemoriesResponse, + summary="Restore soft-deleted memories", + description=""" + Restore soft-deleted memory records based on filter criteria + + ## Functionality: + - Restore previously soft-deleted memories + - Supports restore by event_id (single memory) or user_id (batch) + - At least one filter must be specified + + ## Filter Parameters (provide one): + - **event_id**: Restore a specific memory by its event_id + - **user_id**: Restore all soft-deleted memories of a user + + ## Use Cases: + - Undo accidental deletion + - Recover user data + - Data restoration after testing + """, + responses={ + 400: { + "description": "Request parameter error", + "content": { + "application/json": { + "example": { + "status": ErrorStatus.FAILED.value, + "code": ErrorCode.INVALID_PARAMETER.value, + "message": "At least one of event_id or user_id must be provided", + } + } + }, + }, + 500: { + "description": "Internal server error", + "content": { + "application/json": { + "example": { + "status": ErrorStatus.FAILED.value, + "code": ErrorCode.SYSTEM_ERROR.value, + "message": "Failed to restore memories, please try again later", + } + } + }, + }, + }, + ) + async def restore_memories( + self, + fastapi_request: FastAPIRequest, + request_body=None, + ) -> RestoreMemoriesResponse: + """ + Restore soft-deleted memory data based on filter criteria + """ + del request_body + + try: + from core.oxm.constants import MAGIC_ALL + + params = await self._collect_request_params(fastapi_request) + + event_id = params.get("event_id", MAGIC_ALL) + user_id = params.get("user_id", MAGIC_ALL) + + # Validate: at least one filter required + if (not event_id or event_id == MAGIC_ALL) and ( + not user_id or user_id == MAGIC_ALL + ): + raise HTTPException( + status_code=400, + detail="At least one of event_id or user_id must be provided", + ) + + logger.info( + "Received restore request: event_id=%s, user_id=%s", + event_id, + user_id, + ) + + # Get restore service + restore_service = get_bean_by_type(MemCellRestoreService) + + # Execute restore + result = await restore_service.restore_by_combined_criteria( + event_id=event_id, + user_id=user_id, + ) + + if not result["success"]: + error_msg = result.get( + "error", "No soft-deleted memories found matching the criteria" + ) + logger.warning("Restore operation returned no results: %s", result) + raise HTTPException(status_code=404, detail=error_msg) + + logger.info( + "Restore request completed successfully: filters=%s, count=%d", + result["filters"], + result["count"], + ) + + return { + "status": ErrorStatus.OK.value, + "message": f"Successfully restored {result['count']} {'memory' if result['count'] == 1 else 'memories'}", + "result": {"filters": result["filters"], "count": result["count"]}, + } + + except HTTPException: + raise + except Exception as e: + logger.error("Restore request processing failed: %s", e, exc_info=True) + raise HTTPException( + status_code=500, + detail="Failed to restore memories, please try again later", + ) from e diff --git a/src/service/memcell_restore_service.py b/src/service/memcell_restore_service.py new file mode 100644 index 00000000..31489064 --- /dev/null +++ b/src/service/memcell_restore_service.py @@ -0,0 +1,95 @@ +""" +MemCell Restore Service - Handle restore logic for soft-deleted MemCells + +Provides multiple restoration methods: +- Restore by single event_id +- Batch restore by user_id +- Batch restore by combined criteria +""" + +from typing import Optional +from core.di.decorators import component +from core.observation.logger import get_logger +from infra_layer.adapters.out.persistence.repository.memcell_raw_repository import ( + MemCellRawRepository, +) + +logger = get_logger(__name__) + + +@component("memcell_restore_service") +class MemCellRestoreService: + """MemCell restore service for soft-deleted records""" + + def __init__(self, memcell_repository: MemCellRawRepository): + self.memcell_repository = memcell_repository + logger.info("MemCellRestoreService initialized") + + async def restore_by_combined_criteria( + self, + event_id: Optional[str] = None, + user_id: Optional[str] = None, + ) -> dict: + """ + Restore soft-deleted MemCells based on combined criteria + + Args: + event_id: The event_id of MemCell + user_id: User ID (batch restore all deleted memories of a user) + + Returns: + dict: Dictionary containing restoration results + - filters: List of filter conditions used + - count: Number of restored records + - success: Whether the operation succeeded + """ + from core.oxm.constants import MAGIC_ALL + + filters_used = [] + + # Restore by event_id + if event_id and event_id != MAGIC_ALL: + filters_used.append("event_id") + try: + success = await self.memcell_repository.restore_by_event_id(event_id) + return { + "filters": filters_used, + "count": 1 if success else 0, + "success": success, + } + except Exception as e: + logger.error("Failed to restore by event_id: %s", e) + return { + "filters": filters_used, + "count": 0, + "success": False, + "error": str(e), + } + + # Restore by user_id + if user_id and user_id != MAGIC_ALL: + filters_used.append("user_id") + try: + count = await self.memcell_repository.restore_by_user_id(user_id) + return { + "filters": filters_used, + "count": count, + "success": count > 0, + } + except Exception as e: + logger.error("Failed to restore by user_id: %s", e) + return { + "filters": filters_used, + "count": 0, + "success": False, + "error": str(e), + } + + # No filter conditions provided + logger.warning("No restore criteria provided") + return { + "filters": [], + "count": 0, + "success": False, + "error": "At least one of event_id or user_id must be provided", + }