diff --git a/core/trino-main/src/test/java/io/trino/operator/scalar/TestStringFunctions.java b/core/trino-main/src/test/java/io/trino/operator/scalar/TestStringFunctions.java index f2f0c75c7be8..188e062525e6 100644 --- a/core/trino-main/src/test/java/io/trino/operator/scalar/TestStringFunctions.java +++ b/core/trino-main/src/test/java/io/trino/operator/scalar/TestStringFunctions.java @@ -1041,7 +1041,7 @@ public void testCharConcat() assertFunction("concat('hello na\u00EFve', cast(' world' as char(6)))", createCharType(17), "hello na\u00EFve world"); - assertInvalidFunction("concat(cast('ab ' as char(40000)), cast('' as char(40000)))", "line 1:1: CHAR length scale must be in range [0, 65536]"); + assertInvalidFunction("concat(cast('ab ' as char(40000)), cast('' as char(40000)))", "line 1:1: CHAR length must be in range [0, 65536], got 80000"); assertFunction("concat(cast(null as char(1)), cast(' ' as char(1)))", createCharType(2), null); } diff --git a/core/trino-spi/src/main/java/io/trino/spi/type/CharType.java b/core/trino-spi/src/main/java/io/trino/spi/type/CharType.java index d846d4f60ecf..a0cfd5fd0080 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/type/CharType.java +++ b/core/trino-spi/src/main/java/io/trino/spi/type/CharType.java @@ -63,7 +63,7 @@ private CharType(long length) Slice.class); if (length < 0 || length > MAX_LENGTH) { - throw new TrinoException(INVALID_FUNCTION_ARGUMENT, format("CHAR length scale must be in range [0, %s]", MAX_LENGTH)); + throw new TrinoException(INVALID_FUNCTION_ARGUMENT, format("CHAR length must be in range [0, %s], got %s", MAX_LENGTH, length)); } this.length = (int) length; } diff --git a/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlClient.java b/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlClient.java index c3de109c813f..fc1c86f886d9 100644 --- a/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlClient.java +++ b/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlClient.java @@ -33,6 +33,7 @@ import io.trino.plugin.jdbc.LongWriteFunction; import io.trino.plugin.jdbc.ObjectReadFunction; import io.trino.plugin.jdbc.ObjectWriteFunction; +import io.trino.plugin.jdbc.PredicatePushdownController; import io.trino.plugin.jdbc.ReadFunction; import io.trino.plugin.jdbc.SliceReadFunction; import io.trino.plugin.jdbc.SliceWriteFunction; @@ -122,16 +123,16 @@ import static io.trino.plugin.jdbc.DecimalSessionSessionProperties.getDecimalRoundingMode; import static io.trino.plugin.jdbc.JdbcErrorCode.JDBC_ERROR; import static io.trino.plugin.jdbc.PredicatePushdownController.DISABLE_PUSHDOWN; +import static io.trino.plugin.jdbc.PredicatePushdownController.FULL_PUSHDOWN; import static io.trino.plugin.jdbc.StandardColumnMappings.bigintColumnMapping; import static io.trino.plugin.jdbc.StandardColumnMappings.bigintWriteFunction; import static io.trino.plugin.jdbc.StandardColumnMappings.booleanColumnMapping; import static io.trino.plugin.jdbc.StandardColumnMappings.booleanWriteFunction; +import static io.trino.plugin.jdbc.StandardColumnMappings.charReadFunction; import static io.trino.plugin.jdbc.StandardColumnMappings.charWriteFunction; import static io.trino.plugin.jdbc.StandardColumnMappings.dateColumnMapping; import static io.trino.plugin.jdbc.StandardColumnMappings.dateWriteFunction; import static io.trino.plugin.jdbc.StandardColumnMappings.decimalColumnMapping; -import static io.trino.plugin.jdbc.StandardColumnMappings.defaultCharColumnMapping; -import static io.trino.plugin.jdbc.StandardColumnMappings.defaultVarcharColumnMapping; import static io.trino.plugin.jdbc.StandardColumnMappings.doubleColumnMapping; import static io.trino.plugin.jdbc.StandardColumnMappings.doubleWriteFunction; import static io.trino.plugin.jdbc.StandardColumnMappings.fromTrinoTimestamp; @@ -147,6 +148,7 @@ import static io.trino.plugin.jdbc.StandardColumnMappings.tinyintWriteFunction; import static io.trino.plugin.jdbc.StandardColumnMappings.varbinaryColumnMapping; import static io.trino.plugin.jdbc.StandardColumnMappings.varbinaryWriteFunction; +import static io.trino.plugin.jdbc.StandardColumnMappings.varcharReadFunction; import static io.trino.plugin.jdbc.StandardColumnMappings.varcharWriteFunction; import static io.trino.plugin.jdbc.TypeHandlingJdbcSessionProperties.getUnsupportedTypeHandling; import static io.trino.plugin.jdbc.UnsupportedTypeHandling.CONVERT_TO_VARCHAR; @@ -164,6 +166,7 @@ import static io.trino.spi.StandardErrorCode.NOT_SUPPORTED; import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.BooleanType.BOOLEAN; +import static io.trino.spi.type.CharType.createCharType; import static io.trino.spi.type.DateTimeEncoding.packDateTimeWithZone; import static io.trino.spi.type.DateTimeEncoding.unpackMillisUtc; import static io.trino.spi.type.DateType.DATE; @@ -188,6 +191,8 @@ import static io.trino.spi.type.TypeSignature.mapType; import static io.trino.spi.type.VarbinaryType.VARBINARY; import static io.trino.spi.type.VarcharType.VARCHAR; +import static io.trino.spi.type.VarcharType.createUnboundedVarcharType; +import static io.trino.spi.type.VarcharType.createVarcharType; import static java.lang.Math.max; import static java.lang.Math.min; import static java.lang.String.format; @@ -216,6 +221,22 @@ public class PostgreSqlClient private final List tableTypes; private final AggregateFunctionRewriter aggregateFunctionRewriter; + private static final PredicatePushdownController POSTGRESQL_CHARACTER_PUSHDOWN = (session, domain) -> { + checkArgument( + domain.getType() instanceof VarcharType || domain.getType() instanceof CharType, + "This PredicatePushdownController can be used only for chars and varchars"); + + if (domain.isOnlyNull() || + // PostgreSQL is case sensitive by default + domain.getValues().isDiscreteSet()) { + return FULL_PUSHDOWN.apply(session, domain); + } + + // PostgreSQL by default orders lowercase letters before uppercase, which is different from Trino + // TODO We could still push the predicates down if we could inject a PostgreSQL-specific syntax for selecting a collation for given comparison. + return DISABLE_PUSHDOWN.apply(session, domain); + }; + @Inject public PostgreSqlClient( BaseJdbcConfig config, @@ -455,14 +476,14 @@ public Optional toColumnMapping(ConnectorSession session, Connect } case Types.CHAR: - return Optional.of(defaultCharColumnMapping(typeHandle.getRequiredColumnSize(), true)); + return Optional.of(charColumnMapping(typeHandle.getRequiredColumnSize())); case Types.VARCHAR: if (!jdbcTypeName.equals("varchar")) { // This can be e.g. an ENUM return Optional.of(typedVarcharColumnMapping(jdbcTypeName)); } - return Optional.of(defaultVarcharColumnMapping(typeHandle.getRequiredColumnSize(), true)); + return Optional.of(varcharColumnMapping(typeHandle.getRequiredColumnSize())); case Types.BINARY: return Optional.of(varbinaryColumnMapping()); @@ -689,6 +710,31 @@ public boolean isLimitGuaranteed(ConnectorSession session) return true; } + private static ColumnMapping charColumnMapping(int charLength) + { + if (charLength > CharType.MAX_LENGTH) { + return varcharColumnMapping(charLength); + } + CharType charType = createCharType(charLength); + return ColumnMapping.sliceMapping( + charType, + charReadFunction(charType), + charWriteFunction(), + POSTGRESQL_CHARACTER_PUSHDOWN); + } + + private static ColumnMapping varcharColumnMapping(int varcharLength) + { + VarcharType varcharType = varcharLength <= VarcharType.MAX_LENGTH + ? createVarcharType(varcharLength) + : createUnboundedVarcharType(); + return ColumnMapping.sliceMapping( + varcharType, + varcharReadFunction(varcharType), + varcharWriteFunction(), + POSTGRESQL_CHARACTER_PUSHDOWN); + } + private static ColumnMapping timeColumnMapping(int precision) { verify(precision <= 6, "Unsupported precision: %s", precision); // PostgreSQL limit but also assumption within this method diff --git a/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlConnectorTest.java b/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlConnectorTest.java index cc9ca6c27855..1fe848bd6b6d 100644 --- a/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlConnectorTest.java +++ b/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlConnectorTest.java @@ -25,7 +25,6 @@ import io.trino.testing.sql.JdbcSqlExecutor; import io.trino.testing.sql.TestTable; import org.intellij.lang.annotations.Language; -import org.testng.SkipException; import org.testng.annotations.BeforeClass; import org.testng.annotations.Test; @@ -245,7 +244,7 @@ public void testPredicatePushdown() // varchar range assertThat(query("SELECT regionkey, nationkey, name FROM nation WHERE name BETWEEN 'POLAND' AND 'RPA'")) .matches("VALUES (BIGINT '3', BIGINT '19', CAST('ROMANIA' AS varchar(25)))") - .isFullyPushedDown(); + .isNotFullyPushedDown(FilterNode.class); // varchar different case assertThat(query("SELECT regionkey, nationkey, name FROM nation WHERE name = 'romania'")) @@ -475,7 +474,7 @@ public void testAggregationPushdown() // GROUP BY and WHERE on varchar column // GROUP BY and WHERE on "other" (not aggregation key, not aggregation input) - assertThat(query("SELECT regionkey, sum(nationkey) FROM nation WHERE regionkey < 4 AND name > 'AAA' GROUP BY regionkey")).isFullyPushedDown(); + assertThat(query("SELECT regionkey, sum(nationkey) FROM nation WHERE regionkey < 4 AND name > 'AAA' GROUP BY regionkey")).isNotFullyPushedDown(FilterNode.class); // GROUP BY above WHERE and LIMIT assertThat(query("" + @@ -750,7 +749,7 @@ public void testLimitPushdown() assertThat(query("SELECT name FROM nation WHERE regionkey = 3 LIMIT 5")).isFullyPushedDown(); // with filter over varchar column - assertThat(query("SELECT name FROM nation WHERE name < 'EEE' LIMIT 5")).isFullyPushedDown(); + assertThat(query("SELECT name FROM nation WHERE name < 'EEE' LIMIT 5")).isNotFullyPushedDown(FilterNode.class); // with aggregation assertThat(query("SELECT max(regionkey) FROM nation LIMIT 5")).isFullyPushedDown(); // global aggregation, LIMIT removed @@ -759,7 +758,7 @@ public void testLimitPushdown() // with filter and aggregation assertThat(query("SELECT regionkey, count(*) FROM nation WHERE nationkey < 5 GROUP BY regionkey LIMIT 3")).isFullyPushedDown(); - assertThat(query("SELECT regionkey, count(*) FROM nation WHERE name < 'EGYPT' GROUP BY regionkey LIMIT 3")).isFullyPushedDown(); + assertThat(query("SELECT regionkey, count(*) FROM nation WHERE name < 'EGYPT' GROUP BY regionkey LIMIT 3")).isNotFullyPushedDown(FilterNode.class); // with TopN assertThat(query("SELECT * FROM (SELECT regionkey FROM nation ORDER BY name ASC LIMIT 10) LIMIT 5")).isFullyPushedDown(); @@ -811,13 +810,6 @@ public void testTimestampColumnAndTimestampWithTimeZoneConstant() } } - @Override - public void testCaseSensitiveDataMapping(DataMappingTestSetup dataMappingTestSetup) - { - // TODO - https://github.com/trinodb/trino/issues/3645 - throw new SkipException("PostgreSQL has different collation than Trino"); - } - private String getLongInClause(int start, int length) { String longValues = range(start, start + length)