diff --git a/crates/ruvector-postgres/src/gnn/operators.rs b/crates/ruvector-postgres/src/gnn/operators.rs index 8e967edca..a5187073a 100644 --- a/crates/ruvector-postgres/src/gnn/operators.rs +++ b/crates/ruvector-postgres/src/gnn/operators.rs @@ -285,15 +285,16 @@ mod tests { #[pg_test] fn test_ruvector_gcn_forward() { - let embeddings = vec![vec![1.0, 2.0], vec![3.0, 4.0], vec![5.0, 6.0]]; + let embeddings = JsonB(serde_json::json!([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])); let src = vec![0, 1, 2]; let dst = vec![1, 2, 0]; let result = ruvector_gcn_forward(embeddings, src, dst, None, 2); + let result_arr = result.0.as_array().unwrap(); - assert_eq!(result.len(), 3); - assert_eq!(result[0].len(), 2); + assert_eq!(result_arr.len(), 3); + assert_eq!(result_arr[0].as_array().unwrap().len(), 2); } #[pg_test] @@ -325,15 +326,16 @@ mod tests { #[pg_test] fn test_ruvector_graphsage_forward() { - let embeddings = vec![vec![1.0, 2.0], vec![3.0, 4.0], vec![5.0, 6.0]]; + let embeddings = JsonB(serde_json::json!([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])); let src = vec![0, 1, 2]; let dst = vec![1, 2, 0]; let result = ruvector_graphsage_forward(embeddings, src, dst, 2, 2); + let result_arr = result.0.as_array().unwrap(); - assert_eq!(result.len(), 3); - assert_eq!(result[0].len(), 2); + assert_eq!(result_arr.len(), 3); + assert_eq!(result_arr[0].as_array().unwrap().len(), 2); } #[pg_test] @@ -352,24 +354,26 @@ mod tests { #[pg_test] fn test_empty_inputs() { - let empty_embeddings: Vec> = vec![]; + let empty_embeddings = JsonB(serde_json::json!([])); let empty_src: Vec = vec![]; let empty_dst: Vec = vec![]; let result = ruvector_gcn_forward(empty_embeddings, empty_src, empty_dst, None, 4); + let result_arr = result.0.as_array().unwrap(); - assert_eq!(result.len(), 0); + assert_eq!(result_arr.len(), 0); } #[pg_test] fn test_weighted_gcn() { - let embeddings = vec![vec![1.0, 2.0], vec![3.0, 4.0]]; + let embeddings = JsonB(serde_json::json!([[1.0, 2.0], [3.0, 4.0]])); let src = vec![0]; let dst = vec![1]; let weights = Some(vec![2.0]); let result = ruvector_gcn_forward(embeddings, src, dst, weights, 2); + let result_arr = result.0.as_array().unwrap(); - assert_eq!(result.len(), 2); + assert_eq!(result_arr.len(), 2); } } diff --git a/crates/ruvector-postgres/src/index/hnsw.rs b/crates/ruvector-postgres/src/index/hnsw.rs index d8927a6e9..fbd1505fd 100644 --- a/crates/ruvector-postgres/src/index/hnsw.rs +++ b/crates/ruvector-postgres/src/index/hnsw.rs @@ -523,6 +523,7 @@ mod tests { max_elements: 1000, metric: DistanceMetric::Euclidean, seed: 42, + max_layers: 16, }; let index = HnswIndex::new(3, config); @@ -583,6 +584,7 @@ mod tests { max_elements: 10000, metric: DistanceMetric::Euclidean, seed: 42, + max_layers: 16, }; let index = HnswIndex::new(dims, config); diff --git a/crates/ruvector-postgres/src/index/hnsw_am.rs b/crates/ruvector-postgres/src/index/hnsw_am.rs index b2fda9f85..d4edfc7e3 100644 --- a/crates/ruvector-postgres/src/index/hnsw_am.rs +++ b/crates/ruvector-postgres/src/index/hnsw_am.rs @@ -3,17 +3,18 @@ //! This module implements HNSW as a proper PostgreSQL index access method, //! storing the graph structure in PostgreSQL pages for persistence. +use pgrx::pg_sys::{ + self, bytea, BlockNumber, Buffer, Cost, Datum, IndexAmRoutine, IndexBuildResult, + IndexBulkDeleteCallback, IndexBulkDeleteResult, IndexInfo, IndexPath, IndexScanDesc, + IndexUniqueCheck, IndexVacuumInfo, ItemPointer, ItemPointerData, NodeTag, Page, PageHeaderData, + PlannerInfo, Relation, ScanDirection, ScanKey, Selectivity, Size, TIDBitmap, +}; use pgrx::prelude::*; -use pgrx::pg_sys::{self, Relation, IndexInfo, IndexBuildResult, IndexVacuumInfo, - IndexBulkDeleteResult, IndexBulkDeleteCallback, PlannerInfo, IndexPath, - Cost, Selectivity, IndexScanDesc, ScanDirection, TIDBitmap, ScanKey, - IndexUniqueCheck, ItemPointer, Datum, Buffer, BlockNumber, Page, - IndexAmRoutine, NodeTag, bytea, ItemPointerData, PageHeaderData, Size}; use pgrx::Internal; -use std::ptr; use std::mem::size_of; +use std::ptr; -use crate::distance::{DistanceMetric, distance}; +use crate::distance::{distance, DistanceMetric}; use crate::index::HnswConfig; // ============================================================================ @@ -31,11 +32,11 @@ const HNSW_PAGE_DELETED: u8 = 2; /// Maximum neighbors per node (aligned with default M) #[allow(dead_code)] -const MAX_NEIGHBORS_L0: usize = 32; // 2*M for layer 0 +const MAX_NEIGHBORS_L0: usize = 32; // 2*M for layer 0 #[allow(dead_code)] -const MAX_NEIGHBORS: usize = 16; // M for other layers +const MAX_NEIGHBORS: usize = 16; // M for other layers #[allow(dead_code)] -const MAX_LAYERS: usize = 16; // Maximum graph layers +const MAX_LAYERS: usize = 16; // Maximum graph layers /// P_NEW equivalent for allocating new pages const P_NEW_BLOCK: BlockNumber = pg_sys::InvalidBlockNumber; @@ -73,10 +74,10 @@ impl Default for HnswMetaPage { ef_construction: 64, entry_point: pg_sys::InvalidBlockNumber, max_layer: 0, - metric: 0, // L2 by default + metric: 0, // L2 by default _padding: 0, node_count: 0, - next_block: 1, // First node page + next_block: 1, // First node page } } } @@ -89,7 +90,7 @@ struct HnswNodePageHeader { #[allow(dead_code)] max_layer: u8, _padding: [u8; 2], - item_id: ItemPointerData, // TID of the heap tuple + item_id: ItemPointerData, // TID of the heap tuple } /// Neighbor entry in the graph @@ -137,7 +138,8 @@ unsafe fn get_meta_page(index_rel: Relation) -> (Page, Buffer) { unsafe fn get_or_create_meta_page(index_rel: Relation, for_write: bool) -> (Page, Buffer) { // Check if the relation has any blocks // Use MAIN_FORKNUM (0) for the main relation fork - let nblocks = pg_sys::RelationGetNumberOfBlocksInFork(index_rel, pg_sys::ForkNumber::MAIN_FORKNUM); + let nblocks = + pg_sys::RelationGetNumberOfBlocksInFork(index_rel, pg_sys::ForkNumber::MAIN_FORKNUM); let buffer = if nblocks == 0 { // New index - allocate first page using P_NEW (InvalidBlockNumber) @@ -166,7 +168,8 @@ unsafe fn read_metadata(page: Page) -> HnswMetaPage { /// Write metadata to page unsafe fn write_metadata(page: Page, meta: &HnswMetaPage) { let header = page as *mut PageHeaderData; - let data_ptr = (header as *mut u8).add(std::mem::size_of::()) as *mut HnswMetaPage; + let data_ptr = + (header as *mut u8).add(std::mem::size_of::()) as *mut HnswMetaPage; ptr::write(data_ptr, *meta); } @@ -259,7 +262,11 @@ unsafe fn calculate_distance( // Access Method Callbacks // ============================================================================ -/// Build callback - builds the index from scratch +/// Build callback - builds the HNSW index from scratch +/// +/// Extracts vector dimensions from the indexed column's type modifier. +/// The column must be declared with explicit dimensions, e.g., `ruvector(384)`. +/// Returns an error if dimensions are not specified. #[pg_guard] unsafe extern "C" fn hnsw_build( _heap: Relation, @@ -268,8 +275,32 @@ unsafe extern "C" fn hnsw_build( ) -> *mut IndexBuildResult { pgrx::log!("HNSW: Starting index build"); - // Parse index options - let dimensions = 128; // TODO: Extract from index definition + // Extract dimensions from the indexed column's type modifier + // When user defines ruvector(384), typmod = 384 + let dimensions = { + // RelationGetDescr(index) -> (*index).rd_att + let index_desc = (*index).rd_att; + if index_desc.is_null() || (*index_desc).natts < 1 { + pgrx::error!("HNSW: Cannot build index - no indexed columns found"); + } + + // TupleDescAttr(desc, 0) -> (*desc).attrs.as_ptr().add(0) + let attr = (*index_desc).attrs.as_ptr().add(0); + let typmod = (*attr).atttypmod; + + if typmod > 0 { + typmod as u32 + } else { + // typmod = -1 means dimensions not specified in type declaration + // This happens with: CREATE TABLE t (v ruvector) instead of ruvector(384) + pgrx::error!( + "HNSW: Vector column must have dimensions specified. \ + Use ruvector(dimensions) instead of ruvector, e.g., ruvector(384)" + ); + } + }; + + pgrx::log!("HNSW: Building index with {} dimensions", dimensions); let config = HnswConfig::default(); // Initialize metadata page @@ -298,7 +329,10 @@ unsafe extern "C" fn hnsw_build( // This is a simplified version - full implementation would use IndexBuildHeapScan let tuple_count = 0.0; - pgrx::log!("HNSW: Index build complete, {} tuples indexed", tuple_count as u64); + pgrx::log!( + "HNSW: Index build complete, {} tuples indexed", + tuple_count as u64 + ); // Return build result let mut result = PgBox::::alloc0(); @@ -308,15 +342,38 @@ unsafe extern "C" fn hnsw_build( } /// Build empty index callback +/// +/// Creates an empty HNSW index with proper dimensions from the column's type modifier. #[pg_guard] unsafe extern "C" fn hnsw_buildempty(index: Relation) { pgrx::log!("HNSW: Building empty index"); + // Extract dimensions from the indexed column's type modifier + let dimensions = { + // RelationGetDescr(index) -> (*index).rd_att + let index_desc = (*index).rd_att; + if !index_desc.is_null() && (*index_desc).natts >= 1 { + // TupleDescAttr(desc, 0) -> (*desc).attrs.as_ptr().add(0) + let attr = (*index_desc).attrs.as_ptr().add(0); + let typmod = (*attr).atttypmod; + if typmod > 0 { + typmod as u32 + } else { + 0 + } + } else { + 0 + } + }; + // Initialize metadata page only let (page, buffer) = get_or_create_meta_page(index, true); pg_sys::PageInit(page, pg_sys::BLCKSZ as Size, 0); - let meta = HnswMetaPage::default(); + let meta = HnswMetaPage { + dimensions, + ..Default::default() + }; write_metadata(page, &meta); pg_sys::MarkBufferDirty(buffer); @@ -410,12 +467,12 @@ unsafe extern "C" fn hnsw_costestimate( // Total cost is O(log n) for HNSW let log_tuples = tuples.max(1.0).ln(); - *index_total_cost = log_tuples * 10.0; // Scale factor for page accesses + *index_total_cost = log_tuples * 10.0; // Scale factor for page accesses // HNSW provides good selectivity for top-k queries - *index_selectivity = 0.01; // Typically returns ~1% of tuples - *index_correlation = 0.0; // No correlation with physical order - *index_pages = (tuples / 100.0).max(1.0); // Rough estimate + *index_selectivity = 0.01; // Typically returns ~1% of tuples + *index_correlation = 0.0; // No correlation with physical order + *index_pages = (tuples / 100.0).max(1.0); // Rough estimate } /// Get tuple callback (for index scans) @@ -480,10 +537,7 @@ unsafe extern "C" fn hnsw_canreturn(_index: Relation, attno: ::std::os::raw::c_i /// Options callback - parse index options #[pg_guard] -unsafe extern "C" fn hnsw_options( - _reloptions: Datum, - _validate: bool, -) -> *mut bytea { +unsafe extern "C" fn hnsw_options(_reloptions: Datum, _validate: bool) -> *mut bytea { pgrx::log!("HNSW: Parsing options"); // TODO: Parse m, ef_construction, metric from reloptions @@ -501,14 +555,14 @@ static HNSW_AM_HANDLER: IndexAmRoutine = IndexAmRoutine { type_: NodeTag::T_IndexAmRoutine, // Index structure capabilities - amstrategies: 1, // One strategy: nearest neighbor - amsupport: 1, // One support function: distance + amstrategies: 1, // One strategy: nearest neighbor + amsupport: 1, // One support function: distance amoptsprocnum: 0, amcanorder: false, - amcanorderbyop: true, // Supports ORDER BY with distance operators + amcanorderbyop: true, // Supports ORDER BY with distance operators amcanbackward: false, amcanunique: false, - amcanmulticol: false, // Single column only (vector) + amcanmulticol: false, // Single column only (vector) amoptionalkey: true, amsearcharray: false, amsearchnulls: false, diff --git a/crates/ruvector-postgres/src/learning/patterns.rs b/crates/ruvector-postgres/src/learning/patterns.rs index a0d3c28df..2084c7435 100644 --- a/crates/ruvector-postgres/src/learning/patterns.rs +++ b/crates/ruvector-postgres/src/learning/patterns.rs @@ -355,10 +355,9 @@ mod tests { let extractor = PatternExtractor::new(2); // Consistent trajectories - let trajs: Vec<&QueryTrajectory> = vec![ - &QueryTrajectory::new(vec![1.0], vec![1], 1000, 50, 10), - &QueryTrajectory::new(vec![1.0], vec![1], 1000, 50, 10), - ]; + let traj1 = QueryTrajectory::new(vec![1.0], vec![1], 1000, 50, 10); + let traj2 = QueryTrajectory::new(vec![1.0], vec![1], 1000, 50, 10); + let trajs: Vec<&QueryTrajectory> = vec![&traj1, &traj2]; let confidence = extractor.calculate_confidence(&trajs); assert!(confidence > 0.0 && confidence <= 1.0);