From daebaa3d71664854af2495a5bec46f76315bc508 Mon Sep 17 00:00:00 2001 From: Gregor Ihmor Date: Thu, 14 Nov 2024 21:48:57 +0100 Subject: [PATCH 1/2] Add information about which required field was missing --- .../scalapb/compiler/ParseFromGenerator.scala | 20 ++++++++++++++++--- e2e/src/test/scala/NoBoxSpec.scala | 2 +- 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/compiler-plugin/src/main/scala/scalapb/compiler/ParseFromGenerator.scala b/compiler-plugin/src/main/scala/scalapb/compiler/ParseFromGenerator.scala index 1ef3923ee..9a8b78984 100644 --- a/compiler-plugin/src/main/scala/scalapb/compiler/ParseFromGenerator.scala +++ b/compiler-plugin/src/main/scala/scalapb/compiler/ParseFromGenerator.scala @@ -228,9 +228,23 @@ private[compiler] class ParseFromGenerator( val r = (0 until (requiredFieldMap.size + 63) / 64) .map(i => s"__requiredFields$i != 0L") .mkString(" || ") - p.add( - s"""if (${r}) { throw new _root_.com.google.protobuf.InvalidProtocolBufferException("Message missing required fields.") } """ - ) + p.add(s"""if (${r}) {""") + .indent + .add("val __missingFields = Seq.newBuilder[_root_.scala.Predef.String]") + .print(requiredFieldMap.toSeq.sortBy(_._2)) { + case (p, (fieldDescriptor, fieldNumber)) => + val bitmask = s"0x${"%x".format(1L << fieldNumber)}L" + val fieldVariable = s"__requiredFields${fieldNumber / 64}" + p.add( + s"""if (($fieldVariable & $bitmask) != 0L) __missingFields += "${fieldDescriptor.scalaName}"""" + ) + } + .add( + s"""val __message = s"Message missing required fields: $${__missingFields.result.mkString(", ")}"""", + s"""throw new _root_.com.google.protobuf.InvalidProtocolBufferException(__message)""" + ) + .outdent + .add("}") } .add(s"$myFullScalaName(") .indented( diff --git a/e2e/src/test/scala/NoBoxSpec.scala b/e2e/src/test/scala/NoBoxSpec.scala index 5ae180fd0..c72f26790 100644 --- a/e2e/src/test/scala/NoBoxSpec.scala +++ b/e2e/src/test/scala/NoBoxSpec.scala @@ -36,7 +36,7 @@ class NoBoxSpec extends AnyFlatSpec with Matchers { "RequiredCar" should "fail validation if required field is missing" in { intercept[InvalidProtocolBufferException] { RequiredCar.parseFrom(Array.empty[Byte]) - }.getMessage must be("Message missing required fields.") + }.getMessage must be("Message missing required fields: tyre1") } "RequiredCar" should "fail parsing from text if field is empty" in { From a811bce4200ce954076992b6d4d3b9adf71762fb Mon Sep 17 00:00:00 2001 From: Gregor Ihmor Date: Mon, 18 Nov 2024 21:47:21 +0100 Subject: [PATCH 2/2] Add tests for required feature --- .../scalapb/compiler/ParseFromGenerator.scala | 22 ++++---- e2e/src/test/scala/RequiredFieldsSpec.scala | 51 ++++++++++++++++++- 2 files changed, 62 insertions(+), 11 deletions(-) diff --git a/compiler-plugin/src/main/scala/scalapb/compiler/ParseFromGenerator.scala b/compiler-plugin/src/main/scala/scalapb/compiler/ParseFromGenerator.scala index 9a8b78984..21ab611bf 100644 --- a/compiler-plugin/src/main/scala/scalapb/compiler/ParseFromGenerator.scala +++ b/compiler-plugin/src/main/scala/scalapb/compiler/ParseFromGenerator.scala @@ -110,8 +110,11 @@ private[compiler] class ParseFromGenerator( private def usesBaseTypeInBuilder(field: FieldDescriptor) = field.isSingular - val requiredFieldMap: Map[FieldDescriptor, Int] = - message.fields.filter(fd => fd.isRequired || fd.noBoxRequired).zipWithIndex.toMap + private val requiredFields: Seq[(FieldDescriptor, Int)] = + message.fields.filter(fd => fd.isRequired || fd.noBoxRequired).zipWithIndex + + private val requiredFieldMap: Map[FieldDescriptor, Int] = + requiredFields.toMap val myFullScalaName = message.scalaType.fullNameWithMaybeRoot(message) @@ -231,16 +234,15 @@ private[compiler] class ParseFromGenerator( p.add(s"""if (${r}) {""") .indent .add("val __missingFields = Seq.newBuilder[_root_.scala.Predef.String]") - .print(requiredFieldMap.toSeq.sortBy(_._2)) { - case (p, (fieldDescriptor, fieldNumber)) => - val bitmask = s"0x${"%x".format(1L << fieldNumber)}L" - val fieldVariable = s"__requiredFields${fieldNumber / 64}" - p.add( - s"""if (($fieldVariable & $bitmask) != 0L) __missingFields += "${fieldDescriptor.scalaName}"""" - ) + .print(requiredFields) { case (p, (fieldDescriptor, fieldNumber)) => + val bitmask = f"${1L << fieldNumber}%#018xL" + val fieldVariable = s"__requiredFields${fieldNumber / 64}" + p.add( + s"""if (($fieldVariable & $bitmask) != 0L) __missingFields += "${fieldDescriptor.scalaName}"""" + ) } .add( - s"""val __message = s"Message missing required fields: $${__missingFields.result.mkString(", ")}"""", + s"""val __message = s"Message missing required fields: $${__missingFields.result().mkString(", ")}"""", s"""throw new _root_.com.google.protobuf.InvalidProtocolBufferException(__message)""" ) .outdent diff --git a/e2e/src/test/scala/RequiredFieldsSpec.scala b/e2e/src/test/scala/RequiredFieldsSpec.scala index 13dbf6b8a..af07e19e3 100644 --- a/e2e/src/test/scala/RequiredFieldsSpec.scala +++ b/e2e/src/test/scala/RequiredFieldsSpec.scala @@ -1,11 +1,60 @@ import com.google.protobuf.InvalidProtocolBufferException import com.thesamet.proto.e2e.reqs.RequiredFields +import protobuf_unittest.unittest.TestEmptyMessage +import scalapb.UnknownFieldSet import org.scalatest.flatspec.AnyFlatSpec import org.scalatest.matchers.must.Matchers class RequiredFieldsSpec extends AnyFlatSpec with Matchers { + + private val descriptor = RequiredFields.javaDescriptor + + private def partialMessage(fields: Map[Int, Int]): Array[Byte] = { + val fieldSet = fields.foldLeft(UnknownFieldSet.empty){ case (fieldSet, (field, value)) => + fieldSet + .withField(field, UnknownFieldSet.Field(varint = Seq(value))) + } + + TestEmptyMessage(fieldSet).toByteArray + } + + private val allFieldsSet: Map[Int, Int] = (100 to 164).map(i => (i, i)).toMap + "RequiredMessage" should "throw InvalidProtocolBufferException for empty byte array" in { - intercept[InvalidProtocolBufferException](RequiredFields.parseFrom(Array[Byte]())) + val exception = intercept[InvalidProtocolBufferException](RequiredFields.parseFrom(Array[Byte]())) + + exception.getMessage() must startWith("Message missing required fields") + } + + it should "throw no exception when all fields are set correctly" in { + val parsed = RequiredFields.parseFrom(partialMessage(allFieldsSet)) + parsed must be(a[RequiredFields]) + parsed.f0 must be(100) + parsed.f64 must be(164) + } + + it should "throw an exception if a field is missing and name the missing field" in { + val fields = allFieldsSet.removed(123) + val exception = intercept[InvalidProtocolBufferException](RequiredFields.parseFrom(partialMessage(fields))) + + exception.getMessage() must be("Message missing required fields: f23") + } + + it should "throw an exception if a multiple fields are missing and name those missing fields" in { + val fields = allFieldsSet.removed(123).removed(164).removed(130) + val exception = intercept[InvalidProtocolBufferException](RequiredFields.parseFrom(partialMessage(fields))) + + exception.getMessage() must be("Message missing required fields: f23, f30, f64") + } + + it should "sort the missing fields by field number" in { + val fields = Map.empty[Int, Int] + val exception = intercept[InvalidProtocolBufferException](RequiredFields.parseFrom(partialMessage(fields))) + val missingFields =exception.getMessage().stripPrefix("Message missing required fields: ").split(", ") + + missingFields.sortBy[Int](field => descriptor.findFieldByName(field).getNumber()) must be(missingFields) + + missingFields.toSeq mustBe Seq.tabulate(65)(i => s"f$i") } }