diff --git a/api/src/main/java/org/apache/iceberg/expressions/UnboundPredicate.java b/api/src/main/java/org/apache/iceberg/expressions/UnboundPredicate.java index 75ca9d5835bc..75580c1225c6 100644 --- a/api/src/main/java/org/apache/iceberg/expressions/UnboundPredicate.java +++ b/api/src/main/java/org/apache/iceberg/expressions/UnboundPredicate.java @@ -26,6 +26,7 @@ import org.apache.iceberg.relocated.com.google.common.collect.Iterables; import org.apache.iceberg.relocated.com.google.common.collect.Lists; import org.apache.iceberg.relocated.com.google.common.collect.Sets; +import org.apache.iceberg.transforms.Transforms; import org.apache.iceberg.types.Type; import org.apache.iceberg.types.TypeUtil; import org.apache.iceberg.types.Types; @@ -207,10 +208,42 @@ private Expression bindLiteralOperation(BoundTerm boundTerm) { } } - // TODO: translate truncate(col) == value to startsWith(value) + Expression boundTransformExpression = bindTransformExpression(boundTerm, lit); + if (boundTransformExpression != null) { + return boundTransformExpression; + } + return new BoundLiteralPredicate<>(op(), boundTerm, lit); } + private Expression bindTransformExpression(BoundTerm boundTerm, Literal lit) { + if (op() == Operation.EQ + && boundTerm instanceof BoundTransform + && boundTerm.type().equals(Types.StringType.get())) { + BoundTransform boundTransform = (BoundTransform) boundTerm; + Integer width = Transforms.truncateWidth(boundTransform.transform()); + if (width == null) { + return null; + } + + String value = lit.value().toString(); + int length = value.codePointCount(0, value.length()); + if (length == width) { + return startsWithPredicate(boundTransform.ref(), lit); + } else if (length > width) { + return Expressions.alwaysFalse(); + } + } + + return null; + } + + @SuppressWarnings("unchecked") + private Expression startsWithPredicate(BoundReference ref, Literal lit) { + return new BoundLiteralPredicate<>( + Operation.STARTS_WITH, (BoundReference) ref, (Literal) lit); + } + private Expression bindInOperation(BoundTerm boundTerm) { List> convertedLiterals = Lists.newArrayList( diff --git a/api/src/main/java/org/apache/iceberg/transforms/Transforms.java b/api/src/main/java/org/apache/iceberg/transforms/Transforms.java index a3a6a3f6321d..e2fe8e044896 100644 --- a/api/src/main/java/org/apache/iceberg/transforms/Transforms.java +++ b/api/src/main/java/org/apache/iceberg/transforms/Transforms.java @@ -271,6 +271,20 @@ public static Transform truncate(int width) { return Truncate.get(width); } + /** + * Returns the width of a truncate transform, or null if the transform is not truncate. + * + * @param transform a transform + * @return the width of the truncate transform, or null + */ + public static Integer truncateWidth(Transform transform) { + if (transform instanceof Truncate) { + return ((Truncate) transform).width(); + } + + return null; + } + /** * Returns a {@link Transform} that always produces null. * diff --git a/api/src/test/java/org/apache/iceberg/expressions/TestExpressionBinding.java b/api/src/test/java/org/apache/iceberg/expressions/TestExpressionBinding.java index 24e58ad1e808..25a4563ed785 100644 --- a/api/src/test/java/org/apache/iceberg/expressions/TestExpressionBinding.java +++ b/api/src/test/java/org/apache/iceberg/expressions/TestExpressionBinding.java @@ -230,6 +230,44 @@ public void testTransformExpressionBinding() { .hasToString("bucket[16]"); } + @Test + public void testTruncateStringEqualBinding() { + Expression bound = Binder.bind(STRUCT, equal(truncate("data", 3), "abc")); + TestHelpers.assertAllReferencesBound("TruncateEquals", bound); + BoundPredicate pred = TestHelpers.assertAndUnwrap(bound); + assertThat(pred.op()).isEqualTo(Expression.Operation.STARTS_WITH); + assertThat(pred.term()).isInstanceOf(BoundReference.class); + assertThat(pred.term().ref().fieldId()).isEqualTo(3); + assertThat(pred.asLiteralPredicate().literal().value()).isEqualTo("abc"); + } + + @Test + public void testTruncateStringEqualBindingWithUnicode() { + Expression bound = Binder.bind(STRUCT, equal(truncate("data", 2), "a😀")); + TestHelpers.assertAllReferencesBound("TruncateEquals", bound); + BoundPredicate pred = TestHelpers.assertAndUnwrap(bound); + assertThat(pred.op()).isEqualTo(Expression.Operation.STARTS_WITH); + assertThat(pred.term()).isInstanceOf(BoundReference.class); + assertThat(pred.term().ref().fieldId()).isEqualTo(3); + assertThat(pred.asLiteralPredicate().literal().value()).isEqualTo("a😀"); + } + + @Test + public void testTruncateStringEqualBindingWithShortLiteral() { + Expression bound = Binder.bind(STRUCT, equal(truncate("data", 3), "ab")); + TestHelpers.assertAllReferencesBound("TruncateEquals", bound); + BoundPredicate pred = TestHelpers.assertAndUnwrap(bound); + assertThat(pred.op()).isEqualTo(Expression.Operation.EQ); + assertThat(pred.term()).isInstanceOf(BoundTransform.class); + } + + @Test + public void testTruncateStringEqualBindingWithLongLiteral() { + assertThat(Binder.bind(STRUCT, equal(truncate("data", 3), "abcd"))) + .as("A truncate result cannot be longer than the truncate width") + .isEqualTo(alwaysFalse()); + } + @Test public void testIsNullWithUnknown() { Expression bound = Binder.bind(STRUCT, isNull("always_null")); diff --git a/data/src/test/java/org/apache/iceberg/data/TestMetricsRowGroupFilter.java b/data/src/test/java/org/apache/iceberg/data/TestMetricsRowGroupFilter.java index 3cb46b309d82..1262fdc7c4a5 100644 --- a/data/src/test/java/org/apache/iceberg/data/TestMetricsRowGroupFilter.java +++ b/data/src/test/java/org/apache/iceberg/data/TestMetricsRowGroupFilter.java @@ -1005,7 +1005,7 @@ public void testTransformFilter() { assumeThat(format).isEqualTo(FileFormat.PARQUET); boolean shouldRead = - new ParquetMetricsRowGroupFilter(SCHEMA, equal(truncate("required", 2), "some_value"), true) + new ParquetMetricsRowGroupFilter(SCHEMA, equal(truncate("required", 2), "s"), true) .shouldRead(parquetSchema, rowGroupMetadata); assertThat(shouldRead) .as("Should read: filter contains non-reference evaluate as True") diff --git a/parquet/src/test/java/org/apache/iceberg/parquet/TestDictionaryRowGroupFilter.java b/parquet/src/test/java/org/apache/iceberg/parquet/TestDictionaryRowGroupFilter.java index 22f8068c0fa3..e48e9e046f01 100644 --- a/parquet/src/test/java/org/apache/iceberg/parquet/TestDictionaryRowGroupFilter.java +++ b/parquet/src/test/java/org/apache/iceberg/parquet/TestDictionaryRowGroupFilter.java @@ -1267,7 +1267,7 @@ SCHEMA, greaterThanOrEqual("decimal_fixed", BigDecimal.ZERO)) @TestTemplate public void testTransformFilter() { boolean shouldRead = - new ParquetDictionaryRowGroupFilter(SCHEMA, equal(truncate("required", 2), "some_value")) + new ParquetDictionaryRowGroupFilter(SCHEMA, equal(truncate("required", 2), "s")) .shouldRead(parquetSchema, rowGroupMetadata, dictionaryStore); assertThat(shouldRead) .as("Should read: filter contains non-reference evaluate as True")