Skip to content

Commit 076b8ca

Browse files
authored
Save primitives from erasure (#166)
* Save optional primitive types from erasure * Override implementation when optional primitive but leave all other annotation properties in tact branch: * Ensure that project compiles on Scala 3 by introducing a Scala 3 stub implementation of `erasedOptionalPrimitives`. Tests still fail * Prepare for supporting collections branch: * Add support for collections
1 parent d30e780 commit 076b8ca

File tree

9 files changed

+162
-19
lines changed

9 files changed

+162
-19
lines changed

.github/workflows/ci.yml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,9 +59,13 @@ jobs:
5959
run: sbt '++${{ matrix.scala }}' coverage test coverageReport
6060

6161
- name: Scala build
62-
if: '!startsWith(matrix.scala, ''2.13'')'
62+
if: '!startsWith(matrix.scala, ''2.13'') && !startsWith(matrix.scala, ''3.0'')'
6363
run: sbt '++${{ matrix.scala }}' test
6464

65+
- name: Scala compile
66+
if: startsWith(matrix.scala, '3.0')
67+
run: sbt '++${{ matrix.scala }}' compile
68+
6569
- name: Publish to Codecov.io
6670
if: startsWith(matrix.scala, '2.13')
6771
uses: codecov/codecov-action@v2

build.sbt

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,12 @@ libraryDependencies ++= Seq(
7272
"org.scalatest" %% "scalatest" % "3.2.11" % Test,
7373
"org.slf4j" % "slf4j-simple" % "1.7.36" % Test
7474
)
75+
libraryDependencies ++= {
76+
CrossVersion.partialVersion(Keys.scalaVersion.value) match {
77+
case Some((3, _)) => Seq()
78+
case _ => Seq("org.scala-lang" % "scala-reflect" % scalaVersion.value)
79+
}
80+
}
7581

7682
homepage := Some(new URL("https://github.com/swagger-akka-http/swagger-scala-module"))
7783

@@ -83,10 +89,10 @@ licenses := Seq(("Apache License 2.0", new URL("http://www.apache.org/licenses/L
8389

8490
pomExtra := {
8591
pomExtra.value ++ Group(
86-
<issueManagement>
87-
<system>github</system>
88-
<url>https://github.com/swagger-api/swagger-scala-module/issues</url>
89-
</issueManagement>
92+
<issueManagement>
93+
<system>github</system>
94+
<url>https://github.com/swagger-api/swagger-scala-module/issues</url>
95+
</issueManagement>
9096
<developers>
9197
<developer>
9298
<id>fehguy</id>
@@ -104,7 +110,8 @@ pomExtra := {
104110

105111
ThisBuild / githubWorkflowBuild := Seq(
106112
WorkflowStep.Sbt(List("coverage", "test", "coverageReport"), name = Some("Scala 2.13 build"), cond = Some("startsWith(matrix.scala, '2.13')")),
107-
WorkflowStep.Sbt(List("test"), name = Some("Scala build"), cond = Some("!startsWith(matrix.scala, '2.13')")),
113+
WorkflowStep.Sbt(List("test"), name = Some("Scala build"), cond = Some("!startsWith(matrix.scala, '2.13') && !startsWith(matrix.scala, '3.0')")),
114+
WorkflowStep.Sbt(List("compile"), name = Some("Scala compile"), cond = Some("startsWith(matrix.scala, '3.0')")),
108115
)
109116

110117
ThisBuild / githubWorkflowJavaVersions := Seq(JavaSpec(Zulu, "8"))
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
package com.github.swagger.scala.converter
2+
3+
object ErasureHelper {
4+
5+
def erasedOptionalPrimitives(cls: Class[_]): Map[String, Class[_]] = {
6+
import scala.reflect.runtime.universe
7+
val mirror = universe.runtimeMirror(cls.getClassLoader)
8+
val sym = mirror.staticClass(cls.getName)
9+
val properties = sym.selfType.members
10+
.filterNot(_.isMethod)
11+
.filterNot(_.isClass)
12+
13+
properties.flatMap { prop =>
14+
val maybeClass: Option[Class[_]] = prop.typeSignature.typeArgs.headOption.flatMap { signature =>
15+
if (signature.typeSymbol.isClass) {
16+
Option(mirror.runtimeClass(signature.typeSymbol.asClass))
17+
} else None
18+
}
19+
maybeClass.map(prop.name.toString.trim -> _)
20+
}.toMap
21+
}
22+
23+
}
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
package com.github.swagger.scala.converter
2+
3+
import io.swagger.v3.oas.models.media.Schema
4+
5+
6+
object ErasureHelper {
7+
8+
def erasedOptionalPrimitives(cls: Class[_]): Map[String, Class[_]] = Map.empty[String, Class[_]]
9+
10+
}
11+

src/main/scala/com/github/swagger/scala/converter/SwaggerScalaModelConverter.scala

Lines changed: 44 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import io.swagger.v3.oas.annotations.media.{ArraySchema, Schema => SchemaAnnotat
1515
import io.swagger.v3.oas.models.media.Schema
1616
import org.slf4j.LoggerFactory
1717

18+
import java.util
1819
import scala.util.Try
1920
import scala.util.control.NonFatal
2021

@@ -31,6 +32,7 @@ class SwaggerScalaModelConverter extends ModelResolver(SwaggerScalaModelConverte
3132
private val EnumClass = classOf[scala.Enumeration]
3233
private val OptionClass = classOf[scala.Option[_]]
3334
private val IterableClass = classOf[scala.collection.Iterable[_]]
35+
private val MapClass = classOf[Map[_, _]]
3436
private val SetClass = classOf[scala.collection.Set[_]]
3537
private val BigDecimalClass = classOf[BigDecimal]
3638
private val BigIntClass = classOf[BigInt]
@@ -71,24 +73,37 @@ class SwaggerScalaModelConverter extends ModelResolver(SwaggerScalaModelConverte
7173
}
7274

7375
private def caseClassSchema(cls: Class[_], `type`: AnnotatedType, context: ModelConverterContext,
74-
chain: Iterator[ModelConverter]): Option[Schema[_]] = {
76+
chain: util.Iterator[ModelConverter]): Option[Schema[_]] = {
77+
val erasedProperties = ErasureHelper.erasedOptionalPrimitives(cls)
78+
7579
if (chain.hasNext) {
7680
Option(chain.next().resolve(`type`, context, chain)).map { schema =>
7781
val introspector = BeanIntrospector(cls)
7882
introspector.properties.foreach { property =>
83+
84+
val propertyClass = getPropertyClass(property)
85+
val isOptional = isOption(propertyClass)
86+
87+
erasedProperties.get(property.name).foreach { erasedType =>
88+
val primitiveType = PrimitiveType.fromType(erasedType)
89+
if (primitiveType != null && isOptional) {
90+
updateTypeOnSchema(schema, primitiveType, property.name)
91+
}
92+
if (primitiveType != null && isIterable(propertyClass) && !isMap(propertyClass)) {
93+
updateTypeOnItemsSchema(schema, primitiveType, property.name)
94+
}
95+
}
7996
getPropertyAnnotations(property) match {
8097
case Seq() => {
81-
val propertyClass = getPropertyClass(property)
82-
val optionalFlag = isOption(propertyClass)
83-
if (optionalFlag && schema.getRequired != null && schema.getRequired.contains(property.name)) {
98+
if (isOptional && schema.getRequired != null && schema.getRequired.contains(property.name)) {
8499
schema.getRequired.remove(property.name)
85-
} else if (!optionalFlag) {
100+
} else if (!isOptional) {
86101
addRequiredItem(schema, property.name)
87102
}
88103
}
89104
case annotations => {
90105
val required = getRequiredSettings(annotations).headOption
91-
.getOrElse(!isOption(getPropertyClass(property)))
106+
.getOrElse(!isOptional)
92107
if (required) addRequiredItem(schema, property.name)
93108
}
94109
}
@@ -100,6 +115,28 @@ class SwaggerScalaModelConverter extends ModelResolver(SwaggerScalaModelConverte
100115
}
101116
}
102117

118+
private def updateTypeOnSchema(schema: Schema[_], primitiveType: PrimitiveType, propertyName: String) = {
119+
val property = schema.getProperties.get(propertyName)
120+
val updatedSchema = correctSchema(property, primitiveType)
121+
schema.addProperty(propertyName, updatedSchema)
122+
}
123+
124+
private def updateTypeOnItemsSchema(schema: Schema[_], primitiveType: PrimitiveType, propertyName: String) = {
125+
val property = schema.getProperties.get(propertyName)
126+
val updatedSchema = correctSchema(property.getItems, primitiveType)
127+
property.setItems(updatedSchema)
128+
schema.addProperty(propertyName, property)
129+
}
130+
131+
private def correctSchema(itemSchema: Schema[_], primitiveType: PrimitiveType) = {
132+
val primitiveProperty = primitiveType.createProperty()
133+
val propAsString = objectMapper.writeValueAsString(itemSchema)
134+
val correctedSchema = objectMapper.readValue(propAsString, primitiveProperty.getClass)
135+
correctedSchema.setType(primitiveProperty.getType)
136+
correctedSchema.setFormat(primitiveProperty.getFormat)
137+
correctedSchema
138+
}
139+
103140
private def getRequiredSettings(annotatedType: AnnotatedType): Seq[Boolean] = annotatedType match {
104141
case _: AnnotatedTypeForOption => Seq.empty
105142
case _ => getRequiredSettings(nullSafeList(annotatedType.getCtxAnnotations))
@@ -276,6 +313,7 @@ class SwaggerScalaModelConverter extends ModelResolver(SwaggerScalaModelConverte
276313

277314
private def isOption(cls: Class[_]): Boolean = cls == OptionClass
278315
private def isIterable(cls: Class[_]): Boolean = IterableClass.isAssignableFrom(cls)
316+
private def isMap(cls: Class[_]): Boolean = MapClass.isAssignableFrom(cls)
279317
private def isCaseClass(cls: Class[_]): Boolean = ProductClass.isAssignableFrom(cls)
280318

281319
private def nullSafeList[T](array: Array[T]): List[T] = Option(array) match {

src/test/scala/com/github/swagger/scala/converter/ModelPropertyParserTest.scala

Lines changed: 54 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -107,9 +107,10 @@ class ModelPropertyParserTest extends AnyFlatSpec with Matchers with OptionValue
107107
val model = schemas.get("ModelWOptionInt")
108108
model should be (defined)
109109
model.value.getProperties should not be (null)
110-
val optInt = model.value.getProperties().get("optInt")
110+
val optInt = model.value.getProperties.get("optInt")
111111
optInt should not be (null)
112-
optInt shouldBe a [Schema[_]]
112+
optInt shouldBe a [IntegerSchema]
113+
optInt.asInstanceOf[IntegerSchema].getFormat shouldEqual "int32"
113114
nullSafeList(model.value.getRequired) shouldBe empty
114115
}
115116

@@ -123,9 +124,22 @@ class ModelPropertyParserTest extends AnyFlatSpec with Matchers with OptionValue
123124
optInt should not be (null)
124125
optInt shouldBe a [IntegerSchema]
125126
optInt.asInstanceOf[IntegerSchema].getFormat shouldEqual "int32"
127+
optInt.getDescription shouldBe "This is an optional int"
126128
nullSafeList(model.value.getRequired) shouldBe empty
127129
}
128130

131+
it should "allow annotation to override required with Scala Option Int" in {
132+
val converter = ModelConverters.getInstance()
133+
val schemas = converter.readAll(classOf[ModelWOptionIntSchemaOverrideForRequired]).asScala.toMap
134+
val model = schemas.get("ModelWOptionIntSchemaOverrideForRequired")
135+
model should be(defined)
136+
model.value.getProperties should not be (null)
137+
val optInt = model.value.getProperties().get("optInt")
138+
optInt should not be (null)
139+
optInt shouldBe an [IntegerSchema]
140+
nullSafeList(model.value.getRequired) shouldEqual Seq("optInt")
141+
}
142+
129143
it should "process Model with Scala Option Long" in {
130144
val converter = ModelConverters.getInstance()
131145
val schemas = converter.readAll(classOf[ModelWOptionLong]).asScala.toMap
@@ -134,7 +148,7 @@ class ModelPropertyParserTest extends AnyFlatSpec with Matchers with OptionValue
134148
model.value.getProperties should not be (null)
135149
val optLong = model.value.getProperties().get("optLong")
136150
optLong should not be (null)
137-
optLong shouldBe a [Schema[_]]
151+
optLong shouldBe a [IntegerSchema]
138152
nullSafeList(model.value.getRequired) shouldBe empty
139153
}
140154

@@ -324,6 +338,43 @@ class ModelPropertyParserTest extends AnyFlatSpec with Matchers with OptionValue
324338
nullSafeList(arraySchema.getRequired()) shouldBe empty
325339
}
326340

341+
it should "process Model with Scala Seq Int" in {
342+
val converter = ModelConverters.getInstance()
343+
val schemas = converter.readAll(classOf[ModelWSeqInt]).asScala.toMap
344+
val model = findModel(schemas, "ModelWSeqInt")
345+
model should be(defined)
346+
model.value.getProperties should not be (null)
347+
348+
val stringsField = model.value.getProperties.get("ints")
349+
350+
stringsField shouldBe a[ArraySchema]
351+
val arraySchema = stringsField.asInstanceOf[ArraySchema]
352+
arraySchema.getUniqueItems() shouldBe (null)
353+
arraySchema.getItems shouldBe a[IntegerSchema]
354+
nullSafeMap(arraySchema.getProperties()) shouldBe empty
355+
nullSafeList(arraySchema.getRequired()) shouldBe empty
356+
}
357+
358+
it should "process Model with Scala Seq Int (annotated)" in {
359+
val converter = ModelConverters.getInstance()
360+
val schemas = converter.readAll(classOf[ModelWSeqIntAnnotated]).asScala.toMap
361+
val model = findModel(schemas, "ModelWSeqIntAnnotated")
362+
model should be(defined)
363+
model.value.getProperties should not be (null)
364+
365+
val stringsField = model.value.getProperties.get("ints")
366+
367+
stringsField shouldBe a[ArraySchema]
368+
val arraySchema = stringsField.asInstanceOf[ArraySchema]
369+
arraySchema.getUniqueItems() shouldBe (null)
370+
371+
372+
arraySchema.getItems shouldBe a[IntegerSchema]
373+
arraySchema.getItems.getDescription shouldBe "These are ints"
374+
nullSafeMap(arraySchema.getProperties()) shouldBe empty
375+
nullSafeList(arraySchema.getRequired()) shouldBe empty
376+
}
377+
327378
it should "process Model with Scala Set" in {
328379
val converter = ModelConverters.getInstance()
329380
val schemas = converter.readAll(classOf[ModelWSetString]).asScala.toMap

src/test/scala/com/github/swagger/scala/converter/ScalaModelTest.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ class ScalaModelTest extends AnyFlatSpec with Matchers {
7070

7171
val date = userSchema.getProperties().get("date")
7272
date shouldBe a [DateTimeSchema]
73-
//date.getDescription should be ("the birthdate")
73+
// date.getDescription should be ("the birthdate")
7474
}
7575

7676
it should "read a model with vector property" in {
@@ -85,15 +85,15 @@ class ScalaModelTest extends AnyFlatSpec with Matchers {
8585
val model = schemas("ModelWithIntVector")
8686
val prop = model.getProperties().get("ints")
8787
prop shouldBe a [ArraySchema]
88-
prop.asInstanceOf[ArraySchema].getItems.getType should be ("object")
88+
prop.asInstanceOf[ArraySchema].getItems.getType should be ("integer")
8989
}
9090

9191
it should "read a model with vector of booleans" in {
9292
val schemas = ModelConverters.getInstance().readAll(classOf[ModelWithBooleanVector]).asScala
9393
val model = schemas("ModelWithBooleanVector")
9494
val prop = model.getProperties().get("bools")
9595
prop shouldBe a [ArraySchema]
96-
prop.asInstanceOf[ArraySchema].getItems.getType should be ("object")
96+
prop.asInstanceOf[ArraySchema].getItems.getType should be ("boolean")
9797
}
9898
}
9999

src/test/scala/models/ModelWOptionInt.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,6 @@ import io.swagger.v3.oas.annotations.media.Schema
44

55
case class ModelWOptionInt(optInt: Option[Int])
66

7-
case class ModelWOptionIntSchemaOverride(@Schema(implementation = classOf[Int]) optInt: Option[Int])
7+
case class ModelWOptionIntSchemaOverride(@Schema(description = "This is an optional int") optInt: Option[Int])
8+
9+
case class ModelWOptionIntSchemaOverrideForRequired(@Schema(required = true) optInt: Option[Int])
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
package models
2+
3+
import io.swagger.v3.oas.annotations.media.{ArraySchema, Schema}
4+
5+
case class ModelWSeqInt(ints: Seq[Int])
6+
7+
case class ModelWSeqIntAnnotated(@ArraySchema(arraySchema = new Schema(required = false), schema = new Schema(description = "These are ints")) ints: Seq[Int])

0 commit comments

Comments
 (0)