From f389c35e5f1527f842c009e171802ec1e6688f14 Mon Sep 17 00:00:00 2001 From: Mikhail Kot Date: Thu, 4 Jun 2026 18:14:03 +0100 Subject: [PATCH 1/5] initial len() --- vortex-array/src/scalar_fn/fns/byte_length.rs | 9 + vortex-duckdb/build.rs | 3 +- vortex-duckdb/cpp/include/duckdb_vx.h | 1 + .../cpp/include/duckdb_vx/optimizer.h | 141 ++++++++++ .../cpp/include/duckdb_vx/table_function.h | 17 ++ vortex-duckdb/cpp/optimizer.cpp | 211 +++++++++++++++ vortex-duckdb/cpp/table_function.cpp | 32 ++- vortex-duckdb/include/vortex.h | 6 + vortex-duckdb/src/convert/expr.rs | 28 ++ vortex-duckdb/src/convert/mod.rs | 1 + vortex-duckdb/src/duckdb/database.rs | 8 + vortex-duckdb/src/ffi.rs | 16 ++ vortex-duckdb/src/lib.rs | 1 + vortex-duckdb/src/projection.rs | 112 ++++++-- vortex-duckdb/src/table_function.rs | 21 ++ vortex-layout/src/layouts/dict/reader.rs | 8 +- .../duckdb/projection_expression_pushdown.slt | 242 ++++++++++++++++++ 17 files changed, 820 insertions(+), 37 deletions(-) create mode 100644 vortex-duckdb/cpp/include/duckdb_vx/optimizer.h create mode 100644 vortex-duckdb/cpp/optimizer.cpp create mode 100644 vortex-sqllogictest/slt/duckdb/projection_expression_pushdown.slt diff --git a/vortex-array/src/scalar_fn/fns/byte_length.rs b/vortex-array/src/scalar_fn/fns/byte_length.rs index 13a4f3158b5..aa9c508ea89 100644 --- a/vortex-array/src/scalar_fn/fns/byte_length.rs +++ b/vortex-array/src/scalar_fn/fns/byte_length.rs @@ -24,6 +24,7 @@ use crate::arrays::varbinview::VarBinViewArrayExt; use crate::dtype::DType; use crate::dtype::Nullability; use crate::dtype::PType; +use crate::expr::Expression; use crate::kernel::ExecuteParentKernel; use crate::scalar::Scalar; use crate::scalar_fn::Arity; @@ -122,6 +123,14 @@ impl ScalarFnVTable for ByteLength { } } + fn validity( + &self, + _: &Self::Options, + expression: &Expression, + ) -> VortexResult> { + Ok(Some(expression.child(0).validity()?)) + } + fn is_null_sensitive(&self, _options: &Self::Options) -> bool { false } diff --git a/vortex-duckdb/build.rs b/vortex-duckdb/build.rs index 3451af771bc..507af8a911e 100644 --- a/vortex-duckdb/build.rs +++ b/vortex-duckdb/build.rs @@ -24,7 +24,7 @@ const DEFAULT_DUCKDB_VERSION: &str = "1.5.3"; const BUILD_ARTIFACTS: [&str; 3] = ["libduckdb.dylib", "libduckdb.so", "libduckdb_static.a"]; -const SOURCE_FILES: [&str; 17] = [ +const SOURCE_FILES: [&str; 18] = [ "cpp/client_context.cpp", "cpp/config.cpp", "cpp/copy_function.cpp", @@ -34,6 +34,7 @@ const SOURCE_FILES: [&str; 17] = [ "cpp/expr.cpp", "cpp/file_system.cpp", "cpp/logical_type.cpp", + "cpp/optimizer.cpp", "cpp/replacement_scan.cpp", "cpp/reusable_dict.cpp", "cpp/scalar_function.cpp", diff --git a/vortex-duckdb/cpp/include/duckdb_vx.h b/vortex-duckdb/cpp/include/duckdb_vx.h index dcad0ae1487..176b40a415a 100644 --- a/vortex-duckdb/cpp/include/duckdb_vx.h +++ b/vortex-duckdb/cpp/include/duckdb_vx.h @@ -4,6 +4,7 @@ #pragma once #include "duckdb_vx/client_context.h" +#include "duckdb_vx/optimizer.h" #include "duckdb_vx/config.h" #include "duckdb_vx/copy_function.h" #include "duckdb_vx/data.h" diff --git a/vortex-duckdb/cpp/include/duckdb_vx/optimizer.h b/vortex-duckdb/cpp/include/duckdb_vx/optimizer.h new file mode 100644 index 00000000000..01851a79eaf --- /dev/null +++ b/vortex-duckdb/cpp/include/duckdb_vx/optimizer.h @@ -0,0 +1,141 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors +#pragma once +#include "duckdb.h" + +#ifdef __cplusplus +extern "C" { +#endif + +duckdb_state duckdb_vx_optimizer_extension_register(duckdb_database ffi_db); + +#ifdef __cplusplus +} +#endif + +#ifdef __cplusplus +#include "duckdb/optimizer/optimizer_extension.hpp" +#include "duckdb/planner/expression/bound_columnref_expression.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" +#include "duckdb/planner/operator/logical_get.hpp" +#include + +// Only one consumer of this header file, so "using" is fine +using namespace duckdb; + +using ExpressionPtr = unique_ptr; +using LogicalOperatorPtr = unique_ptr; + +/** + * Column index in requested scan. Example: + * + * CREATE TABLE t (a1 INTEGER, a2 INTEGER, a3 INTEGER); + * SELECT a2, a3 FROM t; + * + * a2's TableColumnScanIndex is 0, a3's TableColumnScanIndex is 1, + * index is index in SELECT clause. + */ +using TableColumnScanIndex = idx_t; + +/** + * Column index in table's storage. Example: + * + * CREATE TABLE t (a1 INTEGER, a2 INTEGER, a3 INTEGER); + * SELECT a2, a3 FROM t; + * + * a2's TableColumnStorageIndex is 1, a3's TableColumnScanIndex is 2, + * index is index of column in table storage. + * + * for i: TableColumnScanIndex, column_ids[i].GetPrimaryIndex() is + * TableColumnStorageIndex + */ +using TableColumnStorageIndex = idx_t; + +using TableIndex = idx_t; + +struct GetAnalysis { + LogicalGet &get; + /** + * for fn(col), mapping of "col scan index" -> "fn expression". + * "fn expression" is nullptr iff column is used with a different function + * or without function application in the query plan. + */ + unordered_map col_to_fn; +}; + +using Analyses = unordered_map; + +/* + * Query plans may have PROJECTIONs which wrap GETs. One example is VIEWs for + * our benchmarks: + * + * CREATE VIEW view AS (SELECT * FROM '*.vortex'); + * SELECT len(col) FROM view; + * + * Second query "col"'s table_index would be 1 (VIEW) and not 0 (GET for + * vortex). But we want to push down len(col) to vortex. So we keep an aliases + * mapping of + * + * "projection table index" to "projection operator". + * + * to resolve this. + * For simplicity, current implementation is limited to one level i.e. + * VIEW -> GET is pushed down but VIEW->VIEW->GET or VIEW->CTE->GET is not. + */ +using Projections = unordered_map; + +/** + * Collect fn(col) expressions i.e. expressions where a single function (not + * a function chain) wraps a single bound column. If "col" is used without + * function application in "plan", record in "analyses.conflicts" + */ +struct ScalarFnCollect final : LogicalOperatorVisitor { + Analyses &analyses; + const Projections &projections; + + ScalarFnCollect(Analyses &analyses, const Projections &projections); + void VisitOperator(LogicalOperator &op) override; + ExpressionPtr VisitReplace(BoundColumnRefExpression &expr, ExpressionPtr *ptr) override; + ExpressionPtr VisitReplace(BoundFunctionExpression &expr, ExpressionPtr *ptr) override; +}; + +/* + * For "col" in columns collected by ScalarFnCollect, replace fn(col) to "col" + * if "col" doesn't have conflicting usage. Update return types for bound + * columns and logical projections referencing this column. + */ +struct ScalarFnReplace final : LogicalOperatorVisitor { + Analyses &analyses; + const Projections &projections; + + ScalarFnReplace(Analyses &analyses, const Projections &aliases); + ExpressionPtr VisitReplace(BoundColumnRefExpression &expr, ExpressionPtr *ptr) override; + ExpressionPtr VisitReplace(BoundFunctionExpression &expr, ExpressionPtr *ptr) override; +}; + +void FindGetsAndAliases(LogicalOperator &op, + Analyses &analyses, + Projections &aliases, + LogicalOperator *parent = nullptr); + +LogicalOperatorPtr TryPushdownScalarFunctions(ClientContext &context, LogicalOperatorPtr plan); +void VortexOptimizeFunction(OptimizerExtensionInput &input, LogicalOperatorPtr &plan); + +struct VortexOptimizerExtension final : OptimizerExtension { + inline VortexOptimizerExtension() : OptimizerExtension(VortexOptimizeFunction, nullptr, {}) { + } +}; + +struct Binding { + GetAnalysis &analysis; + TableColumnScanIndex column_index; +}; + +/* + * Given a column binding, resolve it to a GET and a GET's column scan index. + * Returns nullopt for virtual columns and columns which are neither part of + * GET nor part of PROJECTION wrapping a GET. + */ +std::optional Resolve(ColumnBinding binding, Analyses &analyses, const Projections &projections); + +#endif diff --git a/vortex-duckdb/cpp/include/duckdb_vx/table_function.h b/vortex-duckdb/cpp/include/duckdb_vx/table_function.h index 4299d3a3b00..a5979376175 100644 --- a/vortex-duckdb/cpp/include/duckdb_vx/table_function.h +++ b/vortex-duckdb/cpp/include/duckdb_vx/table_function.h @@ -10,6 +10,23 @@ #ifdef __cplusplus static_assert(sizeof(idx_t) == 8); + +#include "duckdb/main/capi/capi_internal.hpp" + +duckdb::unique_ptr bind(duckdb::ClientContext &context, + duckdb::TableFunctionBindInput &input, + duckdb::vector &return_types, + duckdb::vector &names); + +struct TableFunctionProjectionExpressionInput { + const duckdb::LogicalGet &get; + const duckdb::Expression &expression; + idx_t projection_idx; +}; + +// true if we can push down the expression, false otherwise +bool projection_expression_pushdown(duckdb::ClientContext &context, + const TableFunctionProjectionExpressionInput &input); #endif #ifdef __cplusplus diff --git a/vortex-duckdb/cpp/optimizer.cpp b/vortex-duckdb/cpp/optimizer.cpp new file mode 100644 index 00000000000..425441e0422 --- /dev/null +++ b/vortex-duckdb/cpp/optimizer.cpp @@ -0,0 +1,211 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors +#include "duckdb/catalog/catalog.hpp" +#include "duckdb/planner/operator/logical_projection.hpp" +#include "duckdb_vx/optimizer.h" +#include "duckdb_vx/table_function.h" +#include "vortex.h" +#include + +extern "C" duckdb_state duckdb_vx_optimizer_extension_register(duckdb_database ffi_db) { + D_ASSERT(ffi_db); + const DatabaseWrapper &wrapper = *reinterpret_cast(ffi_db); + DatabaseInstance &db = *wrapper.database->instance; + try { + DBConfig::GetConfig(db).GetCallbackManager().Register(VortexOptimizerExtension()); + } catch (const std::exception &e) { + ErrorData data(e); + DUCKDB_LOG_ERROR(db, "Failed to create Vortex optimizer extension:\t" + data.Message()); + return DuckDBError; + } + return DuckDBSuccess; +} + +void VortexOptimizeFunction(OptimizerExtensionInput &input, LogicalOperatorPtr &plan) { + plan = TryPushdownScalarFunctions(input.context, std::move(plan)); +} + +LogicalOperatorPtr TryPushdownScalarFunctions(ClientContext &context, LogicalOperatorPtr plan) { + Analyses analyses; + Projections projections; + FindGetsAndAliases(*plan, analyses, projections); + if (analyses.empty()) { + return plan; + } + ScalarFnCollect(analyses, projections).VisitOperator(*plan); + + bool any_pushed = false; + for (auto &[_, analysis] : analyses) { + for (auto &[column_index, expr] : analysis.col_to_fn) { + if (expr == nullptr) { // Conflict for column + continue; + } + const TableColumnStorageIndex storage_index = + analysis.get.GetColumnIds()[column_index].GetPrimaryIndex(); + TableFunctionProjectionExpressionInput input {analysis.get, *expr, storage_index}; + if (projection_expression_pushdown(context, input)) { + analysis.get.types[column_index] = expr->return_type; + analysis.get.returned_types[storage_index] = expr->return_type; + any_pushed = true; + } else { // failed to push down expression, can't replace it + expr = nullptr; + } + } + } + + if (any_pushed) { + ScalarFnReplace(analyses, projections).VisitOperator(*plan); + } + return plan; +} + +void FindGetsAndAliases(LogicalOperator &op, + Analyses &analyses, + Projections &projections, + LogicalOperator *parent) { + if (op.type == LogicalOperatorType::LOGICAL_GET) { + auto &get = op.Cast(); + if (get.function.bind == bind) { + analyses.emplace(get.table_index, GetAnalysis {get, {}}); + if (parent && parent->type == LogicalOperatorType::LOGICAL_PROJECTION) { + const auto &projection = parent->Cast(); + projections.emplace(projection.table_index, projection); + } + } + } + for (auto &child : op.children) { + FindGetsAndAliases(*child, analyses, projections, &op); + } +} + +std::optional Resolve(ColumnBinding binding, Analyses &analyses, const Projections &projections) { + if (IsVirtualColumn(binding.column_index)) { + return std::nullopt; + } + if (const auto it = analyses.find(binding.table_index); it != analyses.end()) { + return {{it->second, binding.column_index}}; + } + + const auto projection_it = projections.find(binding.table_index); + if (projection_it == projections.end()) { + return std::nullopt; + } + + const ExpressionPtr &inner = projection_it->second.expressions[binding.column_index]; + if (inner->GetExpressionType() != ExpressionType::BOUND_COLUMN_REF) { + return std::nullopt; + } + const ColumnBinding &get_binding = inner->Cast().binding; + if (IsVirtualColumn(get_binding.column_index)) { + return std::nullopt; + } + if (const auto it = analyses.find(get_binding.table_index); it != analyses.end()) { + return {{it->second, get_binding.column_index}}; + } + return std::nullopt; +} + +void ScalarFnCollect::VisitOperator(LogicalOperator &op) { + /* + * Logical projection expressions are columns which reference underlying + * GETs. Don't process them, as they would add conflicts for every column + * used in projection. Example: PROJECTION(col) -> GET(col). We don't want + * to visit BoundColumnRefExpression in PROJECTION. + * + * However, ScalarFnReplace will visit them because we need to update their + * types if pushdown succeeded. + */ + if (op.type == LogicalOperatorType::LOGICAL_PROJECTION && + projections.count(op.Cast().table_index)) { + VisitOperatorChildren(op); + return; + } + LogicalOperatorVisitor::VisitOperator(op); +} + +ExpressionPtr ScalarFnCollect::VisitReplace(BoundColumnRefExpression &expr, ExpressionPtr *ptr) { + const auto binding = Resolve(expr.binding, analyses, projections); + if (!binding) { + return std::move(*ptr); + } + auto &[analysis, column_index] = *binding; + + // Column is used without function applied to it, register a conflict. + // Not emplace() as we need to update the value if it was present + analysis.col_to_fn[column_index] = nullptr; + return std::move(*ptr); +} + +ExpressionPtr ScalarFnCollect::VisitReplace(BoundFunctionExpression &expr, ExpressionPtr *ptr) { + if (expr.children.size() != 1 || + expr.children[0]->GetExpressionType() != ExpressionType::BOUND_COLUMN_REF) { + return std::move(*ptr); + } + const auto &bound_col = expr.children[0]->Cast(); + const auto binding = Resolve(bound_col.binding, analyses, projections); + if (!binding) { + return std::move(*ptr); + } + auto &[analysis, column_index] = *binding; + + if (auto it = analysis.col_to_fn.find(column_index); it == analysis.col_to_fn.end()) { + // This is the first time we see the column used by a single function. + analysis.col_to_fn.emplace(column_index, &expr); + } else if (it->second == nullptr || !it->second->Equals(expr)) { + // Either column is used with different function in "expr" or + // there already is a conflict. + it->second = nullptr; + } + + // We don't want to descend into child BoundColumnRefExpression because we + // have already registered a conflict if it was present. + return std::move(*ptr); +} + +static bool conflict(const GetAnalysis &analysis, TableColumnScanIndex idx) { + const auto it = analysis.col_to_fn.find(idx); + return it == analysis.col_to_fn.end() || it->second == nullptr; +} + +ExpressionPtr ScalarFnReplace::VisitReplace(BoundColumnRefExpression &expr, ExpressionPtr *ptr) { + const auto binding = Resolve(expr.binding, analyses, projections); + if (!binding) { + return std::move(*ptr); + } + + const auto &[analysis, column_index] = *binding; + if (!conflict(analysis, column_index)) { + expr.return_type = analysis.get.types[column_index]; + } + + return std::move(*ptr); +} + +ExpressionPtr ScalarFnReplace::VisitReplace(BoundFunctionExpression &expr, ExpressionPtr *ptr) { + if (expr.children.size() != 1 || + expr.children[0]->GetExpressionType() != ExpressionType::BOUND_COLUMN_REF) { + return std::move(*ptr); + } + ExpressionPtr &bound_col_base = expr.children[0]; + const auto &bound_col = bound_col_base->Cast(); + const auto binding = Resolve(bound_col.binding, analyses, projections); + if (!binding) { + return std::move(*ptr); + } + + const auto &[analysis, column_index] = *binding; + if (conflict(analysis, column_index)) { + return std::move(*ptr); + } + + bound_col_base->return_type = analysis.get.types[column_index]; + return std::move(bound_col_base); +} + +ScalarFnCollect::ScalarFnCollect(Analyses &analyses, const Projections &projections) + : analyses(analyses), projections(projections) { +} + +ScalarFnReplace::ScalarFnReplace(Analyses &analyses, const Projections &projections) + : analyses(analyses), projections(projections) { +} diff --git a/vortex-duckdb/cpp/table_function.cpp b/vortex-duckdb/cpp/table_function.cpp index 6c87d209b4d..1e08a86bd24 100644 --- a/vortex-duckdb/cpp/table_function.cpp +++ b/vortex-duckdb/cpp/table_function.cpp @@ -4,6 +4,7 @@ #include "duckdb_vx/data.hpp" #include "duckdb_vx/error.hpp" #include "duckdb_vx/table_function.h" +#include "duckdb_vx/expr.h" #include "vortex.h" #include "duckdb.h" @@ -14,6 +15,7 @@ #include "duckdb/main/capi/capi_internal.hpp" #include "duckdb/main/connection.hpp" #include "duckdb/parser/parsed_data/create_table_function_info.hpp" +#include "duckdb/planner/operator/logical_get.hpp" using namespace std::string_literals; using namespace duckdb; @@ -170,15 +172,35 @@ struct CTableBindResult { vector &names; }; +bool projection_expression_pushdown(ClientContext &, const TableFunctionProjectionExpressionInput &input) { + const auto &bind = input.get.bind_data->Cast(); + + // This is a flaw of Duckdb API which doesn't allow passing non-const + // expressions. We never modify the value on Rust side. + auto ffi_expr = reinterpret_cast(const_cast(&input.expression)); + void *const ffi_bind = bind.ffi_data->DataPtr(); + duckdb_vx_error error_out = nullptr; + + const bool ret = duckdb_table_function_pushdown_projection_expression( // + ffi_bind, + ffi_expr, + input.projection_idx, + &error_out); + if (error_out) { + throw BinderException(IntoErrString(error_out)); + } + return ret; +} + /** * Called for every new query. For example, if there is a VIEW over *.vortex, * and after a query another file is added matching the glob, for second query * bind() will be called again. */ -unique_ptr c_bind(ClientContext &context, - TableFunctionBindInput &input, - vector &return_types, - vector &names) { +unique_ptr bind(ClientContext &context, + TableFunctionBindInput &input, + vector &return_types, + vector &names) { CTableBindResult result = {return_types, names}; duckdb_vx_error error_out = nullptr; @@ -364,7 +386,7 @@ InsertionOrderPreservingMap c_to_string(TableFunctionToStringInput &inpu } duckdb_state register_table_function(DatabaseInstance &db, LogicalType parameter, const std::string &name) { - TableFunction tf(name, {}, function, c_bind, c_init_global, init_local); + TableFunction tf(name, {}, function, bind, c_init_global, init_local); tf.projection_pushdown = true; tf.filter_pushdown = true; diff --git a/vortex-duckdb/include/vortex.h b/vortex-duckdb/include/vortex.h index 03835870848..3480e625fb9 100644 --- a/vortex-duckdb/include/vortex.h +++ b/vortex-duckdb/include/vortex.h @@ -61,6 +61,12 @@ bool duckdb_table_function_pushdown_complex_filter(void *bind_data, duckdb_vx_expr expr, duckdb_vx_error *error_out); +extern +bool duckdb_table_function_pushdown_projection_expression(void *bind_data, + duckdb_vx_expr expr, + size_t column_id, + duckdb_vx_error *error_out); + extern void duckdb_table_function_scan(void *global_init_data, void *local_init_data, diff --git a/vortex-duckdb/src/convert/expr.rs b/vortex-duckdb/src/convert/expr.rs index 2fdd6d9c033..c6fd262316b 100644 --- a/vortex-duckdb/src/convert/expr.rs +++ b/vortex-duckdb/src/convert/expr.rs @@ -4,7 +4,9 @@ use std::sync::Arc; use tracing::debug; +use vortex::dtype::DType; use vortex::dtype::Nullability; +use vortex::dtype::PType; use vortex::error::VortexError; use vortex::error::VortexExpect; use vortex::error::VortexResult; @@ -13,6 +15,8 @@ use vortex::error::vortex_ensure; use vortex::error::vortex_err; use vortex::expr::Expression; use vortex::expr::and_collect; +use vortex::expr::byte_length; +use vortex::expr::cast; use vortex::expr::col; use vortex::expr::get_item; use vortex::expr::is_not_null; @@ -21,6 +25,7 @@ use vortex::expr::list_contains; use vortex::expr::lit; use vortex::expr::not; use vortex::expr::or_collect; +use vortex::expr::root; use vortex::scalar::Scalar; use vortex::scalar_fn::ScalarFnVTableExt; use vortex::scalar_fn::fns::between::Between; @@ -43,6 +48,7 @@ use crate::duckdb::ExpressionClass::BoundComparison; use crate::duckdb::ExpressionClass::BoundConjunction; use crate::duckdb::ExpressionClass::BoundConstant; use crate::duckdb::ExpressionClass::BoundRef; +use crate::projection::DuckdbField; fn from_bound_str(value: &duckdb::ExpressionRef) -> VortexResult { match value.as_class().vortex_expect("unknown class") { @@ -171,6 +177,28 @@ pub fn can_push_expression(value: &duckdb::ExpressionRef) -> bool { } } +pub fn try_from_projection_expression( + value: &duckdb::ExpressionRef, + field: &DuckdbField, +) -> VortexResult> { + let Some(value) = value.as_class() else { + return Ok(None); + }; + let ExpressionClass::BoundFunction(func) = value else { + return Ok(None); + }; + Ok(match func.scalar_function.name() { + "strlen" => { + let col = byte_length(get_item(field.name.as_str(), root())); + // byte_length returns u64, strlen expects i64 + let dtype = DType::Primitive(PType::I64, field.dtype.nullability()); + let col = cast(col, dtype); + Some(col) + } + _ => None, + }) +} + // If you want to add support for other expressions, also change // can_push_expression fn try_from_expression_inner( diff --git a/vortex-duckdb/src/convert/mod.rs b/vortex-duckdb/src/convert/mod.rs index 0742d2a5b1f..0d641a73457 100644 --- a/vortex-duckdb/src/convert/mod.rs +++ b/vortex-duckdb/src/convert/mod.rs @@ -10,6 +10,7 @@ mod vector; pub use dtype::FromLogicalType; pub use expr::can_push_expression; pub use expr::try_from_bound_expression; +pub use expr::try_from_projection_expression; pub use scalar::*; pub use table_filter::try_from_table_filter; pub use table_filter::try_from_virtual_column_filter; diff --git a/vortex-duckdb/src/duckdb/database.rs b/vortex-duckdb/src/duckdb/database.rs index 65afae1af8c..6ad9623c4e8 100644 --- a/vortex-duckdb/src/duckdb/database.rs +++ b/vortex-duckdb/src/duckdb/database.rs @@ -139,6 +139,14 @@ impl DatabaseRef { Ok(()) } + pub fn register_optimizer_extension(&self) -> VortexResult<()> { + duckdb_try!( + unsafe { cpp::duckdb_vx_optimizer_extension_register(self.as_ptr()) }, + "Failed to register optimizer extension" + ); + Ok(()) + } + pub fn register_copy_function(&self) -> VortexResult<()> { duckdb_try!( unsafe { cpp::duckdb_vx_register_copy_function(self.as_ptr()) }, diff --git a/vortex-duckdb/src/ffi.rs b/vortex-duckdb/src/ffi.rs index e77c0651fb2..de243365efc 100644 --- a/vortex-duckdb/src/ffi.rs +++ b/vortex-duckdb/src/ffi.rs @@ -39,6 +39,7 @@ use crate::table_function::get_partition_data; use crate::table_function::init_global; use crate::table_function::init_local; use crate::table_function::pushdown_complex_filter; +use crate::table_function::pushdown_projection_expression; use crate::table_function::scan; use crate::table_function::statistics; use crate::table_function::table_scan_progress; @@ -111,6 +112,21 @@ unsafe extern "C-unwind" fn duckdb_table_function_pushdown_complex_filter( try_or(error_out, || pushdown_complex_filter(bind_data, expr)) } +#[unsafe(no_mangle)] +unsafe extern "C-unwind" fn duckdb_table_function_pushdown_projection_expression( + bind_data: *mut c_void, + expr: cpp::duckdb_vx_expr, + column_id: usize, + error_out: *mut cpp::duckdb_vx_error, +) -> bool { + let bind_data = unsafe { bind_data.cast::().as_mut() } + .vortex_expect("bind_data null pointer"); + let expr = unsafe { Expression::borrow(expr) }; + try_or(error_out, || { + pushdown_projection_expression(bind_data, expr, column_id) + }) +} + #[unsafe(no_mangle)] unsafe extern "C-unwind" fn duckdb_table_function_scan( global_init_data: *mut c_void, diff --git a/vortex-duckdb/src/lib.rs b/vortex-duckdb/src/lib.rs index 9988b0b4549..616365568b5 100644 --- a/vortex-duckdb/src/lib.rs +++ b/vortex-duckdb/src/lib.rs @@ -74,6 +74,7 @@ pub fn initialize(db: &DatabaseRef) -> VortexResult<()> { Value::from("vortex"), )?; db.register_table_functions()?; + db.register_optimizer_extension()?; db.register_copy_function() } diff --git a/vortex-duckdb/src/projection.rs b/vortex-duckdb/src/projection.rs index a27056e5f01..7d4ae1a7820 100644 --- a/vortex-duckdb/src/projection.rs +++ b/vortex-duckdb/src/projection.rs @@ -1,17 +1,16 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors use std::ops::Range; -use std::sync::Arc; use num_traits::AsPrimitive as _; use vortex::dtype::DType; -use vortex::dtype::FieldNames; use vortex::error::VortexExpect; use vortex::error::VortexResult; use vortex::error::vortex_err; use vortex::expr::Expression; use vortex::expr::and_collect; use vortex::expr::col; +use vortex::expr::get_item; use vortex::expr::merge; use vortex::expr::pack; use vortex::expr::root; @@ -43,6 +42,9 @@ pub struct DuckdbField { pub name: String, pub logical_type: LogicalType, pub dtype: DType, + /// Function to use instead of get_item(col, root()), e.g. len(col). + /// It does not include column name so it's just "len" and not "len(col)" + pub projection_fn: Option, } pub struct Projection { @@ -68,6 +70,7 @@ impl Projection { let mut file_row_number_column_pos = None; let mut is_star = true; let mut real_column_count = 0; + let mut fn_col_count = 0; // DuckDB uses u64 as column indices but Rust uses usize for (column_pos, &column_id) in ids.iter().enumerate() { @@ -86,47 +89,92 @@ impl Projection { file_row_number_column_pos = Some(column_pos); continue; } + if is_virtual_column(column_id) { + continue; + } // In SELECT * DuckDB requests all columns from 0 to column_fields in // increasing order. After removing virtual columns, compare column_id // with (0..column_fields.len()) range. is_star &= column_id == real_column_count; + + // Example: if we SELECT len(str), we can't use root() as we try to + // pushdown scalar functions. + let column_id: usize = column_id.as_(); + let is_projected_col = column_fields[column_id].projection_fn.is_some(); + fn_col_count += is_projected_col as usize; + is_star &= !is_projected_col; + real_column_count += 1; } // Duckdb can request less columns than there are in table i.e. [0, 1] with // 5 columns total. is_star &= real_column_count == column_fields.len() as u64; - let select = if is_star { - root() - } else { - let names = ids - .iter() - .map(|&column_id| { - if has_projection_ids { - let column_id: usize = column_id.as_(); - column_ids[column_id] - } else { - column_id - } - }) - .filter(|&col_id| !is_virtual_column(col_id)) - .map(|column_id| { - let column_id: usize = column_id.as_(); - Arc::from(column_fields[column_id].name.as_str()) - }) - .collect::(); + let has_file_row_number = file_row_number_column_pos.is_some(); + if is_star { + let projection = if has_file_row_number { + // row_idx will be moved to correct position in scan(), prepend here + let row_idx_struct = pack([("file_row_number", row_idx())], false.into()); + merge([row_idx_struct, root()]) + } else { + root() + }; + return Projection { + projection, + file_index_column_pos, + file_row_number_column_pos, + }; + } - select(names, root()) - }; + let has_fn_columns = fn_col_count > 0; + let mut all_exprs = Vec::with_capacity( + (ids.len() + has_file_row_number as usize) * has_fn_columns as usize, + ); + let mut named_fields = Vec::with_capacity(ids.len() * !has_fn_columns as usize); - // file_index column will be filled later when exporting the chunk. - let projection = if file_row_number_column_pos.is_some() { + if has_file_row_number && has_fn_columns { // row_idx will be moved to correct position in scan(), prepend here + all_exprs.push(("file_row_number", row_idx())); + } + + for &column_id in ids { + let column_id = if has_projection_ids { + let column_id: usize = column_id.as_(); + column_ids[column_id] + } else { + column_id + }; + if is_virtual_column(column_id) { + continue; + } + let column_id: usize = column_id.as_(); + let name = column_fields[column_id].name.as_str(); + if !has_fn_columns { + named_fields.push(name); + continue; + } + + let column_field = &column_fields[column_id]; + let expr = match &column_field.projection_fn { + None => get_item(name, root()), + Some(func) => func.clone(), + }; + all_exprs.push((name, expr)); + } + + let projection = if has_fn_columns { + // If file_row_number was requested, it's in all_exprs as first + // element + pack(all_exprs, false.into()) + } else if has_file_row_number { + let select = select(named_fields, root()); + // Here we need to prepend it manually + // row_idx will be moved to correct position in scan() let row_idx_struct = pack([("file_row_number", row_idx())], false.into()); merge([row_idx_struct, select]) } else { - select + select(named_fields, root()) }; Self { @@ -224,6 +272,7 @@ pub fn extract_schema_from_dtype(dtype: &DType) -> VortexResult name: field_name.to_string(), logical_type, dtype: field_dtype, + projection_fn: None, }); } Ok(fields) @@ -232,6 +281,7 @@ pub fn extract_schema_from_dtype(dtype: &DType) -> VortexResult #[cfg(test)] mod tests { use vortex::dtype::DType; + use vortex::expr::lit; use vortex::expr::merge; use vortex::expr::pack; use vortex::expr::root; @@ -242,21 +292,24 @@ mod tests { #[test] fn test_select_star() { let ids = [0, 1, 2]; - let fields = [ + let mut fields = [ DuckdbField { name: "".to_owned(), logical_type: LogicalType::null(), dtype: DType::Null, + projection_fn: None, }, DuckdbField { name: "".to_owned(), logical_type: LogicalType::null(), dtype: DType::Null, + projection_fn: None, }, DuckdbField { name: "".to_owned(), logical_type: LogicalType::null(), dtype: DType::Null, + projection_fn: None, }, ]; @@ -285,5 +338,10 @@ mod tests { let ids = [2, 1, 0]; assert_ne!(Projection::new(None, &ids, &fields).projection, root()); + + // If any column has a projection expression, we can't use SELECT * + fields[0].projection_fn = Some(lit(true)); + let ids = [0, 1, 2]; + assert_ne!(Projection::new(None, &ids, &fields).projection, root()); } } diff --git a/vortex-duckdb/src/table_function.rs b/vortex-duckdb/src/table_function.rs index 11c5851af27..4d789e53091 100644 --- a/vortex-duckdb/src/table_function.rs +++ b/vortex-duckdb/src/table_function.rs @@ -46,6 +46,7 @@ use crate::SESSION; use crate::column_statistics::ColumnStatistics; use crate::column_statistics::ColumnStatisticsAggregate; use crate::convert::try_from_bound_expression; +use crate::convert::try_from_projection_expression; use crate::duckdb::BindInputRef; use crate::duckdb::BindResultRef; use crate::duckdb::ClientContextRef; @@ -429,6 +430,26 @@ pub fn pushdown_complex_filter( Ok(report_pushed) } +pub fn pushdown_projection_expression( + bind_data: &mut TableFunctionBind, + expr: &ExpressionRef, + projection_id: usize, +) -> VortexResult { + let field = &bind_data.column_fields[projection_id]; + debug!(%expr, %projection_id, col_name=field.name, "pushing down projection expression"); + match try_from_projection_expression(expr, field)? { + None => { + debug!(%expr, "failed to push down expression"); + Ok(false) + } + Some(vx_expr) => { + debug!(%expr, "pushed down expression"); + bind_data.column_fields[projection_id].projection_fn = Some(vx_expr); + Ok(true) + } + } +} + /// Get column-wise statistics. Available only if we're reading a single file. pub fn statistics(bind_data: &TableFunctionBind, column_index: usize) -> Option { let children = bind_data.data_source.children(); diff --git a/vortex-layout/src/layouts/dict/reader.rs b/vortex-layout/src/layouts/dict/reader.rs index 002b4b1e902..b82fc8c2382 100644 --- a/vortex-layout/src/layouts/dict/reader.rs +++ b/vortex-layout/src/layouts/dict/reader.rs @@ -100,10 +100,10 @@ impl DictReader { ) .vortex_expect("must construct dict values array evaluation") .map_err(Arc::new) - .map(move |array| { - let array = array?; - Ok(SharedArray::new(array).into_array()) - }) + //.map(move |array| { + // let array = array?; + // Ok(SharedArray::new(array).into_array()) + //}) .boxed() .shared() }) diff --git a/vortex-sqllogictest/slt/duckdb/projection_expression_pushdown.slt b/vortex-sqllogictest/slt/duckdb/projection_expression_pushdown.slt new file mode 100644 index 00000000000..28e49b1d4b0 --- /dev/null +++ b/vortex-sqllogictest/slt/duckdb/projection_expression_pushdown.slt @@ -0,0 +1,242 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright the Vortex contributors + +include ../setup.slt.no + +# column with fn +# column, column with fn +# column +# + cte +# + nested cte +# + view +# + nested view + +# + virtual columns + +# We need to test pushdown for column which scan index +# (i.e. SELECT col) is 0 but which storage index is 1. +# Table where str has storage_index=1 (not 0): exercises scan_index != storage_index. +query I +CREATE TABLE pep_test (before INTEGER, str VARCHAR, after INTEGER); +---- + +query I +INSERT INTO pep_test VALUES (0, 'Hello', 0), (1, 'Hi', 1), (2, 'Hey', 2); +---- +3 + +query I +COPY (SELECT before, str, after FROM pep_test) TO '$__TEST_DIR__/pe-storage-idx.vortex'; +---- +3 + +query I +COPY (SELECT * FROM (VALUES ('Hello'), ('Hi'), ('Hey')) AS t(str)) TO '$__TEST_DIR__/pe-pushdown.vortex'; +---- +3 + +query T +SELECT str FROM '$__TEST_DIR__/pe-pushdown.vortex' ORDER BY str; +---- +Hello +Hey +Hi + +query I +SELECT len(str) AS l FROM '$__TEST_DIR__/pe-pushdown.vortex' ORDER BY l DESC; +---- +5 +3 +2 + +# Can't push down prefix: prefix() is not a supported pushdown function, +# returns BOOLEAN, result is unaffected by any incorrect pushdown. +query I +SELECT prefix(str, 'H')::INTEGER AS l FROM '$__TEST_DIR__/pe-pushdown.vortex' ORDER BY str; +---- +1 +1 +1 + +# len() used in WHERE with bare str in SELECT: conflict on str, no pushdown. +query TI +SELECT str, len(str) AS l FROM '$__TEST_DIR__/pe-pushdown.vortex' +WHERE len(str) > 3 ORDER BY str; +---- +Hello 5 + +# str used bare and in len() in the same query: conflict, no pushdown, correct results. +query TI +SELECT str, len(str) AS l FROM '$__TEST_DIR__/pe-pushdown.vortex' ORDER BY str; +---- +Hello 5 +Hey 3 +Hi 2 + +query IT +SELECT len(str) AS l, str FROM '$__TEST_DIR__/pe-pushdown.vortex' ORDER BY str; +---- +5 Hello +3 Hey +2 Hi + +# fn -> cte +query I +SELECT len(str) FROM (SELECT str FROM '$__TEST_DIR__/pe-pushdown.vortex' ORDER BY str); +---- +5 +3 +2 + +# fn -> cte -> cte +query I +WITH cte1 AS (SELECT str FROM '$__TEST_DIR__/pe-pushdown.vortex' ORDER BY str), + cte2 AS (SELECT str from cte1) +SELECT len(str) FROM cte2; +---- +5 +3 +2 + +# cte -> fn -> cte +query TI +WITH cte1 AS (SELECT str FROM '$__TEST_DIR__/pe-pushdown.vortex' ORDER BY str), + cte2 AS (SELECT str, len(str) as l from cte1) +SELECT str, l FROM cte2; +---- +Hello 5 +Hey 3 +Hi 2 + +query I +CREATE VIEW pe1 AS SELECT str FROM '$__TEST_DIR__/pe-pushdown.vortex' ORDER BY str; +---- + +query T +SELECT str FROM pe1; +---- +Hello +Hey +Hi + +query I +SELECT len(str) AS l FROM pe1; +---- +5 +3 +2 + +# nested view +query I +CREATE VIEW pe2 AS SELECT str FROM pe1; +---- + +query T +SELECT str FROM pe2; +---- +Hello +Hey +Hi + +query I +SELECT len(str) AS l FROM pe2; +---- +5 +3 +2 + +query TI +SELECT str, len(str) AS l FROM pe2; +---- +Hello 5 +Hey 3 +Hi 2 + +# fn -> cte with fn -> cte +query I +WITH cte1 AS (SELECT str FROM '$__TEST_DIR__/pe-pushdown.vortex' ORDER BY str), + cte2 AS (SELECT str, len(str) as l from cte1) +SELECT len(str) FROM cte2; +---- +5 +3 +2 + +# --- storage_index != scan_index --- +# str is the second column in storage (storage_index=1), but scan_index=0 when +# only str is selected. Verifies that storageIndex() is used correctly for +# returned_types and projection_expression_pushdown. +query I +SELECT len(str) AS l FROM '$__TEST_DIR__/pe-storage-idx.vortex' ORDER BY l DESC; +---- +5 +3 +2 + +query II +SELECT before, len(str) AS l FROM '$__TEST_DIR__/pe-storage-idx.vortex' ORDER BY before; +---- +0 5 +1 2 +2 3 + +query II +SELECT len(str) AS l, after FROM '$__TEST_DIR__/pe-storage-idx.vortex' ORDER BY after; +---- +5 0 +2 1 +3 2 + +# --- virtual columns: len() on real column, file_row_number alongside --- +# Virtual columns must not block pushdown of len() on a real column. +query II +SELECT len(str) AS l, file_row_number FROM '$__TEST_DIR__/pe-pushdown.vortex' ORDER BY l DESC; +---- +5 0 +3 2 +2 1 + +# len() on real column alongside file_index virtual column. +query II +SELECT len(str) AS l, file_index FROM '$__TEST_DIR__/pe-pushdown.vortex' ORDER BY l DESC; +---- +5 0 +3 0 +2 0 + +# --- nested VIEW: pushdown is skipped (two levels deep), results still correct --- +# pe2 is SELECT str FROM pe1, which is SELECT str FROM vortex. Two PROJECTION +# levels: optimizer only resolves one level, so no pushdown here. Correct output +# must still be produced by DuckDB computing len() after the scan. +query I +SELECT len(str) AS l FROM pe2 ORDER BY l DESC; +---- +5 +3 +2 + +# nested VIEW with both str and len(str): no pushdown (two levels), correct results. +query TI +SELECT str, len(str) AS l FROM pe2 ORDER BY str; +---- +Hello 5 +Hey 3 +Hi 2 + +# nested VIEW via CTE wrapping pe2. +query I +WITH cte AS (SELECT str FROM pe2) +SELECT len(str) FROM cte ORDER BY str; +---- +5 +3 +2 + +# --- can't push prefix through VIEW (not a supported function) --- +# Verifies that an unsupported function in a view does not corrupt the plan. +query I +SELECT prefix(str, 'H')::INTEGER FROM pe1 ORDER BY str; +---- +1 +1 +1 From b388f139c573f1a4f368070d5de0924ef59714d3 Mon Sep 17 00:00:00 2001 From: Mikhail Kot Date: Wed, 10 Jun 2026 14:17:29 +0100 Subject: [PATCH 2/5] better --- .../cpp/include/duckdb_vx/optimizer.h | 4 ++-- vortex-duckdb/cpp/optimizer.cpp | 20 ++++++++----------- 2 files changed, 10 insertions(+), 14 deletions(-) diff --git a/vortex-duckdb/cpp/include/duckdb_vx/optimizer.h b/vortex-duckdb/cpp/include/duckdb_vx/optimizer.h index 01851a79eaf..c9f1538a7ad 100644 --- a/vortex-duckdb/cpp/include/duckdb_vx/optimizer.h +++ b/vortex-duckdb/cpp/include/duckdb_vx/optimizer.h @@ -126,7 +126,7 @@ struct VortexOptimizerExtension final : OptimizerExtension { } }; -struct Binding { +struct GetBinding { GetAnalysis &analysis; TableColumnScanIndex column_index; }; @@ -136,6 +136,6 @@ struct Binding { * Returns nullopt for virtual columns and columns which are neither part of * GET nor part of PROJECTION wrapping a GET. */ -std::optional Resolve(ColumnBinding binding, Analyses &analyses, const Projections &projections); +std::optional Resolve(ColumnBinding binding, Analyses &analyses, const Projections &projections); #endif diff --git a/vortex-duckdb/cpp/optimizer.cpp b/vortex-duckdb/cpp/optimizer.cpp index 425441e0422..719dbbd06d1 100644 --- a/vortex-duckdb/cpp/optimizer.cpp +++ b/vortex-duckdb/cpp/optimizer.cpp @@ -78,7 +78,7 @@ void FindGetsAndAliases(LogicalOperator &op, } } -std::optional Resolve(ColumnBinding binding, Analyses &analyses, const Projections &projections) { +std::optional Resolve(ColumnBinding binding, Analyses &analyses, const Projections &projections) { if (IsVirtualColumn(binding.column_index)) { return std::nullopt; } @@ -124,15 +124,11 @@ void ScalarFnCollect::VisitOperator(LogicalOperator &op) { } ExpressionPtr ScalarFnCollect::VisitReplace(BoundColumnRefExpression &expr, ExpressionPtr *ptr) { - const auto binding = Resolve(expr.binding, analyses, projections); - if (!binding) { - return std::move(*ptr); + if (const auto binding = Resolve(expr.binding, analyses, projections)) { + // Column is used without function applied to it, register a conflict. + // Not emplace() as we need to update the value if it was present + binding->analysis.col_to_fn[binding->column_index] = nullptr; } - auto &[analysis, column_index] = *binding; - - // Column is used without function applied to it, register a conflict. - // Not emplace() as we need to update the value if it was present - analysis.col_to_fn[column_index] = nullptr; return std::move(*ptr); } @@ -146,11 +142,11 @@ ExpressionPtr ScalarFnCollect::VisitReplace(BoundFunctionExpression &expr, Expre if (!binding) { return std::move(*ptr); } - auto &[analysis, column_index] = *binding; + auto &col_to_fn = binding->analysis.col_to_fn; - if (auto it = analysis.col_to_fn.find(column_index); it == analysis.col_to_fn.end()) { + if (auto it = col_to_fn.find(binding->column_index); it == col_to_fn.end()) { // This is the first time we see the column used by a single function. - analysis.col_to_fn.emplace(column_index, &expr); + col_to_fn.emplace(binding->column_index, &expr); } else if (it->second == nullptr || !it->second->Equals(expr)) { // Either column is used with different function in "expr" or // there already is a conflict. From 2dbe33f55cf7f8d657b58ab8bb52158200d9ad51 Mon Sep 17 00:00:00 2001 From: Mikhail Kot Date: Wed, 10 Jun 2026 15:33:10 +0100 Subject: [PATCH 3/5] still use sharedarray --- vortex-layout/src/layouts/dict/reader.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/vortex-layout/src/layouts/dict/reader.rs b/vortex-layout/src/layouts/dict/reader.rs index b82fc8c2382..002b4b1e902 100644 --- a/vortex-layout/src/layouts/dict/reader.rs +++ b/vortex-layout/src/layouts/dict/reader.rs @@ -100,10 +100,10 @@ impl DictReader { ) .vortex_expect("must construct dict values array evaluation") .map_err(Arc::new) - //.map(move |array| { - // let array = array?; - // Ok(SharedArray::new(array).into_array()) - //}) + .map(move |array| { + let array = array?; + Ok(SharedArray::new(array).into_array()) + }) .boxed() .shared() }) From 583cbed2cb8d308bef6dd7c4eb389d46037ded6f Mon Sep 17 00:00:00 2001 From: Mikhail Kot Date: Wed, 10 Jun 2026 17:08:28 +0100 Subject: [PATCH 4/5] print select projections in duckdb --- vortex-duckdb/src/table_function.rs | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/vortex-duckdb/src/table_function.rs b/vortex-duckdb/src/table_function.rs index 4d789e53091..d8cc804869c 100644 --- a/vortex-duckdb/src/table_function.rs +++ b/vortex-duckdb/src/table_function.rs @@ -518,6 +518,19 @@ pub fn to_string(bind_data: &TableFunctionBind, map: &mut DuckdbStringMapRef) { let mut filters = bind_data.filter_exprs.iter().map(|f| format!("{f}")); map.push("Filters", &filters.join("\n")); } + let projections = bind_data + .column_fields + .iter() + .filter_map(|field| { + field + .projection_fn + .as_ref() + .map(|expr| format!("{}: {expr}", field.name)) + }) + .join("\n"); + if !projections.is_empty() { + map.push("SELECT projections", &projections); + } } fn progress(bytes_read: &AtomicU64, bytes_total: &AtomicU64) -> f64 { From 9099587908b7a78ab5a0823bd409b7eea7d1c1e7 Mon Sep 17 00:00:00 2001 From: Mikhail Kot Date: Wed, 10 Jun 2026 18:02:53 +0100 Subject: [PATCH 5/5] Pushdown some expressions to Dict layout reader --- vortex-array/src/scalar_fn/mod.rs | 15 +++ vortex-layout/src/layouts/dict/reader.rs | 137 +++++++++++++++++++++-- 2 files changed, 145 insertions(+), 7 deletions(-) diff --git a/vortex-array/src/scalar_fn/mod.rs b/vortex-array/src/scalar_fn/mod.rs index 590ccb44224..710d3a95b48 100644 --- a/vortex-array/src/scalar_fn/mod.rs +++ b/vortex-array/src/scalar_fn/mod.rs @@ -48,3 +48,18 @@ mod sealed { /// This can be the **only** implementor for [`super::typed::DynScalarFn`]. impl Sealed for TypedScalarFnInstance {} } + +/* + * A scalar function has a negative cost if applying it to an array and + * canonicalizing is cheaper than canonicalizing an array and applying it. + * + * Example of negative cost expressions are byte_length() and get_item() since + * they don't depend on input size. + * + * Example of non-negative cost expression is like() + */ +pub fn is_negative_cost(id: ScalarFnId) -> bool { + id == Id::new_static("vortex.byte_length") + || id == Id::new_static("vortex.get_item") + || id == Id::new_static("vortex.literal") +} diff --git a/vortex-layout/src/layouts/dict/reader.rs b/vortex-layout/src/layouts/dict/reader.rs index 002b4b1e902..d64a2e47d59 100644 --- a/vortex-layout/src/layouts/dict/reader.rs +++ b/vortex-layout/src/layouts/dict/reader.rs @@ -19,8 +19,15 @@ use vortex_array::arrays::SharedArray; use vortex_array::dtype::DType; use vortex_array::dtype::FieldMask; use vortex_array::expr::Expression; +use vortex_array::expr::is_root; +use vortex_array::expr::label_is_fallible; +use vortex_array::expr::label_null_sensitive; use vortex_array::expr::root; +use vortex_array::expr::traversal::NodeExt; +use vortex_array::expr::traversal::Transformed; +use vortex_array::expr::traversal::TraversalOrder; use vortex_array::optimizer::ArrayOptimizer; +use vortex_array::scalar_fn::is_negative_cost; use vortex_error::VortexError; use vortex_error::VortexExpect; use vortex_error::VortexResult; @@ -100,10 +107,7 @@ impl DictReader { ) .vortex_expect("must construct dict values array evaluation") .map_err(Arc::new) - .map(move |array| { - let array = array?; - Ok(SharedArray::new(array).into_array()) - }) + .map(move |array| Ok(SharedArray::new(array?).into_array())) .boxed() .shared() }) @@ -155,6 +159,49 @@ impl DictReader { } } +fn references_root(expr: &Expression) -> bool { + is_root(expr) || expr.children().iter().any(references_root) +} + +/// Split expression into two parts: +/// +/// left is the optional outer part that we want to apply to array after +/// canonicalizing. +/// right is the optional inner part that we want to apply to array before +/// canonicalizing. +/// +/// We want to push to array only if expression has a negative cost, is +/// infallible and null-insensitive. +fn split_expression_for_pushdown(expr: Expression) -> (Option, Option) { + let labelled_expr = expr.clone(); + let fallible = label_is_fallible(&labelled_expr); + let null_sensitive = label_null_sensitive(&labelled_expr); + let mut inner: Option = None; + + let outer = expr + .transform_down(|node| { + if is_negative_cost(node.id()) + && references_root(&node) + && !fallible.get(&node).copied().unwrap_or(true) + && !null_sensitive.get(&node).copied().unwrap_or(true) + { + inner = Some(node); + Ok(Transformed { + value: root(), + changed: true, + order: TraversalOrder::Skip, + }) + } else { + Ok(Transformed::no(node)) + } + }) + .vortex_expect("infallible") + .into_inner(); + + let outer = (!is_root(&outer)).then_some(outer); + (outer, inner) +} + impl LayoutReader for DictReader { fn name(&self) -> &Arc { &self.name @@ -229,13 +276,18 @@ impl LayoutReader for DictReader { mask: MaskFuture, ) -> VortexResult>> { // TODO: fix up expr partitioning with fallible & null sensitive annotations - let values_eval = self.values_array(); let codes_eval = self .codes .projection_evaluation(row_range, &root(), mask) .map_err(|err| err.with_context("While evaluating projection on codes"))?; - let expr = expr.clone(); + let (expr_outer, expr_inner) = split_expression_for_pushdown(expr.clone()); + + let values_eval = if let Some(inner) = expr_inner { + self.values_eval(inner) + } else { + self.values_array() + }; let all_values_referenced = self.layout.has_all_values_referenced(); Ok(async move { let (values, codes) = try_join!(values_eval.map_err(VortexError::from), codes_eval)?; @@ -252,7 +304,11 @@ impl LayoutReader for DictReader { .into_array() .optimize()?; - array.apply(&expr) + if let Some(expr) = expr_outer { + array.apply(&expr) + } else { + Ok(array) + } } .boxed()) } @@ -281,11 +337,20 @@ mod tests { use vortex_array::dtype::FieldName; use vortex_array::dtype::FieldNames; use vortex_array::dtype::Nullability; + use vortex_array::dtype::PType; + use vortex_array::expr::Expression; + use vortex_array::expr::byte_length; + use vortex_array::expr::cast; use vortex_array::expr::eq; use vortex_array::expr::is_not_null; + use vortex_array::expr::is_root; + use vortex_array::expr::like; use vortex_array::expr::lit; use vortex_array::expr::pack; use vortex_array::expr::root; + use vortex_array::expr::traversal::NodeExt; + use vortex_array::expr::traversal::Transformed; + use vortex_array::expr::traversal::TraversalOrder; use vortex_array::scalar_fn::session::ScalarFnSession; use vortex_array::session::ArraySession; use vortex_array::validity::Validity; @@ -296,6 +361,7 @@ mod tests { use vortex_io::session::RuntimeSessionExt; use vortex_session::VortexSession; + use super::split_expression_for_pushdown; use crate::LayoutId; use crate::LayoutRef; use crate::LayoutStrategy; @@ -542,4 +608,61 @@ mod tests { assert_arrays_eq!(actual_canonical, expected); }) } + + fn join_split_expr(initial: &Expression, outer: Option, inner: Option) { + let outer_expr = outer.unwrap_or_else(root); + let inner_expr = inner.unwrap_or_else(root); + let expected = outer_expr + .transform_down(|node| { + if !is_root(&node) { + return Ok(Transformed::no(node)); + } + Ok(Transformed { + value: inner_expr.clone(), + changed: true, + order: TraversalOrder::Skip, + }) + }) + .vortex_expect("infallible"); + assert_eq!(&expected.into_inner(), initial); + } + + #[test] + fn split_expr_cast_root() { + let (outer, inner) = split_expression_for_pushdown(root()); + assert_eq!(outer, None); + assert_eq!(inner, None); // Applying root to array is useless work + } + + #[test] + fn split_expr_partial_pushdown() { + let dtype = DType::Primitive(PType::U64, Nullability::NonNullable); + let expr = cast(byte_length(root()), dtype.clone()); + let (outer, inner) = split_expression_for_pushdown(expr.clone()); + // [0] = cast([1], dtype) + // [1] = byte_length(root) + assert_eq!(outer, Some(cast(root(), dtype))); + assert_eq!(inner, Some(byte_length(root()))); + join_split_expr(&expr, outer, inner); + } + + #[test] + fn split_expr_full_pushdown() { + let expr = byte_length(root()); + let (outer, inner) = split_expression_for_pushdown(expr.clone()); + assert_eq!(outer, None); + assert_eq!(inner, Some(byte_length(root()))); + join_split_expr(&expr, outer, inner); + } + + #[test] + fn split_expr_no_pushdown() { + // We can push down lit(), but it we replace + // lit() with root(), the semantics change. + let expr = like(root(), lit(1u64)); + let (outer, inner) = split_expression_for_pushdown(expr.clone()); + assert_eq!(outer, Some(expr.clone())); + assert_eq!(inner, None); + join_split_expr(&expr, outer, inner); + } }