|
1 | 1 | package com.github.swagger.enumeratum.converter |
2 | 2 |
|
3 | 3 | import java.util.Iterator |
4 | | -import com.github.swagger.scala.converter.AnnotatedTypeForOption |
| 4 | +import com.github.swagger.scala.converter.{AnnotatedTypeForOption, SwaggerScalaModelConverter} |
5 | 5 | import enumeratum.{Enum, EnumEntry} |
6 | 6 | import io.swagger.v3.core.converter._ |
7 | 7 | import io.swagger.v3.core.jackson.ModelResolver |
8 | 8 | import io.swagger.v3.core.util.{Json, PrimitiveType} |
9 | 9 | import io.swagger.v3.oas.annotations.Parameter |
10 | | -import io.swagger.v3.oas.annotations.media.Schema.AccessMode |
11 | | -import io.swagger.v3.oas.annotations.media.{Schema => SchemaAnnotation} |
| 10 | +import io.swagger.v3.oas.annotations.media.Schema.{AccessMode, RequiredMode} |
| 11 | +import io.swagger.v3.oas.annotations.media.{ArraySchema, Schema => SchemaAnnotation} |
12 | 12 | import io.swagger.v3.oas.models.media.Schema |
13 | 13 |
|
| 14 | +import java.lang.annotation.Annotation |
| 15 | + |
14 | 16 | class SwaggerEnumeratumModelConverter extends ModelResolver(Json.mapper()) { |
15 | 17 | private val enumEntryClass = classOf[EnumEntry] |
16 | 18 |
|
17 | | - def noneIfEmpty(s: String): Option[String] = Option(s).filter(_.trim.nonEmpty) |
| 19 | + private def noneIfEmpty(s: String): Option[String] = Option(s).filter(_.trim.nonEmpty) |
18 | 20 |
|
19 | 21 | override def resolve(annotatedType: AnnotatedType, context: ModelConverterContext, chain: Iterator[ModelConverter]): Schema[_] = { |
20 | 22 | val javaType = _mapper.constructType(annotatedType.getType) |
@@ -84,11 +86,43 @@ class SwaggerEnumeratumModelConverter extends ModelResolver(Json.mapper()) { |
84 | 86 |
|
85 | 87 | private def getRequiredSettings(annotatedType: AnnotatedType): Seq[Boolean] = annotatedType match { |
86 | 88 | case _: AnnotatedTypeForOption => Seq.empty |
87 | | - case _ => { |
88 | | - nullSafeList(annotatedType.getCtxAnnotations).collect { |
89 | | - case p: Parameter => p.required() |
90 | | - case s: SchemaAnnotation => s.required() |
| 89 | + case _ => getRequiredSettings(nullSafeList(annotatedType.getCtxAnnotations)) |
| 90 | + } |
| 91 | + |
| 92 | + private def getRequiredSettings(annotations: Seq[Annotation]): Seq[Boolean] = { |
| 93 | + val flags = annotations.collect { |
| 94 | + case p: Parameter => if (p.required()) RequiredMode.REQUIRED else RequiredMode.NOT_REQUIRED |
| 95 | + case s: SchemaAnnotation => { |
| 96 | + if (s.requiredMode() == RequiredMode.AUTO) { |
| 97 | + if (s.required()) { |
| 98 | + RequiredMode.REQUIRED |
| 99 | + } else if (SwaggerScalaModelConverter.isRequiredBasedOnAnnotation) { |
| 100 | + RequiredMode.NOT_REQUIRED |
| 101 | + } else { |
| 102 | + RequiredMode.AUTO |
| 103 | + } |
| 104 | + } else { |
| 105 | + s.requiredMode() |
| 106 | + } |
91 | 107 | } |
| 108 | + case a: ArraySchema => { |
| 109 | + if (a.arraySchema().requiredMode() == RequiredMode.AUTO) { |
| 110 | + if (a.arraySchema().required()) { |
| 111 | + RequiredMode.REQUIRED |
| 112 | + } else if (SwaggerScalaModelConverter.isRequiredBasedOnAnnotation) { |
| 113 | + RequiredMode.NOT_REQUIRED |
| 114 | + } else { |
| 115 | + RequiredMode.AUTO |
| 116 | + } |
| 117 | + } else { |
| 118 | + a.arraySchema().requiredMode() |
| 119 | + } |
| 120 | + } |
| 121 | + } |
| 122 | + flags.flatMap { |
| 123 | + case RequiredMode.REQUIRED => Some(true) |
| 124 | + case RequiredMode.NOT_REQUIRED => Some(false) |
| 125 | + case _ => None |
92 | 126 | } |
93 | 127 | } |
94 | 128 |
|
|
0 commit comments