Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 24 additions & 4 deletions colgrep/src/index/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -492,7 +492,25 @@ fn run_encode_stage(
model: Colbert,
) -> Result<()> {
while let Ok(chunk) = receiver.recv() {
let raw_embeddings = model.encode_prepared_document_batches(chunk.prepared_batches)?;
// Stop pulling new work promptly on Ctrl+C instead of draining the whole
// in-flight queue (encoding is the slow stage).
if is_interrupted_outside_critical() {
break;
}

// Cancel mid-chunk too: the encoder checks this between batches, so a Ctrl+C
// lands within ~one model forward pass rather than after the whole chunk. The
// partial chunk is dropped uncommitted — `state.json` is only saved per
// completed checkpoint batch and `build_resumable` trims any partial write on
// the next run — so this stays safe and resumable, just near-immediate.
let cancel = || is_interrupted_outside_critical();
let raw_embeddings = match model
.encode_prepared_document_batches_cancellable(chunk.prepared_batches, Some(&cancel))
{
Ok(raw) => raw,
Err(_) if is_interrupted() => break,
Err(e) => return Err(e),
};

sender
.send(RawEncodedChunk {
Expand Down Expand Up @@ -806,7 +824,6 @@ fn run_metadata_stage(
pb: Option<ProgressBar>,
) -> Result<()> {
let mut filtering_exists = filtering::exists(&index_path);
let mut completed_units = 0u64;

while let Ok(chunk) = receiver.recv() {
let metadata: Vec<serde_json::Value> = chunk
Expand Down Expand Up @@ -839,9 +856,12 @@ fn run_metadata_stage(
}

filtering_exists = true;
completed_units += chunk.units.len() as u64;
// Advance the shared progress bar cumulatively. This stage runs once per
// checkpoint batch, so an absolute `set_position` from a per-batch counter
// would reset the bar to 0 every ~BUILD_CHECKPOINT_UNITS; `inc` accumulates
// across batches up to the whole-build total.
if let Some(pb) = pb.as_ref() {
pb.set_position(completed_units);
pb.inc(chunk.units.len() as u64);
}
}

Expand Down
19 changes: 14 additions & 5 deletions colgrep/src/signal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,20 @@ impl Drop for CriticalSectionGuard {
/// Returns an error if the handler cannot be set.
pub fn setup_signal_handler() -> Result<(), ctrlc::Error> {
ctrlc::set_handler(move || {
// Set the flag on first interrupt
if !INTERRUPTED.swap(true, Ordering::SeqCst)
&& CRITICAL_SECTION_DEPTH.load(Ordering::Relaxed) > 0
{
eprintln!("\n⚠️ Interrupt received, finishing current write operation...");
// Acknowledge the first interrupt immediately so the user gets feedback.
// Indexing stops at the next safe checkpoint; finished batches are persisted
// and the build resumes on the next run. (Previously this message only
// printed inside a critical section, so a Ctrl+C during the long encoding
// phase looked like it did nothing.)
if !INTERRUPTED.swap(true, Ordering::SeqCst) {
if CRITICAL_SECTION_DEPTH.load(Ordering::Relaxed) > 0 {
eprintln!("\n⚠️ Interrupt received — finishing the current write, then stopping…");
} else {
eprintln!(
"\n⚠️ Interrupt received — stopping at the next checkpoint. \
Finished batches are kept; rerun to resume."
);
}
}
})
}
Expand Down
26 changes: 26 additions & 0 deletions next-plaid-onnx/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1260,10 +1260,26 @@ impl Colbert {
pub fn encode_prepared_document_batches(
&self,
prepared_batches: Vec<PreparedDocumentBatch>,
) -> Result<Vec<Array2<f32>>> {
self.encode_prepared_document_batches_cancellable(prepared_batches, None)
}

/// Like [`encode_prepared_document_batches`], but checks `cancel` before each
/// batch's forward pass. If `cancel()` returns true, encoding stops promptly
/// (within ~one in-flight `session.run`) and returns `Err`. Callers that drive
/// interruptible work (e.g. colgrep indexing) pass a flag check here to get
/// near-immediate Ctrl+C response instead of finishing the whole chunk.
///
/// [`encode_prepared_document_batches`]: Self::encode_prepared_document_batches
pub fn encode_prepared_document_batches_cancellable(
&self,
prepared_batches: Vec<PreparedDocumentBatch>,
cancel: Option<&(dyn Fn() -> bool + Sync)>,
) -> Result<Vec<Array2<f32>>> {
if prepared_batches.is_empty() {
return Ok(Vec::new());
}
let cancelled = || cancel.map(|c| c()).unwrap_or(false);

// Collect the original-input position for every document across all
// batches in the order they appear here. When `tokenize_documents_in_batches`
Expand All @@ -1284,10 +1300,14 @@ impl Colbert {
let encoded: Vec<Array2<f32>> = if self.sessions.len() <= 1 || prepared_batches.len() == 1 {
let mut all_embeddings = Vec::new();
for prepared_batch in prepared_batches {
if cancelled() {
anyhow::bail!("encoding cancelled");
}
all_embeddings.extend(self.encode_prepared_documents(prepared_batch)?);
}
all_embeddings
} else {
let cancel_ref = cancel;
let results: Vec<Result<Vec<Array2<f32>>>> = std::thread::scope(|scope| {
let mut handles = Vec::with_capacity(prepared_batches.len());

Expand All @@ -1298,6 +1318,12 @@ impl Colbert {
let skiplist_ids = &self.skiplist_ids;

handles.push(scope.spawn(move || {
// Skip not-yet-started batches once cancelled; already-running
// forward passes finish (one `session.run`), so the whole call
// returns within ~one batch of a Ctrl+C.
if cancel_ref.map(|c| c()).unwrap_or(false) {
anyhow::bail!("encoding cancelled");
}
let mut session = session_mutex.lock().unwrap();
encode_prepared_batch_with_session(
&mut session,
Expand Down
Loading