From ae6e1d8ad6816ab6761ec9b911d9721e4d27b02c Mon Sep 17 00:00:00 2001 From: SpollaL Date: Thu, 23 Apr 2026 21:09:59 +0200 Subject: [PATCH] feat: parallel rule execution via tokio JoinSet MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Replaces sequential for-loop with tokio::task::JoinSet for concurrent rule execution - SessionContext wrapped in Arc and cloned into each spawned task - Results sorted back to original rule submission order after concurrent completion - run_rules_parallel() extracted as testable public function in runner.rs - Rule and Check derive Clone to support move into tasks - Zero CLI change — pure internal improvement; existing API unchanged - 3 new parallel-specific tests: all-pass, with-failure, order-preservation Co-Authored-By: Claude Sonnet 4.6 --- src/main.rs | 20 ++---- src/rules.rs | 4 +- src/runner.rs | 170 +++++++++++++++++++++++++++++++++++++++++-------- src/storage.rs | 4 +- 4 files changed, 155 insertions(+), 43 deletions(-) diff --git a/src/main.rs b/src/main.rs index fa02f08..63ab848 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,6 +1,7 @@ use anyhow::Context; use clap::Parser; use datafusion::prelude::*; +use std::sync::Arc; use tracing_subscriber::EnvFilter; mod output; @@ -10,12 +11,11 @@ mod storage; use output::OutputFormat; use rules::RulesFile; -use runner::run_rule; use storage::register_data; use crate::{ output::format_results, - runner::{run_sql, RuleResult, RuleStatus}, + runner::{run_rules_parallel, run_sql, RuleResult, RuleStatus}, }; #[derive(Parser)] @@ -63,7 +63,7 @@ async fn run(args: Cli) -> anyhow::Result<()> { let format: OutputFormat = args .format .context("Could not parse output format. Valid options are json or table")?; - let ctx = SessionContext::new(); + let ctx = Arc::new(SessionContext::new()); register_data(&ctx, &args.file).await?; let schema_cols: Vec = ctx .table("data") @@ -95,21 +95,13 @@ async fn run(args: Cli) -> anyhow::Result<()> { ); return Ok(()); } - let mut any_failed = false; let total_rows = run_sql(&ctx, "SELECT COUNT(*) FROM data".into()).await?; if total_rows == 0 { anyhow::bail!("Input file is empty"); } - let mut results: Vec = Vec::new(); - for rule in &rules.rules { - let result = run_rule(&ctx, rule, total_rows) - .await - .with_context(|| format!("Rule '{}' failed to execute", rule.name))?; - if matches!(result.status, RuleStatus::Fail) { - any_failed = true; - } - results.push(result); - } + let results: Vec = + run_rules_parallel(Arc::clone(&ctx), rules.rules, total_rows).await?; + let any_failed = results.iter().any(|r| matches!(r.status, RuleStatus::Fail)); let out = format_results(&results, &format); println!("{}", out); if any_failed { diff --git a/src/rules.rs b/src/rules.rs index 745b778..a2820e3 100644 --- a/src/rules.rs +++ b/src/rules.rs @@ -7,7 +7,7 @@ pub struct RulesFile { } /// A single data-quality rule targeting one column. -#[derive(Debug, Deserialize)] +#[derive(Debug, Clone, Deserialize)] pub struct Rule { /// Human-readable name shown in output. pub name: String, @@ -27,7 +27,7 @@ pub struct Rule { } /// The type of check to perform on a column. -#[derive(Debug, Deserialize)] +#[derive(Debug, Clone, Deserialize)] #[serde(rename_all = "snake_case")] pub enum Check { /// Fails if any value in the column is NULL. diff --git a/src/runner.rs b/src/runner.rs index 2e5ffcc..b9926f1 100644 --- a/src/runner.rs +++ b/src/runner.rs @@ -3,6 +3,7 @@ use anyhow::Context; use datafusion::arrow::array::Int64Array; use datafusion::prelude::*; use serde::Serialize; +use std::sync::Arc; #[derive(Debug, Serialize)] #[serde(rename_all = "snake_case")] @@ -116,12 +117,12 @@ pub async fn run_sql(ctx: &SessionContext, sql: String) -> anyhow::Result { } pub async fn run_rule( - ctx: &SessionContext, + ctx: Arc, rule: &Rule, total_rows: u64, ) -> anyhow::Result { let sql = build_sql(rule)?; - let violations = run_sql(ctx, sql).await?; + let violations = run_sql(&ctx, sql).await?; let violation_rate = violations as f64 / total_rows as f64; let status = { if violation_rate <= rule.threshold.unwrap_or(0.0) { @@ -139,14 +140,53 @@ pub async fn run_rule( }) } +/// Run all rules concurrently via `tokio::task::JoinSet` and return results in +/// the same order as the input `rules` slice. +pub async fn run_rules_parallel( + ctx: Arc, + rules: Vec, + total_rows: u64, +) -> anyhow::Result> { + // Build a position map so we can restore original order after concurrent execution. + let rule_order: std::collections::HashMap = rules + .iter() + .enumerate() + .map(|(i, r)| (r.name.clone(), i)) + .collect(); + + let mut set = tokio::task::JoinSet::new(); + for rule in rules { + let ctx = Arc::clone(&ctx); + set.spawn(async move { + let name = rule.name.clone(); + run_rule(ctx, &rule, total_rows) + .await + .with_context(|| format!("Rule '{}' failed to execute", name)) + }); + } + + let mut results: Vec = Vec::new(); + while let Some(res) = set.join_next().await { + match res { + Ok(Ok(result)) => results.push(result), + Ok(Err(e)) => return Err(e), + Err(join_err) => return Err(anyhow::anyhow!("Task panicked: {}", join_err)), + } + } + + // Restore original submission order. + results.sort_by_key(|r| rule_order.get(&r.name).copied().unwrap_or(usize::MAX)); + Ok(results) +} + #[cfg(test)] mod test { use super::*; - async fn make_ctx(sql: &str) -> SessionContext { + async fn make_ctx(sql: &str) -> Arc { let ctx = SessionContext::new(); ctx.sql(sql).await.unwrap().collect().await.unwrap(); - ctx + Arc::new(ctx) } fn make_rule(name: &str, column: &str, check: Check) -> Rule { @@ -166,7 +206,7 @@ mod test { async fn test_not_null_fails_when_nulls_present() { let ctx = make_ctx("CREATE TABLE data AS SELECT * FROM (VALUES (1, 'alice'), (2, 'bob'), (NULL, 'carol')) AS t(age, name)").await; let rule = make_rule("age_not_null", "age", Check::NotNull); - let res = run_rule(&ctx, &rule, 3).await.unwrap(); + let res = run_rule(ctx, &rule, 3).await.unwrap(); assert!(matches!(res.status, RuleStatus::Fail)); assert_eq!(res.violations, 1) } @@ -175,7 +215,7 @@ mod test { async fn test_not_null_pass_when_nulls_not_present() { let ctx = make_ctx("CREATE TABLE data AS SELECT * FROM (VALUES (1, 'alice'), (2, 'bob'), (NULL, 'carol')) AS t(age, name)").await; let rule = make_rule("name_not_null", "name", Check::NotNull); - let res = run_rule(&ctx, &rule, 3).await.unwrap(); + let res = run_rule(ctx, &rule, 3).await.unwrap(); assert!(matches!(res.status, RuleStatus::Pass)); assert!(res.violations == 0) } @@ -187,7 +227,7 @@ mod test { min: Some(3.0), ..make_rule("age_gt_3", "age", Check::Min) }; - let res = run_rule(&ctx, &rule, 3).await.unwrap(); + let res = run_rule(ctx, &rule, 3).await.unwrap(); assert!(matches!(res.status, RuleStatus::Fail)); assert_eq!(res.violations, 1) } @@ -199,7 +239,7 @@ mod test { min: Some(1.0), ..make_rule("age_gt_1", "age", Check::Min) }; - let res = run_rule(&ctx, &rule, 3).await.unwrap(); + let res = run_rule(ctx, &rule, 3).await.unwrap(); assert!(matches!(res.status, RuleStatus::Pass)); assert!(res.violations == 0) } @@ -208,7 +248,7 @@ mod test { async fn test_not_empty_fails_when_empty_present() { let ctx = make_ctx("CREATE TABLE data AS SELECT * FROM (VALUES (1, ''), (2, 'bob'), (NULL, 'carol')) AS t(age, name)").await; let rule = make_rule("name_not_empty", "name", Check::NotEmpty); - let res = run_rule(&ctx, &rule, 3).await.unwrap(); + let res = run_rule(ctx, &rule, 3).await.unwrap(); assert!(matches!(res.status, RuleStatus::Fail)); assert_eq!(res.violations, 1) } @@ -217,7 +257,7 @@ mod test { async fn test_not_em_pass_when_empty_not_present() { let ctx = make_ctx("CREATE TABLE data AS SELECT * FROM (VALUES (1, 'alice'), (2, 'bob'), (NULL, 'carol')) AS t(age, name)").await; let rule = make_rule("name_not_empty", "name", Check::NotEmpty); - let res = run_rule(&ctx, &rule, 3).await.unwrap(); + let res = run_rule(ctx, &rule, 3).await.unwrap(); assert!(matches!(res.status, RuleStatus::Pass)); assert!(res.violations == 0) } @@ -229,7 +269,7 @@ mod test { max: Some(2.0), ..make_rule("age_st_2", "age", Check::Max) }; - let res = run_rule(&ctx, &rule, 3).await.unwrap(); + let res = run_rule(ctx, &rule, 3).await.unwrap(); assert!(matches!(res.status, RuleStatus::Fail)); assert_eq!(res.violations, 2) } @@ -241,7 +281,7 @@ mod test { max: Some(2.0), ..make_rule("age_st_2", "age", Check::Max) }; - let res = run_rule(&ctx, &rule, 3).await.unwrap(); + let res = run_rule(ctx, &rule, 3).await.unwrap(); assert!(matches!(res.status, RuleStatus::Pass)); assert!(res.violations == 0) } @@ -255,7 +295,7 @@ mod test { max: Some(8.0), ..make_rule("age_between", "age", Check::Between) }; - let res = run_rule(&ctx, &rule, 3).await.unwrap(); + let res = run_rule(ctx, &rule, 3).await.unwrap(); assert!(matches!(res.status, RuleStatus::Fail)); assert_eq!(res.violations, 2); // 1 and 10 are out of range } @@ -269,7 +309,7 @@ mod test { max: Some(10.0), ..make_rule("age_between", "age", Check::Between) }; - let res = run_rule(&ctx, &rule, 3).await.unwrap(); + let res = run_rule(ctx, &rule, 3).await.unwrap(); assert!(matches!(res.status, RuleStatus::Pass)); assert_eq!(res.violations, 0); } @@ -280,7 +320,7 @@ mod test { make_ctx("CREATE TABLE data AS SELECT * FROM (VALUES ('a'), ('b'), ('a')) AS t(name)") .await; let rule = make_rule("name_unique", "name", Check::Unique); - let res = run_rule(&ctx, &rule, 3).await.unwrap(); + let res = run_rule(ctx, &rule, 3).await.unwrap(); assert!(matches!(res.status, RuleStatus::Fail)); assert_eq!(res.violations, 2); // both 'a' rows are duplicates } @@ -291,7 +331,7 @@ mod test { make_ctx("CREATE TABLE data AS SELECT * FROM (VALUES ('a'), ('b'), ('c')) AS t(name)") .await; let rule = make_rule("name_unique", "name", Check::Unique); - let res = run_rule(&ctx, &rule, 3).await.unwrap(); + let res = run_rule(ctx, &rule, 3).await.unwrap(); assert!(matches!(res.status, RuleStatus::Pass)); assert_eq!(res.violations, 0); } @@ -303,7 +343,7 @@ mod test { pattern: Some("^[^@]+@[^@]+$".to_string()), ..make_rule("email_regex", "email", Check::Regex) }; - let res = run_rule(&ctx, &rule, 2).await.unwrap(); + let res = run_rule(ctx, &rule, 2).await.unwrap(); assert!(matches!(res.status, RuleStatus::Fail)); assert_eq!(res.violations, 1); } @@ -318,7 +358,7 @@ mod test { pattern: Some("^[^@]+@[^@]+$".to_string()), ..make_rule("email_regex", "email", Check::Regex) }; - let res = run_rule(&ctx, &rule, 2).await.unwrap(); + let res = run_rule(ctx, &rule, 2).await.unwrap(); assert!(matches!(res.status, RuleStatus::Pass)); assert_eq!(res.violations, 0); } @@ -332,7 +372,7 @@ mod test { threshold: Some(0.5), // allow up to 50% nulls — 1/3 = 33% should pass ..make_rule("age_not_null", "age", Check::NotNull) }; - let res = run_rule(&ctx, &rule, 3).await.unwrap(); + let res = run_rule(ctx, &rule, 3).await.unwrap(); assert!(matches!(res.status, RuleStatus::Pass)); } @@ -340,7 +380,7 @@ mod test { async fn test_min_without_min_value_returns_error() { let ctx = make_ctx("CREATE TABLE data AS SELECT * FROM (VALUES (1)) AS t(age)").await; let rule = make_rule("bad_rule", "age", Check::Min); // no min set - let res = run_rule(&ctx, &rule, 1).await; + let res = run_rule(ctx, &rule, 1).await; assert!(res.is_err()); } @@ -348,7 +388,7 @@ mod test { async fn test_max_without_max_value_returns_error() { let ctx = make_ctx("CREATE TABLE data AS SELECT * FROM (VALUES (1)) AS t(age)").await; let rule = make_rule("bad_rule", "age", Check::Max); // no max set - let res = run_rule(&ctx, &rule, 1).await; + let res = run_rule(ctx, &rule, 1).await; assert!(res.is_err()); } @@ -357,7 +397,7 @@ mod test { let ctx = make_ctx("CREATE TABLE data AS SELECT * FROM (VALUES ('hello')) AS t(name)").await; let rule = make_rule("bad_rule", "name", Check::Regex); // no pattern set - let res = run_rule(&ctx, &rule, 1).await; + let res = run_rule(ctx, &rule, 1).await; assert!(res.is_err()); } @@ -371,7 +411,7 @@ mod test { pattern: Some("it's".to_string()), ..make_rule("quote_test", "name", Check::Regex) }; - let res = run_rule(&ctx, &rule, 2).await.unwrap(); + let res = run_rule(ctx, &rule, 2).await.unwrap(); assert!(matches!(res.status, RuleStatus::Fail)); assert_eq!(res.violations, 1); } @@ -383,7 +423,7 @@ mod test { sql: Some("SELECT COUNT(*) FROM data WHERE age IS NULL".into()), ..make_rule("age_not_null", "age", Check::Custom) }; - let res = run_rule(&ctx, &rule, 3).await.unwrap(); + let res = run_rule(ctx, &rule, 3).await.unwrap(); assert!(matches!(res.status, RuleStatus::Fail)); assert_eq!(res.violations, 1) } @@ -395,7 +435,7 @@ mod test { sql: Some("SELECT COUNT(*) FROM data WHERE age IS NULL".into()), ..make_rule("age_not_null", "age", Check::Custom) }; - let res = run_rule(&ctx, &rule, 3).await.unwrap(); + let res = run_rule(ctx, &rule, 3).await.unwrap(); assert!(matches!(res.status, RuleStatus::Pass)); assert_eq!(res.violations, 0) } @@ -441,4 +481,84 @@ mod test { ]; assert!(validate_threshold(&rules).is_ok()); } + + // ── Parallel execution tests ────────────────────────────────────────────── + + /// Create a shared in-memory context with a simple numeric table. + async fn make_numeric_ctx() -> Arc { + make_ctx("CREATE TABLE data AS SELECT * FROM (VALUES (1), (2), (3), (4), (5)) AS t(age)") + .await + } + + #[tokio::test] + async fn test_parallel_execution_all_pass() { + let ctx = make_numeric_ctx().await; + // 5 rules that all pass (age >= 1, age <= 10, etc.) + let rules: Vec = (0..5) + .map(|i| Rule { + max: Some(10.0), + ..make_rule(&format!("rule_{}", i), "age", Check::Max) + }) + .collect(); + let total_rows = 5; + let results = run_rules_parallel(ctx, rules.clone(), total_rows) + .await + .unwrap(); + assert_eq!(results.len(), 5); + for res in &results { + assert!(matches!(res.status, RuleStatus::Pass)); + assert_eq!(res.violations, 0); + } + // Verify ordering matches input order + for (i, res) in results.iter().enumerate() { + assert_eq!(res.name, format!("rule_{}", i)); + } + } + + #[tokio::test] + async fn test_parallel_execution_with_failure() { + let ctx = make_numeric_ctx().await; + // Mix: rule_0 passes (max=10), rule_1 fails (max=2, values 3/4/5 violate), rule_2 passes + let rules = vec![ + Rule { + max: Some(10.0), + ..make_rule("pass_rule", "age", Check::Max) + }, + Rule { + max: Some(2.0), + ..make_rule("fail_rule", "age", Check::Max) + }, + Rule { + min: Some(1.0), + ..make_rule("pass_rule_2", "age", Check::Min) + }, + ]; + let results = run_rules_parallel(ctx, rules, 5).await.unwrap(); + assert_eq!(results.len(), 3); + assert!(matches!(results[0].status, RuleStatus::Pass)); + assert!(matches!(results[1].status, RuleStatus::Fail)); + assert_eq!(results[1].violations, 3); // values 3, 4, 5 exceed max=2 + assert!(matches!(results[2].status, RuleStatus::Pass)); + // Names preserved in order + assert_eq!(results[0].name, "pass_rule"); + assert_eq!(results[1].name, "fail_rule"); + assert_eq!(results[2].name, "pass_rule_2"); + } + + #[tokio::test] + async fn test_parallel_execution_preserves_order() { + let ctx = make_numeric_ctx().await; + // Submit 10 rules; completion order is non-deterministic, but output must match input order. + let rules: Vec = (0..10) + .map(|i| Rule { + max: Some(100.0), + ..make_rule(&format!("ordered_rule_{:02}", i), "age", Check::Max) + }) + .collect(); + let expected_names: Vec = rules.iter().map(|r| r.name.clone()).collect(); + let results = run_rules_parallel(ctx, rules, 5).await.unwrap(); + assert_eq!(results.len(), 10); + let actual_names: Vec = results.iter().map(|r| r.name.clone()).collect(); + assert_eq!(actual_names, expected_names); + } } diff --git a/src/storage.rs b/src/storage.rs index 8054052..15c3c4c 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -112,7 +112,7 @@ mod test { threshold: None, sql: None, }; - let result = run_rule(&ctx, &rule, 3).await.unwrap(); + let result = run_rule(Arc::new(ctx), &rule, 3).await.unwrap(); assert!(matches!(result.status, RuleStatus::Fail)); assert_eq!(result.violations, 1); } @@ -170,7 +170,7 @@ mod test { threshold: None, sql: None, }; - let result = run_rule(&ctx, &rule, 3).await.unwrap(); + let result = run_rule(Arc::new(ctx), &rule, 3).await.unwrap(); assert!(matches!(result.status, RuleStatus::Fail)); assert_eq!(result.violations, 1); }