@@ -6,16 +6,22 @@ package software.amazon.smithy.kotlin.codegen.rendering.endpoints
66
77import software.amazon.smithy.codegen.core.CodegenException
88import software.amazon.smithy.codegen.core.Symbol
9+ import software.amazon.smithy.jmespath.JmespathExpression
910import software.amazon.smithy.kotlin.codegen.KotlinSettings
1011import software.amazon.smithy.kotlin.codegen.core.*
1112import software.amazon.smithy.kotlin.codegen.integration.SectionId
1213import software.amazon.smithy.kotlin.codegen.model.*
1314import software.amazon.smithy.kotlin.codegen.model.knowledge.EndpointParameterIndex
1415import software.amazon.smithy.kotlin.codegen.rendering.protocol.ProtocolGenerator
16+ import software.amazon.smithy.kotlin.codegen.rendering.waiters.KotlinJmespathExpressionVisitor
1517import software.amazon.smithy.model.knowledge.TopDownIndex
18+ import software.amazon.smithy.model.shapes.MemberShape
1619import software.amazon.smithy.model.shapes.OperationShape
20+ import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameter
1721import software.amazon.smithy.rulesengine.language.syntax.parameters.ParameterType
1822import software.amazon.smithy.rulesengine.traits.ClientContextParamsTrait
23+ import software.amazon.smithy.rulesengine.traits.OperationContextParamDefinition
24+ import software.amazon.smithy.rulesengine.traits.StaticContextParamDefinition
1925import software.amazon.smithy.utils.StringUtils
2026
2127object EndpointBusinessMetrics : SectionId
@@ -77,6 +83,7 @@ class EndpointResolverAdapterGenerator(
7783 val topDownIndex = TopDownIndex .of(ctx.model)
7884 val operations = topDownIndex.getContainedOperations(ctx.service)
7985 val epParameterIndex = EndpointParameterIndex .of(ctx.model)
86+ val operationsWithContextBindings = operations.filter { epParameterIndex.hasContextParams(it) }
8087
8188 writer.write(
8289 " private typealias BindOperationContextParamsFn = (#T.Builder, #T) -> Unit" ,
@@ -88,24 +95,28 @@ class EndpointResolverAdapterGenerator(
8895 " private val opContextBindings = mapOf<String, BindOperationContextParamsFn> (" ,
8996 " )" ,
9097 ) {
91- val operationsWithContextBindings = operations.filter { epParameterIndex.hasContextParams(it) }
9298 operationsWithContextBindings.forEach { op ->
93- val bindFn = op.bindEndpointContextFn(ctx.settings) { fnWriter ->
94- fnWriter.withBlock(
95- " private fun #L(builder: #T.Builder, request: #T): Unit {" ,
96- " }" ,
97- op.bindEndpointContextFnName(),
98- EndpointParametersGenerator .getSymbol(ctx.settings),
99- RuntimeTypes .HttpClient .Operation .ResolveEndpointRequest ,
100- ) {
101- renderBindOperationContextParams(epParameterIndex, op, fnWriter)
102- }
103- }
104- write(" #S to ::#T," , op.id.name, bindFn)
99+ write(" #S to ::#L," , op.id.name, op.bindEndpointContextFnName())
105100 }
106101 }
102+
103+ operationsWithContextBindings.forEach { op ->
104+ renderBindOperationContextFunction(op, epParameterIndex)
105+ }
107106 }
108107
108+ private fun renderBindOperationContextFunction (op : OperationShape , epParameterIndex : EndpointParameterIndex ) =
109+ writer.write(" " )
110+ .withBlock(
111+ " private fun #L(builder: #T.Builder, request: #T): Unit {" ,
112+ " }" ,
113+ op.bindEndpointContextFnName(),
114+ EndpointParametersGenerator .getSymbol(ctx.settings),
115+ RuntimeTypes .HttpClient .Operation .ResolveEndpointRequest ,
116+ ) {
117+ renderBindOperationContextParams(epParameterIndex, op)
118+ }
119+
109120 private fun renderResolveEndpointParams () {
110121 // NOTE: this is internal as it's re-used for auth scheme resolver generators in specific instances where they
111122 // fallback to endpoint rules (e.g. S3 & EventBridge)
@@ -119,14 +130,21 @@ class EndpointResolverAdapterGenerator(
119130 ) {
120131 writer.addImport(RuntimeTypes .Core .Collections .get)
121132 withBlock(" return #T {" , " }" , EndpointParametersGenerator .getSymbol(ctx.settings)) {
122- // The SEP dictates a specific source order to use when binding parameters (from most specific to least):
123- // 1. staticContextParams (from operation shape)
124- // 2. contextParam (from member of operation input shape)
125- // 3. clientContextParams (from service shape)
126- // 4. builtin binding
127- // 5. builtin default
128- // Sources 4 and 5 are SDK-specific, builtin bindings are plugged in and rendered beforehand such that any bindings
129- // from source 1 or 2 can supersede them.
133+ /*
134+ The spec dictates a specific source order to use when binding parameters (from most specific to least):
135+
136+ 1. staticContextParams (from operation shape)
137+ 2. contextParam (from member of operation input shape)
138+ 3. operationContextParams (from operation shape)
139+ 4. clientContextParams (from service shape)
140+ 5. builtin binding
141+ 6. builtin default
142+
143+ Sources 5 and 6 are SDK-specific
144+
145+ Builtin bindings are plugged in and rendered beforehand such that any bindings from source 1, 2, or 3
146+ can supersede them.
147+ */
130148
131149 // Render builtins
132150 if (rules != null ) {
@@ -140,7 +158,7 @@ class EndpointResolverAdapterGenerator(
140158 // Render client context
141159 renderBindClientContextParams(ctx, writer)
142160
143- // Render operation static/input context (if any)
161+ // Render operation static/input/operation context (if any)
144162 write(" val opName = request.context[#T.OperationName]" , RuntimeTypes .SmithyClient .SdkClientOption )
145163 write(" opContextBindings[opName]?.invoke(this, request)" )
146164 }
@@ -167,42 +185,87 @@ class EndpointResolverAdapterGenerator(
167185 private fun renderBindOperationContextParams (
168186 epParameterIndex : EndpointParameterIndex ,
169187 op : OperationShape ,
170- writer : KotlinWriter ,
171188 ) {
172189 if (rules == null ) return
190+
173191 val staticContextParams = epParameterIndex.staticContextParams(op)
174192 val inputContextParams = epParameterIndex.inputContextParams(op)
193+ val operationContextParams = epParameterIndex.operationContextParams(op)
175194
176- if (inputContextParams.isNotEmpty()) {
177- writer.addImport(RuntimeTypes .Core .Collections .get)
178- writer.write(" @Suppress(#S)" , " UNCHECKED_CAST" )
179- val opInputShape = ctx.model.expectShape(op.inputShape)
180- val inputSymbol = ctx.symbolProvider.toSymbol(opInputShape)
181- writer.write(" val input = request.context[#T.OperationInput] as #T" , RuntimeTypes .HttpClient .Operation .HttpOperationContext , inputSymbol)
182- }
195+ if (inputContextParams.isNotEmpty()) renderInput(op)
183196
184197 for (param in rules.parameters.toList()) {
185198 val paramName = param.name.toString()
186199 val paramDefaultName = param.defaultName()
187200
201+ // Check static params
188202 val staticParam = staticContextParams?.parameters?.get(paramName)
189-
190203 if (staticParam != null ) {
191- writer.writeInline(" builder.#L = " , paramDefaultName)
192- when (param.type) {
193- ParameterType .STRING -> writer.write(" #S" , staticParam.value.expectStringNode().value)
194- ParameterType .BOOLEAN -> writer.write(" #L" , staticParam.value.expectBooleanNode().value)
195- else -> throw CodegenException (" unexpected static context param type ${param.type} " )
196- }
204+ renderStaticParam(staticParam, paramDefaultName, param)
205+ continue
206+ }
207+
208+ // Check input params
209+ val inputParam = inputContextParams[paramName]
210+ if (inputParam != null ) {
211+ renderInputParam(inputParam, paramDefaultName)
197212 continue
198213 }
199214
200- inputContextParams[paramName]?.let {
201- writer.write(" builder.#L = input.#L" , paramDefaultName, it.defaultName())
215+ // Check operation params
216+ val operationParam = operationContextParams?.get(paramName)
217+ if (operationParam != null ) {
218+ renderOperationParam(operationParam, paramDefaultName, op, inputContextParams)
202219 }
203220 }
204221 }
205222
223+ private fun renderInput (op : OperationShape ) {
224+ writer.addImport(RuntimeTypes .Core .Collections .get)
225+ writer.write(" @Suppress(#S)" , " UNCHECKED_CAST" )
226+ val opInputShape = ctx.model.expectShape(op.inputShape)
227+ val inputSymbol = ctx.symbolProvider.toSymbol(opInputShape)
228+ writer.write(" val input = request.context[#T.OperationInput] as #T" , RuntimeTypes .HttpClient .Operation .HttpOperationContext , inputSymbol)
229+ }
230+
231+ private fun renderStaticParam (staticParam : StaticContextParamDefinition , paramDefaultName : String , param : Parameter ) {
232+ writer.writeInline(" builder.#L = " , paramDefaultName)
233+ when (param.type) {
234+ ParameterType .STRING -> writer.write(" #S" , staticParam.value.expectStringNode().value)
235+ ParameterType .BOOLEAN -> writer.write(" #L" , staticParam.value.expectBooleanNode().value)
236+ ParameterType .STRING_ARRAY -> writer.write(" #L" , staticParam.value.expectArrayNode().elements.format())
237+ else -> throw CodegenException (" unexpected static context param type ${param.type} " )
238+ }
239+ }
240+
241+ private fun renderInputParam (inputParam : MemberShape , paramDefaultName : String ) {
242+ writer.write(" builder.#L = input.#L" , paramDefaultName, inputParam.defaultName())
243+ }
244+
245+ private fun renderOperationParam (operationParam : OperationContextParamDefinition , paramDefaultName : String , op : OperationShape , inputContextParams : Map <String , MemberShape >) {
246+ val opInputShape = ctx.model.expectShape(op.inputShape)
247+
248+ if (inputContextParams.isEmpty()) {
249+ // This will already be rendered in the block if inputContextParams is not empty
250+ renderInput(op)
251+ }
252+
253+ val jmespathVisitor = KotlinJmespathExpressionVisitor (
254+ GenerationContext (
255+ ctx.model,
256+ ctx.symbolProvider,
257+ ctx.settings,
258+ ),
259+ writer,
260+ opInputShape,
261+ " input" , // reference the operation input during jmespath codegen
262+ )
263+ val expression = JmespathExpression .parse(operationParam.path)
264+ val expressionResult = expression.accept(jmespathVisitor)
265+
266+ writer.write(" builder.#L = #L" , paramDefaultName, expressionResult.identifier)
267+ }
268+
206269 private fun renderBindClientContextParams (ctx : ProtocolGenerator .GenerationContext , writer : KotlinWriter ) {
207270 val clientContextParams = ctx.service.getTrait<ClientContextParamsTrait >() ? : return
208271 if (rules == null ) return
0 commit comments