diff --git a/datafusion/spark/src/function/math/bin.rs b/datafusion/spark/src/function/math/bin.rs new file mode 100644 index 0000000000000..0eb6e7a22673c --- /dev/null +++ b/datafusion/spark/src/function/math/bin.rs @@ -0,0 +1,177 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ArrayRef, AsArray, StringArray}; +use arrow::datatypes::{ + DataType, Decimal32Type, Decimal64Type, Field, FieldRef, Float16Type, Float32Type, + Float64Type, Int8Type, Int16Type, Int32Type, Int64Type, +}; +use bigdecimal::ToPrimitive; +use datafusion::logical_expr::{ColumnarValue, Signature, TypeSignature, Volatility}; +use datafusion_common::types::{NativeType, logical_int64}; +use datafusion_common::utils::take_function_args; +use datafusion_common::{Result, internal_err}; +use datafusion_expr::{Coercion, ScalarFunctionArgs, ScalarUDFImpl, TypeSignatureClass}; +use datafusion_functions::utils::make_scalar_function; +use std::any::Any; +use std::sync::Arc; + +/// Spark-compatible `bin` expression +/// +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkBin { + signature: Signature, +} + +impl Default for SparkBin { + fn default() -> Self { + Self::new() + } +} + +impl SparkBin { + pub fn new() -> Self { + Self { + signature: Signature::one_of( + vec![TypeSignature::Coercible(vec![Coercion::new_implicit( + TypeSignatureClass::Native(logical_int64()), + vec![TypeSignatureClass::Numeric], + NativeType::Int64, + )])], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for SparkBin { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "bin" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + internal_err!("return_field_from_args should be used instead") + } + + fn return_field_from_args( + &self, + args: datafusion_expr::ReturnFieldArgs, + ) -> Result { + Ok(Arc::new(Field::new( + self.name(), + DataType::Utf8, + args.arg_fields[0].is_nullable(), + ))) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + make_scalar_function(spark_bin_inner, vec![])(&args.args) + } +} + +pub fn spark_bin_inner(arg: &[ArrayRef]) -> Result { + let [array] = take_function_args("bin", arg)?; + match &array.data_type() { + DataType::Int8 => { + let result: StringArray = array + .as_primitive::() + .iter() + .map(|opt| opt.map(|value| spark_bin(value.into()))) + .collect(); + Ok(Arc::new(result)) + } + DataType::Int16 => { + let result: StringArray = array + .as_primitive::() + .iter() + .map(|opt| opt.map(|value| spark_bin(value.into()))) + .collect(); + Ok(Arc::new(result)) + } + DataType::Int32 => { + let result: StringArray = array + .as_primitive::() + .iter() + .map(|opt| opt.map(|value| spark_bin(value.into()))) + .collect(); + Ok(Arc::new(result)) + } + DataType::Int64 => { + let result: StringArray = array + .as_primitive::() + .iter() + .map(|opt| opt.map(spark_bin)) + .collect(); + Ok(Arc::new(result)) + } + DataType::Float16 => { + let result: StringArray = array + .as_primitive::() + .iter() + .map(|opt| opt.map(|value| spark_bin(value.to_i64().unwrap()))) + .collect(); + Ok(Arc::new(result)) + } + DataType::Float32 => { + let result: StringArray = array + .as_primitive::() + .iter() + .map(|opt| opt.map(|value| spark_bin(value.to_i64().unwrap()))) + .collect(); + Ok(Arc::new(result)) + } + DataType::Float64 => { + let result: StringArray = array + .as_primitive::() + .iter() + .map(|opt| opt.map(|value| spark_bin(value.to_i64().unwrap()))) + .collect(); + Ok(Arc::new(result)) + } + DataType::Decimal32(_, _) => { + let result: StringArray = array + .as_primitive::() + .iter() + .map(|opt| opt.map(|value| spark_bin(value.into()))) + .collect(); + Ok(Arc::new(result)) + } + DataType::Decimal64(_, _) => { + let result: StringArray = array + .as_primitive::() + .iter() + .map(|opt| opt.map(spark_bin)) + .collect(); + Ok(Arc::new(result)) + } + data_type => { + internal_err!("bin does not support: {data_type}") + } + } +} + +fn spark_bin(value: i64) -> String { + format!("{value:b}") +} diff --git a/datafusion/spark/src/function/math/mod.rs b/datafusion/spark/src/function/math/mod.rs index 92d8e90ac372e..7f7d04e06b0be 100644 --- a/datafusion/spark/src/function/math/mod.rs +++ b/datafusion/spark/src/function/math/mod.rs @@ -16,6 +16,7 @@ // under the License. pub mod abs; +pub mod bin; pub mod expm1; pub mod factorial; pub mod hex; @@ -42,6 +43,7 @@ make_udf_function!(width_bucket::SparkWidthBucket, width_bucket); make_udf_function!(trigonometry::SparkCsc, csc); make_udf_function!(trigonometry::SparkSec, sec); make_udf_function!(negative::SparkNegative, negative); +make_udf_function!(bin::SparkBin, bin); pub mod expr_fn { use datafusion_functions::export_functions; @@ -70,6 +72,11 @@ pub mod expr_fn { "Returns the negation of expr (unary minus).", arg1 )); + export_functions!(( + bin, + "Returns the string representation of the long value represented in binary.", + arg1 + )); } pub fn functions() -> Vec> { @@ -86,5 +93,6 @@ pub fn functions() -> Vec> { csc(), sec(), negative(), + bin(), ] } diff --git a/datafusion/sqllogictest/test_files/spark/math/bin.slt b/datafusion/sqllogictest/test_files/spark/math/bin.slt index 1fa24e6cda6b0..b2e2aadde44b6 100644 --- a/datafusion/sqllogictest/test_files/spark/math/bin.slt +++ b/datafusion/sqllogictest/test_files/spark/math/bin.slt @@ -15,23 +15,62 @@ # specific language governing permissions and limitations # under the License. -# This file was originally created by a porting script from: -# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function -# This file is part of the implementation of the datafusion-spark function library. -# For more information, please see: -# https://github.com/apache/datafusion/issues/15914 - -## Original Query: SELECT bin(-13); -## PySpark 3.5.5 Result: {'bin(-13)': '1111111111111111111111111111111111111111111111111111111111110011', 'typeof(bin(-13))': 'string', 'typeof(-13)': 'int'} -#query -#SELECT bin(-13::int); - -## Original Query: SELECT bin(13); -## PySpark 3.5.5 Result: {'bin(13)': '1101', 'typeof(bin(13))': 'string', 'typeof(13)': 'int'} -#query -#SELECT bin(13::int); - -## Original Query: SELECT bin(13.3); -## PySpark 3.5.5 Result: {'bin(13.3)': '1101', 'typeof(bin(13.3))': 'string', 'typeof(13.3)': 'decimal(3,1)'} -#query -#SELECT bin(13.3::decimal(3,1)); +query T +SELECT bin(arrow_cast(NULL, 'Int8')); +---- +NULL + +query T +SELECT bin(arrow_cast(0, 'Int8')); +---- +0 + +query T +SELECT bin(arrow_cast(13, 'Int8')); +---- +1101 + +query T +SELECT bin(arrow_cast(13.36, 'Float16')); +---- +1101 + +query T +SELECT bin(13.3::decimal(3,1)); +---- +1101 + +query T +SELECT bin(arrow_cast(-13, 'Int8')); +---- +1111111111111111111111111111111111111111111111111111111111110011 + +query T +SELECT bin(arrow_cast(256, 'Int16')); +---- +100000000 + +query T +SELECT bin(arrow_cast(-32768, 'Int16')); +---- +1111111111111111111111111111111111111111111111111000000000000000 + +query T +SELECT bin(arrow_cast(-2147483648, 'Int32')); +---- +1111111111111111111111111111111110000000000000000000000000000000 + +query T +SELECT bin(arrow_cast(1073741824, 'Int32')); +---- +1000000000000000000000000000000 + +query T +SELECT bin(arrow_cast(-9223372036854775808, 'Int64')); +---- +1000000000000000000000000000000000000000000000000000000000000000 + +query T +SELECT bin(arrow_cast(9223372036854775807, 'Int64')); +---- +111111111111111111111111111111111111111111111111111111111111111