diff --git a/src/agent_memory/storage/sqlite.py b/src/agent_memory/storage/sqlite.py index 3cb111a..fd6fd69 100644 --- a/src/agent_memory/storage/sqlite.py +++ b/src/agent_memory/storage/sqlite.py @@ -111,30 +111,34 @@ def initialize_database(db_path: Path | str) -> None: connection.execute( "CREATE INDEX IF NOT EXISTS idx_episodes_status_scope_importance ON episodes(status, scope, importance_score)" ) - connection.execute( - """ - CREATE TABLE IF NOT EXISTS retrieval_observations ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - created_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP, - surface TEXT NOT NULL, - query_sha256 TEXT NOT NULL, - query_preview TEXT, - preferred_scope TEXT, - limit_value INTEGER NOT NULL, - statuses_json TEXT NOT NULL DEFAULT '["approved"]', - retrieved_memory_refs_json TEXT NOT NULL DEFAULT '[]', - top_memory_ref TEXT, - response_mode TEXT CHECK (response_mode IN ('direct', 'cautious', 'verify_first')), - metadata_json TEXT NOT NULL DEFAULT '{}' - ) - """ - ) - connection.execute( - "CREATE INDEX IF NOT EXISTS idx_retrieval_observations_created_at ON retrieval_observations(created_at, id)" - ) - connection.execute( - "CREATE INDEX IF NOT EXISTS idx_retrieval_observations_surface ON retrieval_observations(surface, created_at)" + _ensure_retrieval_observations_schema(connection) + + +def _ensure_retrieval_observations_schema(connection: sqlite3.Connection) -> None: + connection.execute( + """ + CREATE TABLE IF NOT EXISTS retrieval_observations ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + created_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP, + surface TEXT NOT NULL, + query_sha256 TEXT NOT NULL, + query_preview TEXT, + preferred_scope TEXT, + limit_value INTEGER NOT NULL, + statuses_json TEXT NOT NULL DEFAULT '["approved"]', + retrieved_memory_refs_json TEXT NOT NULL DEFAULT '[]', + top_memory_ref TEXT, + response_mode TEXT CHECK (response_mode IN ('direct', 'cautious', 'verify_first')), + metadata_json TEXT NOT NULL DEFAULT '{}' ) + """ + ) + connection.execute( + "CREATE INDEX IF NOT EXISTS idx_retrieval_observations_created_at ON retrieval_observations(created_at, id)" + ) + connection.execute( + "CREATE INDEX IF NOT EXISTS idx_retrieval_observations_surface ON retrieval_observations(surface, created_at)" + ) def _ensure_memory_table_columns( @@ -829,6 +833,7 @@ def record_retrieval_observation( top_memory_ref = retrieved_memory_refs[0] if retrieved_memory_refs else None query_sha256 = hashlib.sha256(query.encode("utf-8")).hexdigest() with connect(db_path) as connection: + _ensure_retrieval_observations_schema(connection) cursor = connection.execute( """ INSERT INTO retrieval_observations ( @@ -864,6 +869,7 @@ def record_retrieval_observation( def list_retrieval_observations(db_path: Path | str, *, limit: int = 50) -> list[RetrievalObservation]: with connect(db_path) as connection: + _ensure_retrieval_observations_schema(connection) rows = connection.execute( """ SELECT * diff --git a/tests/test_cli.py b/tests/test_cli.py index b627bfd..87d4577 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -166,6 +166,32 @@ def test_python_module_cli_retrieve_observe_records_secret_safe_local_observatio assert "abc123" not in list_result.stdout +def test_python_module_cli_observations_list_migrates_existing_database_without_observation_table(tmp_path: Path) -> None: + db_path = tmp_path / "legacy-observation.db" + initialize_database(db_path) + import sqlite3 + + with sqlite3.connect(db_path) as connection: + connection.execute("DROP TABLE retrieval_observations") + + result = subprocess.run( + [ + sys.executable, + "-m", + "agent_memory.api.cli", + "observations", + "list", + str(db_path), + ], + text=True, + capture_output=True, + check=True, + ) + + payload = json.loads(result.stdout) + assert payload["observations"] == [] + + def test_python_module_cli_retrieve_defaults_to_approved_and_hides_disputed_content(tmp_path: Path) -> None: db_path = tmp_path / "retrieve-approved-only.db" initialize_database(db_path)