Skip to content

Commit 918ffd5

Browse files
Parquet reader - avoid redundant dictionary reads
1 parent 33d6fd5 commit 918ffd5

File tree

10 files changed

+437
-263
lines changed

10 files changed

+437
-263
lines changed

lib/trino-parquet/src/main/java/io/trino/parquet/predicate/PredicateUtils.java

Lines changed: 15 additions & 192 deletions
Large diffs are not rendered by default.

lib/trino-parquet/src/main/java/io/trino/parquet/reader/AbstractColumnReader.java

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@
1818
import io.trino.parquet.DictionaryPage;
1919
import io.trino.parquet.ParquetEncoding;
2020
import io.trino.parquet.PrimitiveField;
21+
import io.trino.parquet.metadata.ColumnChunkMetadata;
22+
import io.trino.parquet.metadata.PrunedBlockMetadata;
23+
import io.trino.parquet.predicate.DictionaryDescriptor;
24+
import io.trino.parquet.predicate.TupleDomainParquetPredicate;
2125
import io.trino.parquet.reader.decoders.ValueDecoder;
2226
import io.trino.parquet.reader.flat.ColumnAdapter;
2327
import io.trino.parquet.reader.flat.DictionaryDecoder;
@@ -28,13 +32,19 @@
2832
import io.trino.spi.type.DateType;
2933
import io.trino.spi.type.Type;
3034
import jakarta.annotation.Nullable;
35+
import org.apache.parquet.column.ColumnDescriptor;
36+
import org.apache.parquet.column.statistics.Statistics;
3137
import org.apache.parquet.io.ParquetDecodingException;
3238

39+
import java.io.IOException;
3340
import java.util.Optional;
3441
import java.util.OptionalLong;
42+
import java.util.Set;
3543

44+
import static com.google.common.base.Preconditions.checkState;
3645
import static io.trino.parquet.ParquetEncoding.PLAIN_DICTIONARY;
3746
import static io.trino.parquet.ParquetEncoding.RLE_DICTIONARY;
47+
import static io.trino.parquet.ParquetReaderUtils.isOnlyDictionaryEncodingPages;
3848
import static io.trino.parquet.reader.decoders.ValueDecoder.ValueDecodersProvider;
3949
import static io.trino.parquet.reader.flat.DictionaryDecoder.DictionaryDecoderProvider;
4050
import static io.trino.parquet.reader.flat.RowRangesIterator.createRowRangesIterator;
@@ -56,6 +66,8 @@ public abstract class AbstractColumnReader<BufferType>
5666
@Nullable
5767
protected DictionaryDecoder<BufferType> dictionaryDecoder;
5868
private boolean produceDictionaryBlock;
69+
@Nullable
70+
private DictionaryPage dictionaryPage;
5971

