Skip to content

Commit

Permalink
Add missing predicate rewrites for Redshift
Browse files Browse the repository at this point in the history
  • Loading branch information
SemionPar authored and wendigo committed Aug 7, 2024
1 parent 2ceefd0 commit 3b0009d
Showing 1 changed file with 16 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,10 @@
import io.trino.plugin.jdbc.aggregation.ImplementSum;
import io.trino.plugin.jdbc.aggregation.ImplementVariancePop;
import io.trino.plugin.jdbc.aggregation.ImplementVarianceSamp;
import io.trino.plugin.jdbc.expression.ComparisonOperator;
import io.trino.plugin.jdbc.expression.JdbcConnectorExpressionRewriterBuilder;
import io.trino.plugin.jdbc.expression.ParameterizedExpression;
import io.trino.plugin.jdbc.expression.RewriteComparison;
import io.trino.plugin.jdbc.logging.RemoteQueryModifier;
import io.trino.spi.TrinoException;
import io.trino.spi.connector.AggregateFunction;
Expand All @@ -67,6 +69,7 @@
import io.trino.spi.connector.JoinCondition;
import io.trino.spi.connector.JoinStatistics;
import io.trino.spi.connector.JoinType;
import io.trino.spi.expression.ConnectorExpression;
import io.trino.spi.statistics.TableStatistics;
import io.trino.spi.type.CharType;
import io.trino.spi.type.Chars;
Expand Down Expand Up @@ -224,6 +227,7 @@ public class RedshiftClient
private final AggregateFunctionRewriter<JdbcExpression, ?> aggregateFunctionRewriter;
private final boolean statisticsEnabled;
private final RedshiftTableStatisticsReader statisticsReader;
private final ConnectorExpressionRewriter<ParameterizedExpression> connectorExpressionRewriter;

@Inject
public RedshiftClient(
Expand All @@ -235,8 +239,13 @@ public RedshiftClient(
RemoteQueryModifier queryModifier)
{
super("\"", connectionFactory, queryBuilder, config.getJdbcTypesMappedToVarchar(), identifierMapping, queryModifier, true);
ConnectorExpressionRewriter<ParameterizedExpression> connectorExpressionRewriter = JdbcConnectorExpressionRewriterBuilder.newBuilder()
connectorExpressionRewriter = JdbcConnectorExpressionRewriterBuilder.newBuilder()
.addStandardRules(this::quoted)
.add(new RewriteComparison(ImmutableSet.of(ComparisonOperator.EQUAL, ComparisonOperator.NOT_EQUAL)))
.map("$less_than(left, right)").to("left < right")
.map("$less_than_or_equal(left, right)").to("left <= right")
.map("$greater_than(left, right)").to("left > right")
.map("$greater_than_or_equal(left, right)").to("left >= right")
.build();

JdbcTypeHandle bigintTypeHandle = new JdbcTypeHandle(Types.BIGINT, Optional.of("bigint"), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty());
Expand Down Expand Up @@ -349,6 +358,12 @@ public Optional<JdbcExpression> implementAggregation(ConnectorSession session, A
return aggregateFunctionRewriter.rewrite(session, aggregate, assignments);
}

@Override
public Optional<ParameterizedExpression> convertPredicate(ConnectorSession session, ConnectorExpression expression, Map<String, ColumnHandle> assignments)
{
return connectorExpressionRewriter.rewrite(session, expression, assignments);
}

@Override
public TableStatistics getTableStatistics(ConnectorSession session, JdbcTableHandle handle)
{
Expand Down

0 comments on commit 3b0009d

Please # to comment.