diff --git a/api/src/main/java/org/opensearch/sql/api/spec/search/NamedArgRewriter.java b/api/src/main/java/org/opensearch/sql/api/spec/search/NamedArgRewriter.java index 8627a76f2cf..14cd6852b1a 100644 --- a/api/src/main/java/org/opensearch/sql/api/spec/search/NamedArgRewriter.java +++ b/api/src/main/java/org/opensearch/sql/api/spec/search/NamedArgRewriter.java @@ -5,16 +5,21 @@ package org.opensearch.sql.api.spec.search; +import java.math.BigDecimal; +import java.util.ArrayList; import java.util.List; import lombok.AccessLevel; import lombok.NoArgsConstructor; +import org.apache.calcite.sql.SqlBasicTypeNameSpec; import org.apache.calcite.sql.SqlCall; +import org.apache.calcite.sql.SqlDataTypeSpec; import org.apache.calcite.sql.SqlIdentifier; import org.apache.calcite.sql.SqlKind; import org.apache.calcite.sql.SqlLiteral; import org.apache.calcite.sql.SqlNode; import org.apache.calcite.sql.fun.SqlStdOperatorTable; import org.apache.calcite.sql.parser.SqlParserPos; +import org.apache.calcite.sql.type.SqlTypeName; import org.apache.calcite.sql.util.SqlShuttle; import org.checkerframework.checker.nullness.qual.Nullable; import org.opensearch.sql.api.spec.UnifiedFunctionSpec; @@ -31,6 +36,10 @@ public final class NamedArgRewriter extends SqlShuttle { public static final NamedArgRewriter INSTANCE = new NamedArgRewriter(); + private static final SqlDataTypeSpec VARCHAR_TYPE = + new SqlDataTypeSpec( + new SqlBasicTypeNameSpec(SqlTypeName.VARCHAR, SqlParserPos.ZERO), SqlParserPos.ZERO); + @Override public @Nullable SqlNode visit(SqlCall call) { SqlCall visited = (SqlCall) super.visit(call); @@ -44,6 +53,7 @@ public final class NamedArgRewriter extends SqlShuttle { * Rewrites each argument into a MAP entry. For match(name, 'John', operator='AND'): *
  • Positional arg: name → MAP('field', name) *
  • Named arg: operator='AND' → MAP('operator', 'AND') + *
  • ARRAY arg: ARRAY['f1','f2'] → MAP('fields', MAP(CAST('f1' AS VARCHAR), 1, ...)) */ private static SqlCall rewriteToMaps(SqlCall call, List paramNames) { List operands = call.getOperandList(); @@ -62,12 +72,34 @@ private static SqlCall rewriteToMaps(SqlCall call, List paramNames) { throw new IllegalArgumentException( String.format("Invalid arguments for function '%s'", call.getOperator().getName())); } - maps[i] = toMap(paramNames.get(i), op); + String paramName = paramNames.get(i); + if ("fields".equals(paramName) + && op instanceof SqlCall arrayCall + && arrayCall.getOperator() == SqlStdOperatorTable.ARRAY_VALUE_CONSTRUCTOR) { + maps[i] = toMap(paramName, expandFieldsArray(arrayCall)); + } else { + maps[i] = toMap(paramName, op); + } } } return call.getOperator().createCall(call.getParserPosition(), maps); } + /** Expands ARRAY['f1', 'f2'] into MAP(CAST('f1' AS VARCHAR), 1, CAST('f2' AS VARCHAR), 1). */ + private static SqlNode expandFieldsArray(SqlCall arrayCall) { + List mapArgs = new ArrayList<>(); + for (SqlNode element : arrayCall.getOperandList()) { + mapArgs.add(castToVarchar(element)); + mapArgs.add(SqlLiteral.createExactNumeric(BigDecimal.ONE.toPlainString(), SqlParserPos.ZERO)); + } + return SqlStdOperatorTable.MAP_VALUE_CONSTRUCTOR.createCall( + SqlParserPos.ZERO, mapArgs.toArray(SqlNode[]::new)); + } + + private static SqlNode castToVarchar(SqlNode node) { + return SqlStdOperatorTable.CAST.createCall(SqlParserPos.ZERO, node, VARCHAR_TYPE); + } + private static SqlNode toMap(String key, SqlNode value) { return SqlStdOperatorTable.MAP_VALUE_CONSTRUCTOR.createCall( SqlParserPos.ZERO, SqlLiteral.createCharString(key, SqlParserPos.ZERO), value); diff --git a/api/src/test/java/org/opensearch/sql/api/UnifiedRelevanceSearchSqlTest.java b/api/src/test/java/org/opensearch/sql/api/UnifiedRelevanceSearchSqlTest.java index 66df9c2e075..df65061aff2 100644 --- a/api/src/test/java/org/opensearch/sql/api/UnifiedRelevanceSearchSqlTest.java +++ b/api/src/test/java/org/opensearch/sql/api/UnifiedRelevanceSearchSqlTest.java @@ -145,6 +145,42 @@ SELECT upper(name) FROM catalog.employees\ // FIXME: Calcite's SQL parser does not support V2 bracket field list syntax ['field1', 'field2']. // Multi-field relevance functions only accept a single column reference in the Calcite SQL path. + @Test + public void testMultiMatchArraySyntax() { + givenQuery( + """ + SELECT * FROM catalog.employees + WHERE multi_match(ARRAY['name', 'department'], 'John')\ + """) + .assertPlanContains( + "multi_match(MAP('fields', MAP('name':VARCHAR, 1, 'department':VARCHAR, 1))," + + " MAP('query', 'John'))"); + } + + @Test + public void testSimpleQueryStringArraySyntax() { + givenQuery( + """ + SELECT * FROM catalog.employees + WHERE simple_query_string(ARRAY['name', 'department'], 'John')\ + """) + .assertPlanContains( + "simple_query_string(MAP('fields', MAP('name':VARCHAR, 1," + + " 'department':VARCHAR, 1)), MAP('query', 'John'))"); + } + + @Test + public void testQueryStringArraySyntax() { + givenQuery( + """ + SELECT * FROM catalog.employees + WHERE query_string(ARRAY['name', 'department'], 'John')\ + """) + .assertPlanContains( + "query_string(MAP('fields', MAP('name':VARCHAR, 1," + + " 'department':VARCHAR, 1)), MAP('query', 'John'))"); + } + @Test public void testMultiMatchBracketSyntaxNotSupported() { givenInvalidQuery(