Skip to content

Commit 0e6492d

Browse files
Generate Json Schemas for Kotlin
1 parent 5962c20 commit 0e6492d

File tree

5 files changed

+283
-68
lines changed

5 files changed

+283
-68
lines changed

gradle/libs.versions.toml

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,24 @@
181181
[libraries.victools-jsonschema-module-jackson.version]
182182
ref = 'victools-json-schema'
183183

184+
[libraries.schema-kenerator-core]
185+
module = 'io.github.smiley4:schema-kenerator-core'
186+
187+
[libraries.schema-kenerator-core.version]
188+
ref = 'schema-kenerator'
189+
190+
[libraries.schema-kenerator-serialization]
191+
module = 'io.github.smiley4:schema-kenerator-serialization'
192+
193+
[libraries.schema-kenerator-serialization.version]
194+
ref = 'schema-kenerator'
195+
196+
[libraries.schema-kenerator-jsonschema]
197+
module = 'io.github.smiley4:schema-kenerator-jsonschema'
198+
199+
[libraries.schema-kenerator-jsonschema.version]
200+
ref = 'schema-kenerator'
201+
184202
[plugins]
185203
aggregate-javadoc = 'io.freefair.aggregate-javadoc:8.6'
186204
dependency-license-report = 'com.github.jk1.dependency-license-report:2.0'
@@ -213,3 +231,4 @@
213231
spring-boot = '3.4.4'
214232
vertx = '4.5.11'
215233
victools-json-schema = '4.37.0'
234+
schema-kenerator = '2.1.2'

sdk-serde-jackson/src/test/java/dev/restate/serde/jackson/JacksonSerdesTest.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,21 @@
1313
import com.fasterxml.jackson.annotation.JsonCreator;
1414
import com.fasterxml.jackson.annotation.JsonProperty;
1515
import com.fasterxml.jackson.core.type.TypeReference;
16+
import com.fasterxml.jackson.databind.node.ObjectNode;
1617
import dev.restate.serde.Serde;
1718
import java.util.List;
1819
import java.util.Objects;
1920
import java.util.Set;
2021
import java.util.stream.Stream;
22+
import org.junit.jupiter.api.Test;
2123
import org.junit.jupiter.params.ParameterizedTest;
2224
import org.junit.jupiter.params.provider.Arguments;
2325
import org.junit.jupiter.params.provider.MethodSource;
2426

