Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@ public List<RelShuttle> postAnalysisRules() {
return List.of(DatetimeUdtNormalizeRule.INSTANCE, DatetimeOutputCastRule.INSTANCE);
}

@Override
public List<RelShuttle> preCompilationRules() {
return List.of(DatetimeUdfCompilationAdapterRule.INSTANCE);
}

/** Maps datetime UDT types to their standard Calcite equivalents. */
@Getter
@RequiredArgsConstructor
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.api.spec.datetime;

import static org.opensearch.sql.api.spec.datetime.DatetimeExtension.UdtMapping.isDatetimeType;

import java.util.ArrayList;
import java.util.List;
import lombok.AccessLevel;
import lombok.NoArgsConstructor;
import org.apache.calcite.rel.RelHomogeneousShuttle;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexShuttle;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.sql.validate.SqlUserDefinedFunction;

/**
* Adapts datetime UDF calls for Enumerable compilation. PPL UDF implementors expect String
* input/output, but after normalization the plan uses standard DATE/TIME/TIMESTAMP types
* (int/long). This rule inserts CASTs to bridge the mismatch:
*
* <pre>
* Before: LAST_DAY($2:DATE) : DATE
* After: CAST(LAST_DAY(CAST($2 AS VARCHAR)):VARCHAR AS DATE)
* </pre>
*/
@NoArgsConstructor(access = AccessLevel.PACKAGE)
class DatetimeUdfCompilationAdapterRule extends RelHomogeneousShuttle {

static final DatetimeUdfCompilationAdapterRule INSTANCE = new DatetimeUdfCompilationAdapterRule();

@Override
public RelNode visit(RelNode other) {
RelNode visited = super.visit(other);
RexBuilder rexBuilder = visited.getCluster().getRexBuilder();
RelDataTypeFactory typeFactory = rexBuilder.getTypeFactory();
return visited.accept(
new RexShuttle() {
@Override
public RexNode visitCall(RexCall call) {
call = (RexCall) super.visitCall(call);
if (!(call.getOperator() instanceof SqlUserDefinedFunction)) {
return call;
}

// Adapt operands: CAST(datetime_operand AS VARCHAR) for UDF implementor
List<RexNode> adapted = new ArrayList<>(call.getOperands().size());
boolean operandsChanged = false;
for (RexNode operand : call.getOperands()) {
if (isDatetimeType(operand.getType().getSqlTypeName())) {
RelDataType varcharType =
typeFactory.createTypeWithNullability(
typeFactory.createSqlType(SqlTypeName.VARCHAR),
operand.getType().isNullable());
adapted.add(rexBuilder.makeCast(varcharType, operand));
operandsChanged = true;
} else {
adapted.add(operand);
}
}

// Adapt result: if return type is datetime, wrap call with VARCHAR return + CAST back
if (isDatetimeType(call.getType().getSqlTypeName())) {
RelDataType declaredType = call.getType();
RelDataType varcharType =
typeFactory.createTypeWithNullability(
typeFactory.createSqlType(SqlTypeName.VARCHAR), declaredType.isNullable());
RexCall varcharCall =
call.clone(varcharType, operandsChanged ? adapted : call.getOperands());
return rexBuilder.makeCast(declaredType, varcharCall);
}

return operandsChanged ? call.clone(call.getType(), adapted) : call;
}
});
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,21 @@

package org.opensearch.sql.api.spec.datetime;

import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import lombok.AccessLevel;
import lombok.NoArgsConstructor;
import org.apache.calcite.rel.RelHomogeneousShuttle;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.logical.LogicalAggregate;
import org.apache.calcite.rel.logical.LogicalProject;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexShuttle;
import org.apache.calcite.sql.type.SqlTypeName;
Expand All @@ -30,10 +36,82 @@ class DatetimeUdtNormalizeRule extends RelHomogeneousShuttle {

@Override
public RelNode visit(RelNode other) {
RelNode visited = super.visit(other);
RexBuilder rexBuilder = visited.getCluster().getRexBuilder();
// Visit children first
List<RelNode> newInputs = new ArrayList<>();
boolean childChanged = false;
for (RelNode input : other.getInputs()) {
RelNode newInput = input.accept(this);
newInputs.add(newInput);
if (newInput != input) {
childChanged = true;
}
}

// Rebuild current node if children changed
RelNode current = other;
if (childChanged) {
if (current instanceof LogicalAggregate agg) {
// Aggregate needs AggregateCall types rebuilt
RelNode newInput = newInputs.get(0);
List<AggregateCall> newAggCalls =
agg.getAggCallList().stream()
.map(
call ->
AggregateCall.create(
call.getAggregation(),
call.isDistinct(),
call.isApproximate(),
call.ignoreNulls(),
call.rexList,
call.getArgList(),
call.filterArg,
call.distinctKeys,
call.collation,
agg.getGroupCount(),
newInput,
null,
call.getName()))
.toList();
current =
agg.copy(
agg.getTraitSet(), newInput, agg.getGroupSet(), agg.getGroupSets(), newAggCalls);
} else if (current instanceof LogicalProject proj) {
// Project needs RexInputRef types refreshed from new child
RelNode newInput = newInputs.get(0);
RexBuilder rexBuilder = proj.getCluster().getRexBuilder();
List<RexNode> newProjects =
proj.getProjects().stream()
.map(
expr ->
expr.accept(
new RexShuttle() {
@Override
public RexNode visitInputRef(RexInputRef ref) {
RelDataType newType =
newInput
.getRowType()
.getFieldList()
.get(ref.getIndex())
.getType();
if (!newType.equals(ref.getType())) {
return rexBuilder.makeInputRef(newType, ref.getIndex());
}
return ref;
}
}))
.toList();
current =
LogicalProject.create(
newInput, proj.getHints(), newProjects, proj.getRowType().getFieldNames());
} else {
current = current.copy(current.getTraitSet(), newInputs);
}
}

// Apply RexShuttle to normalize UDT types in this node's expressions
RexBuilder rexBuilder = current.getCluster().getRexBuilder();
RelDataTypeFactory typeFactory = rexBuilder.getTypeFactory();
return visited.accept(
return current.accept(
new RexShuttle() {
@Override
public RexNode visitCall(RexCall call) {
Expand All @@ -43,7 +121,6 @@ public RexNode visitCall(RexCall call) {
return call;
}

// Normalize UDT return type to standard Calcite DATE/TIME/TIMESTAMP
UdtMapping m = mapping.get();
SqlTypeName stdTypeName = m.getStdType();
RelDataType baseType =
Expand Down
Loading
Loading