@@ -10,12 +10,18 @@ import software.amazon.smithy.kotlin.codegen.core.withBlock
1010import software.amazon.smithy.kotlin.codegen.core.withInlineBlock
1111import software.amazon.smithy.kotlin.codegen.lang.KotlinTypes
1212import software.amazon.smithy.kotlin.codegen.model.getTrait
13+ import software.amazon.smithy.kotlin.codegen.service.MediaType.ANY
14+ import software.amazon.smithy.kotlin.codegen.service.MediaType.JSON
15+ import software.amazon.smithy.kotlin.codegen.service.MediaType.OCTET_STREAM
16+ import software.amazon.smithy.kotlin.codegen.service.MediaType.PLAIN_TEXT
1317import software.amazon.smithy.model.shapes.OperationShape
18+ import software.amazon.smithy.model.shapes.ShapeId
1419import software.amazon.smithy.model.traits.AuthTrait
1520import software.amazon.smithy.model.traits.HttpBearerAuthTrait
1621import software.amazon.smithy.model.traits.HttpLabelTrait
1722import software.amazon.smithy.model.traits.HttpQueryTrait
1823import software.amazon.smithy.model.traits.HttpTrait
24+ import software.amazon.smithy.model.traits.MediaTypeTrait
1925import software.amazon.smithy.model.traits.TimestampFormatTrait
2026import software.amazon.smithy.utils.AbstractCodeWriter
2127
@@ -239,22 +245,22 @@ internal class KtorStubGenerator(
239245 " OPTIONS" -> RuntimeTypes .KtorServerRouting .options
240246 else -> error(" Unsupported http trait ${httpTrait.method} " )
241247 }
242- val contentType = ContentType .fromServiceShape(ctx, serviceShape, shape)
248+ val contentType = MediaType .fromServiceShape(ctx, serviceShape, shape.input.get() )
243249 val contentTypeGuard = when (contentType) {
244- ContentType .CBOR -> " cbor()"
245- ContentType .JSON -> " json()"
246- ContentType .PLAIN_TEXT -> " text()"
247- ContentType . BINARY -> " binary()"
248- ContentType . MEDIA_TYPE -> " any()"
250+ MediaType .CBOR -> " cbor()"
251+ MediaType .JSON -> " json()"
252+ MediaType .PLAIN_TEXT -> " text()"
253+ MediaType . OCTET_STREAM -> " binary()"
254+ MediaType . ANY -> " any()"
249255 }
250256
251- val acceptTypeGuard = when (contentType) {
252- ContentType . CBOR -> " cbor() "
253- ContentType . JSON ,
254- ContentType . PLAIN_TEXT ,
255- ContentType . BINARY ,
256- ContentType . MEDIA_TYPE ,
257- -> " json ()"
257+ val acceptType = MediaType .fromServiceShape(ctx, serviceShape, shape.output.get())
258+ val acceptTypeGuard = when (acceptType) {
259+ MediaType . CBOR -> " cbor() "
260+ MediaType . JSON -> " json() "
261+ MediaType . PLAIN_TEXT -> " text() "
262+ MediaType . OCTET_STREAM -> " binary() "
263+ MediaType . ANY -> " any ()"
258264 }
259265
260266 withBlock(" #T (#S) {" , " }" , RuntimeTypes .KtorServerRouting .route, uri) {
@@ -301,7 +307,7 @@ internal class KtorStubGenerator(
301307 " Malformed CBOR output" ,
302308 )
303309 }
304- .call { renderResponseCall(writer, contentType , successCode) }
310+ .call { renderResponseCall(writer, acceptType , successCode, shape.output.get() ) }
305311 }
306312 withBlock(" catch (t: Throwable) {" , " }" ) {
307313 write(" throw t" )
@@ -397,11 +403,12 @@ internal class KtorStubGenerator(
397403
398404 private fun renderResponseCall (
399405 w : KotlinWriter ,
400- contentType : ContentType ,
406+ acceptType : MediaType ,
401407 successCode : Int ,
408+ outputShapeId : ShapeId ,
402409 ) {
403- when (contentType ) {
404- ContentType .CBOR -> w.withBlock(
410+ when (acceptType ) {
411+ MediaType .CBOR -> w.withBlock(
405412 " #T.#T(" ,
406413 " )" ,
407414 RuntimeTypes .KtorServerCore .applicationCall,
@@ -414,11 +421,20 @@ internal class KtorStubGenerator(
414421 RuntimeTypes .KtorServerHttp .HttpStatusCode ,
415422 )
416423 }
417- ContentType .JSON ,
418- ContentType .PLAIN_TEXT ,
419- ContentType .BINARY ,
420- ContentType .MEDIA_TYPE ,
421- -> w.withBlock(
424+ OCTET_STREAM -> w.withBlock(
425+ " #T.#T(" ,
426+ " )" ,
427+ RuntimeTypes .KtorServerCore .applicationCall,
428+ RuntimeTypes .KtorServerRouting .responseRespondBytes,
429+ ) {
430+ write(" bytes = response," )
431+ write(" contentType = #T," , RuntimeTypes .KtorServerHttp .OctetStream )
432+ write(
433+ " status = #T.fromValue($successCode )," ,
434+ RuntimeTypes .KtorServerHttp .HttpStatusCode ,
435+ )
436+ }
437+ JSON -> w.withBlock(
422438 " #T.#T(" ,
423439 " )" ,
424440 RuntimeTypes .KtorServerCore .applicationCall,
@@ -431,6 +447,36 @@ internal class KtorStubGenerator(
431447 RuntimeTypes .KtorServerHttp .HttpStatusCode ,
432448 )
433449 }
450+ PLAIN_TEXT -> w.withBlock(
451+ " #T.#T(" ,
452+ " )" ,
453+ RuntimeTypes .KtorServerCore .applicationCall,
454+ RuntimeTypes .KtorServerRouting .responseResponseText,
455+ ) {
456+ write(" text = response," )
457+ write(" contentType = #T," , RuntimeTypes .KtorServerHttp .PlainText )
458+ write(
459+ " status = #T.fromValue($successCode )," ,
460+ RuntimeTypes .KtorServerHttp .HttpStatusCode ,
461+ )
462+ }
463+ ANY -> {
464+ val outputShape = ctx.model.expectShape(outputShapeId)
465+ val mediaTraits = outputShape.allMembers.values.firstNotNullOf { it.getTrait<MediaTypeTrait >() }
466+ w.withBlock(
467+ " #T.#T(" ,
468+ " )" ,
469+ RuntimeTypes .KtorServerCore .applicationCall,
470+ RuntimeTypes .KtorServerRouting .responseRespondBytes,
471+ ) {
472+ write(" bytes = response," )
473+ write(" contentType = #S," , mediaTraits.value)
474+ write(
475+ " status = #T.fromValue($successCode )," ,
476+ RuntimeTypes .KtorServerHttp .HttpStatusCode ,
477+ )
478+ }
479+ }
434480 }
435481 }
436482
0 commit comments