diff --git a/.travis.yml b/.travis.yml index 42137d4c4f..2a5218f60d 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,5 +1,5 @@ language: scala -jdk: oraclejdk8 +jdk: openjdk8 sudo: false before_install: diff --git a/scalding-parquet-scrooge/src/main/java/com/twitter/scalding/parquet/scrooge/ScroogeReadSupport.java b/scalding-parquet-scrooge/src/main/java/com/twitter/scalding/parquet/scrooge/ScroogeReadSupport.java index 12854bbd41..f23a76c8c3 100644 --- a/scalding-parquet-scrooge/src/main/java/com/twitter/scalding/parquet/scrooge/ScroogeReadSupport.java +++ b/scalding-parquet-scrooge/src/main/java/com/twitter/scalding/parquet/scrooge/ScroogeReadSupport.java @@ -132,7 +132,9 @@ public static MessageType getSchemaForRead(MessageType fileMessageType, String p */ public static MessageType getSchemaForRead(MessageType fileMessageType, MessageType projectedMessageType) { assertGroupsAreCompatible(fileMessageType, projectedMessageType); - return projectedMessageType; + return ParquetCollectionFormatCompatibility.projectFileSchema( + fileMessageType, projectedMessageType + ); } /** diff --git a/scalding-parquet-scrooge/src/main/scala/com/twitter/scalding/parquet/scrooge/ParquetCollectionFormatCompatibility.scala b/scalding-parquet-scrooge/src/main/scala/com/twitter/scalding/parquet/scrooge/ParquetCollectionFormatCompatibility.scala new file mode 100644 index 0000000000..669afdf2cc --- /dev/null +++ b/scalding-parquet-scrooge/src/main/scala/com/twitter/scalding/parquet/scrooge/ParquetCollectionFormatCompatibility.scala @@ -0,0 +1,179 @@ +package com.twitter.scalding.parquet.scrooge + +import org.apache.parquet.schema.Type.Repetition +import org.apache.parquet.schema.{ GroupType, MessageType, Type } +import org.apache.parquet.thrift.DecodingSchemaMismatchException +import org.slf4j.LoggerFactory + +import scala.collection.JavaConverters._ + +/** + * Project file schema based on projected read schema which may contain different format + * of collection group--list and map. This is currently used in [[ScroogeReadSupport]] where + * projected read schema can come from: + * 1) Thrift struct via [[org.apache.parquet.thrift.ThriftSchemaConvertVisitor]] which always + * describe list with `_tuple` format, and map which has `MAP_KEY_VALUE` annotation. + * 2) User-supplied schema string via config key + * [[org.apache.parquet.hadoop.api.ReadSupport.PARQUET_READ_SCHEMA]] + * + * By definition, the projected read schema is a "sub-graph" of file schema in terms of field names. + * (We do allow optional field in projected read schema to be in + * the projected file schema, even if file schema may not originally contain it.) + * The graphs of the two schemas may, however, differ for list and map type because of multiple + * legacy formats and the canonical one. This class supports all directions of conversion. + * + * The projection strategy is: + * 1) traverse the two schemas and maintain only the fields in the read schema. + * 2) find collection type indicated by `repeated` type, and delegate it to respective list/map formatter. + * 3) wrap back the formatted repeated type with group type from projected read schema. This + * means the optional/required remains the same as that from projected read schema. + */ +private[scrooge] object ParquetCollectionFormatCompatibility { + + private val logger = LoggerFactory.getLogger(getClass) + + /** + * Project file schema to contain the same fields as the given projected read schema. + * The result is projected file schema with the same optional/required fields as the + * projected read schema, but collection type format as the file schema. + * + * @param fileSchema file schema to be projected + * @param projectedReadSchema read schema specifying field projection + */ + def projectFileSchema(fileSchema: MessageType, projectedReadSchema: MessageType): MessageType = { + val projectedFileSchema = projectFileType(fileSchema, projectedReadSchema, FieldContext()).asGroupType() + logger.debug(s"Projected read schema:\n${projectedReadSchema}\n" + + s"File schema:\n${fileSchema}\n" + + s"Projected file schema:\n${projectedFileSchema}") + new MessageType(projectedFileSchema.getName, projectedFileSchema.getFields) + } + + /** + * Main recursion to get projected file type. Traverse given schemas, filter out unneeded + * fields, and format read schema's list/map node to file schema's structure. + * The formatting of repeated type is not to one-to-one node swapping because we also have to + * handle projection and possible nested collection types in the repeated type. + */ + private def projectFileType(fileType: Type, projectedReadType: Type, fieldContext: FieldContext): Type = { + if (projectedReadType.isPrimitive || fileType.isPrimitive) { + // Base-cases to handle primitive types: + if (projectedReadType.isPrimitive && fileType.isPrimitive) { + // The field is a primitive in both schemas + projectedReadType + } else { + // The field is primitive in one schema but non-primitive in the othe other + throw new DecodingSchemaMismatchException( + s"Found schema mismatch between projected read type:\n$projectedReadType\n" + + s"and file type:\n${fileType}") + } + } else { + // Recursive cases to handle non-primitives (lists, maps, and structs): + (extractCollectionGroup(projectedReadType.asGroupType()), extractCollectionGroup(fileType.asGroupType())) match { + case (Some(projectedReadGroup: ListGroup), Some(fileGroup: ListGroup)) => + projectFileGroup(fileGroup, projectedReadGroup, fieldContext.copy(nestedListLevel = fieldContext.nestedListLevel + 1), formatter = ParquetListFormatter) + case (Some(projectedReadGroup: MapGroup), Some(fileGroup: MapGroup)) => + projectFileGroup(fileGroup, projectedReadGroup, fieldContext, formatter = ParquetMapFormatter) + case _ => // Struct projection + val projectedReadGroupType = projectedReadType.asGroupType + val fileGroupType = fileType.asGroupType + val projectedReadFields = projectedReadGroupType.getFields.asScala.map { projectedReadField => + if (!fileGroupType.containsField(projectedReadField.getName)) { + // The projected read schema includes a field which is missing from the file schema. + if (projectedReadField.isRepetition(Repetition.OPTIONAL)) { + // The missing field is optional in the projected read schema. Since the file schema + // doesn't contain this field there are no collection compatibility concerns to worry + // about and we can simply use the supplied schema: + projectedReadField + } else { + // The missing field is repeated or required, which is an error: + throw new DecodingSchemaMismatchException( + s"Found non-optional projected read field ${projectedReadField.getName}:\n$projectedReadField\n\n" + + s"not present in the given file group type:\n${fileGroupType}") + } + } else { + // The field is present in both schemas, so first check that the schemas specify compatible repetition + // values for the field, then recursively process the fields: + val fileFieldIndex = fileGroupType.getFieldIndex(projectedReadField.getName) + val fileField = fileGroupType.getFields.get(fileFieldIndex) + if (fileField.isRepetition(Repetition.OPTIONAL) && projectedReadField.isRepetition(Repetition.REQUIRED)) { + // The field is optional in the file schema but required in the projected read schema; this is an error: + throw new DecodingSchemaMismatchException( + s"Found required projected read field ${projectedReadField.getName}:\n$projectedReadField\n\n" + + s"on optional file field:\n${fileField}") + } else { + // The field's repetitions are compatible in both schemas (e.g. optional in both schemas or required + // in both), so recursively process the field: + projectFileType(fileField, projectedReadField, FieldContext(projectedReadField.getName)) + } + } + } + projectedReadGroupType.withNewFields(projectedReadFields.asJava) + } + } + } + + private def projectFileGroup(fileGroup: CollectionGroup, + projectedReadGroup: CollectionGroup, + fieldContext: FieldContext, + formatter: ParquetCollectionFormatter): GroupType = { + val projectedFileRepeatedType = formatter.formatCompatibleRepeatedType( + fileGroup.repeatedType, + projectedReadGroup.repeatedType, + fieldContext, + projectFileType) + // Respect optional/required from the projected read group. + projectedReadGroup.groupType.withNewFields(projectedFileRepeatedType) + } + + private def extractCollectionGroup(typ: GroupType): Option[CollectionGroup] = { + ParquetListFormatter.extractGroup(typ).orElse(ParquetMapFormatter.extractGroup(typ)) + } +} + +private[scrooge] trait ParquetCollectionFormatter { + /** + * Format source repeated type in the structure of target repeated type. + * + * @param fileRepeatedType repeated type from which the formatted result get the structure + * @param readRepeatedType repeated type from which the formatted result get content + * @param recursiveSolver solver for the inner content of the repeated type + * @return formatted result + */ + def formatCompatibleRepeatedType(fileRepeatedType: Type, + readRepeatedType: Type, + fieldContext: FieldContext, + recursiveSolver: (Type, Type, FieldContext) => Type): Type + + /** + * Extract collection group containing repeated type of different formats. + */ + def extractGroup(typ: GroupType): Option[CollectionGroup] +} + +/** + * Helper class to carry information from the field. Currently it only contains specific to list collection + * @param name field name + * @param nestedListLevel li + */ +private[scrooge] case class FieldContext(name: String = "", nestedListLevel: Int = 0) + +private[scrooge] sealed trait CollectionGroup { + /** + * Type for the collection. + * For example, given the schema, + * required group my_list (LIST) { + * repeated group list { + * optional binary element (UTF8); + * } + * } + * [[groupType]] refers to this whole schema + * [[repeatedType]] refers to inner `repeated` schema + */ + def groupType: GroupType + + def repeatedType: Type +} + +private[scrooge] sealed case class MapGroup(groupType: GroupType, repeatedType: Type) extends CollectionGroup + +private[scrooge] sealed case class ListGroup(groupType: GroupType, repeatedType: Type) extends CollectionGroup \ No newline at end of file diff --git a/scalding-parquet-scrooge/src/main/scala/com/twitter/scalding/parquet/scrooge/ParquetListFormatter.scala b/scalding-parquet-scrooge/src/main/scala/com/twitter/scalding/parquet/scrooge/ParquetListFormatter.scala new file mode 100644 index 0000000000..f12adbf47b --- /dev/null +++ b/scalding-parquet-scrooge/src/main/scala/com/twitter/scalding/parquet/scrooge/ParquetListFormatter.scala @@ -0,0 +1,308 @@ +package com.twitter.scalding.parquet.scrooge + +import org.apache.parquet.schema.{ GroupType, OriginalType, PrimitiveType, Type } +import org.slf4j.LoggerFactory + +import scala.collection.JavaConverters._ + +/** + * Format parquet list schema of read type to structure of file type. + * The supported formats are in `rules` of [[ParquetListFormatRule]]. + * Please see documentation for each rule. + * + * In a common use case, read schema form thrift struct has tuple format created by + * [[org.apache.parquet.thrift.ThriftSchemaConvertVisitor]] which always suffix + * list element with "_tuple". + */ +private[scrooge] object ParquetListFormatter extends ParquetCollectionFormatter { + + private val logger = LoggerFactory.getLogger(getClass) + + private val rules: Seq[ParquetListFormatRule] = Seq( + PrimitiveElementRule, + PrimitiveArrayRule, + GroupElementRule, + GroupArrayRule, + TupleRule, + StandardRule, + SparkLegacyNullableElementRule) + + def formatCompatibleRepeatedType(fileRepeatedType: Type, + readRepeatedType: Type, + fieldContext: FieldContext, + recursiveSolver: (Type, Type, FieldContext) => Type): Type = { + (findRule(fileRepeatedType), findRule(readRepeatedType)) match { + case (Some(fileRule), Some(readRule)) => { + val readElementType = readRule.elementType(readRepeatedType) + val fileElementType = fileRule.elementType(fileRepeatedType) + val solvedElementType = recursiveSolver(fileElementType, readElementType, fieldContext) + + fileRule.createCompliantRepeatedType( + elementType = solvedElementType, + elementName = readRule.elementName(readRepeatedType), + isElementRequired = readRule.isElementRequired(readRepeatedType), + elementOriginalType = readRule.elementOriginalType(readRepeatedType), + fieldContext = fieldContext) + } + + case _ => readRepeatedType + } + } + + def extractGroup(groupType: GroupType): Option[ListGroup] = { + if (isListGroup(groupType)) { + Some(ListGroup(groupType, groupType.getFields.get(0))) + } else { + None + } + } + + private def isListGroup(groupType: GroupType): Boolean = { + groupType.getOriginalType == OriginalType.LIST && + groupType.getFieldCount == 1 && + groupType.getFields.get(0).isRepetition(Type.Repetition.REPEATED) + } + + private def findRule(repeatedType: Type): Option[ParquetListFormatRule] = { + val ruleFound = rules.find(rule => rule.appliesToType(repeatedType)) + if (ruleFound.isEmpty) logger.warn(s"Unable to find matching rule for repeated type:\n$repeatedType") + ruleFound + } +} + +/** + * Rule allowing conversion from one format to other format by + * 1) detect which format is the repeated list type. + * 2) decompose the repeated type into element and other info. + * 3) construct compliant repeated type from the given element and other info. + * For example, + * if read repeated type matches Rule 1, and file type matches Rule 2. + * Rule 1 will decompose the read type, and + * Rule 2 will take that information to construct repeated element in Rule 2 of file type format. + */ +private[scrooge] sealed trait ParquetListFormatRule { + def elementType(repeatedType: Type): Type + + def elementName(repeatedType: Type): String = this.elementType(repeatedType).getName + + def elementOriginalType(repeatedType: Type): OriginalType = this.elementType(repeatedType).getOriginalType + + private[scrooge] def isElementRequired(repeatedType: Type): Boolean + + private[scrooge] def appliesToType(repeatedType: Type): Boolean + + private[scrooge] def createCompliantRepeatedType(elementType: Type, + elementName: String, + isElementRequired: Boolean, + elementOriginalType: OriginalType, + fieldContext: FieldContext): Type +} + +/** + * Rule 1 in https://github.com/apache/parquet-format/blob/master/LogicalTypes.md#backward-compatibility-rules + * Although documentation only mentions `element` primitive and not for `array`, + * Spark does write out with primitive `array` when legacy write format is enabled. + * repeated int32 [element|array]; + */ +private[scrooge] sealed trait PrimitiveListRule extends ParquetListFormatRule { + + def constantElementName: String + + override def elementType(repeatedType: Type): Type = repeatedType + + override private[scrooge] def isElementRequired(repeatedType: Type) = { + // According to rule 1, "the repeated field is not a group, + // then its type is the element type and elements are required." + true + } + + override def appliesToType(repeatedType: Type): Boolean = + repeatedType.isPrimitive && repeatedType.getName == this.constantElementName + + override def createCompliantRepeatedType(typ: Type, name: String, isElementRequired: Boolean, originalType: OriginalType, fieldContext: FieldContext): Type = { + if (!isElementRequired) throw new IllegalArgumentException(s"Primitive ${constantElementName} list format can only take required element") + if (!typ.isPrimitive) throw new IllegalArgumentException(s"Primitive list format cannot take group, but is given $typ") + new PrimitiveType(Type.Repetition.REPEATED, typ.asPrimitiveType.getPrimitiveTypeName, this.constantElementName, originalType) + } +} + +private[scrooge] object PrimitiveElementRule extends PrimitiveListRule { + override def constantElementName: String = "element" +} + +private[scrooge] object PrimitiveArrayRule extends PrimitiveListRule { + override def constantElementName: String = "array" +} + +/** + * Rule 2 in https://github.com/apache/parquet-format/blob/master/LogicalTypes.md#backward-compatibility-rules + * Although documentation only mentions `element` group and not for `array`, + * Spark does write out with group `array` when legacy write format is enabled. + * repeated group [element|array] { + * required binary str (UTF8); + * required int32 num; + * } + */ +private[scrooge] sealed trait GroupListRule extends ParquetListFormatRule { + + def constantElementName: String + + override def isElementRequired(repeatedType: Type): Boolean = { + // According Rule 2, + // "If the repeated field is a group with multiple fields, + // then its type is the element type and elements are required." + true + } + + override def elementType(repeatedType: Type): Type = repeatedType + + override def elementName(repeatedType: Type): String = this.constantElementName + + override def appliesToType(repeatedType: Type): Boolean = { + if (repeatedType.isPrimitive) { + false + } else { + val groupType = repeatedType.asGroupType + groupType.getFields.size > 0 && groupType.getName == this.constantElementName + } + } + + override def createCompliantRepeatedType(typ: Type, name: String, isElementRequired: Boolean, originalType: OriginalType, fieldContext: FieldContext): Type = { + if (!isElementRequired) throw new IllegalArgumentException(s"Group ${constantElementName} list format can only take required element") + if (typ.isPrimitive) throw new IllegalArgumentException(s"Group list format cannot take primitive type, but is given $typ") + else new GroupType(Type.Repetition.REPEATED, this.constantElementName, originalType, typ.asGroupType.getFields) + } +} + +private[scrooge] object GroupElementRule extends GroupListRule { + override def constantElementName: String = "element" +} + +private[scrooge] object GroupArrayRule extends GroupListRule { + override def constantElementName: String = "array" +} + +/** + * Rule 3 in https://github.com/apache/parquet-format/blob/master/LogicalTypes.md#backward-compatibility-rules + * Although the documentation only mentions group with one field, the generated schema from thrift struct + * does write out both primitive type and group type with multiple fields. + * repeated group my_list_field_tuple { + * required binary str (UTF8); + * } + * This repeated type implies the field name is `my_list_field`. This is the only format where + * info is not fully self-contained. + */ +private[scrooge] object TupleRule extends ParquetListFormatRule { + + private val tupleSuffix = "_tuple" + + override def appliesToType(repeatedType: Type): Boolean = repeatedType.getName.endsWith(tupleSuffix) + + override def elementName(repeatedType: Type): String = { + repeatedType.getName.substring(0, repeatedType.getName.length - tupleSuffix.length) + } + + override def elementType(repeatedType: Type): Type = repeatedType + + override private[scrooge] def isElementRequired(repeatedType: Type) = { + true + } + + override def createCompliantRepeatedType(typ: Type, name: String, isElementRequired: Boolean, originalType: OriginalType, fieldContext: FieldContext): Type = { + // nested list has type name of the form: `field_original_name_tuple_tuple..._tuple` for the depth of list + val suffixed_name = (List(fieldContext.name) ++ (1 to fieldContext.nestedListLevel).toList.map(_ => "tuple")).mkString("_") + if (typ.isPrimitive) { + new PrimitiveType(Type.Repetition.REPEATED, typ.asPrimitiveType.getPrimitiveTypeName, suffixed_name, originalType) + } else { + new GroupType(Type.Repetition.REPEATED, suffixed_name, originalType, typ.asGroupType.getFields) + } + } +} + +private[scrooge] sealed trait ThreeLevelRule extends ParquetListFormatRule { + + def constantElementName: String + + def constantRepeatedGroupName: String + + override def appliesToType(repeatedField: Type): Boolean = { + if (repeatedField.isPrimitive || !(repeatedField.getName == constantRepeatedGroupName)) { + false + } else { + elementType(repeatedField).getName == constantElementName + } + } + + override def elementType(repeatedType: Type): Type = firstField(repeatedType.asGroupType) + + override private[scrooge] def isElementRequired(repeatedType: Type): Boolean = { + elementType(repeatedType).getRepetition == Type.Repetition.REQUIRED + } + + override def elementName(repeatedType: Type): String = constantElementName + + override def createCompliantRepeatedType(originalElementType: Type, name: String, isElementRequired: Boolean, originalType: OriginalType, fieldContext: FieldContext): Type = { + + val repetition = if (isElementRequired) Type.Repetition.REQUIRED else Type.Repetition.OPTIONAL + val elementType = if (originalElementType.isPrimitive) { + new PrimitiveType(repetition, originalElementType.asPrimitiveType.getPrimitiveTypeName, constantElementName, originalType) + } else { + new GroupType( + repetition, + constantElementName, + originalType, + originalElementType.asGroupType.getFields) + } + + new GroupType(Type.Repetition.REPEATED, constantRepeatedGroupName, Seq(elementType).asJava) + } + + private def firstField(groupType: GroupType): Type = { + groupType.getFields.get(0) + } +} + +/** + * Standard parquet list format. + * repeated group list { + * element; + * } + */ +private[scrooge] object StandardRule extends ThreeLevelRule { + + def constantElementName = "element" + + def constantRepeatedGroupName = "list" +} + +/** + * Spark legacy format when element is nullable. + * repeated group bag { + * optional array; + * } + * Documentation on Spark is incorrect at the time of writing. It indicates `optional group bag`, + * but it should be `repeated group bag`, and optional element. + * https://github.com/apache/spark/blob/master/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetWriteSupport.scala#L345-L355 + * Writing Dataset[Seq[List]] in Spark with default encoder and legacy mode on will give + * + * message spark_schema { + * optional group value (LIST) { + * repeated group bag { + * optional binary array (UTF8); + * } + * } + * } + */ +private[scrooge] object SparkLegacyNullableElementRule extends ThreeLevelRule { + override def constantElementName: String = "array" + + override def constantRepeatedGroupName: String = "bag" + + override def createCompliantRepeatedType(originalElementType: Type, name: String, isElementRequired: Boolean, originalType: OriginalType, fieldContext: FieldContext): Type = { + if (isElementRequired) { + throw new IllegalArgumentException(s"Spark legacy mode for nullable element cannot take required element. Found: ${originalElementType}") + } else { + super.createCompliantRepeatedType(originalElementType, name, isElementRequired, originalType, fieldContext) + } + } +} diff --git a/scalding-parquet-scrooge/src/main/scala/com/twitter/scalding/parquet/scrooge/ParquetMapFormatter.scala b/scalding-parquet-scrooge/src/main/scala/com/twitter/scalding/parquet/scrooge/ParquetMapFormatter.scala new file mode 100644 index 0000000000..9a10db843c --- /dev/null +++ b/scalding-parquet-scrooge/src/main/scala/com/twitter/scalding/parquet/scrooge/ParquetMapFormatter.scala @@ -0,0 +1,49 @@ +package com.twitter.scalding.parquet.scrooge + +import org.apache.parquet.schema.{ GroupType, OriginalType, Type } + +/** + * Format parquet map schema of read type to structure of file type. + * The supported formats are: + * 1) Standard repeated type of `key_value` without annotation + * 2) Legacy repeated `map` field annotated with (MAP_KEY_VALUE) + * as described in + * https://github.com/apache/parquet-format/blob/master/LogicalTypes.md#maps + * + * In a common use case, read schema from thrift struct has legacy format 2) created by + * [[org.apache.parquet.schema.ConversionPatterns]] + */ +private[scrooge] object ParquetMapFormatter extends ParquetCollectionFormatter { + + def formatCompatibleRepeatedType(fileRepeatedMapType: Type, + readRepeatedMapType: Type, + fieldContext: FieldContext, + recursiveSolver: (Type, Type, FieldContext) => Type): Type = { + val solvedRepeatedType = recursiveSolver(fileRepeatedMapType, readRepeatedMapType, fieldContext) + fileRepeatedMapType.asGroupType().withNewFields(solvedRepeatedType.asGroupType().getFields) + } + + def extractGroup(groupType: GroupType): Option[MapGroup] = { + if (isMapGroup(groupType)) { + Some(MapGroup(groupType, groupType.getFields.get(0))) + } else { + None + } + } + + private def isMapGroup(groupType: GroupType): Boolean = { + (groupType.getOriginalType == OriginalType.MAP) && + (groupType.getFieldCount == 1) && + groupType.getFields.get(0).isRepetition(Type.Repetition.REPEATED) && + (isLegacyRepeatedType(groupType.getFields.get(0)) || + isStandardRepeatedType(groupType.getFields.get(0))) + } + + private def isLegacyRepeatedType(repeatedType: Type) = { + (repeatedType.getName == "map") && (repeatedType.getOriginalType == OriginalType.MAP_KEY_VALUE) + } + + private def isStandardRepeatedType(repeatedType: Type) = { + (repeatedType.getName == "key_value") && (repeatedType.getOriginalType == null) + } +} diff --git a/scalding-parquet-scrooge/src/test/scala/com/twitter/scalding/parquet/scrooge/ParquetCollectionFormatCompatibilityTests.scala b/scalding-parquet-scrooge/src/test/scala/com/twitter/scalding/parquet/scrooge/ParquetCollectionFormatCompatibilityTests.scala new file mode 100644 index 0000000000..948e85d61d --- /dev/null +++ b/scalding-parquet-scrooge/src/test/scala/com/twitter/scalding/parquet/scrooge/ParquetCollectionFormatCompatibilityTests.scala @@ -0,0 +1,1385 @@ +package com.twitter.scalding.parquet.scrooge + +import java.util + +import org.apache.parquet.schema.{ MessageType, MessageTypeParser } +import org.apache.parquet.thrift.{ DecodingSchemaMismatchException, ThriftSchemaConverter } +import org.apache.parquet.thrift.struct.ThriftField.Requirement +import org.apache.parquet.thrift.struct.{ ThriftField, ThriftType } +import org.apache.parquet.thrift.struct.ThriftType.StructType.StructOrUnionType +import org.apache.parquet.thrift.struct.ThriftType.{ ListType, MapType, StructType } +import org.scalatest.{ Matchers, WordSpec } + +class ParquetCollectionFormatCompatibilityTests extends WordSpec with Matchers { + + private def testProjectAndAssertCompatibility(fileSchema: MessageType, + projectedReadSchema: MessageType) = { + val projectedFileSchema = ParquetCollectionFormatCompatibility.projectFileSchema(fileSchema, projectedReadSchema) + ScroogeReadSupport.assertGroupsAreCompatible(fileSchema, projectedFileSchema) + projectedFileSchema + } + + /** + * Helper wrapper to specify repetition string for exhaustive tests + */ + case class TestRepetitions(projectedReadRepetition1: String, projectedReadRepetition2: String, + fileRepetition1: String, fileRepetition2: String) + def feasibleRepetitions = { + for { + projectedRepetition1 <- Seq("required", "optional") + projectedRepetition2 <- Seq("required", "optional") + fileRepetition1 <- Seq("required", "optional") + fileRepetition2 <- Seq("required", "optional") + // when file type is optional, required projected type is breaking + if !(fileRepetition1 == "optional" && projectedRepetition1 == "required") + if !(fileRepetition2 == "optional" && projectedRepetition2 == "required") + } yield { + TestRepetitions(projectedRepetition1, projectedRepetition2, fileRepetition1, fileRepetition2) + } + } + + /** + * The following functions of different list formats are equivalent schemas to describe: + * {{ + * x: Int + * foo_string_list: Seq[Int] + * foo_struct_list: Option[Seq[Struct]] + * foo_list_of_list: Seq[Seq[Long]] + * y: Int + * foo_optional_list: + * }} + */ + def listElementRule(repetition1: String, repetition2: String) = ( + s""" + |message schema { + | $repetition1 int32 x; + | required group foo_string_list (LIST) { + | repeated int32 element; + | } + | optional group foo_struct_list (LIST) { + | repeated group element { + | required binary str (UTF8); + | ${repetition2} int32 num; + | } + | } + | required group foo_list_of_list (LIST) { + | repeated group element (LIST) { + | repeated int64 element; + | } + | } + | $repetition2 int32 y; + |} + """.stripMargin) + + def listArrayRule(repetition1: String, repetition2: String) = ( + s""" + |message schema { + | $repetition1 int32 x; + | required group foo_string_list (LIST) { + | repeated int32 array; + | } + | optional group foo_struct_list (LIST) { + | repeated group array { + | required binary str (UTF8); + | ${repetition2} int32 num; + | } + | } + | required group foo_list_of_list (LIST) { + | repeated group array (LIST) { + | repeated int64 array; + | } + | } + | $repetition2 int32 y; + |} + """.stripMargin) + + def listTupleRule(repetition1: String, repetition2: String) = ( + s""" + |message schema { + | $repetition1 int32 x; + | required group foo_string_list (LIST) { + | repeated int32 foo_string_list_tuple; + | } + | optional group foo_struct_list (LIST) { + | repeated group foo_struct_list_tuple { + | required binary str (UTF8); + | ${repetition2} int32 num; + | } + | } + | required group foo_list_of_list (LIST) { + | repeated group foo_list_of_list_tuple (LIST) { + | repeated int64 foo_list_of_list_tuple_tuple; + | } + | } + | $repetition2 int32 y; + |} + """.stripMargin) + + def listStandardRule(repetition1: String, repetition2: String, nullableElement: Boolean = false) = { + val requiredOrOptional = if (nullableElement) "optional" else "required" + (s""" + |message schema { + | $repetition1 int32 x; + | required group foo_string_list (LIST) { + | repeated group list { + | $requiredOrOptional int32 element; + | } + | } + | optional group foo_struct_list (LIST) { + | repeated group list { + | $requiredOrOptional group element { + | required binary str (UTF8); + | ${repetition2} int32 num; + | } + | } + | } + | required group foo_list_of_list (LIST) { + | repeated group list { + | $requiredOrOptional group element (LIST) { + | repeated group list { + | $requiredOrOptional int64 element; + | } + | } + | } + | } + | $repetition2 int32 y; + |} + """.stripMargin) + } + + val listRequiredElementRules = Seq( + ("element", listElementRule(_, _)), + ("array", listArrayRule(_, _)), + ("tuple", listTupleRule(_, _)), + ("standard", (from: String, to: String) => listStandardRule(from, to, nullableElement = false))) + + // All possible format pairs of list with non-nullable element + for { + (projectedReadRuleName, projectedReadSchemaFunc) <- listRequiredElementRules + (fileRuleName, fileSchemaFunc) <- listRequiredElementRules + } yield { + s"Project for list with non-nullable element file: [${fileRuleName}] read: [${projectedReadRuleName}]" should { + "take option/require specifications from projected read schema" in { + for { + feasibleRepetition <- feasibleRepetitions + } yield { + testProjectedFileSchemaHasReadSchemaRepetitions( + fileSchemaFunc, + projectedReadSchemaFunc, + feasibleRepetition) + } + } + } + } + + def listSparkLegacyNullableElementRule(repetition1: String, repetition2: String) = ( + s""" + |message schema { + | $repetition1 int32 x; + | required group foo_string_list (LIST) { + | repeated group bag { + | optional int32 array; + | } + | } + | optional group foo_struct_list (LIST) { + | repeated group bag { + | optional group array { + | required binary str (UTF8); + | ${repetition2} int32 num; + | } + | } + | } + | required group foo_list_of_list (LIST) { + | repeated group bag { + | optional group array (LIST) { + | repeated group bag { + | optional int64 array; + | } + | } + | } + | } + | $repetition2 int32 y; + |} + """.stripMargin) + + private def testProjectedFileSchemaHasReadSchemaRepetitions( + fileSchemaFunc: (String, String) => String, + projectedReadSchemaFunc: (String, String) => String, + feasibleRepetition: TestRepetitions): Any = { + + val projectedReadSchema = MessageTypeParser.parseMessageType( + projectedReadSchemaFunc( + feasibleRepetition.projectedReadRepetition1, + feasibleRepetition.projectedReadRepetition2)) + val fileSchema = MessageTypeParser.parseMessageType( + fileSchemaFunc( + feasibleRepetition.fileRepetition1, + feasibleRepetition.fileRepetition2)) + val expectedProjectedFileSchema = MessageTypeParser.parseMessageType( + fileSchemaFunc( + feasibleRepetition.projectedReadRepetition1, + feasibleRepetition.projectedReadRepetition2)) + expectedProjectedFileSchema shouldEqual testProjectAndAssertCompatibility(fileSchema, projectedReadSchema) + } + + "Project for list with nullable element" should { + val listNullableElementRules = Seq( + ("spark-legacy", listSparkLegacyNullableElementRule(_, _)), + ("standard-with", (from: String, to: String) => listStandardRule(from, to, nullableElement = true))) + for { + (projectedReadRuleName, projectedReadSchemaFunc) <- listNullableElementRules + (fileRuleName, fileSchemaFunc) <- listNullableElementRules + } yield { + s"file: [${fileRuleName}] read: [${projectedReadRuleName}]" should { + "take option/require specifications from projected read schema" in { + for { + feasibleRepetition <- feasibleRepetitions + } yield { + testProjectedFileSchemaHasReadSchemaRepetitions( + fileSchemaFunc, + projectedReadSchemaFunc, + feasibleRepetition) + } + } + } + } + + "failed to format file: required element, read: legacy write with nullable element" in { + for { + feasibleRepetition <- feasibleRepetitions + (_, requiredElementSchemaFunc) <- listRequiredElementRules + } yield { + val e = intercept[IllegalArgumentException] { + testProjectedFileSchemaHasReadSchemaRepetitions( + fileSchemaFunc = listSparkLegacyNullableElementRule, + projectedReadSchemaFunc = requiredElementSchemaFunc, + feasibleRepetition) + } + e.getMessage should include("Spark legacy mode for nullable element cannot take required element") + } + } + } + + "Project for map" should { + "file/read identity" in { + val fileSchema = MessageTypeParser.parseMessageType( + """ + |message FileSchema { + | required group map (MAP) { + | repeated group key_value { + | required binary key (UTF8); + | required group value { + | required binary _id (UTF8); + | optional double created; + | } + | } + | } + |} + """.stripMargin) + val projectedFileSchema = testProjectAndAssertCompatibility(fileSchema, fileSchema) + fileSchema shouldEqual projectedFileSchema + } + + "file/read identity from thrift struct (string key, struct value)" in { + val listType = new ListType(new ThriftField("list", 2, Requirement.REQUIRED, new ThriftType.StringType)) + val children = new ThriftField("foo", 3, Requirement.REQUIRED, listType) + val mapValueType = new StructType(util.Arrays.asList(children), + StructOrUnionType.STRUCT) + val message = schemaFromThriftMap(mapValueType) + message shouldEqual MessageTypeParser.parseMessageType( + """ + |message ParquetSchema { + | required group map_field (MAP) = 6 { + | repeated group map (MAP_KEY_VALUE) { + | required binary key (UTF8); + | optional group value { + | required group foo (LIST) = 3 { + | repeated binary foo_tuple (UTF8); + | } + | } + | } + | } + |} + """.stripMargin) + + val projectedFileSchema = testProjectAndAssertCompatibility(message, message) + message shouldEqual projectedFileSchema + } + + "file/read identity from thrift struct (string key, list string value)" in { + val listType = new ListType(new ThriftField("list", 2, Requirement.REQUIRED, new ThriftType.StringType)) + val message = schemaFromThriftMap(listType) + message shouldEqual MessageTypeParser.parseMessageType( + """ + |message ParquetSchema { + | required group map_field (MAP) = 6 { + | repeated group map (MAP_KEY_VALUE) { + | required binary key (UTF8); + | optional group value (LIST) { + | repeated binary value_tuple (UTF8); + | } + | } + | } + |} + """.stripMargin) + + val projectedFileSchema = testProjectAndAssertCompatibility(message, message) + message shouldEqual projectedFileSchema + } + + "file: standard key_value, read: legacy (MAP_KEY_VALUE)" in { + val fileSchema = MessageTypeParser.parseMessageType( + """ + |message FileSchema { + | required group map_field (MAP) { + | repeated group key_value { + | required binary key (UTF8); + | required int32 value; + | } + | } + |} + """.stripMargin) + val projectedReadSchema = MessageTypeParser.parseMessageType( + """ + |message ProjectedReadSchema { + | required group map_field (MAP) { + | repeated group map (MAP_KEY_VALUE) { + | required binary key (UTF8); + | optional int32 value; + | } + | } + |} + """.stripMargin) + val projectedFileSchema = testProjectAndAssertCompatibility(fileSchema, projectedReadSchema) + val expected = MessageTypeParser.parseMessageType( + """ + |message ProjectedReadSchema { + | required group map_field (MAP) { + | repeated group key_value { + | required binary key (UTF8); + | optional int32 value; + | } + | } + |} + """.stripMargin) + expected shouldEqual projectedFileSchema + } + + "file: legacy (MAP_KEY_VALUE), read: standard key_value" in { + val fileSchema = MessageTypeParser.parseMessageType( + """ + |message FileSchema { + | required group map_field (MAP) { + | repeated group map (MAP_KEY_VALUE) { + | required binary key (UTF8); + | required int32 value; + | } + | } + |} + """.stripMargin) + val projectedReadSchema = MessageTypeParser.parseMessageType( + """ + |message ProjectedReadSchema { + | required group map_field (MAP) { + | repeated group key_value { + | required binary key (UTF8); + | optional int32 value; + | } + | } + |} + """.stripMargin) + val projectedFileSchema = testProjectAndAssertCompatibility(fileSchema, projectedReadSchema) + val expected = MessageTypeParser.parseMessageType( + """ + |message ProjectedReadSchema { + | required group map_field (MAP) { + | repeated group map (MAP_KEY_VALUE) { + | required binary key (UTF8); + | optional int32 value; + | } + | } + |} + """.stripMargin) + expected shouldEqual projectedFileSchema + } + + "map of map, file: standard key_value, read: legacy (MAP_KEY_VALUE)" in { + val fileSchema = MessageTypeParser.parseMessageType( + """ + |message FileSchema { + | required group map_of_map_field (MAP) { + | repeated group key_value { + | required binary key (UTF8); + | required group value (MAP) { + | repeated group key_value { + | required binary key (UTF8); + | required group value { + | required binary _id (UTF8); + | required int32 x; + | } + | } + | } + | } + | } + |} + """.stripMargin) + val projectedReadSchema = MessageTypeParser.parseMessageType( + """ + |message ProjectedReadSchema { + | required group map_of_map_field (MAP) { + | repeated group map (MAP_KEY_VALUE) { + | required binary key (UTF8); + | required group value (MAP) { + | repeated group map (MAP_KEY_VALUE) { + | required binary key (UTF8); + | required group value { + | optional int32 x; + | } + | } + | } + | } + | } + |} + """.stripMargin) + val projectedFileSchema = testProjectAndAssertCompatibility(fileSchema, projectedReadSchema) + val expected = MessageTypeParser.parseMessageType( + """ + |message ProjectedReadSchema { + | required group map_of_map_field (MAP) { + | repeated group key_value { + | required binary key (UTF8); + | required group value (MAP) { + | repeated group key_value { + | required binary key (UTF8); + | required group value { + | optional int32 x; + | } + | } + | } + | } + | } + |} + """.stripMargin) + expected shouldEqual projectedFileSchema + } + + def schemaFromThriftMap(mapValueType: ThriftType) = { + val mapType = new MapType( + new ThriftField("NOT_USED_KEY", 4, Requirement.REQUIRED, new ThriftType.StringType), + new ThriftField("NOT_USED_VALUE", 5, Requirement.REQUIRED, + mapValueType)) + new ThriftSchemaConverter().convert( + new StructType(util.Arrays.asList( + new ThriftField("map_field", 6, Requirement.REQUIRED, mapType)), StructOrUnionType.STRUCT)) + } + } + + "Format compat for list" should { + "file: primitive array, read: x_tuple" in { + val fileSchema = MessageTypeParser.parseMessageType( + """ + |message FileSchema { + | required group country_codes (LIST) { + | repeated binary array (UTF8); + | } + | required int32 x; + |} + """.stripMargin) + val projectedReadSchema = MessageTypeParser.parseMessageType( + """ + |message ProjectedReadSchema { + | optional group country_codes (LIST) { + | repeated binary country_codes_tuple (UTF8); + | } + |} + """.stripMargin) + val projectedFileSchema = testProjectAndAssertCompatibility(fileSchema, projectedReadSchema) + val expected = MessageTypeParser.parseMessageType( + """ + |message ProjectedReadSchema { + | optional group country_codes (LIST) { + | repeated binary array (UTF8); + | } + |} + """.stripMargin) + expected shouldEqual projectedFileSchema + } + + "file: primitive element, read: x_tuple" in { + val fileSchema = MessageTypeParser.parseMessageType( + """ + |message FileSchema { + | required group country_codes (LIST) { + | repeated binary element (UTF8); + | } + | required int32 x; + |} + """.stripMargin) + val projectedReadSchema = MessageTypeParser.parseMessageType( + """ + |message ProjectedReadSchema { + | optional group country_codes (LIST) { + | repeated binary country_codes_tuple (UTF8); + | } + |} + """.stripMargin) + val projectedFileSchema = testProjectAndAssertCompatibility(fileSchema, projectedReadSchema) + val expected = MessageTypeParser.parseMessageType( + """ + |message ProjectedReadSchema { + | optional group country_codes (LIST) { + | repeated binary element (UTF8); + | } + |} + """.stripMargin) + expected shouldEqual projectedFileSchema + } + + "file: 3-level, read: x_tuple" in { + val fileSchema = MessageTypeParser.parseMessageType( + """ + |message FileSchema { + | required group country_codes (LIST) { + | repeated group list { + | required binary element (UTF8); + | } + | } + | required int32 x; + |} + """.stripMargin) + val projectedReadSchema = MessageTypeParser.parseMessageType( + """ + |message ProjectedReadSchema { + | optional group country_codes (LIST) { + | repeated binary country_codes_tuple (UTF8); + | } + |} + """.stripMargin) + val projectedFileSchema = testProjectAndAssertCompatibility(fileSchema, projectedReadSchema) + // note optional of result, and field rename + val expected = MessageTypeParser.parseMessageType( + """ + |message ProjectedReadSchema { + | optional group country_codes (LIST) { + | repeated group list { + | required binary element (UTF8); + | } + | } + |} + """.stripMargin) + expected shouldEqual projectedFileSchema + } + + "file: group array, read: nested x_tuple" in { + val fileSchema = MessageTypeParser.parseMessageType( + """ + |message FileSchema { + | required group foo (LIST) { + | repeated group array (LIST) { + | repeated binary array (UTF8); + | } + | } + |} + """.stripMargin) + val projectedReadSchema = MessageTypeParser.parseMessageType( + """ + |message ProjectedReadSchema { + | optional group foo (LIST) { + | repeated group foo_tuple (LIST) { + | repeated binary foo_tuple_tuple (UTF8); + | } + | } + |} + """.stripMargin) + val projectedFileSchema = testProjectAndAssertCompatibility(fileSchema, projectedReadSchema) + // note optional of result, and field rename + val expected = MessageTypeParser.parseMessageType( + """ + |message ProjectedReadSchema { + | optional group foo (LIST) { + | repeated group array (LIST) { + | repeated binary array (UTF8); + | } + | } + |} + """.stripMargin) + expected shouldEqual projectedFileSchema + } + + "file: nested 3-level, read: nested x_tuple" in { + val fileSchema = MessageTypeParser.parseMessageType( + """ + |message FileSchema { + | required group foo (LIST) { + | repeated group list { + | required group element (LIST) { + | repeated group list { + | required binary element (UTF8); + | } + | } + | } + | } + |} + """.stripMargin) + val projectedReadSchema = MessageTypeParser.parseMessageType( + """ + |message ProjectedReadSchema { + | optional group foo (LIST) { + | repeated group foo_tuple (LIST) { + | repeated binary foo_tuple_tuple (UTF8); + | } + | } + |} + """.stripMargin) + val projectedFileSchema = testProjectAndAssertCompatibility(fileSchema, projectedReadSchema) + val expected = MessageTypeParser.parseMessageType( + """ + |message ProjectedReadSchema { + | optional group foo (LIST) { + | repeated group list { + | required group element (LIST) { + | repeated group list { + | required binary element (UTF8); + | } + | } + | } + | } + |} + """.stripMargin) + expected shouldEqual projectedFileSchema + } + + "file: 3-level, read: binary array" in { + val fileSchema = MessageTypeParser.parseMessageType( + """ + |message FileSchema { + | required group country_codes (LIST) { + | repeated group list { + | required binary element (UTF8); + | } + | } + | required int32 x; + |} + """.stripMargin) + + // inner list is `binary array` + val projectedReadSchema = MessageTypeParser.parseMessageType( + """ + |message ProjectedReadSchema { + | optional group country_codes (LIST) { + | repeated binary array (UTF8); + | } + |} + """.stripMargin) + val projectedFileSchema = testProjectAndAssertCompatibility(fileSchema, projectedReadSchema) + + val expected = MessageTypeParser.parseMessageType( + """ + |message ProjectedReadSchema { + | optional group country_codes (LIST) { + | repeated group list { + | required binary element (UTF8); + | } + | } + |} + """.stripMargin) + expected shouldEqual projectedFileSchema + } + + "file: 3-level (identity), read: 3-level" in { + val fileSchema = MessageTypeParser.parseMessageType( + """ + |message FileSchema { + | required group country_codes (LIST) { + | repeated group list { + | required binary element (UTF8); + | } + | } + | required int32 x; + |} + """.stripMargin) + val projectedFileSchema = testProjectAndAssertCompatibility(fileSchema, fileSchema) + fileSchema shouldEqual projectedFileSchema + } + + "file: nested 3-level, read: nested primitive array" in { + val fileSchema = MessageTypeParser.parseMessageType( + """ + |message FileSchema { + | required group array_of_country_codes (LIST) { + | repeated group list { + | required group element (LIST) { + | repeated group list { + | required binary element (UTF8); + | } + | } + | } + | } + | required int32 x; + |} + """.stripMargin) + + // inner list is `binary array` + val projectedReadSchema = MessageTypeParser.parseMessageType( + """ + |message ProjectedReadSchema { + | optional group array_of_country_codes (LIST) { + | repeated group list { + | required group element (LIST) { + | repeated binary array (UTF8); + | } + | } + | } + |} + """.stripMargin) + val projectedFileSchema = testProjectAndAssertCompatibility(fileSchema, projectedReadSchema) + + val expected = MessageTypeParser.parseMessageType( + """ + |message ProjectedReadSchema { + | optional group array_of_country_codes (LIST) { + | repeated group list { + | required group element (LIST) { + | repeated group list { + | required binary element (UTF8); + | } + | } + | } + | } + |} + """.stripMargin) + expected shouldEqual projectedFileSchema + } + + "file: 3-level, read: element group" in { + val fileSchema = MessageTypeParser.parseMessageType( + """ + |message FileSchema { + | required group country_codes (LIST) { + | repeated group list { + | required group element { + | required binary foo (UTF8); + | required binary bar (UTF8); + | required binary zing (UTF8); + | } + | } + | } + | required int32 x; + |} + """.stripMargin) + val projectedReadSchema = MessageTypeParser.parseMessageType( + """ + |message ProjectedReadSchema { + | optional group country_codes (LIST) { + | repeated group element { + | optional binary foo (UTF8); + | required binary zing (UTF8); + | } + | } + |} + """.stripMargin) + val projectedFileSchema = testProjectAndAssertCompatibility(fileSchema, projectedReadSchema) + + val expected = MessageTypeParser.parseMessageType( + """ + |message ProjectedReadSchema { + | optional group country_codes (LIST) { + | repeated group list { + | required group element { + | optional binary foo (UTF8); + | required binary zing (UTF8); + | } + | } + | } + |} + """.stripMargin) + expected shouldEqual projectedFileSchema + } + + "file: nested primitive array, read: 3-level" in { + val fileSchema = MessageTypeParser.parseMessageType( + """ + |message FileSchema { + | required group array_of_country_codes (LIST) { + | repeated group list { + | required group element (LIST) { + | repeated binary array (UTF8); + | } + | } + | } + | required int32 x; + |} + """.stripMargin) + + val projectedReadSchema = MessageTypeParser.parseMessageType( + """ + |message ProjectedReadSchema { + | optional group array_of_country_codes (LIST) { + | repeated group list { + | required group element (LIST) { + | repeated group list { + | required binary element (UTF8); + | } + | } + | } + | } + |} + """.stripMargin) + val projectedFileSchema = testProjectAndAssertCompatibility(fileSchema, projectedReadSchema) + + val expected = MessageTypeParser.parseMessageType( + """ + |message ProjectedReadSchema { + | optional group array_of_country_codes (LIST) { + | repeated group list { + | required group element (LIST) { + | repeated binary array (UTF8); + | } + | } + | } + |} + """.stripMargin) + expected shouldEqual projectedFileSchema + } + + "file: 3-level, read: x_tuple in group" in { + val fileSchema = MessageTypeParser.parseMessageType( + """ + |message FileSchema { + | optional group connect_delays (LIST) { + | repeated group list { + | required group element { + | optional binary description (UTF8); + | optional binary created_by (UTF8); + | optional group currencies (LIST) { + | repeated group list { + | required binary element (UTF8); + | } + | } + | } + | } + | } + |} + """.stripMargin) + val projectedReadSchema = MessageTypeParser.parseMessageType( + """ + |message ProjectedReadSchema { + | optional group connect_delays (LIST) { + | repeated group connect_delays_tuple { + | optional binary description (UTF8); + | optional group currencies (LIST) { + | repeated binary currencies_tuple (UTF8); + | } + | } + | } + |} + """.stripMargin) + val projectedFileSchema = testProjectAndAssertCompatibility(fileSchema, projectedReadSchema) + val expected = MessageTypeParser.parseMessageType( + """ + |message ProjectedReadSchema { + | optional group connect_delays (LIST) { + | repeated group list { + | required group element { + | optional binary description (UTF8); + | optional group currencies (LIST) { + | repeated group list { + | required binary element (UTF8); + | } + | } + | } + | } + | } + |} + """.stripMargin) + expected shouldEqual projectedFileSchema + } + + "file: x_tuple, read: 3-level" in { + val fileSchema = MessageTypeParser.parseMessageType( + """ + |message FileSchema { + | required group foo (LIST) { + | repeated group foo_tuple (LIST) { + | repeated binary foo_tuple_tuple (UTF8); + | } + | } + | required int32 x; + |} + """.stripMargin) + val projectedReadSchema = MessageTypeParser.parseMessageType( + """ + |message ProjectedReadSchema { + | optional group foo (LIST) { + | repeated group list { + | required group element (LIST) { + | repeated group list { + | required binary element (UTF8); + | } + | } + | } + | } + | optional int32 x; + |} + """.stripMargin) + val projectedFileSchema = testProjectAndAssertCompatibility(fileSchema, projectedReadSchema) + MessageTypeParser.parseMessageType( + """ + |message ProjectedReadSchema { + | optional group foo (LIST) { + | repeated group foo_tuple (LIST) { + | repeated binary foo_tuple_tuple (UTF8); + | } + | } + | optional int32 x; + |} + """.stripMargin) shouldEqual projectedFileSchema + } + + "file: absent, read: optional list " in { + val fileSchema = MessageTypeParser.parseMessageType( + """ + |message FileSchema { + | required group foo (LIST) { + | repeated group foo_tuple (LIST) { + | repeated binary foo_tuple_tuple (UTF8); + | } + | } + | required int32 x; + |} + """.stripMargin) + val projectedReadSchema = MessageTypeParser.parseMessageType( + """ + |message ProjectedReadSchema { + | optional group foo (LIST) { + | repeated group list { + | required group element (LIST) { + | repeated group list { + | required binary element (UTF8); + | } + | } + | } + | } + | optional group foo_optional (LIST) { + | repeated group list { + | required binary element (UTF8); + | } + | } + | optional int32 x; + |} + """.stripMargin) + val projectedFileSchema = testProjectAndAssertCompatibility(fileSchema, projectedReadSchema) + MessageTypeParser.parseMessageType( + """ + |message ProjectedReadSchema { + | optional group foo (LIST) { + | repeated group foo_tuple (LIST) { + | repeated binary foo_tuple_tuple (UTF8); + | } + | } + | optional group foo_optional (LIST) { + | repeated group list { + | required binary element (UTF8); + | } + | } + | optional int32 x; + |} + """.stripMargin) shouldEqual projectedFileSchema + } + + "file: 3-level, read: unknown to return read type" in { + val fileSchema = MessageTypeParser.parseMessageType( + """ + |message FileSchema { + | required group country_codes (LIST) { + | repeated group list { + | required group element { + | required binary foo (UTF8); + | } + | } + | } + | required int32 x; + |} + """.stripMargin) + val projectedReadSchema = MessageTypeParser.parseMessageType( + """ + |message ProjectedReadSchema { + | optional group country_codes (LIST) { + | repeated group unknown_element_format { + | optional binary foo (UTF8); + | } + | } + |} + """.stripMargin) + val projectedFileSchema = testProjectAndAssertCompatibility(fileSchema, projectedReadSchema) + + val expected = MessageTypeParser.parseMessageType( + """ + |message ProjectedReadSchema { + | optional group country_codes (LIST) { + | repeated group unknown_element_format { + | optional binary foo (UTF8); + | } + | } + |} + """.stripMargin) + expected shouldEqual projectedFileSchema + } + } + + "Format compat for mixed collection" should { + "list of map: file/read identity from thrift struct" in { + val mapType = new MapType( + new ThriftField("NOT_USED_KEY", 4, Requirement.REQUIRED, new ThriftType.StringType), + new ThriftField("NOT_USED_VALUE", 5, Requirement.REQUIRED, new ThriftType.I64Type)) + val message = new ThriftSchemaConverter().convert( + new StructType(util.Arrays.asList( + new ThriftField("list_of_map", 2, Requirement.REQUIRED, new ListType( + new ThriftField("NOT_USED_ELEMENT", 2, Requirement.REQUIRED, mapType)))), StructOrUnionType.STRUCT)) + + message shouldEqual MessageTypeParser.parseMessageType( + """ + |message ParquetSchema { + | required group list_of_map (LIST) = 2 { + | repeated group list_of_map_tuple (MAP) { + | repeated group map (MAP_KEY_VALUE) { + | required binary key (UTF8); + | optional int64 value; + | } + | } + | } + |} + | + """.stripMargin) + + val projectedFileSchema = testProjectAndAssertCompatibility(message, message) + message shouldEqual projectedFileSchema + } + + "map of list, file: standard, read: thrift-generated" in { + val fileSchema = MessageTypeParser.parseMessageType( + """ + |message FileSchema { + | required group map_field (MAP) { + | repeated group key_value { + | required binary key (UTF8); + | required group value { + | optional group foo (LIST) { + | repeated group list { + | required binary element (UTF8); + | } + | } + | required int32 x; + | } + | } + | } + |} + | + """.stripMargin) + val projectedReadSchema = MessageTypeParser.parseMessageType( + """ + |message ParquetSchema { + | required group map_field (MAP) { + | repeated group map (MAP_KEY_VALUE) { + | required binary key (UTF8); + | optional group value { + | optional group foo (LIST) { + | repeated binary foo_tuple (UTF8); + | } + | optional int32 x; + | } + | } + | } + |} + """.stripMargin) + + val projectedFileSchema = testProjectAndAssertCompatibility(fileSchema, projectedReadSchema) + val expected = MessageTypeParser.parseMessageType( + """ + |message ParquetSchema { + | required group map_field (MAP) { + | repeated group key_value { + | required binary key (UTF8); + | optional group value { + | optional group foo (LIST) { + | repeated group list { + | required binary element (UTF8); + | } + | } + | optional int32 x; + | } + | } + | } + |} + """.stripMargin) + expected shouldEqual projectedFileSchema + } + + "file: standard, read: list of map: tuple_x" in { + val fileSchema = MessageTypeParser.parseMessageType( + """ + |message FileSchema { + | required group list_of_map (LIST) { + | repeated group list { + | required group element (MAP) { + | repeated group key_value { + | required binary key (UTF8); + | required group value { + | required binary _id (UTF8); + | required double created; + | } + | } + | } + | } + | } + |} + """.stripMargin) + val projectedReadSchema = MessageTypeParser.parseMessageType( + """ + |message ProjectedReadSchema { + | required group list_of_map (LIST) { + | repeated group element_tuple (MAP) { + | repeated group map (MAP_KEY_VALUE) { + | required binary key (UTF8); + | optional group value { + | optional double created; + | } + | } + | } + | } + |} + """.stripMargin) + + val projectedFileSchema = testProjectAndAssertCompatibility(fileSchema, projectedReadSchema) + val expected = MessageTypeParser.parseMessageType( + """ + |message ProjectedReadSchema { + | required group list_of_map (LIST) { + | repeated group list { + | required group element (MAP) { + | repeated group key_value { + | required binary key (UTF8); + | optional group value { + | optional double created; + | } + | } + | } + | } + | } + |} + """.stripMargin) + expected shouldEqual projectedFileSchema + } + + "file: tuple_x, read: list of map: standard" in { + val fileSchema = MessageTypeParser.parseMessageType( + """ + |message FileSchema { + | required group list_of_map (LIST) { + | repeated group list_of_map_tuple (MAP) { + | repeated group map (MAP_KEY_VALUE) { + | required binary key (UTF8); + | required group value { + | required binary _id (UTF8); + | required double created; + | } + | } + | } + | } + |} + """.stripMargin) + val projectedReadSchema = MessageTypeParser.parseMessageType( + """ + |message ProjectedReadSchema { + | required group list_of_map (LIST) { + | repeated group list { + | required group element (MAP) { + | repeated group key_value { + | required binary key (UTF8); + | optional group value { + | optional double created; + | } + | } + | } + | } + | } + |} + """.stripMargin) + + val projectedFileSchema = testProjectAndAssertCompatibility(fileSchema, projectedReadSchema) + val expected = MessageTypeParser.parseMessageType( + """ + |message ProjectedReadSchema { + | required group list_of_map (LIST) { + | repeated group list_of_map_tuple (MAP) { + | repeated group map (MAP_KEY_VALUE) { + | required binary key (UTF8); + | optional group value { + | optional double created; + | } + | } + | } + | } + |} + """.stripMargin) + expected shouldEqual projectedFileSchema + } + } + + "Format compat: check extra non-optional field projection" should { + "throws on missing (MAP_KEY_VALUE) annotation causing projection of non-existent field" in { + val fileSchema = MessageTypeParser.parseMessageType( + """ + |message FileSchema { + | required group map_field (MAP) { + | repeated group key_value { + | required binary key (UTF8); + | required int32 value; + | } + | } + |} + """.stripMargin) + // `map` isn't annotated with `MAP_KEY_VALUE`, and is thus treated as + // an actual field which then fails projection + val projectedReadSchema = MessageTypeParser.parseMessageType( + """ + |message ProjectedReadSchema { + | required group map_field (MAP) { + | repeated group map { + | required binary key (UTF8); + | optional int32 value; + | } + | } + |} + """.stripMargin) + + val e = intercept[DecodingSchemaMismatchException] { + testProjectAndAssertCompatibility(fileSchema, projectedReadSchema) + } + + e.getMessage should include("non-optional projected read field map:") + } + + "throws on missing `repeated` causing projection of non-existent field" in { + val fileSchema = MessageTypeParser.parseMessageType( + """ + |message FileSchema { + | optional group foo (LIST) { + | repeated group list { + | required group element { + | required binary zing (UTF8); + | } + | } + | } + |} + """.stripMargin) + val projectedReadSchema = MessageTypeParser.parseMessageType( + """ + |message ProjectedReadSchema { + | optional group foo (LIST) { + | required group element { + | optional binary zing (UTF8); + | } + | } + |} + """.stripMargin) + + val e = intercept[DecodingSchemaMismatchException] { + testProjectAndAssertCompatibility(fileSchema, projectedReadSchema) + } + + e.getMessage should include("non-optional projected read field element:") + } + + "throws on required but non-existent in target" in { + val fileSchema = MessageTypeParser.parseMessageType( + """ + |message FileSchema { + | required group map_field (MAP) { + | repeated group key_value { + | required binary key (UTF8); + | required int32 value; + | } + | } + |} + """.stripMargin) + val projectedReadSchema = MessageTypeParser.parseMessageType( + """ + |message ProjectedReadSchema { + | required group map_field (MAP) { + | repeated group map (MAP_KEY_VALUE) { + | required binary key (UTF8); + | optional int32 value; + | required int32 bogus_field; + | } + | } + |} + """.stripMargin) + + val e = intercept[DecodingSchemaMismatchException] { + testProjectAndAssertCompatibility(fileSchema, projectedReadSchema) + } + + e.getMessage should include("non-optional projected read field bogus_field:") + } + } + + "Schema mismatch" should { + "throws exception on inconsistent type between primitive and group" in { + val fileSchema = MessageTypeParser.parseMessageType( + """ + |message FileSchema { + | required group foo { + | repeated group bar { + | required binary _id (UTF8); + | required double created; + | } + | } + |} + """.stripMargin) + val projectedReadSchema = MessageTypeParser.parseMessageType( + """ + |message ProjectedReadSchema { + | required group foo { + | required binary bar (UTF8); + | } + |} + """.stripMargin) + + val e = intercept[DecodingSchemaMismatchException] { + testProjectAndAssertCompatibility(projectedReadSchema, fileSchema) + } + + e.getMessage should include("Found schema mismatch") + } + + "throws exception optional group in file schema but required group in read schema" in { + val fileSchema = MessageTypeParser.parseMessageType( + """ + |message FileSchema { + | required group foo { + | repeated group bar { + | required binary _id (UTF8); + | required double created; + | } + | } + |} + """.stripMargin) + val projectedReadSchema = MessageTypeParser.parseMessageType( + """ + |message ProjectedReadSchema { + | optional group foo { + | required binary bar (UTF8); + | } + |} + """.stripMargin) + + val e = intercept[DecodingSchemaMismatchException] { + testProjectAndAssertCompatibility(projectedReadSchema, fileSchema) + } + + e.getMessage should include ("Found required projected read field foo") + e.getMessage should include ("on optional file field") + } + } +}