diff --git a/sdk-api-kotlin-gen/src/test/kotlin/dev/restate/sdk/kotlin/CodegenTest.kt b/sdk-api-kotlin-gen/src/test/kotlin/dev/restate/sdk/kotlin/CodegenTest.kt index 6adec37b..94eef43c 100644 --- a/sdk-api-kotlin-gen/src/test/kotlin/dev/restate/sdk/kotlin/CodegenTest.kt +++ b/sdk-api-kotlin-gen/src/test/kotlin/dev/restate/sdk/kotlin/CodegenTest.kt @@ -57,8 +57,8 @@ class CodegenTest : TestDefinitions.TestSuite { @Exclusive suspend fun complexType( context: ObjectContext, - request: Map> - ): Map> { + request: Map> + ): Map> { return mapOf() } } diff --git a/sdk-api-kotlin/build.gradle.kts b/sdk-api-kotlin/build.gradle.kts index f04a7682..d7d0abbc 100644 --- a/sdk-api-kotlin/build.gradle.kts +++ b/sdk-api-kotlin/build.gradle.kts @@ -13,6 +13,8 @@ dependencies { implementation(kotlinLibs.kotlinx.serialization.core) implementation(kotlinLibs.kotlinx.serialization.json) + implementation("io.bkbn:kompendium-json-schema:4.0.0-alpha") + implementation(coreLibs.log4j.api) implementation(platform(coreLibs.opentelemetry.bom)) implementation(coreLibs.opentelemetry.kotlin) diff --git a/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/KtSerdes.kt b/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/KtSerdes.kt index 798bac18..65e456b7 100644 --- a/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/KtSerdes.kt +++ b/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/KtSerdes.kt @@ -9,14 +9,27 @@ package dev.restate.sdk.kotlin import dev.restate.sdk.common.DurablePromiseKey +import dev.restate.sdk.common.RichSerde import dev.restate.sdk.common.Serde import dev.restate.sdk.common.StateKey +import io.bkbn.kompendium.json.schema.KotlinXSchemaConfigurator +import io.bkbn.kompendium.json.schema.SchemaGenerator +import io.bkbn.kompendium.json.schema.definition.AnyOfDefinition +import io.bkbn.kompendium.json.schema.definition.ArrayDefinition +import io.bkbn.kompendium.json.schema.definition.JsonSchema +import io.bkbn.kompendium.json.schema.definition.MapDefinition +import io.bkbn.kompendium.json.schema.definition.OneOfDefinition +import io.bkbn.kompendium.json.schema.definition.ReferenceDefinition +import io.bkbn.kompendium.json.schema.definition.TypeDefinition import java.nio.ByteBuffer import java.nio.charset.StandardCharsets import kotlin.reflect.typeOf import kotlinx.serialization.KSerializer +import kotlinx.serialization.encodeToString import kotlinx.serialization.json.Json import kotlinx.serialization.json.JsonNull +import kotlinx.serialization.json.encodeToJsonElement +import kotlinx.serialization.json.jsonObject import kotlinx.serialization.serializer object KtStateKey { @@ -70,12 +83,13 @@ object KtSerdes { } /** Creates a [Serde] implementation using the `kotlinx.serialization` json module. */ - fun json(serializer: KSerializer): Serde { - return object : Serde { + inline fun json(serializer: KSerializer): Serde { + return object : RichSerde { override fun serialize(value: T?): ByteArray { if (value == null) { return Json.encodeToString(JsonNull.serializer(), JsonNull).encodeToByteArray() } + return Json.encodeToString(serializer, value).encodeToByteArray() } @@ -86,6 +100,39 @@ object KtSerdes { override fun contentType(): String { return "application/json" } + + override fun jsonSchema(): String { + fun JsonSchema.sanitizeRefs(): JsonSchema { + return when (this) { + is AnyOfDefinition -> this.copy(anyOf = this.anyOf.map { it.sanitizeRefs() }.toSet()) + is ArrayDefinition -> this.items.sanitizeRefs() + is MapDefinition -> + this.copy(additionalProperties = this.additionalProperties.sanitizeRefs()) + is OneOfDefinition -> this.copy(oneOf = this.oneOf.map { it.sanitizeRefs() }.toSet()) + is ReferenceDefinition -> + this.copy(`$ref` = this.`$ref`.replaceFirst("#/components/schemas", "#/\$defs")) + is TypeDefinition -> + this.copy(properties = this.properties?.mapValues { it.value.sanitizeRefs() }) + else -> this + } + } + + val nestedSchemas = mutableMapOf() + val rootSchema = + SchemaGenerator.fromTypeToSchema( + type = typeOf(), + cache = nestedSchemas, + schemaConfigurator = KotlinXSchemaConfigurator(), + ) + + val defsSchemas: Map = + nestedSchemas.mapValues { e -> e.value.sanitizeRefs() } + val rootElement = + Json.encodeToJsonElement(JsonSchema.serializer(), rootSchema.sanitizeRefs()) + .jsonObject + ("\$defs" to Json.encodeToJsonElement(defsSchemas)) + + return Json.encodeToString(rootElement) + } } } } diff --git a/sdk-common/src/main/java/dev/restate/sdk/common/RichSerde.java b/sdk-common/src/main/java/dev/restate/sdk/common/RichSerde.java index 20e5418f..2e7e0961 100644 --- a/sdk-common/src/main/java/dev/restate/sdk/common/RichSerde.java +++ b/sdk-common/src/main/java/dev/restate/sdk/common/RichSerde.java @@ -14,12 +14,15 @@ /** * Richer version of {@link Serde} containing schema information. * + *

This API should be considered unstable to implement. + * *

You can create one using {@link #withSchema(Object, Serde)}. */ public interface RichSerde extends Serde { /** - * @return a Draft 2020-12 Json Schema + * @return a Draft 2020-12 Json Schema. It should be self-contained, and MUST not contain refs to + * files. If the schema shouldn't be serialized with Jackson, return a {@link String} */ Object jsonSchema(); diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/EndpointManifest.java b/sdk-core/src/main/java/dev/restate/sdk/core/EndpointManifest.java index 03f58f29..782e0c6e 100644 --- a/sdk-core/src/main/java/dev/restate/sdk/core/EndpointManifest.java +++ b/sdk-core/src/main/java/dev/restate/sdk/core/EndpointManifest.java @@ -10,6 +10,7 @@ import static dev.restate.sdk.core.ServiceProtocol.*; +import com.fasterxml.jackson.core.JsonProcessingException; import dev.restate.sdk.common.HandlerType; import dev.restate.sdk.common.RichSerde; import dev.restate.sdk.common.ServiceType; @@ -108,8 +109,17 @@ private static Input convertHandlerInput(HandlerSpecification spec) { : new Input().withRequired(true).withContentType(acceptContentType); if (spec.getRequestSerde() instanceof RichSerde) { - input.setJsonSchema( - Objects.requireNonNull(((RichSerde) spec.getRequestSerde()).jsonSchema())); + Object jsonSchema = + Objects.requireNonNull(((RichSerde) spec.getRequestSerde()).jsonSchema()); + if (jsonSchema instanceof String) { + // We need to convert it to databind JSON value + try { + jsonSchema = MANIFEST_OBJECT_MAPPER.readTree((String) jsonSchema); + } catch (JsonProcessingException e) { + throw new RuntimeException("The schema generated by RichSerde is not a valid JSON", e); + } + } + input.setJsonSchema(jsonSchema); } return input; } @@ -123,8 +133,17 @@ private static Output convertHandlerOutput(HandlerSpecification spec) { .withSetContentTypeIfEmpty(false); if (spec.getResponseSerde() instanceof RichSerde) { - output.setJsonSchema( - Objects.requireNonNull(((RichSerde) spec.getResponseSerde()).jsonSchema())); + Object jsonSchema = + Objects.requireNonNull(((RichSerde) spec.getResponseSerde()).jsonSchema()); + if (jsonSchema instanceof String) { + // We need to convert it to databind JSON value + try { + jsonSchema = MANIFEST_OBJECT_MAPPER.readTree((String) jsonSchema); + } catch (JsonProcessingException e) { + throw new RuntimeException("The schema generated by RichSerde is not a valid JSON", e); + } + } + output.setJsonSchema(jsonSchema); } return output; diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/ServiceProtocol.java b/sdk-core/src/main/java/dev/restate/sdk/core/ServiceProtocol.java index 241aca53..f410a638 100644 --- a/sdk-core/src/main/java/dev/restate/sdk/core/ServiceProtocol.java +++ b/sdk-core/src/main/java/dev/restate/sdk/core/ServiceProtocol.java @@ -135,7 +135,7 @@ static String serviceDiscoveryProtocolVersionToHeaderValue( "Service discovery protocol version '%s' has no header value", version.getNumber())); } - private static final ObjectMapper MANIFEST_OBJECT_MAPPER = new ObjectMapper(); + static final ObjectMapper MANIFEST_OBJECT_MAPPER = new ObjectMapper(); @JsonFilter("V2FieldsFilter") interface V2Mixin {} diff --git a/test-services/build.gradle.kts b/test-services/build.gradle.kts index 98c677c5..db358bcc 100644 --- a/test-services/build.gradle.kts +++ b/test-services/build.gradle.kts @@ -70,4 +70,9 @@ jib { tasks.jar { manifest { attributes["Main-Class"] = "dev.restate.sdk.testservices.MainKt" } } -application { mainClass.set("dev.restate.sdk.testservices.MainKt") } +tasks.withType { + classpath("$projectDir/generated/ksp/main/resources") +} + +application { + mainClass.set("dev.restate.sdk.testservices.MainKt") }