Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,6 @@
import io.trino.sql.tree.ResetSessionAuthorization;
import io.trino.sql.tree.Revoke;
import io.trino.sql.tree.Rollback;
import io.trino.sql.tree.Row;
import io.trino.sql.tree.RowPattern;
import io.trino.sql.tree.SampledRelation;
import io.trino.sql.tree.SecurityCharacteristic;
Expand Down Expand Up @@ -4082,31 +4081,15 @@ protected Scope visitValues(Values node, Optional<Scope> scope)
// add coercions
for (Expression row : node.getRows()) {
Type actualType = analysis.getType(row);
if (row instanceof Row value) {
// coerce Row by fields to preserve Row structure and enable optimizations based on this structure, e.g. pruning, predicate extraction
// TODO coerce the whole Row and add an Optimizer rule that converts CAST(ROW(...) AS ...) into ROW(CAST(...), CAST(...), ...).
// The rule would also handle Row-type expressions that were specified as CAST(ROW). It should support multiple casts over a ROW.
for (int i = 0; i < actualType.getTypeParameters().size(); i++) {
Expression item = value.getFields().get(i).getExpression();
Type actualItemType = actualType.getTypeParameters().get(i);
Type expectedItemType = commonSuperType.getTypeParameters().get(i);
if (!actualItemType.equals(expectedItemType)) {
analysis.addCoercion(item, expectedItemType);
}
}
}
else if (actualType instanceof RowType) {
// coerce row-type expression as a whole
if (!actualType.equals(commonSuperType)) {
analysis.addCoercion(row, commonSuperType);
}
}
else {

Type targetType = commonSuperType;
if (!(actualType instanceof RowType)) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please retain the comment explaining why we have to unpack the supertype in this case.

// coerce field. it will be wrapped in Row by Planner
Type superType = getOnlyElement(commonSuperType.getTypeParameters());
if (!actualType.equals(superType)) {
analysis.addCoercion(row, superType);
}
targetType = getOnlyElement(commonSuperType.getTypeParameters());
}

if (!actualType.equals(targetType)) {
analysis.addCoercion(row, targetType);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1757,24 +1757,12 @@ protected RelationPlan visitValues(Values node, Void context)
TranslationMap translationMap = new TranslationMap(outerContext, analysis.getScope(node), analysis, lambdaDeclarationToSymbolMap, outputSymbols, session, plannerContext);

ImmutableList.Builder<Expression> rows = ImmutableList.builder();
for (io.trino.sql.tree.Expression rowExpression : node.getRows()) {
if (rowExpression instanceof io.trino.sql.tree.Row row) {
ImmutableList.Builder<Expression> rowValues = ImmutableList.builder();
ImmutableList.Builder<RowType.Field> fieldTypes = ImmutableList.builder();
for (int i = 0; i < row.getFields().size(); i++) {
io.trino.sql.tree.Row.Field field = row.getFields().get(i);
Expression expression = coerceIfNecessary(analysis, field.getExpression(), translationMap.rewrite(field.getExpression()));
rowValues.add(expression);
fieldTypes.add(new RowType.Field(field.getName().map(Identifier::getCanonicalValue), expression.type()));
}
rows.add(new Row(rowValues.build(), RowType.from(fieldTypes.build())));
}
else if (analysis.getType(rowExpression) instanceof RowType) {
rows.add(coerceIfNecessary(analysis, rowExpression, translationMap.rewrite(rowExpression)));
}
else {
rows.add(new Row(ImmutableList.of(coerceIfNecessary(analysis, rowExpression, translationMap.rewrite(rowExpression)))));
for (io.trino.sql.tree.Expression row : node.getRows()) {
Expression rewritten = coerceIfNecessary(analysis, row, translationMap.rewrite(row));
if (!(analysis.getType(row) instanceof RowType)) {
rewritten = new Row(ImmutableList.of(rewritten));
}
rows.add(rewritten);
}

ValuesNode valuesNode = new ValuesNode(idAllocator.getNextId(), outputSymbols, rows.build());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@
import io.trino.matching.Capture;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.spi.block.SqlRow;
import io.trino.spi.type.RowType;
import io.trino.spi.type.Type;
import io.trino.sql.ir.Constant;
import io.trino.sql.ir.Expression;
import io.trino.sql.ir.Reference;
import io.trino.sql.ir.Row;
Expand All @@ -40,6 +44,7 @@
import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static io.trino.SystemSessionProperties.isMergeProjectWithValues;
import static io.trino.matching.Capture.newCapture;
import static io.trino.spi.type.TypeUtils.readNativeValue;
import static io.trino.sql.planner.DeterminismEvaluator.isDeterministic;
import static io.trino.sql.planner.ExpressionNodeInliner.replaceExpression;
import static io.trino.sql.planner.plan.Patterns.project;
Expand Down Expand Up @@ -129,10 +134,12 @@ public Result apply(ProjectNode node, Captures captures, Context context)
// do not proceed if ValuesNode contains a non-deterministic expression and it is referenced more than once by the projection
Set<Symbol> nonDeterministicValuesOutputs = new HashSet<>();
for (Expression rowExpression : valuesNode.getRows().get()) {
Row row = (Row) rowExpression;
for (int i = 0; i < valuesNode.getOutputSymbols().size(); i++) {
if (!isDeterministic(row.items().get(i))) {
nonDeterministicValuesOutputs.add(valuesNode.getOutputSymbols().get(i));
if (!(rowExpression instanceof Constant)) {
Row row = (Row) rowExpression;
for (int i = 0; i < valuesNode.getOutputSymbols().size(); i++) {
if (!isDeterministic(row.items().get(i))) {
nonDeterministicValuesOutputs.add(valuesNode.getOutputSymbols().get(i));
}
}
}
}
Expand All @@ -150,7 +157,12 @@ public Result apply(ProjectNode node, Captures captures, Context context)
// inline values expressions into projection's assignments
ImmutableList.Builder<Expression> projectedRows = ImmutableList.builder();
for (Expression rowExpression : valuesNode.getRows().get()) {
Map<Reference, Expression> mapping = buildMappings(valuesNode.getOutputSymbols(), (Row) rowExpression);
Map<Reference, Expression> mapping = switch (rowExpression) {
case Row row -> buildMappings(valuesNode.getOutputSymbols(), row);
case Constant constant -> buildMappings(valuesNode.getOutputSymbols(), constant);
default -> throw new IllegalStateException("Unexpected expression type in ValuesNode: " + rowExpression.getClass().getName());
};

Row projectedRow = new Row(expressions.stream()
.map(expression -> replaceExpression(expression, mapping))
.collect(toImmutableList()));
Expand All @@ -161,7 +173,8 @@ public Result apply(ProjectNode node, Captures captures, Context context)

private static boolean isSupportedValues(ValuesNode valuesNode)
{
return valuesNode.getRows().isEmpty() || valuesNode.getRows().get().stream().allMatch(Row.class::isInstance);
return valuesNode.getRows().isEmpty() ||
valuesNode.getRows().get().stream().allMatch(row -> row instanceof Row || row instanceof Constant);
}

private Map<Reference, Expression> buildMappings(List<Symbol> symbols, Row row)
Expand All @@ -172,4 +185,21 @@ private Map<Reference, Expression> buildMappings(List<Symbol> symbols, Row row)
}
return mappingBuilder.buildOrThrow();
}

private Map<Reference, Expression> buildMappings(List<Symbol> symbols, Constant row)
{
ImmutableMap.Builder<Reference, Expression> mappingBuilder = ImmutableMap.builder();

RowType type = (RowType) row.type();
SqlRow rowValue = (SqlRow) row.value();
for (int field = 0; field < type.getFields().size(); field++) {
Type fieldType = type.getFields().get(field).getType();

mappingBuilder.put(symbols.get(field).toSymbolReference(), new Constant(
fieldType,
readNativeValue(fieldType, rowValue.getRawFieldBlock(field), rowValue.getRawIndex())));
}

return mappingBuilder.buildOrThrow();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,11 @@
import io.trino.matching.Capture;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.spi.block.SqlRow;
import io.trino.spi.type.RowType;
import io.trino.spi.type.Type;
import io.trino.sql.PlannerContext;
import io.trino.sql.ir.Constant;
import io.trino.sql.ir.Expression;
import io.trino.sql.ir.Row;
import io.trino.sql.planner.Symbol;
Expand All @@ -29,12 +33,15 @@
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.ValuesNode;

import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;

import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static io.trino.SystemSessionProperties.getPushFilterIntoValuesMaxRowCount;
import static io.trino.matching.Capture.newCapture;
import static io.trino.spi.type.TypeUtils.readNativeValue;
import static io.trino.sql.ir.Booleans.FALSE;
import static io.trino.sql.ir.Booleans.NULL_BOOLEAN;
import static io.trino.sql.ir.Booleans.TRUE;
Expand Down Expand Up @@ -114,24 +121,25 @@ public Result apply(FilterNode node, Captures captures, Context context)
boolean optimized = false;
boolean keepFilter = false;
for (Expression expression : valuesNode.getRows().orElseThrow()) {
Row row = (Row) expression;
ImmutableMap.Builder<Symbol, Expression> mapping = ImmutableMap.builder();
for (int i = 0; i < valuesNode.getOutputSymbols().size(); i++) {
mapping.put(valuesNode.getOutputSymbols().get(i), row.items().get(i));
}
Expression rewrittenPredicate = inlineSymbols(mapping.buildOrThrow(), predicate);
Map<Symbol, Expression> mapping = switch (expression) {
case Row row -> buildMappings(valuesNode.getOutputSymbols(), row);
case Constant constant -> buildMappings(valuesNode.getOutputSymbols(), constant);
default -> throw new IllegalStateException("Unexpected expression type in ValuesNode: " + expression.getClass().getName());
};

Expression rewrittenPredicate = inlineSymbols(mapping, predicate);
Optional<Expression> optimizedPredicate = newOptimizer(plannerContext).process(rewrittenPredicate, context.getSession(), ImmutableMap.of());

if (optimizedPredicate.isPresent() && optimizedPredicate.get().equals(TRUE)) {
filteredRows.add(row);
filteredRows.add(expression);
}
else if (optimizedPredicate.isPresent() && (optimizedPredicate.get().equals(FALSE) || optimizedPredicate.get().equals(NULL_BOOLEAN))) {
// skip row
optimized = true;
}
else {
// could not evaluate the predicate for the row
filteredRows.add(row);
filteredRows.add(expression);
keepFilter = true;
}
}
Expand Down Expand Up @@ -159,6 +167,33 @@ private static boolean isSupportedValues(ValuesNode valuesNode)
// so we don't need to handle this case here.
// Also, do not optimize if any values row is not a Row instance, because we cannot easily inline
// the columns of non-row expressions in the filter predicate.
return valuesNode.getRows().isPresent() && valuesNode.getRows().get().stream().allMatch(Row.class::isInstance);
return valuesNode.getRows().isPresent() &&
valuesNode.getRows().get().stream().allMatch(row -> row instanceof Row || row instanceof Constant);
}

private Map<Symbol, Expression> buildMappings(List<Symbol> symbols, Row row)
{
ImmutableMap.Builder<Symbol, Expression> mappingBuilder = ImmutableMap.builder();
for (int i = 0; i < row.items().size(); i++) {
mappingBuilder.put(symbols.get(i), row.items().get(i));
}
return mappingBuilder.buildOrThrow();
}

private Map<Symbol, Expression> buildMappings(List<Symbol> symbols, Constant row)
{
ImmutableMap.Builder<Symbol, Expression> mappingBuilder = ImmutableMap.builder();

RowType type = (RowType) row.type();
SqlRow rowValue = (SqlRow) row.value();
for (int field = 0; field < type.getFields().size(); field++) {
Type fieldType = type.getFields().get(field).getType();

mappingBuilder.put(symbols.get(field), new Constant(
fieldType,
readNativeValue(fieldType, rowValue.getRawFieldBlock(field), rowValue.getRawIndex())));
}

return mappingBuilder.buildOrThrow();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@
import static io.trino.sql.planner.assertions.PlanMatchPattern.topNRanking;
import static io.trino.sql.planner.assertions.PlanMatchPattern.unnest;
import static io.trino.sql.planner.assertions.PlanMatchPattern.values;
import static io.trino.sql.planner.assertions.PlanMatchPattern.valuesOf;
import static io.trino.sql.planner.assertions.PlanMatchPattern.windowFunction;
import static io.trino.sql.planner.optimizations.PlanNodeSearcher.searchFrom;
import static io.trino.sql.planner.plan.AggregationNode.Step.FINAL;
Expand Down Expand Up @@ -2290,11 +2291,15 @@ public void testValuesCoercions()
assertPlan("VALUES (TINYINT '1', REAL '1'), (DOUBLE '2', SMALLINT '2')",
CREATED,
anyTree(
values(
valuesOf(
ImmutableList.of("field", "field0"),
ImmutableList.of(
ImmutableList.of(new Cast(new Constant(TINYINT, 1L), DOUBLE), new Constant(REAL, Reals.toReal(1f))),
ImmutableList.of(new Constant(DOUBLE, 2.0), new Cast(new Constant(SMALLINT, 2L), REAL))))));
new Cast(
new Row(ImmutableList.of(new Constant(TINYINT, 1L), new Constant(REAL, Reals.toReal(1f)))),
RowType.anonymousRow(DOUBLE, REAL)),
new Cast(
new Row(ImmutableList.of(new Constant(DOUBLE, 2.0), new Constant(SMALLINT, 2L))),
RowType.anonymousRow(DOUBLE, REAL))))));

// entry of type other than Row coerced as a whole
assertPlan("VALUES DOUBLE '1', CAST(ROW(2) AS row(bigint))",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -793,6 +793,14 @@ public static PlanMatchPattern values(int rowCount)
return values(ImmutableList.of(), nCopies(rowCount, ImmutableList.of()));
}

public static PlanMatchPattern valuesOf(List<String> aliases, List<Expression> expectedRows)
{
return values(
aliasToIndex(aliases),
Optional.of(aliases.size()),
Optional.of(expectedRows));
}

public static PlanMatchPattern values(List<String> aliases, List<List<Expression>> expectedRows)
{
return values(aliases, Optional.of(expectedRows));
Expand Down