Skip to content

Commit 2937bc8

Browse files
authored
[Spark] Fix dependent constraints/generated columns checker for type widening (delta-io#3912)
<!-- Thanks for sending a pull request! Here are some tips for you: 1. If this is your first time, please read our contributor guidelines: https://github.com/delta-io/delta/blob/master/CONTRIBUTING.md 2. If the PR is unfinished, add '[WIP]' in your PR title, e.g., '[WIP] Your PR title ...'. 3. Be sure to keep the PR description updated to reflect all changes. 4. Please write your PR title to summarize what this PR proposes. 5. If possible, provide a concise example to reproduce the issue for a faster review. 6. If applicable, include the corresponding issue number in the PR title and link it in the body. --> #### Which Delta project/connector is this regarding? <!-- Please add the component selected below to the beginning of the pull request title For example: [Spark] Title of my pull request --> - [X] Spark - [ ] Standalone - [ ] Flink - [ ] Kernel - [ ] Other (fill in here) ## Description The current checker of dependent expressions doesn't validate changes for array and map types. For example, usage of type widening could lead to constraints breaks: ``` scala> sql("CREATE TABLE table (a array<byte>) USING DELTA") scala> sql("INSERT INTO table VALUES (array(1, -2, 3))") scala> sql("SELECT hash(a[1]) FROM table").show() +-----------+ | hash(a[1])| +-----------+ |-1160545675| +-----------+ scala> sql("ALTER TABLE table ADD CONSTRAINT ch1 CHECK (hash(a[1]) = -1160545675)") scala> sql("ALTER TABLE table SET TBLPROPERTIES('delta.enableTypeWidening' = true)") scala> sql("ALTER TABLE table CHANGE COLUMN a.element TYPE BIGINT") scala> sql("SELECT hash(a[1]) FROM table").show() +----------+ |hash(a[1])| +----------+ |-981642528| +----------+ scala> sql("INSERT INTO table VALUES (array(1, -2, 3))") 24/11/15 12:53:23 ERROR Utils: Aborting task com.databricks.sql.transaction.tahoe.schema.DeltaInvariantViolationException: [DELTA_VIOLATE_CONSTRAINT_WITH_VALUES] CHECK constraint ch1 (hash(a[1]) = -1160545675) violated by row with values: ``` The proposed algorithm is more strict and regards maps, arrays and structs during constraints/generated columns dependencies. <!-- - Describe what this PR changes. - Describe why we need the change. If this PR resolves an issue be sure to include "Resolves #XXX" to correctly link and close the issue upon merge. --> ## How was this patch tested? Added new tests for constraints and generated columns used with type widening feature. <!-- If tests were added, say they were added here. Please make sure to test the changes thoroughly including negative and positive cases if possible. If the changes were tested in any way other than unit tests, please clarify how you tested step by step (ideally copy and paste-able, so that other reviewers can test and check, and descendants can verify in the future). If the changes were not tested, please explain why. --> ## Does this PR introduce _any_ user-facing changes? Due to strictness of the algorithm new potential dangerous type changes will be prohibited. An exception will be thrown in the example above. But such changes are called in the schema evolution feature mostly that was introduced recently, so it should not affect many users. <!-- If yes, please clarify the previous behavior and the change this PR proposes - provide the console output, description and/or an example to show the behavior difference if possible. If possible, please also clarify if this is a user-facing change compared to the released Delta Lake versions or within the unreleased branches such as master. If no, write 'No'. -->
1 parent 81f27b3 commit 2937bc8

File tree

8 files changed

+367
-31
lines changed

8 files changed

+367
-31
lines changed

spark/src/main/scala/org/apache/spark/sql/delta/commands/alterDeltaTableCommands.scala

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -78,22 +78,22 @@ trait AlterDeltaTableCommand extends DeltaCommand {
7878
protected def checkDependentExpressions(
7979
sparkSession: SparkSession,
8080
columnParts: Seq[String],
81-
newMetadata: actions.Metadata,
81+
oldMetadata: actions.Metadata,
8282
protocol: Protocol): Unit = {
8383
if (!sparkSession.sessionState.conf.getConf(
8484
DeltaSQLConf.DELTA_ALTER_TABLE_CHANGE_COLUMN_CHECK_EXPRESSIONS)) {
8585
return
8686
}
8787
// check if the column to change is referenced by check constraints
8888
val dependentConstraints =
89-
Constraints.findDependentConstraints(sparkSession, columnParts, newMetadata)
89+
Constraints.findDependentConstraints(sparkSession, columnParts, oldMetadata)
9090
if (dependentConstraints.nonEmpty) {
9191
throw DeltaErrors.foundViolatingConstraintsForColumnChange(
9292
UnresolvedAttribute(columnParts).name, dependentConstraints)
9393
}
9494
// check if the column to change is referenced by any generated columns
9595
val dependentGenCols = SchemaUtils.findDependentGeneratedColumns(
96-
sparkSession, columnParts, protocol, newMetadata.schema)
96+
sparkSession, columnParts, protocol, oldMetadata.schema)
9797
if (dependentGenCols.nonEmpty) {
9898
throw DeltaErrors.foundViolatingGeneratedColumnsForColumnChange(
9999
UnresolvedAttribute(columnParts).name, dependentGenCols)
@@ -768,7 +768,7 @@ case class AlterTableDropColumnsDeltaCommand(
768768
configuration = newConfiguration
769769
)
770770
columnsToDrop.foreach { columnParts =>
771-
checkDependentExpressions(sparkSession, columnParts, newMetadata, txn.protocol)
771+
checkDependentExpressions(sparkSession, columnParts, metadata, txn.protocol)
772772
}
773773

774774
txn.updateMetadata(newMetadata)
@@ -927,7 +927,7 @@ case class AlterTableChangeColumnDeltaCommand(
927927
if (newColumn.name != columnName) {
928928
// need to validate the changes if the column is renamed
929929
checkDependentExpressions(
930-
sparkSession, columnPath :+ columnName, newMetadata, txn.protocol)
930+
sparkSession, columnPath :+ columnName, metadata, txn.protocol)
931931
}
932932

933933

spark/src/main/scala/org/apache/spark/sql/delta/constraints/Constraints.scala

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,11 @@ object Constraints {
106106
metadata.configuration.filter {
107107
case (key, constraint) if key.toLowerCase(Locale.ROOT).startsWith("delta.constraints.") =>
108108
SchemaUtils.containsDependentExpression(
109-
sparkSession, columnName, constraint, sparkSession.sessionState.conf.resolver)
109+
sparkSession,
110+
columnName,
111+
constraint,
112+
metadata.schema,
113+
sparkSession.sessionState.conf.resolver)
110114
case _ => false
111115
}
112116
}

spark/src/main/scala/org/apache/spark/sql/delta/schema/ImplicitMetadataOperation.scala

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ import org.apache.spark.internal.MDC
2929
import org.apache.spark.sql.SparkSession
3030
import org.apache.spark.sql.catalyst.expressions.FileSourceGeneratedMetadataStructField
3131
import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes
32-
import org.apache.spark.sql.types.{DataType, StructType}
32+
import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructType}
3333

3434
/**
3535
* A trait that writers into Delta can extend to update the schema and/or partitioning of the table.
@@ -309,19 +309,19 @@ object ImplicitMetadataOperation {
309309
currentDt: DataType,
310310
updateDt: DataType): Unit = (currentDt, updateDt) match {
311311
// we explicitly ignore the check for `StructType` here.
312-
case (StructType(_), StructType(_)) =>
313-
314-
// FIXME: we intentionally incorporate the pattern match for `ArrayType` and `MapType`
315-
// here mainly due to the field paths for maps/arrays in constraints/generated columns
316-
// are *NOT* consistent with regular field paths,
317-
// e.g., `hash(a.arr[0].x)` vs. `hash(a.element.x)`.
318-
// this makes it hard to recurse into maps/arrays and check for the corresponding
319-
// fields - thus we can not actually block the operation even if the updated field
320-
// is being referenced by any CHECK constraints or generated columns.
321-
case (from, to) =>
312+
case (_: StructType, _: StructType) =>
313+
case (current: ArrayType, update: ArrayType) =>
314+
checkConstraintsOrGeneratedColumnsOnStructField(
315+
spark, path :+ "element", protocol, metadata, current.elementType, update.elementType)
316+
case (current: MapType, update: MapType) =>
317+
checkConstraintsOrGeneratedColumnsOnStructField(
318+
spark, path :+ "key", protocol, metadata, current.keyType, update.keyType)
319+
checkConstraintsOrGeneratedColumnsOnStructField(
320+
spark, path :+ "value", protocol, metadata, current.valueType, update.valueType)
321+
case (_, _) =>
322322
if (currentDt != updateDt) {
323-
checkDependentConstraints(spark, path, metadata, from, to)
324-
checkDependentGeneratedColumns(spark, path, protocol, metadata, from, to)
323+
checkDependentConstraints(spark, path, metadata, currentDt, updateDt)
324+
checkDependentGeneratedColumns(spark, path, protocol, metadata, currentDt, updateDt)
325325
}
326326
}
327327

spark/src/main/scala/org/apache/spark/sql/delta/schema/SchemaUtils.scala

Lines changed: 95 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@ import org.apache.spark.internal.MDC
3535
import org.apache.spark.sql._
3636
import org.apache.spark.sql.AnalysisException
3737
import org.apache.spark.sql.catalyst.analysis.{Resolver, UnresolvedAttribute}
38-
import org.apache.spark.sql.catalyst.expressions.AttributeReference
38+
import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, Expression, GetArrayItem, GetArrayStructFields, GetMapValue, GetStructField}
39+
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project}
3940
import org.apache.spark.sql.catalyst.util.CharVarcharUtils
4041
import org.apache.spark.sql.functions.{col, struct}
4142
import org.apache.spark.sql.internal.SQLConf
@@ -1269,20 +1270,58 @@ def normalizeColumnNamesInDataType(
12691270
// identifier with back-ticks.
12701271
def quoteIdentifier(part: String): String = s"`${part.replace("`", "``")}`"
12711272

1273+
private def analyzeExpression(
1274+
spark: SparkSession,
1275+
expr: Expression,
1276+
schema: StructType): Expression = {
1277+
// Workaround for `exp` analyze
1278+
val relation = LocalRelation(schema)
1279+
val relationWithExp = Project(Seq(Alias(expr, "validate_column")()), relation)
1280+
val analyzedPlan = spark.sessionState.analyzer.execute(relationWithExp)
1281+
analyzedPlan.collectFirst {
1282+
case Project(Seq(a: Alias), _: LocalRelation) => a.child
1283+
}.get
1284+
}
1285+
12721286
/**
1273-
* Will a column change, e.g., rename, need to be populated to the expression. This is true when
1274-
* the column to change itself or any of its descendent column is referenced by expression.
1287+
* Collects all attribute references in the given expression tree as a list of paths.
1288+
* In particular, generates paths for nested fields accessed using extraction expressions.
12751289
* For example:
1276-
* - a, length(a) -> true
1277-
* - b, (b.c + 1) -> true, because renaming b1 will need to change the expr to (b1.c + 1).
1278-
* - b.c, (cast b as string) -> false, because you can change b.c to b.c1 without affecting b.
1290+
* - GetStructField(AttributeReference("struct"), "a") -> ["struct.a"]
1291+
* - Size(AttributeReference("array")) -> ["array"]
12791292
*/
1280-
def containsDependentExpression(
1281-
spark: SparkSession,
1293+
private def collectUsedColumns(expression: Expression): Seq[Seq[String]] = {
1294+
val result = new collection.mutable.ArrayBuffer[Seq[String]]()
1295+
1296+
// Firstly, try to get referenced column for a child's expression.
1297+
// If it exists then we try to extend it by current expression.
1298+
// In case if we cannot extend one, we save the received column path (it's as long as possible).
1299+
def traverseAllPaths(exp: Expression): Option[Seq[String]] = exp match {
1300+
case GetStructField(child, _, Some(name)) => traverseAllPaths(child).map(_ :+ name)
1301+
case GetMapValue(child, key) =>
1302+
traverseAllPaths(key).foreach(result += _)
1303+
traverseAllPaths(child).map { childPath =>
1304+
result += childPath :+ "key"
1305+
childPath :+ "value"
1306+
}
1307+
case arrayExtract: GetArrayItem => traverseAllPaths(arrayExtract.child).map(_ :+ "element")
1308+
case arrayExtract: GetArrayStructFields =>
1309+
traverseAllPaths(arrayExtract.child).map(_ :+ "element" :+ arrayExtract.field.name)
1310+
case refCol: AttributeReference => Some(Seq(refCol.name))
1311+
case _ =>
1312+
exp.children.foreach(child => traverseAllPaths(child).foreach(result += _))
1313+
None
1314+
}
1315+
1316+
traverseAllPaths(expression).foreach(result += _)
1317+
1318+
result.toSeq
1319+
}
1320+
1321+
private def fallbackContainsDependentExpression(
1322+
expression: Expression,
12821323
columnToChange: Seq[String],
1283-
exprString: String,
12841324
resolver: Resolver): Boolean = {
1285-
val expression = spark.sessionState.sqlParser.parseExpression(exprString)
12861325
expression.foreach {
12871326
case refCol: UnresolvedAttribute =>
12881327
// columnToChange is the referenced column or its prefix
@@ -1294,6 +1333,51 @@ def normalizeColumnNamesInDataType(
12941333
false
12951334
}
12961335

1336+
/**
1337+
* Will a column change, e.g., rename, need to be populated to the expression. This is true when
1338+
* the column to change itself or any of its descendent column is referenced by expression.
1339+
* For example:
1340+
* - a, length(a) -> true
1341+
* - b, (b.c + 1) -> true, because renaming b1 will need to change the expr to (b1.c + 1).
1342+
* - b.c, (cast b as string) -> true, because change b.c to b.c1 affects (b as string) result.
1343+
*/
1344+
def containsDependentExpression(
1345+
spark: SparkSession,
1346+
columnToChange: Seq[String],
1347+
exprString: String,
1348+
schema: StructType,
1349+
resolver: Resolver): Boolean = {
1350+
val expression = spark.sessionState.sqlParser.parseExpression(exprString)
1351+
if (spark.sessionState.conf.getConf(
1352+
DeltaSQLConf.DELTA_CHANGE_COLUMN_CHECK_DEPENDENT_EXPRESSIONS_USE_V2)) {
1353+
try {
1354+
val analyzedExpr = analyzeExpression(spark, expression, schema)
1355+
val exprColumns = collectUsedColumns(analyzedExpr)
1356+
exprColumns.exists { exprColumn =>
1357+
// Changed column violates expression's column only when:
1358+
// 1) the changed column is a prefix of the referenced column,
1359+
// for example changing type of `col` affects `hash(col[0]) == 0`;
1360+
// 2) or the referenced column is a prefix of the changed column,
1361+
// for example changing type of `col.element` affects `concat_ws('', col) == 'abc'`;
1362+
// 3) or they are equal.
1363+
exprColumn.zip(columnToChange).forall {
1364+
case (exprFieldName, changedFieldName) => resolver(exprFieldName, changedFieldName)
1365+
}
1366+
}
1367+
} catch {
1368+
case NonFatal(e) =>
1369+
deltaAssert(
1370+
check = false,
1371+
name = "containsDependentExpression.checkV2Error",
1372+
msg = "Exception during dependent expression V2 checking: " + e.getMessage
1373+
)
1374+
fallbackContainsDependentExpression(expression, columnToChange, resolver)
1375+
}
1376+
} else {
1377+
fallbackContainsDependentExpression(expression, columnToChange, resolver)
1378+
}
1379+
}
1380+
12971381
/**
12981382
* Find the unsupported data type in a table schema. Return all columns that are using unsupported
12991383
* data types. For example,
@@ -1402,7 +1486,7 @@ def normalizeColumnNamesInDataType(
14021486
SchemaMergingUtils.transformColumns(schema) { (_, field, _) =>
14031487
GeneratedColumn.getGenerationExpressionStr(field.metadata).foreach { exprStr =>
14041488
val needsToChangeExpr = SchemaUtils.containsDependentExpression(
1405-
sparkSession, targetColumn, exprStr, sparkSession.sessionState.conf.resolver)
1489+
sparkSession, targetColumn, exprStr, schema, sparkSession.sessionState.conf.resolver)
14061490
if (needsToChangeExpr) dependentGenCols += field.name -> exprStr
14071491
}
14081492
field

spark/src/main/scala/org/apache/spark/sql/delta/sources/DeltaSQLConf.scala

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1580,6 +1580,21 @@ trait DeltaSQLConfBase {
15801580
.booleanConf
15811581
.createWithDefault(true)
15821582

1583+
val DELTA_CHANGE_COLUMN_CHECK_DEPENDENT_EXPRESSIONS_USE_V2 =
1584+
buildConf("changeColumn.checkDependentExpressionsUseV2")
1585+
.internal()
1586+
.doc(
1587+
"""
1588+
|More accurate implementation of checker for altering/renaming/dropping columns
1589+
|that might be referenced by constraints or generation rules.
1590+
|It respects nested arrays and maps, unlike the V1 checker.
1591+
|
1592+
|This is a safety switch - we should only turn this off when there is an issue with
1593+
|expression checking logic that prevents a valid column change from going through.
1594+
|""".stripMargin)
1595+
.booleanConf
1596+
.createWithDefault(true)
1597+
15831598
val DELTA_ALTER_TABLE_DROP_COLUMN_ENABLED =
15841599
buildConf("alterTable.dropColumn.enabled")
15851600
.internal()

spark/src/test/scala/org/apache/spark/sql/delta/schema/SchemaUtilsSuite.scala

Lines changed: 95 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ import org.scalatest.GivenWhenThen
3636

3737
import org.apache.spark.sql.{AnalysisException, Column, DataFrame, QueryTest, Row}
3838
import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier}
39-
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
39+
import org.apache.spark.sql.catalyst.analysis.{caseInsensitiveResolution, caseSensitiveResolution, UnresolvedAttribute}
4040
import org.apache.spark.sql.catalyst.expressions.{Alias, Cast, Expression}
4141
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project}
4242
import org.apache.spark.sql.functions._
@@ -3048,6 +3048,100 @@ class SchemaUtilsSuite extends QueryTest
30483048
assert(udts.map(_.getClass.getName).toSet == Set(classOf[PointUDT].getName))
30493049
}
30503050

3051+
3052+
test("check if column affects given dependent expressions") {
3053+
val schema = StructType(Seq(
3054+
StructField("cArray", ArrayType(StringType)),
3055+
StructField("cStruct", StructType(Seq(
3056+
StructField("cMap", MapType(IntegerType, ArrayType(BooleanType))),
3057+
StructField("cMapWithComplexKey", MapType(StructType(Seq(
3058+
StructField("a", ArrayType(StringType)),
3059+
StructField("b", BooleanType)
3060+
)), IntegerType))
3061+
)))
3062+
))
3063+
assert(
3064+
SchemaUtils.containsDependentExpression(
3065+
spark,
3066+
columnToChange = Seq("cArray"),
3067+
exprString = "cast(cStruct.cMap as string) == '{}'",
3068+
schema,
3069+
caseInsensitiveResolution) === false
3070+
)
3071+
// Extracting value from map uses key type as well.
3072+
assert(
3073+
SchemaUtils.containsDependentExpression(
3074+
spark,
3075+
columnToChange = Seq("cStruct", "cMap", "key"),
3076+
exprString = "cStruct.cMap['random_key'] == 'string'",
3077+
schema,
3078+
caseInsensitiveResolution) === true
3079+
)
3080+
assert(
3081+
SchemaUtils.containsDependentExpression(
3082+
spark,
3083+
columnToChange = Seq("cstruct"),
3084+
exprString = "size(cStruct.cMap) == 0",
3085+
schema,
3086+
caseSensitiveResolution) === false
3087+
)
3088+
assert(
3089+
SchemaUtils.containsDependentExpression(
3090+
spark,
3091+
columnToChange = Seq("cStruct", "cMap", "key"),
3092+
exprString = "size(cArray) == 1",
3093+
schema,
3094+
caseInsensitiveResolution) === false
3095+
)
3096+
assert(
3097+
SchemaUtils.containsDependentExpression(
3098+
spark,
3099+
columnToChange = Seq("cStruct", "cMap", "key"),
3100+
exprString = "cStruct.cMapWithComplexKey[struct(cArray, false)] == 0",
3101+
schema,
3102+
caseInsensitiveResolution) === false
3103+
)
3104+
assert(
3105+
SchemaUtils.containsDependentExpression(
3106+
spark,
3107+
columnToChange = Seq("cArray", "element"),
3108+
exprString = "cStruct.cMapWithComplexKey[struct(cArray, false)] == 0",
3109+
schema,
3110+
caseInsensitiveResolution) === true
3111+
)
3112+
assert(
3113+
SchemaUtils.containsDependentExpression(
3114+
spark,
3115+
columnToChange = Seq("cStruct", "cMapWithComplexKey", "key", "b"),
3116+
exprString = "cStruct.cMapWithComplexKey[struct(cArray, false)] == 0",
3117+
schema,
3118+
caseInsensitiveResolution) === true
3119+
)
3120+
assert(
3121+
SchemaUtils.containsDependentExpression(
3122+
spark,
3123+
columnToChange = Seq("cArray", "element"),
3124+
exprString = "concat_ws('', cArray) == 'string'",
3125+
schema,
3126+
caseInsensitiveResolution) === true
3127+
)
3128+
assert(
3129+
SchemaUtils.containsDependentExpression(
3130+
spark,
3131+
columnToChange = Seq("CARRAY"),
3132+
exprString = "cArray[0] > 'a'",
3133+
schema,
3134+
caseInsensitiveResolution) === true
3135+
)
3136+
assert(
3137+
SchemaUtils.containsDependentExpression(
3138+
spark,
3139+
columnToChange = Seq("CARRAY", "element"),
3140+
exprString = "cArray[0] > 'a'",
3141+
schema,
3142+
caseSensitiveResolution) === false
3143+
)
3144+
}
30513145
}
30523146

30533147
object UnsupportedDataType extends DataType {

0 commit comments

Comments
 (0)