6072
public AbstractColumnReader(
6173
PrimitiveField field,
@@ -77,6 +89,7 @@ public void setPageReader(PageReader pageReader, Optional<FilteredRowRanges> row
7789
// if it is partly or completely dictionary encoded. At most one dictionary page
7890
// can be placed in a column chunk.
7991
DictionaryPage dictionaryPage = pageReader.readDictionaryPage();
92+
this.dictionaryPage = dictionaryPage;
8093

8194
// For dictionary based encodings - https://github.com/apache/parquet-format/blob/master/Encodings.md
8295
if (dictionaryPage != null) {
@@ -87,6 +100,28 @@ public void setPageReader(PageReader pageReader, Optional<FilteredRowRanges> row
87100
this.rowRanges = createRowRangesIterator(rowRanges);
88101
}
89102

103+
public boolean dictionaryPredicateMatch(RowGroupInfo rowGroupInfo)
104+
throws IOException
105+
{
106+
checkState(hasPageReader(), "Don't have a pageReader yet, invoke setPageReader() first");
107+
Optional<TupleDomainParquetPredicate> indexPredicate = rowGroupInfo.indexPredicate();
108+
Optional<Set<ColumnDescriptor>> candidateColumnsForDictionaryMatching = rowGroupInfo.candidateColumnsForDictionaryMatching();
109+
if (indexPredicate.isPresent() && candidateColumnsForDictionaryMatching.isPresent()) {
110+
ColumnDescriptor descriptor = field.getDescriptor();
111+
PrunedBlockMetadata prunedBlockMetadata = rowGroupInfo.prunedBlockMetadata();
112+
ColumnChunkMetadata columnMetaData = prunedBlockMetadata.getColumnChunkMetaData(descriptor);
113+
if (candidateColumnsForDictionaryMatching.get().contains(descriptor) && isOnlyDictionaryEncodingPages(columnMetaData)) {
114+
Statistics<?> columnStatistics = columnMetaData.getStatistics();
115+
boolean nullAllowed = columnStatistics == null || columnStatistics.getNumNulls() != 0;
116+
return indexPredicate.get().matches(new DictionaryDescriptor(
117+
descriptor,
118+
nullAllowed,
119+
Optional.ofNullable(dictionaryPage)));
120+
}
121+
}
122+
return true;
123+
}
124+
90125
protected abstract boolean isNonNull();
91126

92127
protected boolean produceDictionaryBlock()

lib/trino-parquet/src/main/java/io/trino/parquet/reader/ColumnReader.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
*/
1414
package io.trino.parquet.reader;
1515

16+
import java.io.IOException;
1617
import java.util.Optional;
1718

1819
public interface ColumnReader
@@ -21,6 +22,9 @@ public interface ColumnReader
2122

2223
void setPageReader(PageReader pageReader, Optional<FilteredRowRanges> rowRanges);
2324

25+
boolean dictionaryPredicateMatch(RowGroupInfo rowGroupInfo)
26+
throws IOException;
27+
2428
void prepareNextRead(int batchSize);
2529

2630
ColumnChunk readPrimitive();

lib/trino-parquet/src/main/java/io/trino/parquet/reader/ParquetReader.java

Lines changed: 91 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ public class ParquetReader
109109
private static final int BATCH_SIZE_GROWTH_FACTOR = 2;
110110
public static final String PARQUET_CODEC_METRIC_PREFIX = "ParquetReaderCompressionFormat_";
111111
public static final String COLUMN_INDEX_ROWS_FILTERED = "ParquetColumnIndexRowsFiltered";
112+
public static final String PARQUET_READER_DICTIONARY_FILTERED_ROWGROUPS = "ParquetReaderDictionaryFilteredRowGroups";
112113

113114
private final Optional<String> fileCreatedBy;
114115
private final List<RowGroupInfo> rowGroups;
@@ -151,6 +152,7 @@ public class ParquetReader
151152
private int currentPageId;
152153

153154
private long columnIndexRowsFiltered = -1;
155+
private long dictionaryFilteredRowGroups;
154156
private final Optional<FileDecryptionContext> decryptionContext;
155157

156158
public ParquetReader(
@@ -467,38 +469,67 @@ private int nextBatch()
467469
private boolean advanceToNextRowGroup()
468470
throws IOException
469471
{
470-
currentRowGroupMemoryContext.close();
471-
currentRowGroupMemoryContext = memoryContext.newAggregatedMemoryContext();
472-
freeCurrentRowGroupBuffers();
473-
474-
if (currentRowGroup >= 0 && rowGroupStatisticsValidation.isPresent()) {
475-
StatisticsValidation statisticsValidation = rowGroupStatisticsValidation.get();
476-
writeValidation.orElseThrow().validateRowGroupStatistics(dataSource.getId(), currentBlockMetadata, statisticsValidation.build());
477-
statisticsValidation.reset();
478-
}
479-
480-
currentRowGroup++;
481-
if (currentRowGroup == rowGroups.size()) {
482-
return false;
483-
}
484-
RowGroupInfo rowGroupInfo = rowGroups.get(currentRowGroup);
485-
currentBlockMetadata = rowGroupInfo.prunedBlockMetadata();
486-
firstRowIndexInGroup = rowGroupInfo.fileRowOffset();
487-
currentGroupRowCount = currentBlockMetadata.getRowCount();
488-
FilteredRowRanges currentGroupRowRanges = blockRowRanges[currentRowGroup];
489-
log.debug("advanceToNextRowGroup dataSource %s, currentRowGroup %d, rowRanges %s, currentBlockMetadata %s", dataSource.getId(), currentRowGroup, currentGroupRowRanges, currentBlockMetadata);
490-
if (currentGroupRowRanges != null) {
491-
long rowCount = currentGroupRowRanges.getRowCount();
492-
columnIndexRowsFiltered += currentGroupRowCount - rowCount;
493-
if (rowCount == 0) {
494-
// Filters on multiple columns with page indexes may yield non-overlapping row ranges and eliminate the entire row group.
495-
// Advance to next row group to ensure that we don't return a null Page and close the page source before all row groups are processed
496-
return advanceToNextRowGroup();
472+
while (currentRowGroup < rowGroups.size()) {
473+
currentRowGroupMemoryContext.close();
474+
currentRowGroupMemoryContext = memoryContext.newAggregatedMemoryContext();
475+
freeCurrentRowGroupBuffers();
476+
477+
if (currentRowGroup >= 0 && rowGroupStatisticsValidation.isPresent()) {
478+
StatisticsValidation statisticsValidation = rowGroupStatisticsValidation.get();
479+
writeValidation.orElseThrow().validateRowGroupStatistics(dataSource.getId(), currentBlockMetadata, statisticsValidation.build());
480+
statisticsValidation.reset();
481+
}
482+
483+
currentRowGroup++;
484+
if (currentRowGroup == rowGroups.size()) {
485+
return false;
486+
}
487+
RowGroupInfo rowGroupInfo = rowGroups.get(currentRowGroup);
488+
currentBlockMetadata = rowGroupInfo.prunedBlockMetadata();
489+
firstRowIndexInGroup = rowGroupInfo.fileRowOffset();
490+
currentGroupRowCount = currentBlockMetadata.getRowCount();
491+
FilteredRowRanges currentGroupRowRanges = blockRowRanges[currentRowGroup];
492+
log.debug("advanceToNextRowGroup dataSource %s, currentRowGroup %d, rowRanges %s, currentBlockMetadata %s", dataSource.getId(), currentRowGroup, currentGroupRowRanges, currentBlockMetadata);
493+
if (currentGroupRowRanges != null) {
494+
long rowCount = currentGroupRowRanges.getRowCount();
495+
columnIndexRowsFiltered += currentGroupRowCount - rowCount;
496+
if (rowCount == 0) {
497+
// Filters on multiple columns with page indexes may yield non-overlapping row ranges and eliminate the entire row group.
498+
// Advance to next row group to ensure that we don't return a null Page and close the page source before all row groups are processed
499+
continue;
500+
}
501+
currentGroupRowCount = rowCount;
502+
}
503+
nextRowInGroup = 0L;
504+
initializeColumnReaders();
505+
506+
// check dictionary predicate matches, or skip row group
507+
if (!dictionaryPredicateMatch(rowGroupInfo)) {
508+
dictionaryFilteredRowGroups++;
509+
continue;
510+
}
511+
return true;
512+
}
513+
return false;
514+
}
515+
516+
private boolean dictionaryPredicateMatch(RowGroupInfo rowGroupInfo)
517+
{
518+
for (PrimitiveField field : primitiveFields) {
519+
// check presence of indexPredicate and don't eagerly initializePageReader if it's not present
520+
if (rowGroupInfo.indexPredicate().isPresent()) {
521+
try {
522+
initializePageReader(field);
523+
boolean match = columnReaders.get(field.getId()).dictionaryPredicateMatch(rowGroupInfo);
524+
if (!match) {
525+
return false;
526+
}
527+
}
528+
catch (Exception e) {
529+
log.error(e, "Error while matching dictionary predicate for field " + field);
530+
}
497531
}
498-
currentGroupRowCount = rowCount;
499532
}
500-
nextRowInGroup = 0L;
501-
initializeColumnReaders();
502533
return true;
503534
}
504535

@@ -654,29 +685,10 @@ private FilteredOffsetIndex getFilteredOffsetIndex(FilteredRowRanges rowRanges,
654685
private ColumnChunk readPrimitive(PrimitiveField field)
655686
throws IOException
656687
{
657-
ColumnDescriptor columnDescriptor = field.getDescriptor();
658688
int fieldId = field.getId();
659689
ColumnReader columnReader = columnReaders.get(fieldId);
660690
if (!columnReader.hasPageReader()) {
661-
validateParquet(currentBlockMetadata.getRowCount() > 0, dataSource.getId(), "Row group has 0 rows");
662-
ColumnChunkMetadata metadata = currentBlockMetadata.getColumnChunkMetaData(columnDescriptor);
663-
FilteredRowRanges rowRanges = blockRowRanges[currentRowGroup];
664-
OffsetIndex offsetIndex = null;
665-
if (rowRanges != null) {
666-
offsetIndex = getFilteredOffsetIndex(rowRanges, currentRowGroup, currentBlockMetadata.getRowCount(), metadata.getPath());
667-
}
668-
ChunkedInputStream columnChunkInputStream = chunkReaders.get(new ChunkKey(fieldId, currentRowGroup));
669-
columnReader.setPageReader(
670-
createPageReader(
671-
dataSource.getId(),
672-
columnChunkInputStream,
673-
metadata,
674-
columnDescriptor,
675-
offsetIndex,
676-
fileCreatedBy,
677-
decryptionContext,
678-
options.getMaxPageReadSize().toBytes()),
679-
Optional.ofNullable(rowRanges));
691+
initializePageReader(field);
680692
}
681693
ColumnChunk columnChunk = columnReader.readPrimitive();
682694

@@ -692,6 +704,34 @@ private ColumnChunk readPrimitive(PrimitiveField field)
692704
return columnChunk;
693705
}
694706

707+
private void initializePageReader(PrimitiveField field)
708+
throws ParquetCorruptionException
709+
{
710+
ColumnDescriptor columnDescriptor = field.getDescriptor();
711+
int fieldId = field.getId();
712+
ColumnReader columnReader = columnReaders.get(fieldId);
713+
checkState(!columnReader.hasPageReader(), "Page reader already initialized");
714+
validateParquet(currentBlockMetadata.getRowCount() > 0, dataSource.getId(), "Row group has 0 rows");
715+
ColumnChunkMetadata metadata = currentBlockMetadata.getColumnChunkMetaData(columnDescriptor);
716+
FilteredRowRanges rowRanges = blockRowRanges[currentRowGroup];
717+
OffsetIndex offsetIndex = null;
718+
if (rowRanges != null) {
719+
offsetIndex = getFilteredOffsetIndex(rowRanges, currentRowGroup, currentBlockMetadata.getRowCount(), metadata.getPath());
720+
}
721+
ChunkedInputStream columnChunkInputStream = chunkReaders.get(new ChunkKey(fieldId, currentRowGroup));
722+
columnReader.setPageReader(
723+
createPageReader(
724+
dataSource.getId(),
725+
columnChunkInputStream,
726+
metadata,
727+
columnDescriptor,
728+
offsetIndex,
729+
fileCreatedBy,
730+
decryptionContext,
731+
options.getMaxPageReadSize().toBytes()),
732+
Optional.ofNullable(rowRanges));
733+
}
734+
695735
public List<Column> getColumnFields()
696736
{
697737
return columnFields;
@@ -704,6 +744,7 @@ public Metrics getMetrics()
704744
if (columnIndexRowsFiltered >= 0) {
705745
metrics.put(COLUMN_INDEX_ROWS_FILTERED, new LongCount(columnIndexRowsFiltered));
706746
}
747+
metrics.put(PARQUET_READER_DICTIONARY_FILTERED_ROWGROUPS, new LongCount(dictionaryFilteredRowGroups));
707748
metrics.putAll(dataSource.getMetrics().getMetrics());
708749

709750
return new Metrics(metrics.buildOrThrow());

lib/trino-parquet/src/main/java/io/trino/parquet/reader/RowGroupInfo.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,13 @@
1414
package io.trino.parquet.reader;
1515

1616
import io.trino.parquet.metadata.PrunedBlockMetadata;
17+
import io.trino.parquet.predicate.TupleDomainParquetPredicate;
18+
import org.apache.parquet.column.ColumnDescriptor;
1719
import org.apache.parquet.internal.filter2.columnindex.ColumnIndexStore;
1820

1921
import java.util.Optional;
22+
import java.util.Set;
2023

21-
public record RowGroupInfo(PrunedBlockMetadata prunedBlockMetadata, long fileRowOffset, Optional<ColumnIndexStore> columnIndexStore) {}
24+
public record RowGroupInfo(PrunedBlockMetadata prunedBlockMetadata, long fileRowOffset, Optional<ColumnIndexStore> columnIndexStore,
25+
Optional<TupleDomainParquetPredicate> indexPredicate, Optional<Set<ColumnDescriptor>> candidateColumnsForDictionaryMatching)
26+
{}

lib/trino-parquet/src/main/java/io/trino/parquet/writer/ParquetWriter.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,7 @@ private ParquetReader createParquetReader(ParquetDataSource input, ParquetMetada
277277
long nextStart = 0;
278278
ImmutableList.Builder<RowGroupInfo> rowGroupInfoBuilder = ImmutableList.builder();
279279
for (BlockMetadata block : parquetMetadata.getBlocks()) {
280-
rowGroupInfoBuilder.add(new RowGroupInfo(createPrunedColumnsMetadata(block, input.getId(), descriptorsByPath), nextStart, Optional.empty()));
280+
rowGroupInfoBuilder.add(new RowGroupInfo(createPrunedColumnsMetadata(block, input.getId(), descriptorsByPath), nextStart, Optional.empty(), Optional.empty(), Optional.empty()));
281281
nextStart += block.rowCount();
282282
}
283283
return new ParquetReader(

lib/trino-parquet/src/test/java/io/trino/parquet/ParquetTestUtils.java

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,20 @@ public static List<io.trino.spi.Page> generateInputPages(List<Type> types, int p
205205
return pagesBuilder.build();
206206
}
207207

208+
public static List<io.trino.spi.Page> generateInputPagesWithBlockData(List<Type> types, List<? extends List<?>> blockData, int pageCount)
209+
{
210+
checkArgument(blockData.size() == types.size());
211+
ImmutableList.Builder<io.trino.spi.Page> pagesBuilder = ImmutableList.builder();
212+
for (int pageIndex = 0; pageIndex < pageCount; pageIndex++) {
213+
Block[] blocks = new Block[types.size()];
214+
for (int i = 0; i < types.size(); i++) {
215+
blocks[i] = generateBlock(types.get(i), blockData.get(i));
216+
}
217+
pagesBuilder.add(new Page(blocks));
218+
}
219+
return pagesBuilder.build();
220+
}
221+
208222
public static List<Integer> generateGroupSizes(int positionsCount)
209223
{
210224
int maxGroupSize = 17;

lib/trino-parquet/src/test/java/io/trino/parquet/crypto/TestParquetEncryption.java

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -475,20 +475,10 @@ void encryptedDictionaryPruningTwoColumns()
475475
TupleDomainParquetPredicate predicateAge = new TupleDomainParquetPredicate(
476476
domainAge, ImmutableList.of(age), UTC);
477477

478-
List<RowGroupInfo> groupsAge = getFilteredRowGroups(
479-
0,
480-
source.getEstimatedSize(),
481-
source,
482-
metadata,
483-
List.of(domainAge),
484-
List.of(predicateAge),
485-
ImmutableMap.of(ImmutableList.of("age"), age),
486-
UTC,
487-
200,
488-
ParquetReaderOptions.builder().build());
478+
Map<String, List<Integer>> data = readTwoColumnFile(file, new TestingKeyRetriever(Optional.of(KEY_FOOT), Optional.of(KEY_AGE), Optional.of(KEY_ID)), domainAge, predicateAge);
489479

490-
// No row-groups should pass after dictionary pruning
491-
assertThat(groupsAge).isEmpty();
480+
// Should be filtered by dictionary filtering in reader
481+
assertThat(data).containsValues(List.of(), List.of());
492482

493483
// ——— Predicate on inaccessible column (id = missingId) → should fail (no column key) ———
494484
TupleDomain<ColumnDescriptor> domainId = TupleDomain.withColumnDomains(ImmutableMap.of(id, singleValue(INTEGER, (long) missingId)));
@@ -905,11 +895,28 @@ private static List<Integer> readSingleColumnFile(
905895
}
906896
}
907897

898+
private static Map<String, List<Integer>> readTwoColumnFile(
899+
File file, DecryptionKeyRetriever retriever)
900+
throws IOException
901+
{
902+
ColumnDescriptor ageDescriptor = new ColumnDescriptor(
903+
new String[] {"age"},
904+
Types.required(PrimitiveType.PrimitiveTypeName.INT32).named("age"), 0, 0);
905+
906+
ColumnDescriptor idDescriptor = new ColumnDescriptor(
907+
new String[] {"id"},
908+
Types.required(PrimitiveType.PrimitiveTypeName.INT32).named("id"), 0, 0);
909+
910+
TupleDomainParquetPredicate allPredicate = new TupleDomainParquetPredicate(
911+
TupleDomain.all(), ImmutableList.of(ageDescriptor, idDescriptor), UTC);
912+
return readTwoColumnFile(file, retriever, TupleDomain.all(), allPredicate);
913+
}
914+
908915
/**
909916
* Reads both columns and returns a map “age” → values, “id → values.
910917
*/
911918
private static Map<String, List<Integer>> readTwoColumnFile(
912-
File file, DecryptionKeyRetriever retriever)
919+
File file, DecryptionKeyRetriever retriever, TupleDomain<ColumnDescriptor> domain, TupleDomainParquetPredicate predicate)
913920
throws IOException
914921
{
915922
ParquetDataSource source = new FileParquetDataSource(file, ParquetReaderOptions.builder().build());
@@ -931,12 +938,9 @@ private static Map<String, List<Integer>> readTwoColumnFile(
931938
ImmutableList.of("age"), ageDescriptor,
932939
ImmutableList.of("id"), idDescriptor);
933940

934-
TupleDomainParquetPredicate predicate = new TupleDomainParquetPredicate(
935-
TupleDomain.all(), ImmutableList.of(ageDescriptor, idDescriptor), UTC);
936-
937941
List<RowGroupInfo> groups = getFilteredRowGroups(
938942
0, source.getEstimatedSize(), source, metadata,
939-
List.of(TupleDomain.all()), List.of(predicate),
943+
List.of(domain), List.of(predicate),
940944
byPath, UTC, 200, ParquetReaderOptions.builder().build());
941945

942946
PrimitiveField ageField = new PrimitiveField(INTEGER, true, ageDescriptor, 0);

0 commit comments

Comments
 (0)