diff --git a/muninn/core/memory.py b/muninn/core/memory.py index fec88f5..f85ed46 100644 --- a/muninn/core/memory.py +++ b/muninn/core/memory.py @@ -21,37 +21,39 @@ import asyncio import hashlib import json -import uuid -import time import logging import os -from typing import List, Optional, Dict, Any, Tuple +import time from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple +from muninn.advanced.cross_agent import FederationManager +from muninn.advanced.temporal_kg import TemporalKnowledgeGraph +from muninn.chains import MemoryChainDetector +from muninn.consolidation.daemon import ConsolidationDaemon +from muninn.core.config import SUPPORTED_MODEL_PROFILES, MuninnConfig +from muninn.core.feature_flags import get_flags +from muninn.core.ingestion_manager import IngestionManager from muninn.core.types import ( - MemoryRecord, MemoryType, Provenance, SearchResult, - ExtractionResult, Entity, Relation, + ExtractionResult, + MemoryRecord, + MemoryType, + Provenance, + SearchResult, ) -from muninn.core.config import MuninnConfig, SUPPORTED_MODEL_PROFILES -from muninn.store.sqlite_metadata import SQLiteMetadataStore -from muninn.store.vector_store import VectorStore -from muninn.store.graph_store import GraphStore -from muninn.retrieval.bm25 import BM25Index -from muninn.retrieval.reranker import Reranker -from muninn.retrieval.hybrid import HybridRetriever -from muninn.retrieval.scout import MuninnScout from muninn.extraction.pipeline import ExtractionPipeline -from muninn.scoring.importance import calculate_importance, calculate_novelty -from muninn.consolidation.daemon import ConsolidationDaemon from muninn.goal import GoalCompass -from muninn.observability import OTelGenAITracer -from muninn.chains import MemoryChainDetector -from muninn.ingestion import IngestionPipeline, discover_legacy_sources as discover_legacy_sources_catalog +from muninn.ingestion import IngestionPipeline +from muninn.ingestion import discover_legacy_sources as discover_legacy_sources_catalog from muninn.ingestion.parser import infer_source_type -from muninn.core.ingestion_manager import IngestionManager -from muninn.advanced.temporal_kg import TemporalKnowledgeGraph -from muninn.advanced.cross_agent import FederationManager -from muninn.core.feature_flags import get_flags +from muninn.observability import OTelGenAITracer +from muninn.retrieval.bm25 import BM25Index +from muninn.retrieval.hybrid import HybridRetriever +from muninn.retrieval.reranker import Reranker +from muninn.retrieval.scout import MuninnScout +from muninn.store.graph_store import GraphStore +from muninn.store.sqlite_metadata import SQLiteMetadataStore +from muninn.store.vector_store import VectorStore logger = logging.getLogger("Muninn") @@ -409,19 +411,7 @@ def _upsert_memory_chain_links( candidate_records=candidate_records, ) - persisted = 0 - for link in links: - created = self._graph.add_chain_link( - predecessor_id=link.predecessor_id, - successor_id=link.successor_id, - relation_type=link.relation_type, - confidence=link.confidence, - reason=link.reason, - shared_entities=link.shared_entities, - hours_apart=link.hours_apart, - ) - if created: - persisted += 1 + persisted = self._graph.add_chain_links_batch(links) return persisted except Exception as e: logger.warning("Memory-chain linking failed (non-fatal): %s", e) @@ -476,8 +466,8 @@ async def add( # Handle terminal early returns (DEDUP_SKIP, CONFLICT_SKIP) # Note: DEDUP_SIGNAL_UPDATE is handled below as it may fall through to ADD. if ( - processed.get("id") is None - and "event" in processed + processed.get("id") is None + and "event" in processed and processed["event"] not in ("PROCESS_COMPLETE", "DEDUP_SIGNAL_UPDATE") ): return processed @@ -487,7 +477,7 @@ async def add( dedup_result = processed["dedup"] embedding = processed["embedding"] record = processed["record"] - + merged_successfully = False async with self._write_lock: existing = await asyncio.to_thread(self._metadata.get, dedup_result.existing_memory_id) @@ -514,7 +504,7 @@ async def add( asyncio.to_thread(self._bm25.add, dedup_result.existing_memory_id, merged_content, user_id, namespace) ) merged_successfully = True - + if merged_successfully: return { "id": dedup_result.existing_memory_id, @@ -557,7 +547,7 @@ def _write_graph(): uid = record.metadata.get("user_id", "global") ns = record.namespace self._graph.add_memory_node( - record.id, + record.id, extraction.summary or content[:200], user_id=uid, namespace=ns ) @@ -612,7 +602,7 @@ def _write_colbert(): } if conflict_info: result["conflict"] = conflict_info - + if self._goal_compass is not None and record.project: drift = await self._goal_compass.evaluate_drift( text=content, @@ -680,7 +670,7 @@ async def search( {"query_preview": self._otel.maybe_content(query)}, ) effective_filters = dict(filters or {}) - + # v3.24.0: Default to excluding archived memories if "archived" not in effective_filters: effective_filters["archived"] = False @@ -918,7 +908,7 @@ async def record_retrieval_feedback( ) # Update Elo rating based on feedback outcome - from muninn.scoring.elo import calculate_elo_update, INITIAL_ELO + from muninn.scoring.elo import INITIAL_ELO, calculate_elo_update record = await asyncio.to_thread(self._metadata.get, memory_id) if record: current_elo = record.metadata.get("elo_rating", INITIAL_ELO) if record.metadata else INITIAL_ELO @@ -1047,7 +1037,7 @@ async def set_project_goal( raise RuntimeError("Goal compass is disabled by feature flag") if not goal_statement.strip(): raise ValueError("goal_statement cannot be empty") - + async with self._write_lock: return await self._goal_compass.set_goal( user_id=user_id, @@ -1313,10 +1303,10 @@ async def _add_chunk_task(chunk, source_context, record_ref): chunk_metadata = dict(base_metadata) chunk_metadata.update(source_context) chunk_metadata.update(chunk.metadata) - + # Map chunk.source_type to media_type if possible media_type = chunk.source_type if chunk.source_type in ["image", "audio", "video"] else "text" - + async with semaphore: try: add_result = await self.add( @@ -1327,7 +1317,7 @@ async def _add_chunk_task(chunk, source_context, record_ref): provenance=Provenance.INGESTED, media_type=media_type, ) - + if add_result.get("event") in {"DEDUP_SKIP", "CONFLICT_SKIP"}: skipped_chunks += 1 record_ref["chunks_skipped"] += 1 @@ -1354,12 +1344,12 @@ async def _add_chunk_task(chunk, source_context, record_ref): "chunks_failed": 0, } source_payloads.append(source_record) - + if source_result.status != "processed": continue source_context = source_context_by_path.get(source_result.source_path, {}) - + # Create tasks for all chunks in this source tasks = [ _add_chunk_task(chunk, source_context, source_record) @@ -1765,7 +1755,7 @@ def _update_graph(): ns = record.namespace self._graph.delete_memory_references(record.id) self._graph.add_memory_node( - record.id, + record.id, extraction.summary or data[:200], user_id=uid, namespace=ns ) @@ -1784,7 +1774,7 @@ def _update_graph(): user_id=uid, namespace=ns, ) - + def _update_bm25(): if data is not None: uid = (record.metadata or {}).get("user_id", "global") @@ -2372,7 +2362,7 @@ async def get_temporal_knowledge( self._check_initialized() if not self._temporal_kg: return [] - + ts = float(timestamp) if timestamp is not None else time.time() # This is a read operation, usually fast, but we can offload if needed. # Kuzu reads are blocking, so offload to thread. diff --git a/muninn/store/graph_store.py b/muninn/store/graph_store.py index 8ae5a50..18f5fa5 100644 --- a/muninn/store/graph_store.py +++ b/muninn/store/graph_store.py @@ -4,13 +4,16 @@ Kuzu-based knowledge graph for entity relationships and graph-enhanced retrieval. """ -import logging -import time import json +import logging +import math import threading +import time from pathlib import Path -from typing import Optional, List, Dict, Any, Tuple -import math +from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple + +if TYPE_CHECKING: + from muninn.chains.detector import MemoryChainLink import kuzu @@ -146,10 +149,10 @@ def _initialize(self): logger.info(f"Graph store initialized at {self.db_path}") def add_entity( - self, - name: str, - entity_type: str, - user_id: str = "global", + self, + name: str, + entity_type: str, + user_id: str = "global", namespace: str = "global" ) -> bool: conn = self._get_conn() @@ -185,7 +188,7 @@ def create_relation( ) -> bool: conn = self._get_conn() now = time.time() - + s_id = f"{user_id}/{namespace}/{subject}" o_id = f"{user_id}/{namespace}/{obj}" @@ -230,19 +233,19 @@ def add_memory_node( return False def link_memory_to_entity( - self, - memory_id: str, - entity_name: str, + self, + memory_id: str, + entity_name: str, role: str = "mention", user_id: str = "global", namespace: str = "global" ) -> bool: conn = self._get_conn() e_id = f"{user_id}/{namespace}/{entity_name}" - + # Ensure entity exists in this scope self.add_entity(entity_name, "unknown", user_id, namespace) - + try: conn.execute( "MATCH (m:Memory {id: $mid}), (e:Entity {id: $eid}) " @@ -254,9 +257,73 @@ def link_memory_to_entity( logger.debug(f"Memory-entity link: {e}") return False + def add_chain_links_batch(self, links: 'Iterable[MemoryChainLink]') -> int: + """ + Batch add directed memory-to-memory chain edges. + Returns the number of links successfully created. + """ + from collections import defaultdict + + conn = self._get_conn() + now = time.time() + + # Group valid links by relation type + grouped_data = defaultdict(list) + + for link in links: + rel = str(link.relation_type or "PRECEDES").upper() + if rel not in {"PRECEDES", "CAUSES"}: + continue + if link.predecessor_id == link.successor_id: + continue + + conf = max(0.0, min(1.0, float(link.confidence))) + hours = float(link.hours_apart) if link.hours_apart is not None else None + payload = json.dumps(link.shared_entities or [], ensure_ascii=False) + + grouped_data[rel].append({ + "pred": link.predecessor_id, + "succ": link.successor_id, + "conf": conf, + "reason": (link.reason or "")[:500], + "shared": payload, + "hours": hours, + "now": now, + }) + + persisted = 0 + for rel, data in grouped_data.items(): + if not data: + continue + + try: + conn.execute( + f"UNWIND $data AS d MATCH (a:Memory {{id: d.pred}}), (b:Memory {{id: d.succ}}) " + f"CREATE (a)-[:{rel} {{confidence: d.conf, reason: d.reason, " + f"shared_entities_json: d.shared, hours_apart: d.hours, created_at: d.now}}]->(b)", + {"data": data} + ) + persisted += len(data) + except Exception as e: + logger.debug(f"Batch add memory chain links for {rel}: {e}") + # Fallback to individual inserts if batch fails + for row in data: + try: + conn.execute( + f"MATCH (a:Memory {{id: $pred}}), (b:Memory {{id: $succ}}) " + f"CREATE (a)-[:{rel} {{confidence: $conf, reason: $reason, " + f"shared_entities_json: $shared, hours_apart: $hours, created_at: $now}}]->(b)", + row + ) + persisted += 1 + except Exception as inner_e: + logger.debug(f"Fallback individual add memory chain link: {inner_e}") + + return persisted + def find_related_memories( - self, - query_entities: List[str], + self, + query_entities: List[str], limit: int = 20, user_id: str = "global", namespace: str = "global" @@ -384,9 +451,9 @@ def search_memories( unique = sorted(seen.values(), key=lambda x: x["score"], reverse=True) return unique[:limit] def get_entity_centrality( - self, - entity_name: str, - user_id: str = "global", + self, + entity_name: str, + user_id: str = "global", namespace: str = "global" ) -> float: """Get degree centrality of an entity (normalized by max possible degree) within a scope.""" @@ -468,14 +535,14 @@ def get_entity_count(self) -> int: return 0 def get_all_entities( - self, + self, limit: int = 100, user_id: Optional[str] = None, namespace: Optional[str] = None ) -> List[Dict[str, Any]]: conn = self._get_conn() entities = [] - + where_clause = "WHERE 1=1" params = {"limit": limit} if user_id: @@ -649,4 +716,4 @@ def close(self): self._db = None # Clear current thread's connection if it exists if hasattr(self._thread_local, "conn"): - del self._thread_local.conn \ No newline at end of file + del self._thread_local.conn diff --git a/tests/test_memory_chains.py b/tests/test_memory_chains.py index 226d964..8520fef 100644 --- a/tests/test_memory_chains.py +++ b/tests/test_memory_chains.py @@ -87,7 +87,7 @@ async def _embed(_text): memory._vectors = MagicMock() memory._vectors.count.return_value = 0 memory._graph = MagicMock() - memory._graph.add_chain_link.return_value = True + memory._graph.add_chain_links_batch.return_value = 1 memory._bm25 = MagicMock() memory._goal_compass = None memory._ingestion_manager = IngestionManager(memory) @@ -103,7 +103,7 @@ async def _embed(_text): assert result["event"] == "ADD" assert result["chain_links_created"] >= 1 - assert memory._graph.add_chain_link.call_count >= 1 + assert memory._graph.add_chain_links_batch.call_count >= 1 stored_record = memory._metadata.add.call_args.args[0] assert stored_record.metadata["entity_names"] == ["Redis", "Queue"] \ No newline at end of file