Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 14 additions & 10 deletions crates/ruvector-postgres/src/gnn/operators.rs
Original file line number Diff line number Diff line change
Expand Up @@ -285,15 +285,16 @@ mod tests {

#[pg_test]
fn test_ruvector_gcn_forward() {
let embeddings = vec![vec![1.0, 2.0], vec![3.0, 4.0], vec![5.0, 6.0]];
let embeddings = JsonB(serde_json::json!([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]));

let src = vec![0, 1, 2];
let dst = vec![1, 2, 0];

let result = ruvector_gcn_forward(embeddings, src, dst, None, 2);
let result_arr = result.0.as_array().unwrap();

assert_eq!(result.len(), 3);
assert_eq!(result[0].len(), 2);
assert_eq!(result_arr.len(), 3);
assert_eq!(result_arr[0].as_array().unwrap().len(), 2);
}

#[pg_test]
Expand Down Expand Up @@ -325,15 +326,16 @@ mod tests {

#[pg_test]
fn test_ruvector_graphsage_forward() {
let embeddings = vec![vec![1.0, 2.0], vec![3.0, 4.0], vec![5.0, 6.0]];
let embeddings = JsonB(serde_json::json!([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]));

let src = vec![0, 1, 2];
let dst = vec![1, 2, 0];

let result = ruvector_graphsage_forward(embeddings, src, dst, 2, 2);
let result_arr = result.0.as_array().unwrap();

assert_eq!(result.len(), 3);
assert_eq!(result[0].len(), 2);
assert_eq!(result_arr.len(), 3);
assert_eq!(result_arr[0].as_array().unwrap().len(), 2);
}

#[pg_test]
Expand All @@ -352,24 +354,26 @@ mod tests {

#[pg_test]
fn test_empty_inputs() {
let empty_embeddings: Vec<Vec<f32>> = vec![];
let empty_embeddings = JsonB(serde_json::json!([]));
let empty_src: Vec<i32> = vec![];
let empty_dst: Vec<i32> = vec![];

let result = ruvector_gcn_forward(empty_embeddings, empty_src, empty_dst, None, 4);
let result_arr = result.0.as_array().unwrap();

assert_eq!(result.len(), 0);
assert_eq!(result_arr.len(), 0);
}

#[pg_test]
fn test_weighted_gcn() {
let embeddings = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
let embeddings = JsonB(serde_json::json!([[1.0, 2.0], [3.0, 4.0]]));
let src = vec![0];
let dst = vec![1];
let weights = Some(vec![2.0]);

let result = ruvector_gcn_forward(embeddings, src, dst, weights, 2);
let result_arr = result.0.as_array().unwrap();

assert_eq!(result.len(), 2);
assert_eq!(result_arr.len(), 2);
}
}
2 changes: 2 additions & 0 deletions crates/ruvector-postgres/src/index/hnsw.rs
Original file line number Diff line number Diff line change
Expand Up @@ -523,6 +523,7 @@ mod tests {
max_elements: 1000,
metric: DistanceMetric::Euclidean,
seed: 42,
max_layers: 16,
};

let index = HnswIndex::new(3, config);
Expand Down Expand Up @@ -583,6 +584,7 @@ mod tests {
max_elements: 10000,
metric: DistanceMetric::Euclidean,
seed: 42,
max_layers: 16,
};

let index = HnswIndex::new(dims, config);
Expand Down
118 changes: 86 additions & 32 deletions crates/ruvector-postgres/src/index/hnsw_am.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,18 @@
//! This module implements HNSW as a proper PostgreSQL index access method,
//! storing the graph structure in PostgreSQL pages for persistence.

use pgrx::pg_sys::{
self, bytea, BlockNumber, Buffer, Cost, Datum, IndexAmRoutine, IndexBuildResult,
IndexBulkDeleteCallback, IndexBulkDeleteResult, IndexInfo, IndexPath, IndexScanDesc,
IndexUniqueCheck, IndexVacuumInfo, ItemPointer, ItemPointerData, NodeTag, Page, PageHeaderData,
PlannerInfo, Relation, ScanDirection, ScanKey, Selectivity, Size, TIDBitmap,
};
use pgrx::prelude::*;
use pgrx::pg_sys::{self, Relation, IndexInfo, IndexBuildResult, IndexVacuumInfo,
IndexBulkDeleteResult, IndexBulkDeleteCallback, PlannerInfo, IndexPath,
Cost, Selectivity, IndexScanDesc, ScanDirection, TIDBitmap, ScanKey,
IndexUniqueCheck, ItemPointer, Datum, Buffer, BlockNumber, Page,
IndexAmRoutine, NodeTag, bytea, ItemPointerData, PageHeaderData, Size};
use pgrx::Internal;
use std::ptr;
use std::mem::size_of;
use std::ptr;

use crate::distance::{DistanceMetric, distance};
use crate::distance::{distance, DistanceMetric};
use crate::index::HnswConfig;

// ============================================================================
Expand All @@ -31,11 +32,11 @@ const HNSW_PAGE_DELETED: u8 = 2;

/// Maximum neighbors per node (aligned with default M)
#[allow(dead_code)]
const MAX_NEIGHBORS_L0: usize = 32; // 2*M for layer 0
const MAX_NEIGHBORS_L0: usize = 32; // 2*M for layer 0
#[allow(dead_code)]
const MAX_NEIGHBORS: usize = 16; // M for other layers
const MAX_NEIGHBORS: usize = 16; // M for other layers
#[allow(dead_code)]
const MAX_LAYERS: usize = 16; // Maximum graph layers
const MAX_LAYERS: usize = 16; // Maximum graph layers

/// P_NEW equivalent for allocating new pages
const P_NEW_BLOCK: BlockNumber = pg_sys::InvalidBlockNumber;
Expand Down Expand Up @@ -73,10 +74,10 @@ impl Default for HnswMetaPage {
ef_construction: 64,
entry_point: pg_sys::InvalidBlockNumber,
max_layer: 0,
metric: 0, // L2 by default
metric: 0, // L2 by default
_padding: 0,
node_count: 0,
next_block: 1, // First node page
next_block: 1, // First node page
}
}
}
Expand All @@ -89,7 +90,7 @@ struct HnswNodePageHeader {
#[allow(dead_code)]
max_layer: u8,
_padding: [u8; 2],
item_id: ItemPointerData, // TID of the heap tuple
item_id: ItemPointerData, // TID of the heap tuple
}

