-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathcli.py
More file actions
310 lines (251 loc) · 10.5 KB
/
cli.py
File metadata and controls
310 lines (251 loc) · 10.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
"""RTIE CLI — Test the semantic search pipeline directly.
Usage:
python cli.py index Index loader-validated functions (default; safe)
python cli.py index --force Re-embed loader-validated functions (skip cache)
python cli.py index --from-disk Walk db/modules/* on disk (W93b opt-in)
python cli.py index --from-disk --force Same, forcing re-embed
python cli.py status Check index status
python cli.py ask "your question" Ask a question
Notes:
The default `index` path reads loader-populated ``graph:<schema>:<fn>``
keys from Redis (same path the backend lifespan uses at startup) — it
requires the loader to have run at least once. Start the backend via
``python run.py`` before first use.
``--from-disk`` walks ``db/modules/*`` directly and embeds every .sql
file, including functions the loader rejected. Pre-W93b default — kept
as an opt-in for rebuilds outside the loader's view; not recommended
for normal use.
"""
import asyncio
import json
import os
import sys
from dotenv import load_dotenv
load_dotenv(".env.dev")
async def get_clients():
"""Initialize Redis clients."""
from src.tools.vector_store import VectorStore
from src.tools.cache_tools import CacheClient
vs = VectorStore(
host=os.getenv("REDIS_HOST", "localhost"),
port=int(os.getenv("REDIS_PORT", "6379")),
)
await vs.connect()
await vs.ensure_index()
cache = CacheClient(
host=os.getenv("REDIS_HOST", "localhost"),
port=int(os.getenv("REDIS_PORT", "6379")),
key_prefix="rtie",
)
await cache.connect()
return vs, cache
async def cmd_index(force: bool = False, from_disk: bool = False):
"""Index functions for semantic search.
Default (`from_disk=False`) — calls :meth:`IndexerAgent.index_all_loaded`,
the same Phase-3 path the backend lifespan uses ([main.py:562](src/main.py)).
Scans ``graph:<schema>:<fn>`` keys in Redis so the indexed corpus exactly
matches the corpus the rest of RTIE serves answers from. Loader must have
run at least once (start the backend via ``python run.py`` first).
``from_disk=True`` (W93b opt-in) — calls :meth:`IndexerAgent.index_all_modules`,
the pre-W93b path that walks ``db/modules/*`` and embeds every .sql file,
including functions the loader rejected. Kept as an escape hatch for
rebuilds outside the loader's view; not the recommended default.
"""
from src.agents.indexer import IndexerAgent
vs, _ = await get_clients()
indexer = IndexerAgent(
vector_store=vs,
embedding_model=os.getenv("EMBEDDING_MODEL", "text-embedding-3-small"),
llm_provider="openai",
llm_model=os.getenv("OPENAI_MODEL", "gpt-4o-mini"),
)
if from_disk:
print(
"Indexing db/modules/* from disk "
"(--from-disk; includes loader-rejected files)..."
)
result = await indexer.index_all_modules(force=force)
for module, info in result.get("results", {}).items():
print(f"\n Module: {module}")
print(f" Total: {info.get('total_functions', 0)}")
print(f" Indexed: {info.get('indexed', 0)}")
print(f" Skipped: {info.get('skipped', 0)}")
print(f" Errors: {info.get('errors', 0)}")
if info.get("indexed_functions"):
print(f" Indexed: {', '.join(info['indexed_functions'])}")
if info.get("error_details"):
for e in info["error_details"]:
print(f" ERROR: {e['name']} — {e['error']}")
await vs.close()
return
# Default path: loader-validated.
import redis as _redis
graph_redis = _redis.Redis(
host=os.getenv("REDIS_HOST", "localhost"),
port=int(os.getenv("REDIS_PORT", "6379")),
)
print("Indexing loader-validated functions (graph:<schema>:<fn>)...")
result = await indexer.index_all_loaded(
graph_redis_client=graph_redis,
force=force,
)
per_schema = result.get("results") or {}
if not per_schema:
print(
"\n No schemas discovered — no graph:<schema>:<fn> keys in Redis.\n"
" Run the backend at least once (`python run.py`) to load functions,\n"
" or use `python cli.py index --from-disk` to walk db/modules/* directly."
)
else:
for schema, info in per_schema.items():
print(
f"\n Auto-index {schema}: "
f"{info.get('indexed', 0)} indexed, "
f"{info.get('skipped', 0)} skipped, "
f"{info.get('errors', 0)} errors"
)
try:
graph_redis.close()
except Exception:
pass
await vs.close()
async def cmd_status():
"""Check index status."""
vs, _ = await get_clients()
stats = await vs.get_index_stats()
print(f"Index: {stats.get('index_name', 'N/A')}")
print(f"Documents: {stats.get('num_docs', 0)}")
print(f"Records: {stats.get('num_records', 0)}")
functions = await vs.list_indexed_functions()
if functions:
print(f"\nIndexed functions ({len(functions)}):")
for fn in sorted(functions):
print(f" - {fn}")
await vs.close()
async def cmd_ask(question: str):
"""Ask a question using the semantic search pipeline."""
from langchain_openai import OpenAIEmbeddings
from src.llm_factory import create_llm
from src.agents.indexer import IndexerAgent
from langchain_core.messages import SystemMessage, HumanMessage
vs, cache = await get_clients()
# Step 1: Classify query
print(f"\n{'='*60}")
print(f"Question: {question}")
print(f"{'='*60}")
# Step 2: Embed and search (vector) + keyword boost
print("\n[1/4] Embedding query and searching Redis...")
embeddings = OpenAIEmbeddings(model=os.getenv("EMBEDDING_MODEL", "text-embedding-3-small"))
query_vec = await embeddings.aembed_query(question)
vector_results = await vs.search(query_embedding=query_vec, top_k=10)
# Keyword boost: re-rank results that mention query terms in description/columns
query_upper = question.upper()
query_words = [w for w in query_upper.split() if len(w) > 3]
for r in vector_results:
keyword_hits = 0
text = f"{r.get('description', '')} {r.get('key_columns', '')} {r.get('tables_written', '')}".upper()
for word in query_words:
if word in text:
keyword_hits += 1
# Lower score = better in cosine distance, so subtract bonus
r["boosted_score"] = r["score"] - (keyword_hits * 0.15)
results = sorted(vector_results, key=lambda r: r["boosted_score"])[:5]
if not results:
print(" No results found! Make sure functions are indexed (python cli.py index)")
await vs.close()
return
print(f" Found {len(results)} relevant functions:")
for r in results:
print(f" - {r['function_name']} (vec: {r['score']:.4f}, boosted: {r['boosted_score']:.4f})")
# Step 3: Fetch source code for each function
print("\n[2/4] Fetching source code...")
from src.agents.metadata_interpreter import _scan_modules_for_file, _read_sql_file
multi_source = {}
for r in results:
fn_name = r["function_name"]
filepath = _scan_modules_for_file(fn_name)
if filepath:
lines = _read_sql_file(filepath)
source_text = "".join(
line["text"] if isinstance(line, dict) else str(line) for line in lines
)
multi_source[fn_name] = {
"source": source_text,
"description": r.get("description", ""),
"tables_read": r.get("tables_read", ""),
"tables_written": r.get("tables_written", ""),
"score": r["score"],
}
print(f" {fn_name}: {len(lines)} lines loaded")
else:
print(f" {fn_name}: FILE NOT FOUND")
if not multi_source:
print(" No source code found!")
await vs.close()
return
# Step 4: Send to OpenAI gpt-4o for analysis
print(f"\n[3/4] Analyzing {len(multi_source)} functions via gpt-4o...")
llm = create_llm(provider="openai", model="gpt-4o", temperature=0, max_tokens=2000)
per_function_answers = []
for fn_name, data in multi_source.items():
# Send only description + first 100 lines to keep payload small
source_lines = data["source"].split("\n")
truncated = "\n".join(source_lines[:100])
if len(source_lines) > 100:
truncated += f"\n... ({len(source_lines) - 100} more lines truncated)"
prompt = (
f"Question: {question}\n\n"
f"Function: {fn_name}\n"
f"Description: {data['description']}\n"
f"Tables Read: {data['tables_read']}\n"
f"Tables Written: {data['tables_written']}\n\n"
f"Source (first 100 lines):\n{truncated}\n\n"
f"If this function is relevant to the question, explain how. "
f"If not relevant, say 'NOT RELEVANT' and nothing else."
)
try:
resp = await llm.ainvoke([HumanMessage(content=prompt)])
answer = resp.content.strip()
if "NOT RELEVANT" not in answer.upper():
per_function_answers.append(f"### {fn_name}\n{answer}")
print(f" {fn_name}: relevant")
else:
print(f" {fn_name}: not relevant (skipped)")
except Exception as e:
print(f" {fn_name}: ERROR — {e}")
# Combine answers
if per_function_answers:
combined = "\n\n".join(per_function_answers)
else:
combined = "None of the found functions appear directly relevant to the question."
class _Msg:
content = f"## Answer: {question}\n\n{combined}"
response = _Msg()
print(f"\n[4/4] Answer:")
print(f"{'='*60}")
print(response.content)
print(f"{'='*60}")
await vs.close()
await cache.close()
async def main():
args = sys.argv[1:]
if not args:
print(__doc__)
return
cmd = args[0]
if cmd == "index":
if "--help" in args or "-h" in args:
print(__doc__)
return
force = "--force" in args
from_disk = "--from-disk" in args
await cmd_index(force=force, from_disk=from_disk)
elif cmd == "status":
await cmd_status()
elif cmd == "ask" and len(args) > 1:
question = " ".join(args[1:])
await cmd_ask(question)
else:
print(__doc__)
if __name__ == "__main__":
asyncio.run(main())