@@ -14,11 +14,13 @@ import software.amazon.smithy.kotlin.codegen.service.MediaType.ANY
1414import software.amazon.smithy.kotlin.codegen.service.MediaType.JSON
1515import software.amazon.smithy.kotlin.codegen.service.MediaType.OCTET_STREAM
1616import software.amazon.smithy.kotlin.codegen.service.MediaType.PLAIN_TEXT
17+ import software.amazon.smithy.model.shapes.MapShape
1718import software.amazon.smithy.model.shapes.OperationShape
1819import software.amazon.smithy.model.shapes.ShapeId
1920import software.amazon.smithy.model.traits.AuthTrait
2021import software.amazon.smithy.model.traits.HttpBearerAuthTrait
2122import software.amazon.smithy.model.traits.HttpLabelTrait
23+ import software.amazon.smithy.model.traits.HttpQueryParamsTrait
2224import software.amazon.smithy.model.traits.HttpQueryTrait
2325import software.amazon.smithy.model.traits.HttpTrait
2426import software.amazon.smithy.model.traits.MediaTypeTrait
@@ -248,19 +250,19 @@ internal class KtorStubGenerator(
248250 val contentType = MediaType .fromServiceShape(ctx, serviceShape, shape.input.get())
249251 val contentTypeGuard = when (contentType) {
250252 MediaType .CBOR -> " cbor()"
251- MediaType . JSON -> " json()"
252- MediaType . PLAIN_TEXT -> " text()"
253- MediaType . OCTET_STREAM -> " binary()"
254- MediaType . ANY -> " any()"
253+ JSON -> " json()"
254+ PLAIN_TEXT -> " text()"
255+ OCTET_STREAM -> " binary()"
256+ ANY -> " any()"
255257 }
256258
257259 val acceptType = MediaType .fromServiceShape(ctx, serviceShape, shape.output.get())
258260 val acceptTypeGuard = when (acceptType) {
259261 MediaType .CBOR -> " cbor()"
260- MediaType . JSON -> " json()"
261- MediaType . PLAIN_TEXT -> " text()"
262- MediaType . OCTET_STREAM -> " binary()"
263- MediaType . ANY -> " any()"
262+ JSON -> " json()"
263+ PLAIN_TEXT -> " text()"
264+ OCTET_STREAM -> " binary()"
265+ ANY -> " any()"
264266 }
265267
266268 withBlock(" #T (#S) {" , " }" , RuntimeTypes .KtorServerRouting .route, uri) {
@@ -329,7 +331,7 @@ internal class KtorStubGenerator(
329331 val memberName = member.key
330332 val memberShape = member.value
331333
332- val httpLabelVariableName = " call.parameters[\" $memberName \" ]"
334+ val httpLabelVariableName = " call.parameters[\" $memberName \" ]? "
333335 val targetShape = ctx.model.expectShape(memberShape.target)
334336 writer.writeInline(" $memberName = " )
335337 .call {
@@ -346,14 +348,16 @@ internal class KtorStubGenerator(
346348
347349 private fun readHttpQuery (shape : OperationShape , writer : KotlinWriter ) {
348350 val inputShape = ctx.model.expectShape(shape.input.get())
351+ val httpQueryKeys = mutableSetOf<String >()
349352 inputShape.allMembers
350353 .filter { member -> member.value.hasTrait(HttpQueryTrait .ID ) }
351354 .forEach { member ->
352355 val memberName = member.key
353356 val memberShape = member.value
354357 val httpQueryTrait = memberShape.getTrait<HttpQueryTrait >()!!
355- val httpQueryVariableName = " call.request.queryParameters[\" ${httpQueryTrait.value} \" ]"
358+ val httpQueryVariableName = " call.request.queryParameters[\" ${httpQueryTrait.value} \" ]? "
356359 val targetShape = ctx.model.expectShape(memberShape.target)
360+ httpQueryKeys.add(httpQueryTrait.value)
357361 writer.writeInline(" $memberName = " )
358362 .call {
359363 when {
@@ -365,7 +369,7 @@ internal class KtorStubGenerator(
365369 " ?: emptyList())"
366370 writer.withBlock(" $httpQueryListVariableName .mapNotNull{" , " }" ) {
367371 renderCastingPrimitiveFromShapeType(
368- " it" ,
372+ " it? " ,
369373 listMemberTargetShapeId.type,
370374 writer,
371375 listMemberShape.getTrait<TimestampFormatTrait >() ? : targetShape.getTrait<TimestampFormatTrait >(),
@@ -383,6 +387,33 @@ internal class KtorStubGenerator(
383387 }
384388 }
385389 }
390+ val httpQueryParamsMember = inputShape.allMembers.values.firstOrNull { it.hasTrait(HttpQueryParamsTrait .ID ) }
391+ httpQueryParamsMember?.apply {
392+ val httpQueryParamsMemberName = httpQueryParamsMember.memberName
393+ val httpQueryParamsMapShape = ctx.model.expectShape(httpQueryParamsMember.target) as MapShape
394+ val httpQueryParamsMapValueTypeShape = ctx.model.expectShape(httpQueryParamsMapShape.value.target)
395+ println (httpQueryParamsMapShape)
396+ val httpQueryKeysLiteral = httpQueryKeys.joinToString(" , " ) { " \" $it \" " }
397+ writer.withInlineBlock(" $httpQueryParamsMemberName = call.request.queryParameters.entries().filter { (key, _) ->" , " }" ) {
398+ write(" key !in setOf($httpQueryKeysLiteral )" )
399+ }
400+ .withBlock(" .associate { (key, values) ->" , " }" ) {
401+ if (httpQueryParamsMapValueTypeShape.isListShape) {
402+ write(" key to values!!" )
403+ } else {
404+ write(" key to values.first()" )
405+ }
406+ }
407+ .withBlock(" .mapValues { (_, value) ->" , " }" ) {
408+ renderCastingPrimitiveFromShapeType(
409+ " value" ,
410+ httpQueryParamsMapValueTypeShape.type,
411+ writer,
412+ httpQueryParamsMapValueTypeShape.getTrait<TimestampFormatTrait >() ? : httpQueryParamsMapShape.getTrait<TimestampFormatTrait >(),
413+ " Unsupported type ${httpQueryParamsMapValueTypeShape.type} for httpQuery" ,
414+ )
415+ }
416+ }
386417 }
387418
388419 private fun renderRoutingAuth (w : KotlinWriter , shape : OperationShape ) {
@@ -668,13 +699,25 @@ internal class KtorStubGenerator(
668699 writer.withBlock(" public class AcceptTypeGuardConfig {" , " }" ) {
669700 write(" public var allow: List<#T> = emptyList()" , RuntimeTypes .KtorServerHttp .ContentType )
670701 write(" " )
702+ withBlock(" public fun any(): Unit {" , " }" ) {
703+ write(" allow = listOf(#T)" , RuntimeTypes .KtorServerHttp .Any )
704+ }
705+ write(" " )
671706 withBlock(" public fun json(): Unit {" , " }" ) {
672707 write(" allow = listOf(#T)" , RuntimeTypes .KtorServerHttp .Json )
673708 }
674709 write(" " )
675710 withBlock(" public fun cbor(): Unit {" , " }" ) {
676711 write(" allow = listOf(#T)" , RuntimeTypes .KtorServerHttp .Cbor )
677712 }
713+ write(" " )
714+ withBlock(" public fun text(): Unit {" , " }" ) {
715+ write(" allow = listOf(#T)" , RuntimeTypes .KtorServerHttp .PlainText )
716+ }
717+ write(" " )
718+ withBlock(" public fun binary(): Unit {" , " }" ) {
719+ write(" allow = listOf(#T)" , RuntimeTypes .KtorServerHttp .OctetStream )
720+ }
678721 }
679722 .write(" " )
680723
0 commit comments