/// Neighbor entry in the graph
Expand Down Expand Up @@ -137,7 +138,8 @@ unsafe fn get_meta_page(index_rel: Relation) -> (Page, Buffer) {
unsafe fn get_or_create_meta_page(index_rel: Relation, for_write: bool) -> (Page, Buffer) {
// Check if the relation has any blocks
// Use MAIN_FORKNUM (0) for the main relation fork
let nblocks = pg_sys::RelationGetNumberOfBlocksInFork(index_rel, pg_sys::ForkNumber::MAIN_FORKNUM);
let nblocks =
pg_sys::RelationGetNumberOfBlocksInFork(index_rel, pg_sys::ForkNumber::MAIN_FORKNUM);

let buffer = if nblocks == 0 {
// New index - allocate first page using P_NEW (InvalidBlockNumber)
Expand Down Expand Up @@ -166,7 +168,8 @@ unsafe fn read_metadata(page: Page) -> HnswMetaPage {
/// Write metadata to page
unsafe fn write_metadata(page: Page, meta: &HnswMetaPage) {
let header = page as *mut PageHeaderData;
let data_ptr = (header as *mut u8).add(std::mem::size_of::<PageHeaderData>()) as *mut HnswMetaPage;
let data_ptr =
(header as *mut u8).add(std::mem::size_of::<PageHeaderData>()) as *mut HnswMetaPage;
ptr::write(data_ptr, *meta);
}

Expand Down Expand Up @@ -259,7 +262,11 @@ unsafe fn calculate_distance(
// Access Method Callbacks
// ============================================================================

/// Build callback - builds the index from scratch
/// Build callback - builds the HNSW index from scratch
///
/// Extracts vector dimensions from the indexed column's type modifier.
/// The column must be declared with explicit dimensions, e.g., `ruvector(384)`.
/// Returns an error if dimensions are not specified.
#[pg_guard]
unsafe extern "C" fn hnsw_build(
_heap: Relation,
Expand All @@ -268,8 +275,32 @@ unsafe extern "C" fn hnsw_build(
) -> *mut IndexBuildResult {
pgrx::log!("HNSW: Starting index build");

// Parse index options
let dimensions = 128; // TODO: Extract from index definition
// Extract dimensions from the indexed column's type modifier
// When user defines ruvector(384), typmod = 384
let dimensions = {
// RelationGetDescr(index) -> (*index).rd_att
let index_desc = (*index).rd_att;
if index_desc.is_null() || (*index_desc).natts < 1 {
pgrx::error!("HNSW: Cannot build index - no indexed columns found");
}

// TupleDescAttr(desc, 0) -> (*desc).attrs.as_ptr().add(0)
let attr = (*index_desc).attrs.as_ptr().add(0);
let typmod = (*attr).atttypmod;

if typmod > 0 {
typmod as u32
} else {
// typmod = -1 means dimensions not specified in type declaration
// This happens with: CREATE TABLE t (v ruvector) instead of ruvector(384)
pgrx::error!(
"HNSW: Vector column must have dimensions specified. \
Use ruvector(dimensions) instead of ruvector, e.g., ruvector(384)"
);
}
};

pgrx::log!("HNSW: Building index with {} dimensions", dimensions);
let config = HnswConfig::default();

// Initialize metadata page
Expand Down Expand Up @@ -298,7 +329,10 @@ unsafe extern "C" fn hnsw_build(
// This is a simplified version - full implementation would use IndexBuildHeapScan
let tuple_count = 0.0;

pgrx::log!("HNSW: Index build complete, {} tuples indexed", tuple_count as u64);
pgrx::log!(
"HNSW: Index build complete, {} tuples indexed",
tuple_count as u64
);

// Return build result
let mut result = PgBox::<IndexBuildResult>::alloc0();
Expand All @@ -308,15 +342,38 @@ unsafe extern "C" fn hnsw_build(
}

/// Build empty index callback
///
/// Creates an empty HNSW index with proper dimensions from the column's type modifier.
#[pg_guard]
unsafe extern "C" fn hnsw_buildempty(index: Relation) {
pgrx::log!("HNSW: Building empty index");

// Extract dimensions from the indexed column's type modifier
let dimensions = {
// RelationGetDescr(index) -> (*index).rd_att
let index_desc = (*index).rd_att;
if !index_desc.is_null() && (*index_desc).natts >= 1 {
// TupleDescAttr(desc, 0) -> (*desc).attrs.as_ptr().add(0)
let attr = (*index_desc).attrs.as_ptr().add(0);
let typmod = (*attr).atttypmod;
if typmod > 0 {
typmod as u32
} else {
0
}
} else {
0
}
};

// Initialize metadata page only
let (page, buffer) = get_or_create_meta_page(index, true);
pg_sys::PageInit(page, pg_sys::BLCKSZ as Size, 0);

let meta = HnswMetaPage::default();
let meta = HnswMetaPage {
dimensions,
..Default::default()
};
write_metadata(page, &meta);

pg_sys::MarkBufferDirty(buffer);
Expand Down Expand Up @@ -410,12 +467,12 @@ unsafe extern "C" fn hnsw_costestimate(

// Total cost is O(log n) for HNSW
let log_tuples = tuples.max(1.0).ln();
*index_total_cost = log_tuples * 10.0; // Scale factor for page accesses
*index_total_cost = log_tuples * 10.0; // Scale factor for page accesses

// HNSW provides good selectivity for top-k queries
*index_selectivity = 0.01; // Typically returns ~1% of tuples
*index_correlation = 0.0; // No correlation with physical order
*index_pages = (tuples / 100.0).max(1.0); // Rough estimate
*index_selectivity = 0.01; // Typically returns ~1% of tuples
*index_correlation = 0.0; // No correlation with physical order
*index_pages = (tuples / 100.0).max(1.0); // Rough estimate
}

/// Get tuple callback (for index scans)
Expand Down Expand Up @@ -480,10 +537,7 @@ unsafe extern "C" fn hnsw_canreturn(_index: Relation, attno: ::std::os::raw::c_i

/// Options callback - parse index options
#[pg_guard]
unsafe extern "C" fn hnsw_options(
_reloptions: Datum,
_validate: bool,
) -> *mut bytea {
unsafe extern "C" fn hnsw_options(_reloptions: Datum, _validate: bool) -> *mut bytea {
pgrx::log!("HNSW: Parsing options");

// TODO: Parse m, ef_construction, metric from reloptions
Expand All @@ -501,14 +555,14 @@ static HNSW_AM_HANDLER: IndexAmRoutine = IndexAmRoutine {
type_: NodeTag::T_IndexAmRoutine,

// Index structure capabilities
amstrategies: 1, // One strategy: nearest neighbor
amsupport: 1, // One support function: distance
amstrategies: 1, // One strategy: nearest neighbor
amsupport: 1, // One support function: distance
amoptsprocnum: 0,
amcanorder: false,
amcanorderbyop: true, // Supports ORDER BY with distance operators
amcanorderbyop: true, // Supports ORDER BY with distance operators
amcanbackward: false,
amcanunique: false,
amcanmulticol: false, // Single column only (vector)
amcanmulticol: false, // Single column only (vector)
amoptionalkey: true,
amsearcharray: false,
amsearchnulls: false,
Expand Down
7 changes: 3 additions & 4 deletions crates/ruvector-postgres/src/learning/patterns.rs
Original file line number Diff line number Diff line change
Expand Up @@ -355,10 +355,9 @@ mod tests {
let extractor = PatternExtractor::new(2);

// Consistent trajectories
let trajs: Vec<&QueryTrajectory> = vec![
&QueryTrajectory::new(vec![1.0], vec![1], 1000, 50, 10),
&QueryTrajectory::new(vec![1.0], vec![1], 1000, 50, 10),
];
let traj1 = QueryTrajectory::new(vec![1.0], vec![1], 1000, 50, 10);
let traj2 = QueryTrajectory::new(vec![1.0], vec![1], 1000, 50, 10);
let trajs: Vec<&QueryTrajectory> = vec![&traj1, &traj2];

let confidence = extractor.calculate_confidence(&trajs);
assert!(confidence > 0.0 && confidence <= 1.0);
Expand Down
Loading