From 8be85dbf3975f985a6824b1c8a7803b65731a2b5 Mon Sep 17 00:00:00 2001 From: juruo-c <2915601267@qq.com> Date: Sun, 21 Dec 2025 21:26:59 +0800 Subject: [PATCH 1/2] [Feature](function) Support function cross_product of DuckDB --- .../array/function_array_cross_product.cpp | 28 +++ .../array/function_array_cross_product.h | 204 ++++++++++++++++++ .../array/function_array_register.cpp | 2 + .../doris/catalog/BuiltinScalarFunctions.java | 2 + .../functions/scalar/CrossProduct.java | 75 +++++++ .../visitor/ScalarFunctionVisitor.java | 5 + .../test_array_cross_product_function.out | 15 ++ .../test_array_cross_product_function.groovy | 62 ++++++ 8 files changed, 393 insertions(+) create mode 100644 be/src/vec/functions/array/function_array_cross_product.cpp create mode 100644 be/src/vec/functions/array/function_array_cross_product.h create mode 100644 fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/CrossProduct.java create mode 100644 regression-test/data/query_p0/sql_functions/array_functions/test_array_cross_product_function.out create mode 100644 regression-test/suites/query_p0/sql_functions/array_functions/test_array_cross_product_function.groovy diff --git a/be/src/vec/functions/array/function_array_cross_product.cpp b/be/src/vec/functions/array/function_array_cross_product.cpp new file mode 100644 index 00000000000000..3352316145df71 --- /dev/null +++ b/be/src/vec/functions/array/function_array_cross_product.cpp @@ -0,0 +1,28 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "vec/functions/array/function_array_cross_product.h" + +#include "vec/functions/simple_function_factory.h" + +namespace doris::vectorized { + +void register_function_array_cross_product(SimpleFunctionFactory& factory) { + factory.register_function(); +} + +} // namespace doris::vectorized \ No newline at end of file diff --git a/be/src/vec/functions/array/function_array_cross_product.h b/be/src/vec/functions/array/function_array_cross_product.h new file mode 100644 index 00000000000000..7b05baba134f00 --- /dev/null +++ b/be/src/vec/functions/array/function_array_cross_product.h @@ -0,0 +1,204 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include + +#include "common/exception.h" +#include "common/status.h" +#include "runtime/primitive_type.h" +#include "vec/columns/column.h" +#include "vec/columns/column_array.h" +#include "vec/columns/column_nullable.h" +#include "vec/common/assert_cast.h" +#include "vec/core/types.h" +#include "vec/data_types/data_type.h" +#include "vec/data_types/data_type_array.h" +#include "vec/data_types/data_type_nullable.h" +#include "vec/data_types/data_type_number.h" +#include "vec/functions/array/function_array_utils.h" +#include "vec/functions/function.h" +#include "vec/utils/util.hpp" + +namespace doris::vectorized { + +class FunctionArrayCrossProduct : public IFunction { +public: + using DataType = PrimitiveTypeTraits::DataType; + using ColumnType = PrimitiveTypeTraits::ColumnType; + + static constexpr auto name = "cross_product"; + String get_name() const override { return name; } + static FunctionPtr create() { return std::make_shared(); } + size_t get_number_of_arguments() const override { return 2; } + + DataTypePtr get_return_type_impl(const DataTypes& arguments) const override { + if (arguments.size() != 2) { + throw doris::Exception(ErrorCode::INVALID_ARGUMENT, + "Invalid number of arguments for function {}", get_name()); + } + + if (arguments[0]->get_primitive_type() != TYPE_ARRAY || + arguments[1]->get_primitive_type() != TYPE_ARRAY) { + throw doris::Exception(ErrorCode::INVALID_ARGUMENT, + "Arguments for function {} must be arrays", get_name()); + } + + // return ARRAY + return std::make_shared( + std::make_shared()); + } + + // strict semantics: do not allow NULL + bool use_default_implementation_for_nulls() const override { return false; } + + Status execute_impl(FunctionContext* context, Block& block, const ColumnNumbers& arguments, + uint32_t result, size_t input_rows_count) const override { + const auto& arg1 = block.get_by_position(arguments[0]); + const auto& arg2 = block.get_by_position(arguments[1]); + + auto col1 = arg1.column->convert_to_full_column_if_const(); + auto col2 = arg2.column->convert_to_full_column_if_const(); + + if (col1->size() != col2->size()) { + return Status::RuntimeError( + fmt::format("function {} have different input array sizes: {} and {}", + get_name(), col1->size(), col2->size())); + } + + const ColumnArray* arr1 = nullptr; + const ColumnArray* arr2 = nullptr; + + if (const auto* nullable = + typeid_cast(col1.get())) { + if (nullable->has_null()) { + throw doris::Exception(ErrorCode::INVALID_ARGUMENT, + "First argument for function {} cannot be null", get_name()); + } + arr1 = assert_cast(nullable->get_nested_column_ptr().get()); + } else { + arr1 = assert_cast(col1.get()); + } + + if (const auto* nullable = + typeid_cast(col2.get())) { + if (nullable->has_null()) { + throw doris::Exception(ErrorCode::INVALID_ARGUMENT, + "Second argument for function {} cannot be null", + get_name()); + } + arr2 = assert_cast(nullable->get_nested_column_ptr().get()); + } else { + arr2 = assert_cast(col2.get()); + } + + const ColumnFloat32* float1 = nullptr; + const ColumnFloat32* float2 = nullptr; + + if (const auto* nullable = + typeid_cast(arr1->get_data_ptr().get())) { + if (nullable->has_null()) { + throw doris::Exception(ErrorCode::INVALID_ARGUMENT, + "First argument for function {} cannot have null elements", + get_name()); + } + float1 = assert_cast(nullable->get_nested_column_ptr().get()); + } else { + float1 = assert_cast(arr1->get_data_ptr().get()); + } + + if (const auto* nullable = + typeid_cast(arr2->get_data_ptr().get())) { + if (nullable->has_null()) { + throw doris::Exception(ErrorCode::INVALID_ARGUMENT, + "Second argument for function {} cannot have null elements", + get_name()); + } + float2 = assert_cast(nullable->get_nested_column_ptr().get()); + } else { + float2 = assert_cast(arr2->get_data_ptr().get()); + } + + const auto* offset1 = + assert_cast(arr1->get_offsets_ptr().get()); + const auto* offset2 = + assert_cast(arr2->get_offsets_ptr().get()); + + // prepare result data + auto nested_res = ColumnFloat32::create(); + auto offsets_res = ColumnArray::ColumnOffsets::create(); + auto& offsets_data = offsets_res->get_data(); + offsets_data.reserve(input_rows_count); + size_t current_offset = 0; + + size_t row_cnt = offset1->size(); + size_t prev_offset1 = 0; + size_t prev_offset2 = 0; + for (ssize_t row = 0; row < row_cnt; ++row) { + ssize_t size1 = offset1->get_data()[row] - prev_offset1; + ssize_t size2 = offset2->get_data()[row] - prev_offset2; + + if (size1 != size2) [[unlikely]] { + return Status::InvalidArgument( + "function {} have different input element sizes of array: {} and {}", + get_name(), size1, size2); + } + + if (size1 != 3 || size2 != 3) { + throw doris::Exception(ErrorCode::INVALID_ARGUMENT, + "function {} requires arrays of size 3", + get_name()); + } + + ssize_t base1 = prev_offset1; + ssize_t base2 = prev_offset2; + + float a1 = float1->get_data()[base1]; + float a2 = float1->get_data()[base1 + 1]; + float a3 = float1->get_data()[base1 + 2]; + + float b1 = float2->get_data()[base2]; + float b2 = float2->get_data()[base2 + 1]; + float b3 = float2->get_data()[base2 + 2]; + + float c1 = a2 * b3 - a3 * b2; + float c2 = a3 * b1 - a1 * b3; + float c3 = a1 * b2 - a2 * b1; + + nested_res->insert_value(c1); + nested_res->insert_value(c2); + nested_res->insert_value(c3); + + current_offset += 3; + offsets_data.push_back(current_offset); + + prev_offset1 = offset1->get_data()[row]; + prev_offset2 = offset2->get_data()[row]; + } + + auto result_col = ColumnArray::create( + ColumnNullable::create(std::move(nested_res), + ColumnUInt8::create(nested_res->size(), 0)), + std::move(offsets_res)); + + block.replace_by_position(result, std::move(result_col)); + return Status::OK(); + } +}; + +} // namespace doris::vectorized \ No newline at end of file diff --git a/be/src/vec/functions/array/function_array_register.cpp b/be/src/vec/functions/array/function_array_register.cpp index aa92e89128fec3..a3689d6a579e04 100644 --- a/be/src/vec/functions/array/function_array_register.cpp +++ b/be/src/vec/functions/array/function_array_register.cpp @@ -53,6 +53,7 @@ void register_function_array_pushback(SimpleFunctionFactory& factory); void register_function_array_first_or_last_index(SimpleFunctionFactory& factory); void register_function_array_cum_sum(SimpleFunctionFactory& factory); void register_function_array_count(SimpleFunctionFactory&); +void register_function_array_cross_product(SimpleFunctionFactory& factory); void register_function_array_filter_function(SimpleFunctionFactory&); void register_function_array_splits(SimpleFunctionFactory&); void register_function_array_contains_all(SimpleFunctionFactory&); @@ -91,6 +92,7 @@ void register_function_array(SimpleFunctionFactory& factory) { register_function_array_first_or_last_index(factory); register_function_array_cum_sum(factory); register_function_array_count(factory); + register_function_array_cross_product(factory); register_function_array_filter_function(factory); register_function_array_splits(factory); register_function_array_contains_all(factory); diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinScalarFunctions.java b/fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinScalarFunctions.java index b7970c448cf0ca..a40df206f079e0 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinScalarFunctions.java +++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinScalarFunctions.java @@ -153,6 +153,7 @@ import org.apache.doris.nereids.trees.expressions.functions.scalar.CreateMap; import org.apache.doris.nereids.trees.expressions.functions.scalar.CreateNamedStruct; import org.apache.doris.nereids.trees.expressions.functions.scalar.CreateStruct; +import org.apache.doris.nereids.trees.expressions.functions.scalar.CrossProduct; import org.apache.doris.nereids.trees.expressions.functions.scalar.Csc; import org.apache.doris.nereids.trees.expressions.functions.scalar.CurrentCatalog; import org.apache.doris.nereids.trees.expressions.functions.scalar.CurrentDate; @@ -695,6 +696,7 @@ public class BuiltinScalarFunctions implements FunctionHelper { scalar(CreateMap.class, "map"), scalar(CreateStruct.class, "struct"), scalar(CreateNamedStruct.class, "named_struct"), + scalar(CrossProduct.class, "cross_product"), scalar(CurrentCatalog.class, "current_catalog"), scalar(CurrentDate.class, "curdate", "current_date"), scalar(CurrentTime.class, "curtime", "current_time"), diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/CrossProduct.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/CrossProduct.java new file mode 100644 index 00000000000000..cf32ca3b8ebd90 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/CrossProduct.java @@ -0,0 +1,75 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.trees.expressions.functions.scalar; + +import org.apache.doris.catalog.FunctionSignature; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.functions.AlwaysNotNullable; +import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature; +import org.apache.doris.nereids.trees.expressions.shape.BinaryExpression; +import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; +import org.apache.doris.nereids.types.ArrayType; +import org.apache.doris.nereids.types.FloatType; + +import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableList; + +import java.util.List; + +/** + * cosine_distance function + */ +public class CrossProduct extends ScalarFunction implements ExplicitlyCastableSignature, + BinaryExpression, AlwaysNotNullable { + + public static final List SIGNATURES = ImmutableList.of( + FunctionSignature.ret(ArrayType.of(FloatType.INSTANCE)) + .args(ArrayType.of(FloatType.INSTANCE), ArrayType.of(FloatType.INSTANCE)) + ); + + /** + * constructor with 1 argument. + */ + public CrossProduct(Expression arg0, Expression arg1) { + super("cross_product", arg0, arg1); + } + + /** constructor for withChildren and reuse signature */ + private CrossProduct(ScalarFunctionParams functionParams) { + super(functionParams); + } + + /** + * withChildren. + */ + @Override + public CrossProduct withChildren(List children) { + Preconditions.checkArgument(children.size() == 2); + return new CrossProduct(getFunctionParams(children)); + } + + @Override + public R accept(ExpressionVisitor visitor, C context) { + return visitor.visitCrossProduct(this, context); + } + + @Override + public List getSignatures() { + return SIGNATURES; + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/ScalarFunctionVisitor.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/ScalarFunctionVisitor.java index 5a556add8376b3..9edd4170b6dfcf 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/ScalarFunctionVisitor.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/ScalarFunctionVisitor.java @@ -164,6 +164,7 @@ import org.apache.doris.nereids.trees.expressions.functions.scalar.CreateMap; import org.apache.doris.nereids.trees.expressions.functions.scalar.CreateNamedStruct; import org.apache.doris.nereids.trees.expressions.functions.scalar.CreateStruct; +import org.apache.doris.nereids.trees.expressions.functions.scalar.CrossProduct; import org.apache.doris.nereids.trees.expressions.functions.scalar.Csc; import org.apache.doris.nereids.trees.expressions.functions.scalar.CurrentCatalog; import org.apache.doris.nereids.trees.expressions.functions.scalar.CurrentDate; @@ -1081,6 +1082,10 @@ default R visitCountSubstring(CountSubstring countSubstring, C context) { return visitScalarFunction(countSubstring, context); } + default R visitCrossProduct(CrossProduct crossProduct, C context) { + return visitScalarFunction(crossProduct, context); + } + default R visitCurrentCatalog(CurrentCatalog currentCatalog, C context) { return visitScalarFunction(currentCatalog, context); } diff --git a/regression-test/data/query_p0/sql_functions/array_functions/test_array_cross_product_function.out b/regression-test/data/query_p0/sql_functions/array_functions/test_array_cross_product_function.out new file mode 100644 index 00000000000000..6720f674adfb96 --- /dev/null +++ b/regression-test/data/query_p0/sql_functions/array_functions/test_array_cross_product_function.out @@ -0,0 +1,15 @@ +-- This file is automatically generated. You should know what you did if you want to edit this +-- !sql -- +[-1, 2, -1] + +-- !sql -- +[0, 0, 0] + +-- !sql -- +[0, 0, 0] + +-- !sql -- +[0, 0, 1] + +-- !sql -- +[0, 0, -1] \ No newline at end of file diff --git a/regression-test/suites/query_p0/sql_functions/array_functions/test_array_cross_product_function.groovy b/regression-test/suites/query_p0/sql_functions/array_functions/test_array_cross_product_function.groovy new file mode 100644 index 00000000000000..20d4ddfdada3b6 --- /dev/null +++ b/regression-test/suites/query_p0/sql_functions/array_functions/test_array_cross_product_function.groovy @@ -0,0 +1,62 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +suite("test_array_cross_product_function") { + // normal test cases + qt_sql "SELECT cross_product([1, 2, 3], [2, 3, 4])" + qt_sql "SELECT cross_product([1, 2, 3], [0, 0, 0])" + qt_sql "SELECT cross_product([0, 0, 0], [1, 2, 3])" + qt_sql "SELECT cross_product([1, 0, 0], [0, 1, 0])" + qt_sql "SELECT cross_product([0, 1, 0], [1, 0, 0])" + + // abnormal test cases + try { + sql "SELECT cross_product(NULL, [1, 2, 3])" + } catch (Exception ex) { + assert("${ex}".contains("First argument for function cross_product cannot be null")) + } + try { + sql "SELECT cross_product([1, 2, 3], NULL)" + } catch (Exception ex) { + assert("${ex}".contains("Second argument for function cross_product cannot be null")) + } + try { + sql "SELECT cross_product([1, NULL, 3], [1, 2, 3])" + } catch (Exception ex) { + assert("${ex}".contains("First argument for function cross_product cannot have null elements")) + } + try { + sql "SELECT cross_product([1, 2, 3], [NULL, 2, 3])" + } catch (Exception ex) { + assert("${ex}".contains("Second argument for function cross_product cannot have null elements")) + } + try { + sql "SELECT cross_product([1, 2, 3], [1, 2])" + } catch (Exception ex) { + assert("${ex}".contains("function cross_product have different input element sizes of array: 3 and 2")) + } + try { + sql "SELECT cross_product([1, 2], [3, 4])" + } catch (Exception ex) { + assert("${ex}".contains("function cross_product requires arrays of size 3")) + } + try { + sql "SELECT cross_product([1, 2, 3, 4], [1, 2, 3, 4])" + } catch (Exception ex) { + assert("${ex}".contains("function cross_product requires arrays of size 3")) + } +} From 402901baac787866f43c9818ea258bd52230bcf1 Mon Sep 17 00:00:00 2001 From: juruo-c <2915601267@qq.com> Date: Tue, 23 Dec 2025 00:09:15 +0800 Subject: [PATCH 2/2] fix some error and add table test --- .../array/function_array_cross_product.cpp | 172 ++++++++++++++- .../array/function_array_cross_product.h | 204 ------------------ .../functions/scalar/CrossProduct.java | 10 +- .../test_array_cross_product_function.out | 16 +- .../test_array_cross_product_function.groovy | 62 +++--- 5 files changed, 228 insertions(+), 236 deletions(-) delete mode 100644 be/src/vec/functions/array/function_array_cross_product.h diff --git a/be/src/vec/functions/array/function_array_cross_product.cpp b/be/src/vec/functions/array/function_array_cross_product.cpp index 3352316145df71..1f809ac09c45ee 100644 --- a/be/src/vec/functions/array/function_array_cross_product.cpp +++ b/be/src/vec/functions/array/function_array_cross_product.cpp @@ -15,12 +15,182 @@ // specific language governing permissions and limitations // under the License. -#include "vec/functions/array/function_array_cross_product.h" +#include +#include "common/exception.h" +#include "common/status.h" +#include "runtime/primitive_type.h" +#include "vec/columns/column.h" +#include "vec/columns/column_array.h" +#include "vec/columns/column_nullable.h" +#include "vec/common/assert_cast.h" +#include "vec/core/types.h" +#include "vec/data_types/data_type.h" +#include "vec/data_types/data_type_array.h" +#include "vec/data_types/data_type_nullable.h" +#include "vec/data_types/data_type_number.h" +#include "vec/functions/array/function_array_utils.h" +#include "vec/functions/function.h" #include "vec/functions/simple_function_factory.h" +#include "vec/utils/util.hpp" namespace doris::vectorized { +class FunctionArrayCrossProduct : public IFunction { +public: + using DataType = PrimitiveTypeTraits::DataType; + using ColumnType = PrimitiveTypeTraits::ColumnType; + + static constexpr auto name = "cross_product"; + String get_name() const override { return name; } + static FunctionPtr create() { return std::make_shared(); } + size_t get_number_of_arguments() const override { return 2; } + + DataTypePtr get_return_type_impl(const DataTypes& arguments) const override { + if (arguments.size() != 2) { + throw doris::Exception(ErrorCode::INVALID_ARGUMENT, + "Invalid number of arguments for function {}", get_name()); + } + + if (arguments[0]->get_primitive_type() != TYPE_ARRAY || + arguments[1]->get_primitive_type() != TYPE_ARRAY) { + throw doris::Exception(ErrorCode::INVALID_ARGUMENT, + "Arguments for function {} must be arrays", get_name()); + } + + // return ARRAY + return std::make_shared(std::make_shared()); + } + + Status execute_impl(FunctionContext* context, Block& block, const ColumnNumbers& arguments, + uint32_t result, size_t input_rows_count) const override { + const auto& arg1 = block.get_by_position(arguments[0]); + const auto& arg2 = block.get_by_position(arguments[1]); + + const IColumn* col1 = arg1.column.get(); + const IColumn* col2 = arg2.column.get(); + + const ColumnConst* col1_const = nullptr; + const ColumnConst* col2_const = nullptr; + + if (is_column_const(*col1)) { + col1_const = assert_cast(col1); + col1 = &col1_const->get_data_column(); + } + if (is_column_const(*col2)) { + col2_const = assert_cast(col2); + col2 = &col2_const->get_data_column(); + } + + const ColumnArray* arr1 = nullptr; + const ColumnArray* arr2 = nullptr; + + if (col1->is_nullable()) { + auto nullable1 = assert_cast(col1); + arr1 = assert_cast(nullable1->get_nested_column_ptr().get()); + } else { + arr1 = assert_cast(col1); + } + if (col2->is_nullable()) { + auto nullable2 = assert_cast(col2); + arr2 = assert_cast(nullable2->get_nested_column_ptr().get()); + } else { + arr2 = assert_cast(col2); + } + + const ColumnFloat64* float1 = nullptr; + const ColumnFloat64* float2 = nullptr; + + if (arr1->get_data_ptr()->is_nullable()) { + if (arr1->get_data_ptr()->has_null()) { + return Status::InvalidArgument( + "First argument for function {} cannot have null elements", get_name()); + } + auto nullable1 = assert_cast(arr1->get_data_ptr().get()); + float1 = assert_cast(nullable1->get_nested_column_ptr().get()); + } else { + float1 = assert_cast(arr1->get_data_ptr().get()); + } + + if (arr2->get_data_ptr()->is_nullable()) { + if (arr2->get_data_ptr()->has_null()) { + return Status::InvalidArgument( + "Second argument for function {} cannot have null elements", get_name()); + } + auto nullable2 = assert_cast(arr2->get_data_ptr().get()); + float2 = assert_cast(nullable2->get_nested_column_ptr().get()); + } else { + float2 = assert_cast(arr2->get_data_ptr().get()); + } + + const auto* offset1 = + assert_cast(arr1->get_offsets_ptr().get()); + const auto* offset2 = + assert_cast(arr2->get_offsets_ptr().get()); + + // prepare result data + auto nested_res = ColumnFloat64::create(); + auto& nested_data = nested_res->get_data(); + nested_data.resize(3 * input_rows_count); + + auto offsets_res = ColumnArray::ColumnOffsets::create(); + auto& offsets_data = offsets_res->get_data(); + offsets_data.resize(input_rows_count); + + size_t current_offset = 0; + for (ssize_t row = 0; row < input_rows_count; ++row) { + ssize_t row1 = col1_const ? 0 : row; + ssize_t row2 = col2_const ? 0 : row; + + ssize_t prev_offset1 = (row1 == 0) ? 0 : offset1->get_data()[row1 - 1]; + ssize_t prev_offset2 = (row2 == 0) ? 0 : offset2->get_data()[row2 - 1]; + + ssize_t size1 = offset1->get_data()[row] - prev_offset1; + ssize_t size2 = offset2->get_data()[row] - prev_offset2; + + if (size1 == 0 || size2 == 0) { + nested_data[current_offset] = 0; + nested_data[current_offset + 1] = 0; + nested_data[current_offset + 2] = 0; + + current_offset += 3; + offsets_data[row] = current_offset; + continue; + } + + if (size1 != 3 || size2 != 3) { + return Status::InvalidArgument("function {} requires arrays of size 3", get_name()); + } + + ssize_t base1 = prev_offset1; + ssize_t base2 = prev_offset2; + + double a1 = float1->get_data()[base1]; + double a2 = float1->get_data()[base1 + 1]; + double a3 = float1->get_data()[base1 + 2]; + + double b1 = float2->get_data()[base2]; + double b2 = float2->get_data()[base2 + 1]; + double b3 = float2->get_data()[base2 + 2]; + + nested_data[current_offset] = a2 * b3 - a3 * b2; + nested_data[current_offset + 1] = a3 * b1 - a1 * b3; + nested_data[current_offset + 2] = a1 * b2 - a2 * b1; + + current_offset += 3; + offsets_data[row] = current_offset; + } + + auto result_col = ColumnArray::create( + ColumnNullable::create(std::move(nested_res), + ColumnUInt8::create(nested_res->size(), 0)), + std::move(offsets_res)); + + block.replace_by_position(result, std::move(result_col)); + return Status::OK(); + } +}; + void register_function_array_cross_product(SimpleFunctionFactory& factory) { factory.register_function(); } diff --git a/be/src/vec/functions/array/function_array_cross_product.h b/be/src/vec/functions/array/function_array_cross_product.h deleted file mode 100644 index 7b05baba134f00..00000000000000 --- a/be/src/vec/functions/array/function_array_cross_product.h +++ /dev/null @@ -1,204 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#pragma once - -#include - -#include "common/exception.h" -#include "common/status.h" -#include "runtime/primitive_type.h" -#include "vec/columns/column.h" -#include "vec/columns/column_array.h" -#include "vec/columns/column_nullable.h" -#include "vec/common/assert_cast.h" -#include "vec/core/types.h" -#include "vec/data_types/data_type.h" -#include "vec/data_types/data_type_array.h" -#include "vec/data_types/data_type_nullable.h" -#include "vec/data_types/data_type_number.h" -#include "vec/functions/array/function_array_utils.h" -#include "vec/functions/function.h" -#include "vec/utils/util.hpp" - -namespace doris::vectorized { - -class FunctionArrayCrossProduct : public IFunction { -public: - using DataType = PrimitiveTypeTraits::DataType; - using ColumnType = PrimitiveTypeTraits::ColumnType; - - static constexpr auto name = "cross_product"; - String get_name() const override { return name; } - static FunctionPtr create() { return std::make_shared(); } - size_t get_number_of_arguments() const override { return 2; } - - DataTypePtr get_return_type_impl(const DataTypes& arguments) const override { - if (arguments.size() != 2) { - throw doris::Exception(ErrorCode::INVALID_ARGUMENT, - "Invalid number of arguments for function {}", get_name()); - } - - if (arguments[0]->get_primitive_type() != TYPE_ARRAY || - arguments[1]->get_primitive_type() != TYPE_ARRAY) { - throw doris::Exception(ErrorCode::INVALID_ARGUMENT, - "Arguments for function {} must be arrays", get_name()); - } - - // return ARRAY - return std::make_shared( - std::make_shared()); - } - - // strict semantics: do not allow NULL - bool use_default_implementation_for_nulls() const override { return false; } - - Status execute_impl(FunctionContext* context, Block& block, const ColumnNumbers& arguments, - uint32_t result, size_t input_rows_count) const override { - const auto& arg1 = block.get_by_position(arguments[0]); - const auto& arg2 = block.get_by_position(arguments[1]); - - auto col1 = arg1.column->convert_to_full_column_if_const(); - auto col2 = arg2.column->convert_to_full_column_if_const(); - - if (col1->size() != col2->size()) { - return Status::RuntimeError( - fmt::format("function {} have different input array sizes: {} and {}", - get_name(), col1->size(), col2->size())); - } - - const ColumnArray* arr1 = nullptr; - const ColumnArray* arr2 = nullptr; - - if (const auto* nullable = - typeid_cast(col1.get())) { - if (nullable->has_null()) { - throw doris::Exception(ErrorCode::INVALID_ARGUMENT, - "First argument for function {} cannot be null", get_name()); - } - arr1 = assert_cast(nullable->get_nested_column_ptr().get()); - } else { - arr1 = assert_cast(col1.get()); - } - - if (const auto* nullable = - typeid_cast(col2.get())) { - if (nullable->has_null()) { - throw doris::Exception(ErrorCode::INVALID_ARGUMENT, - "Second argument for function {} cannot be null", - get_name()); - } - arr2 = assert_cast(nullable->get_nested_column_ptr().get()); - } else { - arr2 = assert_cast(col2.get()); - } - - const ColumnFloat32* float1 = nullptr; - const ColumnFloat32* float2 = nullptr; - - if (const auto* nullable = - typeid_cast(arr1->get_data_ptr().get())) { - if (nullable->has_null()) { - throw doris::Exception(ErrorCode::INVALID_ARGUMENT, - "First argument for function {} cannot have null elements", - get_name()); - } - float1 = assert_cast(nullable->get_nested_column_ptr().get()); - } else { - float1 = assert_cast(arr1->get_data_ptr().get()); - } - - if (const auto* nullable = - typeid_cast(arr2->get_data_ptr().get())) { - if (nullable->has_null()) { - throw doris::Exception(ErrorCode::INVALID_ARGUMENT, - "Second argument for function {} cannot have null elements", - get_name()); - } - float2 = assert_cast(nullable->get_nested_column_ptr().get()); - } else { - float2 = assert_cast(arr2->get_data_ptr().get()); - } - - const auto* offset1 = - assert_cast(arr1->get_offsets_ptr().get()); - const auto* offset2 = - assert_cast(arr2->get_offsets_ptr().get()); - - // prepare result data - auto nested_res = ColumnFloat32::create(); - auto offsets_res = ColumnArray::ColumnOffsets::create(); - auto& offsets_data = offsets_res->get_data(); - offsets_data.reserve(input_rows_count); - size_t current_offset = 0; - - size_t row_cnt = offset1->size(); - size_t prev_offset1 = 0; - size_t prev_offset2 = 0; - for (ssize_t row = 0; row < row_cnt; ++row) { - ssize_t size1 = offset1->get_data()[row] - prev_offset1; - ssize_t size2 = offset2->get_data()[row] - prev_offset2; - - if (size1 != size2) [[unlikely]] { - return Status::InvalidArgument( - "function {} have different input element sizes of array: {} and {}", - get_name(), size1, size2); - } - - if (size1 != 3 || size2 != 3) { - throw doris::Exception(ErrorCode::INVALID_ARGUMENT, - "function {} requires arrays of size 3", - get_name()); - } - - ssize_t base1 = prev_offset1; - ssize_t base2 = prev_offset2; - - float a1 = float1->get_data()[base1]; - float a2 = float1->get_data()[base1 + 1]; - float a3 = float1->get_data()[base1 + 2]; - - float b1 = float2->get_data()[base2]; - float b2 = float2->get_data()[base2 + 1]; - float b3 = float2->get_data()[base2 + 2]; - - float c1 = a2 * b3 - a3 * b2; - float c2 = a3 * b1 - a1 * b3; - float c3 = a1 * b2 - a2 * b1; - - nested_res->insert_value(c1); - nested_res->insert_value(c2); - nested_res->insert_value(c3); - - current_offset += 3; - offsets_data.push_back(current_offset); - - prev_offset1 = offset1->get_data()[row]; - prev_offset2 = offset2->get_data()[row]; - } - - auto result_col = ColumnArray::create( - ColumnNullable::create(std::move(nested_res), - ColumnUInt8::create(nested_res->size(), 0)), - std::move(offsets_res)); - - block.replace_by_position(result, std::move(result_col)); - return Status::OK(); - } -}; - -} // namespace doris::vectorized \ No newline at end of file diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/CrossProduct.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/CrossProduct.java index cf32ca3b8ebd90..46854a6d56a4f6 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/CrossProduct.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/CrossProduct.java @@ -19,12 +19,12 @@ import org.apache.doris.catalog.FunctionSignature; import org.apache.doris.nereids.trees.expressions.Expression; -import org.apache.doris.nereids.trees.expressions.functions.AlwaysNotNullable; import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature; +import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable; import org.apache.doris.nereids.trees.expressions.shape.BinaryExpression; import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; import org.apache.doris.nereids.types.ArrayType; -import org.apache.doris.nereids.types.FloatType; +import org.apache.doris.nereids.types.DoubleType; import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; @@ -35,11 +35,11 @@ * cosine_distance function */ public class CrossProduct extends ScalarFunction implements ExplicitlyCastableSignature, - BinaryExpression, AlwaysNotNullable { + BinaryExpression, PropagateNullable { public static final List SIGNATURES = ImmutableList.of( - FunctionSignature.ret(ArrayType.of(FloatType.INSTANCE)) - .args(ArrayType.of(FloatType.INSTANCE), ArrayType.of(FloatType.INSTANCE)) + FunctionSignature.ret(ArrayType.of(DoubleType.INSTANCE)) + .args(ArrayType.of(DoubleType.INSTANCE), ArrayType.of(DoubleType.INSTANCE)) ); /** diff --git a/regression-test/data/query_p0/sql_functions/array_functions/test_array_cross_product_function.out b/regression-test/data/query_p0/sql_functions/array_functions/test_array_cross_product_function.out index 6720f674adfb96..85d0af0564d8de 100644 --- a/regression-test/data/query_p0/sql_functions/array_functions/test_array_cross_product_function.out +++ b/regression-test/data/query_p0/sql_functions/array_functions/test_array_cross_product_function.out @@ -12,4 +12,18 @@ [0, 0, 1] -- !sql -- -[0, 0, -1] \ No newline at end of file +[0, 0, -1] + +-- !sql -- +\N + +-- !sql -- +\N + +-- !sql -- +1 [-1, 2, -1] +2 [0, 0, 0] +3 [0, 0, 1] +4 [0, 0, -1] +5 \N + diff --git a/regression-test/suites/query_p0/sql_functions/array_functions/test_array_cross_product_function.groovy b/regression-test/suites/query_p0/sql_functions/array_functions/test_array_cross_product_function.groovy index 20d4ddfdada3b6..610d8135a5bd69 100644 --- a/regression-test/suites/query_p0/sql_functions/array_functions/test_array_cross_product_function.groovy +++ b/regression-test/suites/query_p0/sql_functions/array_functions/test_array_cross_product_function.groovy @@ -22,41 +22,53 @@ suite("test_array_cross_product_function") { qt_sql "SELECT cross_product([0, 0, 0], [1, 2, 3])" qt_sql "SELECT cross_product([1, 0, 0], [0, 1, 0])" qt_sql "SELECT cross_product([0, 1, 0], [1, 0, 0])" + qt_sql "SELECT cross_product(NULL, [1, 2, 3])" + qt_sql "SELECT cross_product([1, 2, 3], NULL)" + + def tableName = "array_cross_product_test" + sql "DROP TABLE IF EXISTS ${tableName}" + sql """ + CREATE TABLE ${tableName} ( + id INT, + vec1 ARRAY, + vec2 ARRAY + ) + DUPLICATE KEY(id) + DISTRIBUTED BY HASH(id) BUCKETS 3 + PROPERTIES ( + "replication_num" = "1" + ) + """ + sql """ + INSERT INTO ${tableName} values + (1, [1, 2, 3], [2, 3, 4]), + (2, [1, 2, 3], [0, 0, 0]), + (3, [1, 0, 0], [0, 1, 0]), + (4, [0, 1, 0], [1, 0, 0]), + (5, NULL, [1, 0, 0]) + """ + qt_sql "SELECT id, CROSS_PRODUCT(vec1, vec2) from ${tableName} ORDER BY id" + sql "DROP TABLE IF EXISTS ${tableName}" // abnormal test cases - try { - sql "SELECT cross_product(NULL, [1, 2, 3])" - } catch (Exception ex) { - assert("${ex}".contains("First argument for function cross_product cannot be null")) - } - try { - sql "SELECT cross_product([1, 2, 3], NULL)" - } catch (Exception ex) { - assert("${ex}".contains("Second argument for function cross_product cannot be null")) - } - try { + test { sql "SELECT cross_product([1, NULL, 3], [1, 2, 3])" - } catch (Exception ex) { - assert("${ex}".contains("First argument for function cross_product cannot have null elements")) + exception "First argument for function cross_product cannot have null elements" } - try { + test { sql "SELECT cross_product([1, 2, 3], [NULL, 2, 3])" - } catch (Exception ex) { - assert("${ex}".contains("Second argument for function cross_product cannot have null elements")) + exception "Second argument for function cross_product cannot have null elements" } - try { + test { sql "SELECT cross_product([1, 2, 3], [1, 2])" - } catch (Exception ex) { - assert("${ex}".contains("function cross_product have different input element sizes of array: 3 and 2")) + exception "function cross_product requires arrays of size 3" } - try { + test { sql "SELECT cross_product([1, 2], [3, 4])" - } catch (Exception ex) { - assert("${ex}".contains("function cross_product requires arrays of size 3")) + exception "function cross_product requires arrays of size 3" } - try { + test { sql "SELECT cross_product([1, 2, 3, 4], [1, 2, 3, 4])" - } catch (Exception ex) { - assert("${ex}".contains("function cross_product requires arrays of size 3")) + exception "function cross_product requires arrays of size 3" } }