Skip to content

Commit bd942cf

Browse files
authored
try to fix issues with case class required fields (#62)
1 parent 8decdb2 commit bd942cf

File tree

2 files changed

+97
-10
lines changed

2 files changed

+97
-10
lines changed

src/main/scala/com/github/swagger/scala/converter/SwaggerScalaModelConverter.scala

Lines changed: 95 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
package com.github.swagger.scala.converter
22

3+
import java.lang.annotation.Annotation
34
import java.lang.reflect.ParameterizedType
45
import java.util.Iterator
56

67
import com.fasterxml.jackson.databind.JavaType
78
import com.fasterxml.jackson.databind.`type`.ReferenceType
9+
import com.fasterxml.jackson.module.scala.introspect.{BeanIntrospector, PropertyDescriptor}
810
import com.fasterxml.jackson.module.scala.{DefaultScalaModule, JsonScalaEnumeration}
911
import io.swagger.v3.core.converter._
1012
import io.swagger.v3.core.jackson.ModelResolver
@@ -27,6 +29,8 @@ class SwaggerScalaModelConverter extends ModelResolver(SwaggerScalaModelConverte
2729
private val SetClass = classOf[scala.collection.Set[_]]
2830
private val BigDecimalClass = classOf[BigDecimal]
2931
private val BigIntClass = classOf[BigInt]
32+
private val ProductClass = classOf[Product]
33+
private val AnyClass = classOf[Any]
3034

3135
override def resolve(`type`: AnnotatedType, context: ModelConverterContext, chain: Iterator[ModelConverter]): Schema[_] = {
3236
val javaType = _mapper.constructType(`type`.getType)
@@ -40,6 +44,8 @@ class SwaggerScalaModelConverter extends ModelResolver(SwaggerScalaModelConverte
4044
resolve(nextType(baseType, `type`, javaType), context, chain)
4145
} else if (!annotatedOverrides.headOption.getOrElse(true)) {
4246
resolve(nextType(new AnnotatedTypeForOption(), `type`, javaType), context, chain)
47+
} else if (isCaseClass(cls)) {
48+
caseClassSchema(cls, `type`, context, chain).getOrElse(None.orNull)
4349
} else if (chain.hasNext) {
4450
val nextResolved = Option(chain.next().resolve(`type`, context, chain))
4551
nextResolved match {
@@ -68,14 +74,41 @@ class SwaggerScalaModelConverter extends ModelResolver(SwaggerScalaModelConverte
6874
}
6975
}
7076

71-
private def getRequiredSettings(`type`: AnnotatedType): Seq[Boolean] = `type` match {
72-
case _: AnnotatedTypeForOption => Seq.empty
73-
case _ => {
74-
nullSafeList(`type`.getCtxAnnotations).collect {
75-
case p: Parameter => p.required()
76-
case s: SchemaAnnotation => s.required()
77-
case a: ArraySchema => a.arraySchema().required()
77+
private def caseClassSchema(cls: Class[_], `type`: AnnotatedType, context: ModelConverterContext,
78+
chain: Iterator[ModelConverter]): Option[Schema[_]] = {
79+
if (chain.hasNext) {
80+
Option(chain.next().resolve(`type`, context, chain)).map { schema =>
81+
val introspector = BeanIntrospector(cls)
82+
introspector.properties.foreach { property =>
83+
getPropertyAnnotations(property) match {
84+
case Seq() => {
85+
val propertyClass = getPropertyClass(property)
86+
if (!isOption(propertyClass)) addRequiredItem(schema, property.name)
87+
}
88+
case annotations => {
89+
val required = getRequiredSettings(annotations).headOption
90+
.getOrElse(!isOption(getPropertyClass(property)))
91+
if (required) addRequiredItem(schema, property.name)
92+
}
93+
}
94+
}
95+
schema
7896
}
97+
} else {
98+
None
99+
}
100+
}
101+
102+
private def getRequiredSettings(annotatedType: AnnotatedType): Seq[Boolean] = annotatedType match {
103+
case _: AnnotatedTypeForOption => Seq.empty
104+
case _ => getRequiredSettings(nullSafeList(annotatedType.getCtxAnnotations))
105+
}
106+
107+
private def getRequiredSettings(annotations: Seq[Annotation]): Seq[Boolean] = {
108+
annotations.collect {
109+
case p: Parameter => p.required()
110+
case s: SchemaAnnotation => s.required()
111+
case a: ArraySchema => a.arraySchema().required()
79112
}
80113
}
81114

@@ -183,8 +216,63 @@ class SwaggerScalaModelConverter extends ModelResolver(SwaggerScalaModelConverte
183216
else None
184217
}
185218

219+
private def getPropertyClass(property: PropertyDescriptor): Class[_] = {
220+
property.param match {
221+
case Some(constructorParameter) => {
222+
val types = constructorParameter.constructor.getParameterTypes
223+
if (constructorParameter.index > types.size) {
224+
AnyClass
225+
} else {
226+
types(constructorParameter.index)
227+
}
228+
}
229+
case _ => property.field match {
230+
case Some(field) => field.getType
231+
case _ => property.setter match {
232+
case Some(setter) if setter.getParameterCount == 1 => {
233+
setter.getParameterTypes()(0)
234+
}
235+
case _ => property.beanSetter match {
236+
case Some(setter) if setter.getParameterCount == 1 => {
237+
setter.getParameterTypes()(0)
238+
}
239+
case _ => AnyClass
240+
}
241+
}
242+
}
243+
}
244+
}
245+
246+
private def getPropertyAnnotations(property: PropertyDescriptor): Seq[Annotation] = {
247+
property.param match {
248+
case Some(constructorParameter) => {
249+
val types = constructorParameter.constructor.getParameterAnnotations
250+
if (constructorParameter.index > types.size) {
251+
Seq.empty
252+
} else {
253+
types(constructorParameter.index).toSeq
254+
}
255+
}
256+
case _ => property.field match {
257+
case Some(field) => field.getAnnotations.toSeq
258+
case _ => property.setter match {
259+
case Some(setter) if setter.getParameterCount == 1 => {
260+
setter.getAnnotations().toSeq
261+
}
262+
case _ => property.beanSetter match {
263+
case Some(setter) if setter.getParameterCount == 1 => {
264+
setter.getAnnotations().toSeq
265+
}
266+
case _ => Seq.empty
267+
}
268+
}
269+
}
270+
}
271+
}
272+
186273
private def isOption(cls: Class[_]): Boolean = cls == OptionClass
187274
private def isIterable(cls: Class[_]): Boolean = IterableClass.isAssignableFrom(cls)
275+
private def isCaseClass(cls: Class[_]): Boolean = ProductClass.isAssignableFrom(cls)
188276

189277
private def nullSafeList[T](array: Array[T]): List[T] = Option(array) match {
190278
case None => List.empty[T]

src/test/scala/com/github/swagger/scala/converter/ModelPropertyParserTest.scala

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -347,8 +347,7 @@ class ModelPropertyParserTest extends AnyFlatSpec with Matchers with OptionValue
347347
val1Field shouldBe a [IntegerSchema]
348348
val val2Field = model.value.getProperties.get("val2")
349349
val2Field shouldBe a [IntegerSchema]
350-
//TODO try to fix this
351-
//model.value.getRequired().asScala shouldEqual Seq("val1", "val2")
350+
model.value.getRequired().asScala shouldEqual Seq("val1", "val2")
352351
}
353352

354353
private def findModel(schemas: Map[String, Schema[_]], name: String): Option[Schema[_]] = {
@@ -367,7 +366,7 @@ class ModelPropertyParserTest extends AnyFlatSpec with Matchers with OptionValue
367366
val schemas = converter.readAll(classOf[ModelWStringSeq]).asScala.toMap
368367
val model = findModel(schemas, "ModelWStringSeq")
369368
model should be(defined)
370-
nullSafeList(model.value.getRequired) shouldEqual Seq()
369+
nullSafeList(model.value.getRequired) shouldBe empty
371370
}
372371

373372
it should "process Array-Model with forced required Scala Option Seq" in {

0 commit comments

Comments
 (0)