@@ -66,7 +66,7 @@ import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
66
66
import org .apache .spark .sql .execution .datasources .v2 .DataSourceV2Relation
67
67
import org .apache .spark .sql .execution .streaming .StreamingRelation
68
68
import org .apache .spark .sql .internal .SQLConf
69
- import org .apache .spark .sql .types .{ ArrayType , DataType , IntegerType , MapType , StructField , StructType }
69
+ import org .apache .spark .sql .types ._
70
70
import org .apache .spark .sql .util .CaseInsensitiveStringMap
71
71
72
72
/**
@@ -81,8 +81,8 @@ class DeltaAnalysis(session: SparkSession)
81
81
override def apply (plan : LogicalPlan ): LogicalPlan = plan.resolveOperatorsDown {
82
82
// INSERT INTO by ordinal and df.insertInto()
83
83
case a @ AppendDelta (r, d) if ! a.isByName &&
84
- needsSchemaAdjustmentByOrdinal(d.name() , a.query, r.schema) =>
85
- val projection = resolveQueryColumnsByOrdinal(a.query, r.output, d.name() )
84
+ needsSchemaAdjustmentByOrdinal(d, a.query, r.schema) =>
85
+ val projection = resolveQueryColumnsByOrdinal(a.query, r.output, d)
86
86
if (projection != a.query) {
87
87
a.copy(query = projection)
88
88
} else {
@@ -208,8 +208,8 @@ class DeltaAnalysis(session: SparkSession)
208
208
209
209
// INSERT OVERWRITE by ordinal and df.insertInto()
210
210
case o @ OverwriteDelta (r, d) if ! o.isByName &&
211
- needsSchemaAdjustmentByOrdinal(d.name() , o.query, r.schema) =>
212
- val projection = resolveQueryColumnsByOrdinal(o.query, r.output, d.name() )
211
+ needsSchemaAdjustmentByOrdinal(d, o.query, r.schema) =>
212
+ val projection = resolveQueryColumnsByOrdinal(o.query, r.output, d)
213
213
if (projection != o.query) {
214
214
val aliases = AttributeMap (o.query.output.zip(projection.output).collect {
215
215
case (l : AttributeReference , r : AttributeReference ) if ! l.sameRef(r) => (l, r)
@@ -245,9 +245,9 @@ class DeltaAnalysis(session: SparkSession)
245
245
case o @ DynamicPartitionOverwriteDelta (r, d) if o.resolved
246
246
=>
247
247
val adjustedQuery = if (! o.isByName &&
248
- needsSchemaAdjustmentByOrdinal(d.name() , o.query, r.schema)) {
248
+ needsSchemaAdjustmentByOrdinal(d, o.query, r.schema)) {
249
249
// INSERT OVERWRITE by ordinal and df.insertInto()
250
- resolveQueryColumnsByOrdinal(o.query, r.output, d.name() )
250
+ resolveQueryColumnsByOrdinal(o.query, r.output, d)
251
251
} else if (o.isByName && o.origin.sqlText.nonEmpty &&
252
252
needsSchemaAdjustmentByName(o.query, r.output, d)) {
253
253
// INSERT OVERWRITE by name
@@ -850,12 +850,14 @@ class DeltaAnalysis(session: SparkSession)
850
850
* type column/field.
851
851
*/
852
852
private def resolveQueryColumnsByOrdinal (
853
- query : LogicalPlan , targetAttrs : Seq [Attribute ], tblName : String ): LogicalPlan = {
853
+ query : LogicalPlan , targetAttrs : Seq [Attribute ], deltaTable : DeltaTableV2 ): LogicalPlan = {
854
854
// always add a Cast. it will be removed in the optimizer if it is unnecessary.
855
855
val project = query.output.zipWithIndex.map { case (attr, i) =>
856
856
if (i < targetAttrs.length) {
857
857
val targetAttr = targetAttrs(i)
858
- addCastToColumn(attr, targetAttr, tblName)
858
+ addCastToColumn(attr, targetAttr, deltaTable.name(),
859
+ allowTypeWidening = allowTypeWidening(deltaTable)
860
+ )
859
861
} else {
860
862
attr
861
863
}
@@ -890,47 +892,69 @@ class DeltaAnalysis(session: SparkSession)
890
892
.getOrElse {
891
893
throw DeltaErrors .missingColumn(attr, targetAttrs)
892
894
}
893
- addCastToColumn(attr, targetAttr, deltaTable.name())
895
+ addCastToColumn(attr, targetAttr, deltaTable.name(),
896
+ allowTypeWidening = allowTypeWidening(deltaTable)
897
+ )
894
898
}
895
899
Project (project, query)
896
900
}
897
901
898
902
private def addCastToColumn (
899
903
attr : Attribute ,
900
904
targetAttr : Attribute ,
901
- tblName : String ): NamedExpression = {
905
+ tblName : String ,
906
+ allowTypeWidening : Boolean ): NamedExpression = {
902
907
val expr = (attr.dataType, targetAttr.dataType) match {
903
908
case (s, t) if s == t =>
904
909
attr
905
910
case (s : StructType , t : StructType ) if s != t =>
906
- addCastsToStructs(tblName, attr, s, t)
911
+ addCastsToStructs(tblName, attr, s, t, allowTypeWidening )
907
912
case (ArrayType (s : StructType , sNull : Boolean ), ArrayType (t : StructType , tNull : Boolean ))
908
913
if s != t && sNull == tNull =>
909
- addCastsToArrayStructs(tblName, attr, s, t, sNull)
914
+ addCastsToArrayStructs(tblName, attr, s, t, sNull, allowTypeWidening)
915
+ case (s : AtomicType , t : AtomicType )
916
+ if allowTypeWidening && TypeWidening .isTypeChangeSupportedForSchemaEvolution(t, s) =>
917
+ // Keep the type from the query, the target schema will be updated to widen the existing
918
+ // type to match it.
919
+ attr
910
920
case _ =>
911
921
getCastFunction(attr, targetAttr.dataType, targetAttr.name)
912
922
}
913
923
Alias (expr, targetAttr.name)(explicitMetadata = Option (targetAttr.metadata))
914
924
}
915
925
926
+ /**
927
+ * Whether inserting values that have a wider type than the table has is allowed. In that case,
928
+ * values are not downcasted to the current table type and the table schema is updated instead to
929
+ * use the wider type.
930
+ */
931
+ private def allowTypeWidening (deltaTable : DeltaTableV2 ): Boolean = {
932
+ val options = new DeltaOptions (Map .empty[String , String ], conf)
933
+ options.canMergeSchema && TypeWidening .isEnabled(
934
+ deltaTable.initialSnapshot.protocol,
935
+ deltaTable.initialSnapshot.metadata
936
+ )
937
+ }
938
+
916
939
/**
917
940
* With Delta, we ACCEPT_ANY_SCHEMA, meaning that Spark doesn't automatically adjust the schema
918
941
* of INSERT INTO. This allows us to perform better schema enforcement/evolution. Since Spark
919
942
* skips this step, we see if we need to perform any schema adjustment here.
920
943
*/
921
944
private def needsSchemaAdjustmentByOrdinal (
922
- tableName : String ,
945
+ deltaTable : DeltaTableV2 ,
923
946
query : LogicalPlan ,
924
947
schema : StructType ): Boolean = {
925
948
val output = query.output
926
949
if (output.length < schema.length) {
927
- throw DeltaErrors .notEnoughColumnsInInsert(tableName , output.length, schema.length)
950
+ throw DeltaErrors .notEnoughColumnsInInsert(deltaTable.name() , output.length, schema.length)
928
951
}
929
952
// Now we should try our best to match everything that already exists, and leave the rest
930
953
// for schema evolution to WriteIntoDelta
931
954
val existingSchemaOutput = output.take(schema.length)
932
955
existingSchemaOutput.map(_.name) != schema.map(_.name) ||
933
- ! SchemaUtils .isReadCompatible(schema.asNullable, existingSchemaOutput.toStructType)
956
+ ! SchemaUtils .isReadCompatible(schema.asNullable, existingSchemaOutput.toStructType,
957
+ allowTypeWidening = allowTypeWidening(deltaTable))
934
958
}
935
959
936
960
/**
@@ -984,7 +1008,10 @@ class DeltaAnalysis(session: SparkSession)
984
1008
}
985
1009
val specifiedTargetAttrs = targetAttrs.filter(col => userSpecifiedNames.contains(col.name))
986
1010
! SchemaUtils .isReadCompatible(
987
- specifiedTargetAttrs.toStructType.asNullable, query.output.toStructType)
1011
+ specifiedTargetAttrs.toStructType.asNullable,
1012
+ query.output.toStructType,
1013
+ allowTypeWidening = allowTypeWidening(deltaTable)
1014
+ )
988
1015
}
989
1016
990
1017
// Get cast operation for the level of strictness in the schema a user asked for
@@ -1014,7 +1041,8 @@ class DeltaAnalysis(session: SparkSession)
1014
1041
tableName : String ,
1015
1042
parent : NamedExpression ,
1016
1043
source : StructType ,
1017
- target : StructType ): NamedExpression = {
1044
+ target : StructType ,
1045
+ allowTypeWidening : Boolean ): NamedExpression = {
1018
1046
if (source.length < target.length) {
1019
1047
throw DeltaErrors .notEnoughColumnsInInsert(
1020
1048
tableName, source.length, target.length, Some (parent.qualifiedName))
@@ -1025,12 +1053,20 @@ class DeltaAnalysis(session: SparkSession)
1025
1053
case t : StructType =>
1026
1054
val subField = Alias (GetStructField (parent, i, Option (name)), target(i).name)(
1027
1055
explicitMetadata = Option (metadata))
1028
- addCastsToStructs(tableName, subField, nested, t)
1056
+ addCastsToStructs(tableName, subField, nested, t, allowTypeWidening )
1029
1057
case o =>
1030
1058
val field = parent.qualifiedName + " ." + name
1031
1059
val targetName = parent.qualifiedName + " ." + target(i).name
1032
1060
throw DeltaErrors .cannotInsertIntoColumn(tableName, field, targetName, o.simpleString)
1033
1061
}
1062
+
1063
+ case (StructField (name, dt : AtomicType , _, _), i) if i < target.length && allowTypeWidening &&
1064
+ TypeWidening .isTypeChangeSupportedForSchemaEvolution(
1065
+ target(i).dataType.asInstanceOf [AtomicType ], dt) =>
1066
+ val targetAttr = target(i)
1067
+ Alias (
1068
+ GetStructField (parent, i, Option (name)),
1069
+ targetAttr.name)(explicitMetadata = Option (targetAttr.metadata))
1034
1070
case (other, i) if i < target.length =>
1035
1071
val targetAttr = target(i)
1036
1072
Alias (
@@ -1054,9 +1090,11 @@ class DeltaAnalysis(session: SparkSession)
1054
1090
parent : NamedExpression ,
1055
1091
source : StructType ,
1056
1092
target : StructType ,
1057
- sourceNullable : Boolean ): Expression = {
1093
+ sourceNullable : Boolean ,
1094
+ allowTypeWidening : Boolean ): Expression = {
1058
1095
val structConverter : (Expression , Expression ) => Expression = (_, i) =>
1059
- addCastsToStructs(tableName, Alias (GetArrayItem (parent, i), i.toString)(), source, target)
1096
+ addCastsToStructs(
1097
+ tableName, Alias (GetArrayItem (parent, i), i.toString)(), source, target, allowTypeWidening)
1060
1098
val transformLambdaFunc = {
1061
1099
val elementVar = NamedLambdaVariable (" elementVar" , source, sourceNullable)
1062
1100
val indexVar = NamedLambdaVariable (" indexVar" , IntegerType , false )
0 commit comments