Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,21 @@

package org.opensearch.sql.api.spec.search;

import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.List;
import lombok.AccessLevel;
import lombok.NoArgsConstructor;
import org.apache.calcite.sql.SqlBasicTypeNameSpec;
import org.apache.calcite.sql.SqlCall;
import org.apache.calcite.sql.SqlDataTypeSpec;
import org.apache.calcite.sql.SqlIdentifier;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.SqlLiteral;
import org.apache.calcite.sql.SqlNode;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.parser.SqlParserPos;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.sql.util.SqlShuttle;
import org.checkerframework.checker.nullness.qual.Nullable;
import org.opensearch.sql.api.spec.UnifiedFunctionSpec;
Expand All @@ -31,6 +36,10 @@ public final class NamedArgRewriter extends SqlShuttle {

public static final NamedArgRewriter INSTANCE = new NamedArgRewriter();

private static final SqlDataTypeSpec VARCHAR_TYPE =
new SqlDataTypeSpec(
new SqlBasicTypeNameSpec(SqlTypeName.VARCHAR, SqlParserPos.ZERO), SqlParserPos.ZERO);

@Override
public @Nullable SqlNode visit(SqlCall call) {
SqlCall visited = (SqlCall) super.visit(call);
Expand All @@ -44,6 +53,7 @@ public final class NamedArgRewriter extends SqlShuttle {
* Rewrites each argument into a MAP entry. For match(name, 'John', operator='AND'):
* <li>Positional arg: name → MAP('field', name)
* <li>Named arg: operator='AND' → MAP('operator', 'AND')
* <li>ARRAY arg: ARRAY['f1','f2'] → MAP('fields', MAP(CAST('f1' AS VARCHAR), 1, ...))
*/
private static SqlCall rewriteToMaps(SqlCall call, List<String> paramNames) {
List<SqlNode> operands = call.getOperandList();
Expand All @@ -62,12 +72,34 @@ private static SqlCall rewriteToMaps(SqlCall call, List<String> paramNames) {
throw new IllegalArgumentException(
String.format("Invalid arguments for function '%s'", call.getOperator().getName()));
}
maps[i] = toMap(paramNames.get(i), op);
String paramName = paramNames.get(i);
if ("fields".equals(paramName)
&& op instanceof SqlCall arrayCall
&& arrayCall.getOperator() == SqlStdOperatorTable.ARRAY_VALUE_CONSTRUCTOR) {
maps[i] = toMap(paramName, expandFieldsArray(arrayCall));
} else {
maps[i] = toMap(paramName, op);
}
}
}
return call.getOperator().createCall(call.getParserPosition(), maps);
}

/** Expands ARRAY['f1', 'f2'] into MAP(CAST('f1' AS VARCHAR), 1, CAST('f2' AS VARCHAR), 1). */
private static SqlNode expandFieldsArray(SqlCall arrayCall) {
List<SqlNode> mapArgs = new ArrayList<>();
for (SqlNode element : arrayCall.getOperandList()) {
mapArgs.add(castToVarchar(element));
mapArgs.add(SqlLiteral.createExactNumeric(BigDecimal.ONE.toPlainString(), SqlParserPos.ZERO));
}
return SqlStdOperatorTable.MAP_VALUE_CONSTRUCTOR.createCall(
SqlParserPos.ZERO, mapArgs.toArray(SqlNode[]::new));
}

private static SqlNode castToVarchar(SqlNode node) {
return SqlStdOperatorTable.CAST.createCall(SqlParserPos.ZERO, node, VARCHAR_TYPE);
}

private static SqlNode toMap(String key, SqlNode value) {
return SqlStdOperatorTable.MAP_VALUE_CONSTRUCTOR.createCall(
SqlParserPos.ZERO, SqlLiteral.createCharString(key, SqlParserPos.ZERO), value);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,42 @@ SELECT upper(name) FROM catalog.employees\
// FIXME: Calcite's SQL parser does not support V2 bracket field list syntax ['field1', 'field2'].
// Multi-field relevance functions only accept a single column reference in the Calcite SQL path.

@Test
public void testMultiMatchArraySyntax() {
givenQuery(
"""
SELECT * FROM catalog.employees
WHERE multi_match(ARRAY['name', 'department'], 'John')\
""")
.assertPlanContains(
"multi_match(MAP('fields', MAP('name':VARCHAR, 1, 'department':VARCHAR, 1)),"
+ " MAP('query', 'John'))");
}

@Test
public void testSimpleQueryStringArraySyntax() {
givenQuery(
"""
SELECT * FROM catalog.employees
WHERE simple_query_string(ARRAY['name', 'department'], 'John')\
""")
.assertPlanContains(
"simple_query_string(MAP('fields', MAP('name':VARCHAR, 1,"
+ " 'department':VARCHAR, 1)), MAP('query', 'John'))");
}

@Test
public void testQueryStringArraySyntax() {
givenQuery(
"""
SELECT * FROM catalog.employees
WHERE query_string(ARRAY['name', 'department'], 'John')\
""")
.assertPlanContains(
"query_string(MAP('fields', MAP('name':VARCHAR, 1,"
+ " 'department':VARCHAR, 1)), MAP('query', 'John'))");
}

@Test
public void testMultiMatchBracketSyntaxNotSupported() {
givenInvalidQuery(
Expand Down
Loading