2527
class JacksonSerdesTest {
2628

29+
record Recursive(String value, Recursive rec) {}
30+
2731
public static class Person {
2832

2933
private final String name;
@@ -75,4 +79,11 @@ private static Stream<Arguments> roundtripTestCases() {
7579
<T> void roundtrip(T value, Serde<T> serde) {
7680
assertThat(serde.deserialize(serde.serialize(value))).isEqualTo(value);
7781
}
82+
83+
@Test
84+
void schemaGenWorksWithRecursion() {
85+
ObjectNode node =
86+
(ObjectNode) ((Serde.JsonSchema) JacksonSerdes.of(Recursive.class).jsonSchema()).schema();
87+
assertThat(node.at("/properties/rec/$ref").textValue()).isEqualTo("#");
88+
}
7889
}

sdk-serde-kotlinx/build.gradle.kts

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,12 @@ description = "Restate SDK Kotlinx Serialization integration"
88
dependencies {
99
api(libs.kotlinx.serialization.json)
1010
implementation(libs.kotlinx.serialization.core)
11+
implementation(libs.schema.kenerator.core)
12+
implementation(libs.schema.kenerator.serialization)
13+
implementation(libs.schema.kenerator.jsonschema)
1114

1215
implementation(project(":common"))
16+
17+
testImplementation(libs.junit.jupiter)
18+
testImplementation(libs.assertj)
1319
}

sdk-serde-kotlinx/src/main/kotlin/dev/restate/serde/kotlinx/KotlinSerializationSerdeFactory.kt

Lines changed: 72 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -13,20 +13,25 @@ import dev.restate.serde.Serde
1313
import dev.restate.serde.SerdeFactory
1414
import dev.restate.serde.TypeRef
1515
import dev.restate.serde.TypeTag
16+
import io.github.smiley4.schemakenerator.jsonschema.JsonSchemaSteps.compileReferencing
17+
import io.github.smiley4.schemakenerator.jsonschema.JsonSchemaSteps.generateJsonSchema
18+
import io.github.smiley4.schemakenerator.jsonschema.JsonSchemaSteps.withTitle
19+
import io.github.smiley4.schemakenerator.jsonschema.TitleBuilder
20+
import io.github.smiley4.schemakenerator.jsonschema.data.RefType
21+
import io.github.smiley4.schemakenerator.jsonschema.data.TitleType
22+
import io.github.smiley4.schemakenerator.jsonschema.jsonDsl.JsonArray
23+
import io.github.smiley4.schemakenerator.jsonschema.jsonDsl.JsonNode
24+
import io.github.smiley4.schemakenerator.jsonschema.jsonDsl.JsonObject
25+
import io.github.smiley4.schemakenerator.jsonschema.jsonDsl.JsonTextValue
26+
import io.github.smiley4.schemakenerator.serialization.SerializationSteps.analyzeTypeUsingKotlinxSerialization
27+
import io.github.smiley4.schemakenerator.serialization.SerializationSteps.initial
1628
import java.nio.charset.StandardCharsets
1729
import kotlin.reflect.KClass
1830
import kotlin.reflect.KType
1931
import kotlinx.serialization.*
20-
import kotlinx.serialization.builtins.*
21-
import kotlinx.serialization.descriptors.PrimitiveKind
22-
import kotlinx.serialization.descriptors.SerialDescriptor
23-
import kotlinx.serialization.descriptors.StructureKind
24-
import kotlinx.serialization.encodeToString
32+
import kotlinx.serialization.builtins.nullable
2533
import kotlinx.serialization.json.Json
26-
import kotlinx.serialization.json.JsonArray
27-
import kotlinx.serialization.json.JsonElement
2834
import kotlinx.serialization.json.JsonNull
29-
import kotlinx.serialization.json.JsonTransformingSerializer
3035
import kotlinx.serialization.modules.SerializersModule
3136

3237
/**
@@ -104,6 +109,8 @@ constructor(private val json: Json = Json.Default) : SerdeFactory {
104109

105110
/** Creates a [Serde] implementation using the `kotlinx.serialization` json module. */
106111
fun <T : Any?> jsonSerde(json: Json = Json.Default, serializer: KSerializer<T>): Serde<T> {
112+
val schema = jsonSchema(json, serializer).prettyPrint()
113+
107114
return object : Serde<T> {
108115
@Suppress("WRONG_NULLABILITY_FOR_JAVA_OVERRIDE")
109116
override fun serialize(value: T?): Slice {
@@ -124,75 +131,72 @@ constructor(private val json: Json = Json.Default) : SerdeFactory {
124131
}
125132

126133
override fun jsonSchema(): Serde.Schema {
127-
val schema: JsonSchema = serializer.descriptor.jsonSchema()
128-
return Serde.StringifiedJsonSchema(Json.encodeToString(schema))
134+
return Serde.StringifiedJsonSchema(schema)
129135
}
130136
}
131137
}
132138

133-
@Serializable
134-
@PublishedApi
135-
internal data class JsonSchema(
136-
@Serializable(with = StringListSerializer::class) val type: List<String>? = null,
137-
val format: String? = null,
138-
) {
139-
companion object {
140-
val INT = JsonSchema(type = listOf("number"), format = "int32")
141-
142-
val LONG = JsonSchema(type = listOf("number"), format = "int64")
143-
144-
val DOUBLE = JsonSchema(type = listOf("number"), format = "double")
145-
146-
val FLOAT = JsonSchema(type = listOf("number"), format = "float")
147-
148-
val STRING = JsonSchema(type = listOf("string"))
149-
150-
val BOOLEAN = JsonSchema(type = listOf("boolean"))
151-
152-
val OBJECT = JsonSchema(type = listOf("object"))
153-
154-
val LIST = JsonSchema(type = listOf("array"))
155-
156-
val ANY = JsonSchema()
157-
}
158-
}
159-
160-
object StringListSerializer :
161-
JsonTransformingSerializer<List<String>>(ListSerializer(String.Companion.serializer())) {
162-
override fun transformSerialize(element: JsonElement): JsonElement {
163-
require(element is JsonArray)
164-
return element.singleOrNull() ?: element
139+
private fun <T : Any?> jsonSchema(json: Json, serializer: KSerializer<T>): JsonNode =
140+
runCatching {
141+
val intermediateStep =
142+
initial(serializer.descriptor)
143+
.analyzeTypeUsingKotlinxSerialization {
144+
serializersModule = json.serializersModule
145+
}
146+
.generateJsonSchema()
147+
.withTitle(type = TitleType.SIMPLE)
148+
val compiledSchema = intermediateStep.compileReferencing(RefType.SIMPLE)
149+
150+
// In case of nested schemas, compileReferencing also contains self schema...
151+
val rootSchemaName =
152+
TitleBuilder.BUILDER_SIMPLE(
153+
compiledSchema.typeData, intermediateStep.typeDataById)
154+
155+
// If schema is not json object, then it's boolean, so we're good no need for
156+
// additional manipulation
157+
if (compiledSchema.json !is JsonObject) {
158+
return compiledSchema.json
159+
}
160+
161+
// Assemble the final schema now
162+
val rootNode = compiledSchema.json as JsonObject
163+
// Add $schema
164+
rootNode.properties.put(
165+
"\$schema", JsonTextValue("https://json-schema.org/draft/2020-12/schema"))
166+
// Add $defs
167+
val definitions =
168+
compiledSchema.definitions.filter { it.key != rootSchemaName }.toMutableMap()
169+
if (definitions.isNotEmpty()) {
170+
rootNode.properties.put("\$defs", JsonObject(definitions))
171+
}
172+
// Replace all $refs
173+
rootNode.fixRefsPrefix("#/definitions/$rootSchemaName")
174+
175+
return rootNode
176+
}
177+
.getOrDefault(JsonObject(mutableMapOf()))
178+
179+
private fun JsonNode.fixRefsPrefix(rootDefinition: String) {
180+
when (this) {
181+
is JsonArray -> this.items.forEach { it.fixRefsPrefix(rootDefinition) }
182+
is JsonObject -> this.fixRefsPrefix(rootDefinition)
183+
else -> {}
165184
}
166185
}
167186

168-
/**
169-
* Super simplistic json schema generation. We should replace this with an appropriate library.
170-
*/
171-
@OptIn(ExperimentalSerializationApi::class)
172-
@PublishedApi
173-
internal fun SerialDescriptor.jsonSchema(): JsonSchema {
174-
var schema =
175-
when (this.kind) {
176-
PrimitiveKind.BOOLEAN -> JsonSchema.BOOLEAN
177-
PrimitiveKind.BYTE -> JsonSchema.INT
178-
PrimitiveKind.CHAR -> JsonSchema.STRING
179-
PrimitiveKind.DOUBLE -> JsonSchema.DOUBLE
180-
PrimitiveKind.FLOAT -> JsonSchema.FLOAT
181-
PrimitiveKind.INT -> JsonSchema.INT
182-
PrimitiveKind.LONG -> JsonSchema.LONG
183-
PrimitiveKind.SHORT -> JsonSchema.INT
184-
PrimitiveKind.STRING -> JsonSchema.STRING
185-
StructureKind.LIST -> JsonSchema.LIST
186-
StructureKind.MAP -> JsonSchema.OBJECT
187-
else -> JsonSchema.ANY
187+
private fun JsonObject.fixRefsPrefix(rootDefinition: String) {
188+
this.properties.computeIfPresent("\$ref") { key, node ->
189+
if (node is JsonTextValue) {
190+
if (node.value.startsWith(rootDefinition)) {
191+
JsonTextValue("#/" + node.value.removePrefix(rootDefinition))
192+
} else {
193+
JsonTextValue("#/\$defs/" + node.value.removePrefix("#/definitions/"))
188194
}
189-
190-
// Add nullability constraint
191-
if (this.isNullable && schema.type != null) {
192-
schema = schema.copy(type = schema.type.plus("null"))
195+
} else {
196+
node
197+
}
193198
}
194-
195-
return schema
199+
this.properties.values.forEach { it.fixRefsPrefix(rootDefinition) }
196200
}
197201
}
198202

0 commit comments

Comments
 (0)