Skip to content

Commit 069a162

Browse files
Simple Json Schemas for Kotlin (#403)
1 parent fa728dd commit 069a162

File tree

5 files changed

+115
-10
lines changed

5 files changed

+115
-10
lines changed

sdk-api-kotlin-gen/src/test/kotlin/dev/restate/sdk/kotlin/CodegenTest.kt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,8 @@ class CodegenTest : TestDefinitions.TestSuite {
5757
@Exclusive
5858
suspend fun complexType(
5959
context: ObjectContext,
60-
request: Map<Output, List<out Input>>
61-
): Map<Input, List<out Output>> {
60+
request: Map<String, List<out Input>>
61+
): Map<String, List<out Output>> {
6262
return mapOf()
6363
}
6464
}

sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/KtSerdes.kt

Lines changed: 85 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,26 @@
99
package dev.restate.sdk.kotlin
1010

1111
import dev.restate.sdk.common.DurablePromiseKey
12+
import dev.restate.sdk.common.RichSerde
1213
import dev.restate.sdk.common.Serde
1314
import dev.restate.sdk.common.StateKey
1415
import java.nio.ByteBuffer
1516
import java.nio.charset.StandardCharsets
1617
import kotlin.reflect.typeOf
18+
import kotlinx.serialization.ExperimentalSerializationApi
1719
import kotlinx.serialization.KSerializer
20+
import kotlinx.serialization.Serializable
21+
import kotlinx.serialization.builtins.ListSerializer
22+
import kotlinx.serialization.builtins.serializer
23+
import kotlinx.serialization.descriptors.PrimitiveKind
24+
import kotlinx.serialization.descriptors.SerialDescriptor
25+
import kotlinx.serialization.descriptors.StructureKind
26+
import kotlinx.serialization.encodeToString
1827
import kotlinx.serialization.json.Json
28+
import kotlinx.serialization.json.JsonArray
29+
import kotlinx.serialization.json.JsonElement
1930
import kotlinx.serialization.json.JsonNull
31+
import kotlinx.serialization.json.JsonTransformingSerializer
2032
import kotlinx.serialization.serializer
2133

2234
object KtStateKey {
@@ -70,12 +82,13 @@ object KtSerdes {
7082
}
7183

7284
/** Creates a [Serde] implementation using the `kotlinx.serialization` json module. */
73-
fun <T : Any?> json(serializer: KSerializer<T>): Serde<T> {
74-
return object : Serde<T> {
85+
inline fun <reified T : Any?> json(serializer: KSerializer<T>): Serde<T> {
86+
return object : RichSerde<T> {
7587
override fun serialize(value: T?): ByteArray {
7688
if (value == null) {
7789
return Json.encodeToString(JsonNull.serializer(), JsonNull).encodeToByteArray()
7890
}
91+
7992
return Json.encodeToString(serializer, value).encodeToByteArray()
8093
}
8194

@@ -86,6 +99,76 @@ object KtSerdes {
8699
override fun contentType(): String {
87100
return "application/json"
88101
}
102+
103+
override fun jsonSchema(): String {
104+
val schema: JsonSchema = serializer.descriptor.jsonSchema()
105+
return Json.encodeToString(schema)
106+
}
89107
}
90108
}
109+
110+
@Serializable
111+
@PublishedApi
112+
internal data class JsonSchema(
113+
@Serializable(with = StringListSerializer::class) val type: List<String>? = null,
114+
val format: String? = null,
115+
) {
116+
companion object {
117+
val INT = JsonSchema(type = listOf("number"), format = "int32")
118+
119+
val LONG = JsonSchema(type = listOf("number"), format = "int64")
120+
121+
val DOUBLE = JsonSchema(type = listOf("number"), format = "double")
122+
123+
val FLOAT = JsonSchema(type = listOf("number"), format = "float")
124+
125+
val STRING = JsonSchema(type = listOf("string"))
126+
127+
val BOOLEAN = JsonSchema(type = listOf("boolean"))
128+
129+
val OBJECT = JsonSchema(type = listOf("object"))
130+
131+
val LIST = JsonSchema(type = listOf("array"))
132+
133+
val ANY = JsonSchema()
134+
}
135+
}
136+
137+
object StringListSerializer :
138+
JsonTransformingSerializer<List<String>>(ListSerializer(String.Companion.serializer())) {
139+
override fun transformSerialize(element: JsonElement): JsonElement {
140+
require(element is JsonArray)
141+
return element.singleOrNull() ?: element
142+
}
143+
}
144+
145+
/**
146+
* Super simplistic json schema generation. We should replace this with an appropriate library.
147+
*/
148+
@OptIn(ExperimentalSerializationApi::class)
149+
@PublishedApi
150+
internal fun SerialDescriptor.jsonSchema(): JsonSchema {
151+
var schema =
152+
when (this.kind) {
153+
PrimitiveKind.BOOLEAN -> JsonSchema.BOOLEAN
154+
PrimitiveKind.BYTE -> JsonSchema.INT
155+
PrimitiveKind.CHAR -> JsonSchema.STRING
156+
PrimitiveKind.DOUBLE -> JsonSchema.DOUBLE
157+
PrimitiveKind.FLOAT -> JsonSchema.FLOAT
158+
PrimitiveKind.INT -> JsonSchema.INT
159+
PrimitiveKind.LONG -> JsonSchema.LONG
160+
PrimitiveKind.SHORT -> JsonSchema.INT
161+
PrimitiveKind.STRING -> JsonSchema.STRING
162+
StructureKind.LIST -> JsonSchema.LIST
163+
StructureKind.MAP -> JsonSchema.OBJECT
164+
else -> JsonSchema.ANY
165+
}
166+
167+
// Add nullability constraint
168+
if (this.isNullable && schema.type != null) {
169+
schema = schema.copy(type = schema.type.plus("null"))
170+
}
171+
172+
return schema
173+
}
91174
}

sdk-common/src/main/java/dev/restate/sdk/common/RichSerde.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,15 @@
1414
/**
1515
* Richer version of {@link Serde} containing schema information.
1616
*
17+
* <p>This API should be considered unstable to implement.
18+
*
1719
* <p>You can create one using {@link #withSchema(Object, Serde)}.
1820
*/
1921
public interface RichSerde<T extends @Nullable Object> extends Serde<T> {
2022

2123
/**
22-
* @return a Draft 2020-12 Json Schema
24+
* @return a Draft 2020-12 Json Schema. It should be self-contained, and MUST not contain refs to
25+
* files. If the schema shouldn't be serialized with Jackson, return a {@link String}
2326
*/
2427
Object jsonSchema();
2528

sdk-core/src/main/java/dev/restate/sdk/core/EndpointManifest.java

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
import static dev.restate.sdk.core.ServiceProtocol.*;
1212

13+
import com.fasterxml.jackson.core.JsonProcessingException;
1314
import dev.restate.sdk.common.HandlerType;
1415
import dev.restate.sdk.common.RichSerde;
1516
import dev.restate.sdk.common.ServiceType;
@@ -108,8 +109,17 @@ private static Input convertHandlerInput(HandlerSpecification<?, ?> spec) {
108109
: new Input().withRequired(true).withContentType(acceptContentType);
109110

110111
if (spec.getRequestSerde() instanceof RichSerde) {
111-
input.setJsonSchema(
112-
Objects.requireNonNull(((RichSerde<?>) spec.getRequestSerde()).jsonSchema()));
112+
Object jsonSchema =
113+
Objects.requireNonNull(((RichSerde<?>) spec.getRequestSerde()).jsonSchema());
114+
if (jsonSchema instanceof String) {
115+
// We need to convert it to databind JSON value
116+
try {
117+
jsonSchema = MANIFEST_OBJECT_MAPPER.readTree((String) jsonSchema);
118+
} catch (JsonProcessingException e) {
119+
throw new RuntimeException("The schema generated by RichSerde is not a valid JSON", e);
120+
}
121+
}
122+
input.setJsonSchema(jsonSchema);
113123
}
114124
return input;
115125
}
@@ -123,8 +133,17 @@ private static Output convertHandlerOutput(HandlerSpecification<?, ?> spec) {
123133
.withSetContentTypeIfEmpty(false);
124134

125135
if (spec.getResponseSerde() instanceof RichSerde) {
126-
output.setJsonSchema(
127-
Objects.requireNonNull(((RichSerde<?>) spec.getResponseSerde()).jsonSchema()));
136+
Object jsonSchema =
137+
Objects.requireNonNull(((RichSerde<?>) spec.getResponseSerde()).jsonSchema());
138+
if (jsonSchema instanceof String) {
139+
// We need to convert it to databind JSON value
140+
try {
141+
jsonSchema = MANIFEST_OBJECT_MAPPER.readTree((String) jsonSchema);
142+
} catch (JsonProcessingException e) {
143+
throw new RuntimeException("The schema generated by RichSerde is not a valid JSON", e);
144+
}
145+
}
146+
output.setJsonSchema(jsonSchema);
128147
}
129148

130149
return output;

sdk-core/src/main/java/dev/restate/sdk/core/ServiceProtocol.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ static String serviceDiscoveryProtocolVersionToHeaderValue(
135135
"Service discovery protocol version '%s' has no header value", version.getNumber()));
136136
}
137137

138-
private static final ObjectMapper MANIFEST_OBJECT_MAPPER = new ObjectMapper();
138+
static final ObjectMapper MANIFEST_OBJECT_MAPPER = new ObjectMapper();
139139

140140
@JsonFilter("V2FieldsFilter")
141141
interface V2Mixin {}

0 commit comments

Comments
 (0)