diff --git a/datafusion/physical-expr-common/src/utils.rs b/datafusion/physical-expr-common/src/utils.rs index 278294a9bf2ad..01a01c6ebf32e 100644 --- a/datafusion/physical-expr-common/src/utils.rs +++ b/datafusion/physical-expr-common/src/utils.rs @@ -15,16 +15,21 @@ // specific language governing permissions and limitations // under the License. +use std::borrow::Cow; use std::sync::Arc; use crate::metrics::ExpressionEvaluatorMetrics; use crate::physical_expr::PhysicalExpr; use crate::tree_node::ExprContext; -use arrow::array::{Array, ArrayRef, BooleanArray, MutableArrayData, make_array}; -use arrow::compute::{SlicesIterator, and_kleene, is_not_null}; +use arrow::array::{Array, ArrayRef, BooleanArray, BooleanBufferBuilder, MutableArrayData, make_array, new_null_array, PrimitiveArray}; +use arrow::buffer::{BooleanBuffer, Buffer, NullBuffer}; +use arrow::compute::{SlicesIterator, and_kleene, is_not_null, prep_null_mask_filter}; +use arrow::datatypes::{ArrowNativeType, ArrowPrimitiveType, DataType}; +use arrow::downcast_primitive_array; use arrow::record_batch::RecordBatch; use datafusion_common::Result; +use datafusion_common::ScalarValue::Null; use datafusion_expr_common::sort_properties::ExprProperties; /// Represents a [`PhysicalExpr`] node with associated properties (order and @@ -50,6 +55,11 @@ impl ExprPropertiesNode { } } +/// If the mask selects more than this fraction of rows, use +/// `set_slices()` to copy contiguous ranges. Otherwise iterate +/// over individual positions using `set_indices()` +const SCATTER_SLICES_SELECTIVITY_THRESHOLD: f64 = 0.8; + /// Scatter `truthy` array by boolean mask. When the mask evaluates `true`, next values of `truthy` /// are taken, when the mask evaluates `false` values null values are filled. /// @@ -57,40 +67,174 @@ impl ExprPropertiesNode { /// * `mask` - Boolean values used to determine where to put the `truthy` values /// * `truthy` - All values of this array are to scatter according to `mask` into final result. pub fn scatter(mask: &BooleanArray, truthy: &dyn Array) -> Result { - let truthy = truthy.to_data(); + let mask = match mask.null_count() { + 0 => Cow::Borrowed(mask), + _ => Cow::Owned(prep_null_mask_filter(mask)), + }; - // update the mask so that any null values become false - // (SlicesIterator doesn't respect nulls) - let mask = and_kleene(mask, &is_not_null(mask)?)?; + let output_len = mask.len(); + let count = mask.true_count(); - let mut mutable = MutableArrayData::new(vec![&truthy], true, mask.len()); + // Fast path: no true values mean all-null object + if count == 0 { + return Ok(new_null_array(truthy.data_type(), output_len)); + } - // the SlicesIterator slices only the true values. So the gaps left by this iterator we need to - // fill with falsy values + // Fast path: all true means output = truthy + if count == output_len { + return Ok(truthy.slice(0, truthy.len())); + } - // keep track of how much is filled - let mut filled = 0; - // keep track of current position we have in truthy array - let mut true_pos = 0; + let selectivity = count as f64 / output_len as f64; + let mask_buffer = mask.values(); - SlicesIterator::new(&mask).for_each(|(start, end)| { - // the gap needs to be filled with nulls - if start > filled { - mutable.extend_nulls(start - filled); + scatter_array(truthy, mask_buffer, output_len, selectivity) +} + +/// Type-specific dispatch for scatter +fn scatter_array( + truthy: &dyn Array, + mask: &BooleanBuffer, + output_len: usize, + selectivity: f64, +) -> Result { + downcast_primitive_array! { + truthy => Ok(Arc::new(scatter_primitive(truthy, mask, output_len, + selectivity))), + DataType::Boolean => { + Ok(Arc::new(scatter_boolean(truthy.as_boolean(), mask, output_len, + selectivity))) + } + DataType::Utf8 => { + Ok(Arc::new(scatter_bytes(truthy.as_string::(), mask, + output_len, selectivity))) + } + DataType::LargeUtf8 => { + Ok(Arc::new(scatter_bytes(truthy.as_string::(), mask, + output_len, selectivity))) + } + DataType::Utf8View => { + Ok(Arc::new(scatter_byte_view(truthy.as_string_view(), mask, + output_len, selectivity))) + } + DataType::Binary => { + Ok(Arc::new(scatter_bytes(truthy.as_binary::(), mask, + output_len, selectivity))) + } + DataType::LargeBinary => { + Ok(Arc::new(scatter_bytes(truthy.as_binary::(), mask, + output_len, selectivity))) + } + DataType::BinaryView => { + Ok(Arc::new(scatter_byte_view(truthy.as_binary_view(), mask, + output_len, selectivity))) + } + DataType::FixedSizeBinary(_) => { + Ok(Arc::new(scatter_fixed_size_binary(truthy.as_fixed_size_binary(), mask, + output_len, selectivity))) + } + DataType::Dictionary(_, _) => { + downcast_dictionary_array! { + truthy => Ok(Arc::new(scatter_dict(truthy, mask, output_len, + selectivity))), + t => scatter_fallback(truthy, mask, output_len) + } + } + _ => scatter_fallback(truthy, mask, output_len) + } +} + +fn scatter_native( + src: &[T], + mask: &BooleanBuffer, + output_len: usize, + selectivity: f64, +) -> Buffer { + let mut output = vec![T::default(); output_len]; + let mut src_offset = 0; + + if selectivity > SCATTER_SLICES_SELECTIVITY_THRESHOLD { + for (start, end) in mask.set_slices() { + let len = end - start; + output[start..end].copy_from_slice(&src[src_offset..src_offset + len]); + src_offset += len; + } + } else { + for dst_idx in mask.set_indices() { + output[dst_idx] = src[src_offset]; + src_offset += 1; } - // fill with truthy values - let len = end - start; - mutable.extend(0, true_pos, true_pos + len); - true_pos += len; - filled = end; - }); - // the remaining part is falsy - if filled < mask.len() { - mutable.extend_nulls(mask.len() - filled); } - let data = mutable.freeze(); - Ok(make_array(data)) + output.into() +} + +fn scatter_bits( + src: &BooleanBuffer, + mask: &BooleanBuffer, + output_len: usize, + selectivity: f64, +) -> Buffer { + let mut builder = BooleanBufferBuilder::new(output_len); + builder.advance(output_len); + let mut src_offset = 0; + + if selectivity > SCATTER_SLICES_SELECTIVITY_THRESHOLD { + for (start, end) in mask.set_slices() { + for i in start..end { + if src.value(src_offset) { + builder.set_bit(i, true); + } + src_offset += 1; + } + } + } else { + for dst_idx in mask.set_indices() { + if src.value(src_offset) { + builder.set_bit(dst_idx, true); + } + src_offset += 1; + } + } + + builder.finish().into_inner() +} + +fn scatter_null_mask( + src_nulls: Option<&NullBuffer>, + mask: &BooleanBuffer, + output_len: usize, + selectivity: f64, +) -> Option { + let false_count = output_len - mask.count_set_bits(); + let src_null_count = src_nulls.map(|n| n.null_count()).unwrap_or(0); + + if src_null_count == 0 { + if false_count == 0 { + None + } else { + Some(NullBuffer::new(mask.clone())) + } + } else { + let src_nulls = src_nulls.unwrap(); + let scattered = scatter_bits(src_nulls.inner(), mask, output_len, selectivity); + let valid_count = scattered.count_set_bits_offset(0, output_len); + let null_count = output_len - valid_count; + if null_count == 0 { + None + } else { + Some(NullBuffer::new(BooleanBuffer::new( + scattered, 0, output_len, + ))) + } + } +} + +fn scatter_primitive(truthy: &PrimitiveArray,mask: &BooleanBuffer, + output_len: usize, + selectivity: f64, + ) -> PrimitiveArray { + todo!() } /// Evaluates expressions against a record batch.