-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathgraph_query.py
More file actions
212 lines (182 loc) · 9.77 KB
/
graph_query.py
File metadata and controls
212 lines (182 loc) · 9.77 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
import logging
import sqlite3
from typing import List, Dict, Any, Set
from opentelemetry import trace
from .telemetry import get_tracer
tracer = get_tracer(__name__)
logger = logging.getLogger(__name__)
class GraphQueryService:
"""Encapsulates complex Recursive CTE logic for graph traversal."""
def __init__(self, db_manager):
self.db = db_manager
def get_related_memories(self, start_memory_ids: List[int], anchor_entities: List[str] = None, gravity: float = 1.0, depth: int = 2, per_node_limit: int = 10, allowed_edges: List[str] = None) -> Dict[int, float]:
"""
Traverses the graph to find memories linked to the start nodes.
Path: Memory <-> Entity <-> Entity <-> Memory
Returns: Dict[memory_id, max_discovery_score]
"""
with tracer.start_as_current_span("reverie.graph.traversal") as span:
span.set_attribute("graph.start_nodes", len(start_memory_ids))
span.set_attribute("graph.depth", depth)
if allowed_edges:
span.set_attribute("graph.allowed_edges", allowed_edges)
if not start_memory_ids:
return {}
cursor = self.db.get_cursor()
# Resolve anchor entities to IDs
anchor_ids = set()
if anchor_entities:
placeholders = ",".join(["?"] * len(anchor_entities))
query = f"SELECT id FROM entities WHERE name IN ({placeholders})"
params = tuple(anchor_entities)
with self.db.trace_query("SELECT", "entities", query, params) as span:
cursor.execute(query, params)
anchor_ids = {row[0] for row in cursor.fetchall()}
# Track state: (node_id, node_type) -> max_score
visited = {}
for mid in start_memory_ids:
visited[(mid, 'MEMORY')] = 1.0 # Seeds start with max confidence
current_layer = []
for mid in start_memory_ids:
current_layer.append((mid, 'MEMORY'))
# All found memories: id -> score
found_memories = {}
# Edge Filter Clause
edge_filter = ""
edge_params = []
if allowed_edges:
placeholders = ",".join(["?"] * len(allowed_edges))
edge_filter = f"AND r.relation_type IN ({placeholders})"
edge_params = list(allowed_edges)
for level in range(depth):
if not current_layer:
break
next_layer = []
# Batch process current_layer to respect SQLite parameter limits
BATCH_SIZE = 400
for i in range(0, len(current_layer), BATCH_SIZE):
batch = current_layer[i:i+BATCH_SIZE]
# Build current_layer VALUES clause
values_placeholders = ",".join(["(?, ?)"] * len(batch))
values_params = []
for nid, ntype in batch:
values_params.extend([nid, ntype])
anchor_list = list(anchor_ids) if anchor_ids else [-1]
anchor_placeholders = ",".join(["?"] * len(anchor_list))
# Bulk Expansion Query:
# 1. Identifies all neighbors (forward and backward edges) for the batch
# 2. Ranks neighbors per source node using ROW_NUMBER()
# 3. Applies the per_node_limit in-database to manage memory
bulk_query = f"""
WITH current_layer_nodes(node_id, node_type) AS (
VALUES {values_placeholders}
),
candidates AS (
-- Forward edges: current_layer is source
SELECT
r.target_id as next_id, r.target_type as next_type, r.confidence_score, r.id as rel_id,
r.source_id, r.source_type
FROM memory_relations r
JOIN current_layer_nodes cl ON r.source_id = cl.node_id AND r.source_type = cl.node_type
WHERE 1=1 {edge_filter}
UNION ALL
-- Backward edges: current_layer is target
SELECT
r.source_id as next_id, r.source_type as next_type, r.confidence_score, r.id as rel_id,
r.target_id as source_id, r.target_type as source_type
FROM memory_relations r
JOIN current_layer_nodes cl ON r.target_id = cl.node_id AND r.target_type = cl.node_type
WHERE 1=1 {edge_filter}
),
scored AS (
SELECT
next_id, next_type, confidence_score, rel_id, source_id, source_type,
CASE WHEN next_type = 'ENTITY' AND next_id IN ({anchor_placeholders}) THEN 1 ELSE 0 END as is_anchor
FROM candidates
),
ranked AS (
SELECT
next_id, next_type,
(confidence_score * (1 + (is_anchor * ?))) as d_score,
ROW_NUMBER() OVER (
PARTITION BY source_id, source_type
ORDER BY
(CASE WHEN next_type = 'ENTITY' THEN 1 ELSE 0 END) DESC,
(confidence_score * (1 + (is_anchor * ?))) DESC,
rel_id ASC
) as rn
FROM scored
)
SELECT next_id, next_type, d_score
FROM ranked
WHERE rn <= ?
ORDER BY d_score DESC
"""
# Parameters: values_params, edge_params (x2), anchor_list (x1), gravity (x2), per_node_limit
params = values_params + edge_params + edge_params + anchor_list + [gravity, gravity, per_node_limit]
with self.db.trace_query("SELECT", "memory_relations", bulk_query, tuple(params)) as span:
span.set_attribute("graph.batch_size", len(batch))
cursor.execute(bulk_query, tuple(params))
rows = cursor.fetchall()
for next_id, next_type, d_score in rows:
if (next_id, next_type) not in visited:
visited[(next_id, next_type)] = d_score
next_layer.append((next_id, next_type))
if next_type == 'MEMORY':
found_memories[next_id] = max(found_memories.get(next_id, 0), d_score)
current_layer = next_layer
# Global cap
if len(found_memories) >= 50:
break
return found_memories
def get_memories_by_entities(self, entity_names: List[str]) -> List[int]:
"""Finds memories that are directly linked to any of the given entity names."""
with tracer.start_as_current_span("reverie.graph.entity_lookup") as span:
span.set_attribute("graph.entity_count", len(entity_names))
if not entity_names: return []
cursor = self.db.get_cursor()
placeholders = ','.join(['?'] * len(entity_names))
query = f"""
SELECT DISTINCT source_id
FROM memory_relations
WHERE source_type = 'MEMORY'
AND target_type = 'ENTITY'
AND target_id IN (SELECT id FROM entities WHERE name IN ({placeholders}))
UNION
SELECT DISTINCT target_id
FROM memory_relations
WHERE target_type = 'MEMORY'
AND source_type = 'ENTITY'
AND source_id IN (SELECT id FROM entities WHERE name IN ({placeholders}))
"""
with self.db.trace_query("SELECT", "memory_relations", query, tuple(entity_names + entity_names)) as span:
cursor.execute(query, entity_names + entity_names)
return [row[0] for row in cursor.fetchall()]
def get_neighbors_summaries(self, memory_ids: List[int]) -> Dict[int, str]:
"""Returns a mapping of memory_id -> neighbors summary string."""
if not memory_ids:
return {}
with tracer.start_as_current_span("reverie.graph.batch_summary") as span:
cursor = self.db.get_cursor()
placeholders = ",".join(["?"] * len(memory_ids))
query = f"""
SELECT ma.source_id, e.name, e.label, ma.relation_type
FROM memory_relations ma
JOIN entities e ON ma.target_id = e.id AND ma.target_type = 'ENTITY'
WHERE ma.source_id IN ({placeholders}) AND ma.source_type = 'MEMORY'
"""
try:
params = tuple(memory_ids)
with self.db.trace_query("SELECT", "memory_relations", query, params) as sql_span:
cursor.execute(query, params)
rows = cursor.fetchall()
# Group by source_id
grouped = {}
for mid, name, label, rel_type in rows:
link = f"[{label}: {name} ({rel_type})]"
grouped.setdefault(mid, []).append(link)
# Format into strings
return {mid: " Linked Entities: " + ", ".join(links) for mid, links in grouped.items()}
except Exception as e:
logger.debug(f"Failed to fetch batch neighbors summary: {e}")
return {}