Skip to content

Commit 99147bf

Browse files
dnhatnywangd
authored andcommitted
Optimize ordinal inputs in Values aggregation (elastic#127849)
Currently, time-series aggregations use the `values` aggregation to collect dimension values. While we might introduce a specialized aggregation for this in the future, for now, we are using `values`, and the inputs are likely ordinal blocks. This change speeds up the `values` aggregation when the inputs are ordinal-based. Execution time reduced from 461ms to 192ms for 1000 groups. ``` ValuesAggregatorBenchmark.run BytesRef 10000 avgt 7 461.938 ± 6.089 ms/op ValuesAggregatorBenchmark.run BytesRef 10000 avgt 7 192.898 ± 1.781 ms/op ```
1 parent 6230abd commit 99147bf

File tree

8 files changed

+284
-18
lines changed

8 files changed

+284
-18
lines changed

benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/ValuesAggregatorBenchmark.java

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,13 @@
2121
import org.elasticsearch.compute.data.Block;
2222
import org.elasticsearch.compute.data.BlockFactory;
2323
import org.elasticsearch.compute.data.BytesRefBlock;
24+
import org.elasticsearch.compute.data.BytesRefVector;
2425
import org.elasticsearch.compute.data.ElementType;
2526
import org.elasticsearch.compute.data.IntBlock;
27+
import org.elasticsearch.compute.data.IntVector;
2628
import org.elasticsearch.compute.data.LongBlock;
2729
import org.elasticsearch.compute.data.LongVector;
30+
import org.elasticsearch.compute.data.OrdinalBytesRefVector;
2831
import org.elasticsearch.compute.data.Page;
2932
import org.elasticsearch.compute.operator.AggregationOperator;
3033
import org.elasticsearch.compute.operator.DriverContext;
@@ -282,11 +285,18 @@ private static Block dataBlock(int groups, String dataType) {
282285
int blockLength = blockLength(groups);
283286
return switch (dataType) {
284287
case BYTES_REF -> {
285-
try (BytesRefBlock.Builder builder = blockFactory.newBytesRefBlockBuilder(blockLength)) {
288+
try (
289+
BytesRefVector.Builder dict = blockFactory.newBytesRefVectorBuilder(blockLength);
290+
IntVector.Builder ords = blockFactory.newIntVectorBuilder(blockLength)
291+
) {
292+
final int dictLength = Math.min(blockLength, KEYWORDS.length);
293+
for (int i = 0; i < dictLength; i++) {
294+
dict.appendBytesRef(KEYWORDS[i]);
295+
}
286296
for (int i = 0; i < blockLength; i++) {
287-
builder.appendBytesRef(KEYWORDS[i % KEYWORDS.length]);
297+
ords.appendInt(i % dictLength);
288298
}
289-
yield builder.build();
299+
yield new OrdinalBytesRefVector(ords.build(), dict.build()).asBlock();
290300
}
291301
}
292302
case INT -> {

docs/changelog/127849.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 127849
2+
summary: Optimize ordinal inputs in Values aggregation
3+
area: "ES|QL"
4+
type: enhancement
5+
issues: []

x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/GroupingAggregatorImplementer.java

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535

3636
import static java.util.stream.Collectors.joining;
3737
import static org.elasticsearch.compute.gen.AggregatorImplementer.capitalize;
38+
import static org.elasticsearch.compute.gen.Methods.optionalStaticMethod;
3839
import static org.elasticsearch.compute.gen.Methods.requireAnyArgs;
3940
import static org.elasticsearch.compute.gen.Methods.requireAnyType;
4041
import static org.elasticsearch.compute.gen.Methods.requireArgs;
@@ -336,10 +337,32 @@ private MethodSpec prepareProcessPage() {
336337
builder.beginControlFlow("if (valuesBlock.mayHaveNulls())");
337338
builder.addStatement("state.enableGroupIdTracking(seenGroupIds)");
338339
builder.endControlFlow();
339-
builder.addStatement("return $L", addInput(b -> b.addStatement("addRawInput(positionOffset, groupIds, valuesBlock$L)", extra)));
340+
if (shouldWrapAddInput(blockType(aggParam.type()))) {
341+
builder.addStatement(
342+
"var addInput = $L",
343+
addInput(b -> b.addStatement("addRawInput(positionOffset, groupIds, valuesBlock$L)", extra))
344+
);
345+
builder.addStatement("return $T.wrapAddInput(addInput, state, valuesBlock)", declarationType);
346+
} else {
347+
builder.addStatement(
348+
"return $L",
349+
addInput(b -> b.addStatement("addRawInput(positionOffset, groupIds, valuesBlock$L)", extra))
350+
);
351+
}
340352
}
341353
builder.endControlFlow();
342-
builder.addStatement("return $L", addInput(b -> b.addStatement("addRawInput(positionOffset, groupIds, valuesVector$L)", extra)));
354+
if (shouldWrapAddInput(vectorType(aggParam.type()))) {
355+
builder.addStatement(
356+
"var addInput = $L",
357+
addInput(b -> b.addStatement("addRawInput(positionOffset, groupIds, valuesVector$L)", extra))
358+
);
359+
builder.addStatement("return $T.wrapAddInput(addInput, state, valuesVector)", declarationType);
360+
} else {
361+
builder.addStatement(
362+
"return $L",
363+
addInput(b -> b.addStatement("addRawInput(positionOffset, groupIds, valuesVector$L)", extra))
364+
);
365+
}
343366
return builder.build();
344367
}
345368

@@ -526,6 +549,15 @@ private void combineRawInputForArray(MethodSpec.Builder builder, String arrayVar
526549
warningsBlock(builder, () -> builder.addStatement("$T.combine(state, groupId, $L)", declarationType, arrayVariable));
527550
}
528551

552+
private boolean shouldWrapAddInput(TypeName valuesType) {
553+
return optionalStaticMethod(
554+
declarationType,
555+
requireType(GROUPING_AGGREGATOR_FUNCTION_ADD_INPUT),
556+
requireName("wrapAddInput"),
557+
requireArgs(requireType(GROUPING_AGGREGATOR_FUNCTION_ADD_INPUT), requireType(aggState.declaredType()), requireType(valuesType))
558+
) != null;
559+
}
560+
529561
private void warningsBlock(MethodSpec.Builder builder, Runnable block) {
530562
if (warnExceptions.isEmpty() == false) {
531563
builder.beginControlFlow("try");

x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/Methods.java

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -59,20 +59,31 @@ static ExecutableElement requireStaticMethod(
5959
TypeMatcher returnTypeMatcher,
6060
NameMatcher nameMatcher,
6161
ArgumentMatcher argumentMatcher
62+
) {
63+
ExecutableElement method = optionalStaticMethod(declarationType, returnTypeMatcher, nameMatcher, argumentMatcher);
64+
if (method == null) {
65+
var message = nameMatcher.names.size() == 1 ? "Requires method: " : "Requires one of methods: ";
66+
var signatures = nameMatcher.names.stream()
67+
.map(name -> "public static " + returnTypeMatcher + " " + declarationType + "#" + name + "(" + argumentMatcher + ")")
68+
.collect(joining(" or "));
69+
throw new IllegalArgumentException(message + signatures);
70+
}
71+
return method;
72+
}
73+
74+
static ExecutableElement optionalStaticMethod(
75+
TypeElement declarationType,
76+
TypeMatcher returnTypeMatcher,
77+
NameMatcher nameMatcher,
78+
ArgumentMatcher argumentMatcher
6279
) {
6380
return typeAndSuperType(declarationType).flatMap(type -> ElementFilter.methodsIn(type.getEnclosedElements()).stream())
6481
.filter(method -> method.getModifiers().contains(Modifier.STATIC))
6582
.filter(method -> nameMatcher.test(method.getSimpleName().toString()))
6683
.filter(method -> returnTypeMatcher.test(TypeName.get(method.getReturnType())))
6784
.filter(method -> argumentMatcher.test(method.getParameters().stream().map(it -> TypeName.get(it.asType())).toList()))
6885
.findFirst()
69-
.orElseThrow(() -> {
70-
var message = nameMatcher.names.size() == 1 ? "Requires method: " : "Requires one of methods: ";
71-
var signatures = nameMatcher.names.stream()
72-
.map(name -> "public static " + returnTypeMatcher + " " + declarationType + "#" + name + "(" + argumentMatcher + ")")
73-
.collect(joining(" or "));
74-
return new IllegalArgumentException(message + signatures);
75-
});
86+
.orElse(null);
7687
}
7788

7889
static NameMatcher requireName(String... names) {

x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregator.java

Lines changed: 18 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/ValuesBytesRefGroupingAggregatorFunction.java

Lines changed: 4 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.
Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.compute.aggregation;
9+
10+
import org.apache.lucene.util.BytesRef;
11+
import org.elasticsearch.compute.aggregation.blockhash.BlockHash;
12+
import org.elasticsearch.compute.data.BytesRefBlock;
13+
import org.elasticsearch.compute.data.BytesRefVector;
14+
import org.elasticsearch.compute.data.IntArrayBlock;
15+
import org.elasticsearch.compute.data.IntBigArrayBlock;
16+
import org.elasticsearch.compute.data.IntBlock;
17+
import org.elasticsearch.compute.data.IntVector;
18+
import org.elasticsearch.compute.data.OrdinalBytesRefBlock;
19+
import org.elasticsearch.core.Releasables;
20+
21+
final class ValuesBytesRefAggregators {
22+
static GroupingAggregatorFunction.AddInput wrapAddInput(
23+
GroupingAggregatorFunction.AddInput delegate,
24+
ValuesBytesRefAggregator.GroupingState state,
25+
BytesRefBlock values
26+
) {
27+
OrdinalBytesRefBlock valuesOrdinal = values.asOrdinals();
28+
if (valuesOrdinal == null) {
29+
return delegate;
30+
}
31+
BytesRefVector dict = valuesOrdinal.getDictionaryVector();
32+
final IntVector hashIds;
33+
BytesRef spare = new BytesRef();
34+
try (var hashIdsBuilder = values.blockFactory().newIntVectorFixedBuilder(dict.getPositionCount())) {
35+
for (int p = 0; p < dict.getPositionCount(); p++) {
36+
hashIdsBuilder.appendInt(Math.toIntExact(BlockHash.hashOrdToGroup(state.bytes.add(dict.getBytesRef(p, spare)))));
37+
}
38+
hashIds = hashIdsBuilder.build();
39+
}
40+
IntBlock ordinalIds = valuesOrdinal.getOrdinalsBlock();
41+
return new GroupingAggregatorFunction.AddInput() {
42+
@Override
43+
public void add(int positionOffset, IntArrayBlock groupIds) {
44+
for (int groupPosition = 0; groupPosition < groupIds.getPositionCount(); groupPosition++) {
45+
if (groupIds.isNull(groupPosition)) {
46+
continue;
47+
}
48+
int groupStart = groupIds.getFirstValueIndex(groupPosition);
49+
int groupEnd = groupStart + groupIds.getValueCount(groupPosition);
50+
for (int g = groupStart; g < groupEnd; g++) {
51+
int groupId = groupIds.getInt(g);
52+
if (ordinalIds.isNull(groupPosition + positionOffset)) {
53+
continue;
54+
}
55+
int valuesStart = ordinalIds.getFirstValueIndex(groupPosition + positionOffset);
56+
int valuesEnd = valuesStart + ordinalIds.getValueCount(groupPosition + positionOffset);
57+
for (int v = valuesStart; v < valuesEnd; v++) {
58+
state.values.add(groupId, hashIds.getInt(ordinalIds.getInt(v)));
59+
}
60+
}
61+
}
62+
}
63+
64+
@Override
65+
public void add(int positionOffset, IntBigArrayBlock groupIds) {
66+
for (int groupPosition = 0; groupPosition < groupIds.getPositionCount(); groupPosition++) {
67+
if (groupIds.isNull(groupPosition)) {
68+
continue;
69+
}
70+
int groupStart = groupIds.getFirstValueIndex(groupPosition);
71+
int groupEnd = groupStart + groupIds.getValueCount(groupPosition);
72+
for (int g = groupStart; g < groupEnd; g++) {
73+
int groupId = groupIds.getInt(g);
74+
if (ordinalIds.isNull(groupPosition + positionOffset)) {
75+
continue;
76+
}
77+
int valuesStart = ordinalIds.getFirstValueIndex(groupPosition + positionOffset);
78+
int valuesEnd = valuesStart + ordinalIds.getValueCount(groupPosition + positionOffset);
79+
for (int v = valuesStart; v < valuesEnd; v++) {
80+
state.values.add(groupId, hashIds.getInt(ordinalIds.getInt(v)));
81+
}
82+
}
83+
}
84+
}
85+
86+
@Override
87+
public void add(int positionOffset, IntVector groupIds) {
88+
for (int groupPosition = 0; groupPosition < groupIds.getPositionCount(); groupPosition++) {
89+
int groupId = groupIds.getInt(groupPosition);
90+
if (ordinalIds.isNull(groupPosition + positionOffset)) {
91+
continue;
92+
}
93+
int valuesStart = ordinalIds.getFirstValueIndex(groupPosition + positionOffset);
94+
int valuesEnd = valuesStart + ordinalIds.getValueCount(groupPosition + positionOffset);
95+
for (int v = valuesStart; v < valuesEnd; v++) {
96+
state.values.add(groupId, hashIds.getInt(ordinalIds.getInt(v)));
97+
}
98+
}
99+
}
100+
101+
@Override
102+
public void close() {
103+
Releasables.close(hashIds, delegate);
104+
}
105+
};
106+
}
107+
108+
static GroupingAggregatorFunction.AddInput wrapAddInput(
109+
GroupingAggregatorFunction.AddInput delegate,
110+
ValuesBytesRefAggregator.GroupingState state,
111+
BytesRefVector values
112+
) {
113+
var valuesOrdinal = values.asOrdinals();
114+
if (valuesOrdinal == null) {
115+
return delegate;
116+
}
117+
BytesRefVector dict = valuesOrdinal.getDictionaryVector();
118+
final IntVector hashIds;
119+
BytesRef spare = new BytesRef();
120+
try (var hashIdsBuilder = values.blockFactory().newIntVectorFixedBuilder(dict.getPositionCount())) {
121+
for (int p = 0; p < dict.getPositionCount(); p++) {
122+
hashIdsBuilder.appendInt(Math.toIntExact(BlockHash.hashOrdToGroup(state.bytes.add(dict.getBytesRef(p, spare)))));
123+
}
124+
hashIds = hashIdsBuilder.build();
125+
}
126+
var ordinalIds = valuesOrdinal.getOrdinalsVector();
127+
return new GroupingAggregatorFunction.AddInput() {
128+
@Override
129+
public void add(int positionOffset, IntArrayBlock groupIds) {
130+
for (int groupPosition = 0; groupPosition < groupIds.getPositionCount(); groupPosition++) {
131+
if (groupIds.isNull(groupPosition)) {
132+
continue;
133+
}
134+
int groupStart = groupIds.getFirstValueIndex(groupPosition);
135+
int groupEnd = groupStart + groupIds.getValueCount(groupPosition);
136+
for (int g = groupStart; g < groupEnd; g++) {
137+
int groupId = groupIds.getInt(g);
138+
state.values.add(groupId, hashIds.getInt(ordinalIds.getInt(groupPosition + positionOffset)));
139+
}
140+
}
141+
}
142+
143+
@Override
144+
public void add(int positionOffset, IntBigArrayBlock groupIds) {
145+
for (int groupPosition = 0; groupPosition < groupIds.getPositionCount(); groupPosition++) {
146+
if (groupIds.isNull(groupPosition)) {
147+
continue;
148+
}
149+
int groupStart = groupIds.getFirstValueIndex(groupPosition);
150+
int groupEnd = groupStart + groupIds.getValueCount(groupPosition);
151+
for (int g = groupStart; g < groupEnd; g++) {
152+
int groupId = groupIds.getInt(g);
153+
state.values.add(groupId, hashIds.getInt(ordinalIds.getInt(groupPosition + positionOffset)));
154+
}
155+
}
156+
}
157+
158+
@Override
159+
public void add(int positionOffset, IntVector groupIds) {
160+
for (int groupPosition = 0; groupPosition < groupIds.getPositionCount(); groupPosition++) {
161+
int groupId = groupIds.getInt(groupPosition);
162+
state.values.add(groupId, hashIds.getInt(ordinalIds.getInt(groupPosition + positionOffset)));
163+
}
164+
}
165+
166+
@Override
167+
public void close() {
168+
Releasables.close(hashIds, delegate);
169+
}
170+
};
171+
}
172+
}

0 commit comments

Comments
 (0)