Skip to content

Commit 497bb25

Browse files
committed
Remove special casing for explicit row in VALUES
Injecting coerciones field by field is no longer needed. This is now handled by PushCastIntoRow.
1 parent e6b50a0 commit 497bb25

File tree

5 files changed

+65
-51
lines changed

5 files changed

+65
-51
lines changed

core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java

Lines changed: 8 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,6 @@
227227
import io.trino.sql.tree.ResetSessionAuthorization;
228228
import io.trino.sql.tree.Revoke;
229229
import io.trino.sql.tree.Rollback;
230-
import io.trino.sql.tree.Row;
231230
import io.trino.sql.tree.RowPattern;
232231
import io.trino.sql.tree.SampledRelation;
233232
import io.trino.sql.tree.SecurityCharacteristic;
@@ -4082,31 +4081,15 @@ protected Scope visitValues(Values node, Optional<Scope> scope)
40824081
// add coercions
40834082
for (Expression row : node.getRows()) {
40844083
Type actualType = analysis.getType(row);
4085-
if (row instanceof Row value) {
4086-
// coerce Row by fields to preserve Row structure and enable optimizations based on this structure, e.g. pruning, predicate extraction
4087-
// TODO coerce the whole Row and add an Optimizer rule that converts CAST(ROW(...) AS ...) into ROW(CAST(...), CAST(...), ...).
4088-
// The rule would also handle Row-type expressions that were specified as CAST(ROW). It should support multiple casts over a ROW.
4089-
for (int i = 0; i < actualType.getTypeParameters().size(); i++) {
4090-
Expression item = value.getFields().get(i).getExpression();
4091-
Type actualItemType = actualType.getTypeParameters().get(i);
4092-
Type expectedItemType = commonSuperType.getTypeParameters().get(i);
4093-
if (!actualItemType.equals(expectedItemType)) {
4094-
analysis.addCoercion(item, expectedItemType);
4095-
}
4096-
}
4097-
}
4098-
else if (actualType instanceof RowType) {
4099-
// coerce row-type expression as a whole
4100-
if (!actualType.equals(commonSuperType)) {
4101-
analysis.addCoercion(row, commonSuperType);
4102-
}
4103-
}
4104-
else {
4084+
4085+
Type targetType = commonSuperType;
4086+
if (!(actualType instanceof RowType)) {
41054087
// coerce field. it will be wrapped in Row by Planner
4106-
Type superType = getOnlyElement(commonSuperType.getTypeParameters());
4107-
if (!actualType.equals(superType)) {
4108-
analysis.addCoercion(row, superType);
4109-
}
4088+
targetType = getOnlyElement(commonSuperType.getTypeParameters());
4089+
}
4090+
4091+
if (!actualType.equals(targetType)) {
4092+
analysis.addCoercion(row, targetType);
41104093
}
41114094
}
41124095

core/trino-main/src/main/java/io/trino/sql/planner/RelationPlanner.java

Lines changed: 5 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1757,24 +1757,12 @@ protected RelationPlan visitValues(Values node, Void context)
17571757
TranslationMap translationMap = new TranslationMap(outerContext, analysis.getScope(node), analysis, lambdaDeclarationToSymbolMap, outputSymbols, session, plannerContext);
17581758

17591759
ImmutableList.Builder<Expression> rows = ImmutableList.builder();
1760-
for (io.trino.sql.tree.Expression rowExpression : node.getRows()) {
1761-
if (rowExpression instanceof io.trino.sql.tree.Row row) {
1762-
ImmutableList.Builder<Expression> rowValues = ImmutableList.builder();
1763-
ImmutableList.Builder<RowType.Field> fieldTypes = ImmutableList.builder();
1764-
for (int i = 0; i < row.getFields().size(); i++) {
1765-
io.trino.sql.tree.Row.Field field = row.getFields().get(i);
1766-
Expression expression = coerceIfNecessary(analysis, field.getExpression(), translationMap.rewrite(field.getExpression()));
1767-
rowValues.add(expression);
1768-
fieldTypes.add(new RowType.Field(field.getName().map(Identifier::getCanonicalValue), expression.type()));
1769-
}
1770-
rows.add(new Row(rowValues.build(), RowType.from(fieldTypes.build())));
1771-
}
1772-
else if (analysis.getType(rowExpression) instanceof RowType) {
1773-
rows.add(coerceIfNecessary(analysis, rowExpression, translationMap.rewrite(rowExpression)));
1774-
}
1775-
else {
1776-
rows.add(new Row(ImmutableList.of(coerceIfNecessary(analysis, rowExpression, translationMap.rewrite(rowExpression)))));
1760+
for (io.trino.sql.tree.Expression row : node.getRows()) {
1761+
Expression rewritten = coerceIfNecessary(analysis, row, translationMap.rewrite(row));
1762+
if (!(analysis.getType(row) instanceof RowType)) {
1763+
rewritten = new Row(ImmutableList.of(rewritten));
17771764
}
1765+
rows.add(rewritten);
17781766
}
17791767

17801768
ValuesNode valuesNode = new ValuesNode(idAllocator.getNextId(), outputSymbols, rows.build());

core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/MergeProjectWithValues.java

Lines changed: 36 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,10 @@
2020
import io.trino.matching.Capture;
2121
import io.trino.matching.Captures;
2222
import io.trino.matching.Pattern;
23+
import io.trino.spi.block.SqlRow;
24+
import io.trino.spi.type.RowType;
25+
import io.trino.spi.type.Type;
26+
import io.trino.sql.ir.Constant;
2327
import io.trino.sql.ir.Expression;
2428
import io.trino.sql.ir.Reference;
2529
import io.trino.sql.ir.Row;
@@ -40,6 +44,7 @@
4044
import static com.google.common.collect.ImmutableSet.toImmutableSet;
4145
import static io.trino.SystemSessionProperties.isMergeProjectWithValues;
4246
import static io.trino.matching.Capture.newCapture;
47+
import static io.trino.spi.type.TypeUtils.readNativeValue;
4348
import static io.trino.sql.planner.DeterminismEvaluator.isDeterministic;
4449
import static io.trino.sql.planner.ExpressionNodeInliner.replaceExpression;
4550
import static io.trino.sql.planner.plan.Patterns.project;
@@ -129,10 +134,12 @@ public Result apply(ProjectNode node, Captures captures, Context context)
129134
// do not proceed if ValuesNode contains a non-deterministic expression and it is referenced more than once by the projection
130135
Set<Symbol> nonDeterministicValuesOutputs = new HashSet<>();
131136
for (Expression rowExpression : valuesNode.getRows().get()) {
132-
Row row = (Row) rowExpression;
133-
for (int i = 0; i < valuesNode.getOutputSymbols().size(); i++) {
134-
if (!isDeterministic(row.items().get(i))) {
135-
nonDeterministicValuesOutputs.add(valuesNode.getOutputSymbols().get(i));
137+
if (!(rowExpression instanceof Constant)) {
138+
Row row = (Row) rowExpression;
139+
for (int i = 0; i < valuesNode.getOutputSymbols().size(); i++) {
140+
if (!isDeterministic(row.items().get(i))) {
141+
nonDeterministicValuesOutputs.add(valuesNode.getOutputSymbols().get(i));
142+
}
136143
}
137144
}
138145
}
@@ -150,7 +157,12 @@ public Result apply(ProjectNode node, Captures captures, Context context)
150157
// inline values expressions into projection's assignments
151158
ImmutableList.Builder<Expression> projectedRows = ImmutableList.builder();
152159
for (Expression rowExpression : valuesNode.getRows().get()) {
153-
Map<Reference, Expression> mapping = buildMappings(valuesNode.getOutputSymbols(), (Row) rowExpression);
160+
Map<Reference, Expression> mapping = switch (rowExpression) {
161+
case Row row -> buildMappings(valuesNode.getOutputSymbols(), row);
162+
case Constant constant -> buildMappings(valuesNode.getOutputSymbols(), constant);
163+
default -> throw new IllegalStateException("Unexpected expression type in ValuesNode: " + rowExpression.getClass().getName());
164+
};
165+
154166
Row projectedRow = new Row(expressions.stream()
155167
.map(expression -> replaceExpression(expression, mapping))
156168
.collect(toImmutableList()));
@@ -161,7 +173,8 @@ public Result apply(ProjectNode node, Captures captures, Context context)
161173

162174
private static boolean isSupportedValues(ValuesNode valuesNode)
163175
{
164-
return valuesNode.getRows().isEmpty() || valuesNode.getRows().get().stream().allMatch(Row.class::isInstance);
176+
return valuesNode.getRows().isEmpty() ||
177+
valuesNode.getRows().get().stream().allMatch(row -> row instanceof Row || row instanceof Constant);
165178
}
166179

167180
private Map<Reference, Expression> buildMappings(List<Symbol> symbols, Row row)
@@ -172,4 +185,21 @@ private Map<Reference, Expression> buildMappings(List<Symbol> symbols, Row row)
172185
}
173186
return mappingBuilder.buildOrThrow();
174187
}
188+
189+
private Map<Reference, Expression> buildMappings(List<Symbol> symbols, Constant row)
190+
{
191+
ImmutableMap.Builder<Reference, Expression> mappingBuilder = ImmutableMap.builder();
192+
193+
RowType type = (RowType) row.type();
194+
for (int field = 0; field < type.getFields().size(); field++) {
195+
Type fieldType = type.getFields().get(field).getType();
196+
197+
SqlRow rowValue = (SqlRow) row.value();
198+
mappingBuilder.put(symbols.get(field).toSymbolReference(), new Constant(
199+
fieldType,
200+
readNativeValue(fieldType, rowValue.getRawFieldBlock(field), rowValue.getRawIndex())));
201+
}
202+
203+
return mappingBuilder.buildOrThrow();
204+
}
175205
}

core/trino-main/src/test/java/io/trino/sql/planner/TestLogicalPlanner.java

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,7 @@
162162
import static io.trino.sql.planner.assertions.PlanMatchPattern.topNRanking;
163163
import static io.trino.sql.planner.assertions.PlanMatchPattern.unnest;
164164
import static io.trino.sql.planner.assertions.PlanMatchPattern.values;
165+
import static io.trino.sql.planner.assertions.PlanMatchPattern.valuesOf;
165166
import static io.trino.sql.planner.assertions.PlanMatchPattern.windowFunction;
166167
import static io.trino.sql.planner.optimizations.PlanNodeSearcher.searchFrom;
167168
import static io.trino.sql.planner.plan.AggregationNode.Step.FINAL;
@@ -2290,11 +2291,15 @@ public void testValuesCoercions()
22902291
assertPlan("VALUES (TINYINT '1', REAL '1'), (DOUBLE '2', SMALLINT '2')",
22912292
CREATED,
22922293
anyTree(
2293-
values(
2294+
valuesOf(
22942295
ImmutableList.of("field", "field0"),
22952296
ImmutableList.of(
2296-
ImmutableList.of(new Cast(new Constant(TINYINT, 1L), DOUBLE), new Constant(REAL, Reals.toReal(1f))),
2297-
ImmutableList.of(new Constant(DOUBLE, 2.0), new Cast(new Constant(SMALLINT, 2L), REAL))))));
2297+
new Cast(
2298+
new Row(ImmutableList.of(new Constant(TINYINT, 1L), new Constant(REAL, Reals.toReal(1f)))),
2299+
RowType.anonymousRow(DOUBLE, REAL)),
2300+
new Cast(
2301+
new Row(ImmutableList.of(new Constant(DOUBLE, 2.0), new Constant(SMALLINT, 2L))),
2302+
RowType.anonymousRow(DOUBLE, REAL))))));
22982303

22992304
// entry of type other than Row coerced as a whole
23002305
assertPlan("VALUES DOUBLE '1', CAST(ROW(2) AS row(bigint))",

core/trino-main/src/test/java/io/trino/sql/planner/assertions/PlanMatchPattern.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -793,6 +793,14 @@ public static PlanMatchPattern values(int rowCount)
793793
return values(ImmutableList.of(), nCopies(rowCount, ImmutableList.of()));
794794
}
795795

796+
public static PlanMatchPattern valuesOf(List<String> aliases, List<Expression> expectedRows)
797+
{
798+
return values(
799+
aliasToIndex(aliases),
800+
Optional.of(aliases.size()),
801+
Optional.of(expectedRows));
802+
}
803+
796804
public static PlanMatchPattern values(List<String> aliases, List<List<Expression>> expectedRows)
797805
{
798806
return values(aliases, Optional.of(expectedRows));

0 commit comments

Comments
 (0)