diff --git a/be/src/vec/functions/array/function_array_cross_product.cpp b/be/src/vec/functions/array/function_array_cross_product.cpp new file mode 100644 index 00000000000000..1f809ac09c45ee --- /dev/null +++ b/be/src/vec/functions/array/function_array_cross_product.cpp @@ -0,0 +1,198 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include + +#include "common/exception.h" +#include "common/status.h" +#include "runtime/primitive_type.h" +#include "vec/columns/column.h" +#include "vec/columns/column_array.h" +#include "vec/columns/column_nullable.h" +#include "vec/common/assert_cast.h" +#include "vec/core/types.h" +#include "vec/data_types/data_type.h" +#include "vec/data_types/data_type_array.h" +#include "vec/data_types/data_type_nullable.h" +#include "vec/data_types/data_type_number.h" +#include "vec/functions/array/function_array_utils.h" +#include "vec/functions/function.h" +#include "vec/functions/simple_function_factory.h" +#include "vec/utils/util.hpp" + +namespace doris::vectorized { + +class FunctionArrayCrossProduct : public IFunction { +public: + using DataType = PrimitiveTypeTraits::DataType; + using ColumnType = PrimitiveTypeTraits::ColumnType; + + static constexpr auto name = "cross_product"; + String get_name() const override { return name; } + static FunctionPtr create() { return std::make_shared(); } + size_t get_number_of_arguments() const override { return 2; } + + DataTypePtr get_return_type_impl(const DataTypes& arguments) const override { + if (arguments.size() != 2) { + throw doris::Exception(ErrorCode::INVALID_ARGUMENT, + "Invalid number of arguments for function {}", get_name()); + } + + if (arguments[0]->get_primitive_type() != TYPE_ARRAY || + arguments[1]->get_primitive_type() != TYPE_ARRAY) { + throw doris::Exception(ErrorCode::INVALID_ARGUMENT, + "Arguments for function {} must be arrays", get_name()); + } + + // return ARRAY + return std::make_shared(std::make_shared()); + } + + Status execute_impl(FunctionContext* context, Block& block, const ColumnNumbers& arguments, + uint32_t result, size_t input_rows_count) const override { + const auto& arg1 = block.get_by_position(arguments[0]); + const auto& arg2 = block.get_by_position(arguments[1]); + + const IColumn* col1 = arg1.column.get(); + const IColumn* col2 = arg2.column.get(); + + const ColumnConst* col1_const = nullptr; + const ColumnConst* col2_const = nullptr; + + if (is_column_const(*col1)) { + col1_const = assert_cast(col1); + col1 = &col1_const->get_data_column(); + } + if (is_column_const(*col2)) { + col2_const = assert_cast(col2); + col2 = &col2_const->get_data_column(); + } + + const ColumnArray* arr1 = nullptr; + const ColumnArray* arr2 = nullptr; + + if (col1->is_nullable()) { + auto nullable1 = assert_cast(col1); + arr1 = assert_cast(nullable1->get_nested_column_ptr().get()); + } else { + arr1 = assert_cast(col1); + } + if (col2->is_nullable()) { + auto nullable2 = assert_cast(col2); + arr2 = assert_cast(nullable2->get_nested_column_ptr().get()); + } else { + arr2 = assert_cast(col2); + } + + const ColumnFloat64* float1 = nullptr; + const ColumnFloat64* float2 = nullptr; + + if (arr1->get_data_ptr()->is_nullable()) { + if (arr1->get_data_ptr()->has_null()) { + return Status::InvalidArgument( + "First argument for function {} cannot have null elements", get_name()); + } + auto nullable1 = assert_cast(arr1->get_data_ptr().get()); + float1 = assert_cast(nullable1->get_nested_column_ptr().get()); + } else { + float1 = assert_cast(arr1->get_data_ptr().get()); + } + + if (arr2->get_data_ptr()->is_nullable()) { + if (arr2->get_data_ptr()->has_null()) { + return Status::InvalidArgument( + "Second argument for function {} cannot have null elements", get_name()); + } + auto nullable2 = assert_cast(arr2->get_data_ptr().get()); + float2 = assert_cast(nullable2->get_nested_column_ptr().get()); + } else { + float2 = assert_cast(arr2->get_data_ptr().get()); + } + + const auto* offset1 = + assert_cast(arr1->get_offsets_ptr().get()); + const auto* offset2 = + assert_cast(arr2->get_offsets_ptr().get()); + + // prepare result data + auto nested_res = ColumnFloat64::create(); + auto& nested_data = nested_res->get_data(); + nested_data.resize(3 * input_rows_count); + + auto offsets_res = ColumnArray::ColumnOffsets::create(); + auto& offsets_data = offsets_res->get_data(); + offsets_data.resize(input_rows_count); + + size_t current_offset = 0; + for (ssize_t row = 0; row < input_rows_count; ++row) { + ssize_t row1 = col1_const ? 0 : row; + ssize_t row2 = col2_const ? 0 : row; + + ssize_t prev_offset1 = (row1 == 0) ? 0 : offset1->get_data()[row1 - 1]; + ssize_t prev_offset2 = (row2 == 0) ? 0 : offset2->get_data()[row2 - 1]; + + ssize_t size1 = offset1->get_data()[row] - prev_offset1; + ssize_t size2 = offset2->get_data()[row] - prev_offset2; + + if (size1 == 0 || size2 == 0) { + nested_data[current_offset] = 0; + nested_data[current_offset + 1] = 0; + nested_data[current_offset + 2] = 0; + + current_offset += 3; + offsets_data[row] = current_offset; + continue; + } + + if (size1 != 3 || size2 != 3) { + return Status::InvalidArgument("function {} requires arrays of size 3", get_name()); + } + + ssize_t base1 = prev_offset1; + ssize_t base2 = prev_offset2; + + double a1 = float1->get_data()[base1]; + double a2 = float1->get_data()[base1 + 1]; + double a3 = float1->get_data()[base1 + 2]; + + double b1 = float2->get_data()[base2]; + double b2 = float2->get_data()[base2 + 1]; + double b3 = float2->get_data()[base2 + 2]; + + nested_data[current_offset] = a2 * b3 - a3 * b2; + nested_data[current_offset + 1] = a3 * b1 - a1 * b3; + nested_data[current_offset + 2] = a1 * b2 - a2 * b1; + + current_offset += 3; + offsets_data[row] = current_offset; + } + + auto result_col = ColumnArray::create( + ColumnNullable::create(std::move(nested_res), + ColumnUInt8::create(nested_res->size(), 0)), + std::move(offsets_res)); + + block.replace_by_position(result, std::move(result_col)); + return Status::OK(); + } +}; + +void register_function_array_cross_product(SimpleFunctionFactory& factory) { + factory.register_function(); +} + +} // namespace doris::vectorized \ No newline at end of file diff --git a/be/src/vec/functions/array/function_array_register.cpp b/be/src/vec/functions/array/function_array_register.cpp index aa92e89128fec3..a3689d6a579e04 100644 --- a/be/src/vec/functions/array/function_array_register.cpp +++ b/be/src/vec/functions/array/function_array_register.cpp @@ -53,6 +53,7 @@ void register_function_array_pushback(SimpleFunctionFactory& factory); void register_function_array_first_or_last_index(SimpleFunctionFactory& factory); void register_function_array_cum_sum(SimpleFunctionFactory& factory); void register_function_array_count(SimpleFunctionFactory&); +void register_function_array_cross_product(SimpleFunctionFactory& factory); void register_function_array_filter_function(SimpleFunctionFactory&); void register_function_array_splits(SimpleFunctionFactory&); void register_function_array_contains_all(SimpleFunctionFactory&); @@ -91,6 +92,7 @@ void register_function_array(SimpleFunctionFactory& factory) { register_function_array_first_or_last_index(factory); register_function_array_cum_sum(factory); register_function_array_count(factory); + register_function_array_cross_product(factory); register_function_array_filter_function(factory); register_function_array_splits(factory); register_function_array_contains_all(factory); diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinScalarFunctions.java b/fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinScalarFunctions.java index b7970c448cf0ca..a40df206f079e0 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinScalarFunctions.java +++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinScalarFunctions.java @@ -153,6 +153,7 @@ import org.apache.doris.nereids.trees.expressions.functions.scalar.CreateMap; import org.apache.doris.nereids.trees.expressions.functions.scalar.CreateNamedStruct; import org.apache.doris.nereids.trees.expressions.functions.scalar.CreateStruct; +import org.apache.doris.nereids.trees.expressions.functions.scalar.CrossProduct; import org.apache.doris.nereids.trees.expressions.functions.scalar.Csc; import org.apache.doris.nereids.trees.expressions.functions.scalar.CurrentCatalog; import org.apache.doris.nereids.trees.expressions.functions.scalar.CurrentDate; @@ -695,6 +696,7 @@ public class BuiltinScalarFunctions implements FunctionHelper { scalar(CreateMap.class, "map"), scalar(CreateStruct.class, "struct"), scalar(CreateNamedStruct.class, "named_struct"), + scalar(CrossProduct.class, "cross_product"), scalar(CurrentCatalog.class, "current_catalog"), scalar(CurrentDate.class, "curdate", "current_date"), scalar(CurrentTime.class, "curtime", "current_time"), diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/CrossProduct.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/CrossProduct.java new file mode 100644 index 00000000000000..46854a6d56a4f6 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/CrossProduct.java @@ -0,0 +1,75 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.trees.expressions.functions.scalar; + +import org.apache.doris.catalog.FunctionSignature; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature; +import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable; +import org.apache.doris.nereids.trees.expressions.shape.BinaryExpression; +import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; +import org.apache.doris.nereids.types.ArrayType; +import org.apache.doris.nereids.types.DoubleType; + +import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableList; + +import java.util.List; + +/** + * cosine_distance function + */ +public class CrossProduct extends ScalarFunction implements ExplicitlyCastableSignature, + BinaryExpression, PropagateNullable { + + public static final List SIGNATURES = ImmutableList.of( + FunctionSignature.ret(ArrayType.of(DoubleType.INSTANCE)) + .args(ArrayType.of(DoubleType.INSTANCE), ArrayType.of(DoubleType.INSTANCE)) + ); + + /** + * constructor with 1 argument. + */ + public CrossProduct(Expression arg0, Expression arg1) { + super("cross_product", arg0, arg1); + } + + /** constructor for withChildren and reuse signature */ + private CrossProduct(ScalarFunctionParams functionParams) { + super(functionParams); + } + + /** + * withChildren. + */ + @Override + public CrossProduct withChildren(List children) { + Preconditions.checkArgument(children.size() == 2); + return new CrossProduct(getFunctionParams(children)); + } + + @Override + public R accept(ExpressionVisitor visitor, C context) { + return visitor.visitCrossProduct(this, context); + } + + @Override + public List getSignatures() { + return SIGNATURES; + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/ScalarFunctionVisitor.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/ScalarFunctionVisitor.java index 5a556add8376b3..9edd4170b6dfcf 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/ScalarFunctionVisitor.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/ScalarFunctionVisitor.java @@ -164,6 +164,7 @@ import org.apache.doris.nereids.trees.expressions.functions.scalar.CreateMap; import org.apache.doris.nereids.trees.expressions.functions.scalar.CreateNamedStruct; import org.apache.doris.nereids.trees.expressions.functions.scalar.CreateStruct; +import org.apache.doris.nereids.trees.expressions.functions.scalar.CrossProduct; import org.apache.doris.nereids.trees.expressions.functions.scalar.Csc; import org.apache.doris.nereids.trees.expressions.functions.scalar.CurrentCatalog; import org.apache.doris.nereids.trees.expressions.functions.scalar.CurrentDate; @@ -1081,6 +1082,10 @@ default R visitCountSubstring(CountSubstring countSubstring, C context) { return visitScalarFunction(countSubstring, context); } + default R visitCrossProduct(CrossProduct crossProduct, C context) { + return visitScalarFunction(crossProduct, context); + } + default R visitCurrentCatalog(CurrentCatalog currentCatalog, C context) { return visitScalarFunction(currentCatalog, context); } diff --git a/regression-test/data/query_p0/sql_functions/array_functions/test_array_cross_product_function.out b/regression-test/data/query_p0/sql_functions/array_functions/test_array_cross_product_function.out new file mode 100644 index 00000000000000..85d0af0564d8de --- /dev/null +++ b/regression-test/data/query_p0/sql_functions/array_functions/test_array_cross_product_function.out @@ -0,0 +1,29 @@ +-- This file is automatically generated. You should know what you did if you want to edit this +-- !sql -- +[-1, 2, -1] + +-- !sql -- +[0, 0, 0] + +-- !sql -- +[0, 0, 0] + +-- !sql -- +[0, 0, 1] + +-- !sql -- +[0, 0, -1] + +-- !sql -- +\N + +-- !sql -- +\N + +-- !sql -- +1 [-1, 2, -1] +2 [0, 0, 0] +3 [0, 0, 1] +4 [0, 0, -1] +5 \N + diff --git a/regression-test/suites/query_p0/sql_functions/array_functions/test_array_cross_product_function.groovy b/regression-test/suites/query_p0/sql_functions/array_functions/test_array_cross_product_function.groovy new file mode 100644 index 00000000000000..610d8135a5bd69 --- /dev/null +++ b/regression-test/suites/query_p0/sql_functions/array_functions/test_array_cross_product_function.groovy @@ -0,0 +1,74 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +suite("test_array_cross_product_function") { + // normal test cases + qt_sql "SELECT cross_product([1, 2, 3], [2, 3, 4])" + qt_sql "SELECT cross_product([1, 2, 3], [0, 0, 0])" + qt_sql "SELECT cross_product([0, 0, 0], [1, 2, 3])" + qt_sql "SELECT cross_product([1, 0, 0], [0, 1, 0])" + qt_sql "SELECT cross_product([0, 1, 0], [1, 0, 0])" + qt_sql "SELECT cross_product(NULL, [1, 2, 3])" + qt_sql "SELECT cross_product([1, 2, 3], NULL)" + + def tableName = "array_cross_product_test" + sql "DROP TABLE IF EXISTS ${tableName}" + sql """ + CREATE TABLE ${tableName} ( + id INT, + vec1 ARRAY, + vec2 ARRAY + ) + DUPLICATE KEY(id) + DISTRIBUTED BY HASH(id) BUCKETS 3 + PROPERTIES ( + "replication_num" = "1" + ) + """ + sql """ + INSERT INTO ${tableName} values + (1, [1, 2, 3], [2, 3, 4]), + (2, [1, 2, 3], [0, 0, 0]), + (3, [1, 0, 0], [0, 1, 0]), + (4, [0, 1, 0], [1, 0, 0]), + (5, NULL, [1, 0, 0]) + """ + qt_sql "SELECT id, CROSS_PRODUCT(vec1, vec2) from ${tableName} ORDER BY id" + sql "DROP TABLE IF EXISTS ${tableName}" + + // abnormal test cases + test { + sql "SELECT cross_product([1, NULL, 3], [1, 2, 3])" + exception "First argument for function cross_product cannot have null elements" + } + test { + sql "SELECT cross_product([1, 2, 3], [NULL, 2, 3])" + exception "Second argument for function cross_product cannot have null elements" + } + test { + sql "SELECT cross_product([1, 2, 3], [1, 2])" + exception "function cross_product requires arrays of size 3" + } + test { + sql "SELECT cross_product([1, 2], [3, 4])" + exception "function cross_product requires arrays of size 3" + } + test { + sql "SELECT cross_product([1, 2, 3, 4], [1, 2, 3, 4])" + exception "function cross_product requires arrays of size 3" + } +}