@@ -24,6 +24,7 @@ import software.amazon.smithy.rulesengine.language.syntax.rule.EndpointRule
2424import software.amazon.smithy.rulesengine.language.syntax.rule.ErrorRule
2525import software.amazon.smithy.rulesengine.language.syntax.rule.Rule
2626import software.amazon.smithy.rulesengine.language.syntax.rule.TreeRule
27+ import java.util.stream.Collectors
2728
2829/* *
2930 * The core set of standard library functions available to the rules language.
@@ -49,24 +50,9 @@ typealias EndpointPropertyRenderer = (KotlinWriter, Expression, ExpressionRender
4950 * An expression renderer generates code for an endpoint expression construct.
5051 */
5152fun interface ExpressionRenderer {
52- fun renderExpression (expr : Expression )
53+ fun renderExpression (expr : Expression ): EndpointInfo
5354}
5455
55- /* *
56- * Will be toggled to true if it is determined an endpoint is account ID based then to false again
57- */
58- private var hasAccountIdBasedEndpoint = false
59-
60- /* *
61- * Will be toggled to true if determined an endpoint comes from a service endpoint override then to false again
62- */
63- private var hasServiceEndpointOverride = false
64-
65- /* *
66- * Will be toggled to true when rendering an endpoint URL then to false again
67- */
68- private var renderingEndpointUrl = false
69-
7056/* *
7157 * Renders the default endpoint provider based on the provided rule set.
7258 */
@@ -121,9 +107,7 @@ class DefaultEndpointProviderGenerator(
121107 }
122108 }
123109
124- override fun renderExpression (expr : Expression ) {
125- expr.accept(expressionGenerator)
126- }
110+ override fun renderExpression (expr : Expression ): EndpointInfo = expr.accept(expressionGenerator) ? : EndpointInfo .Empty
127111
128112 private fun renderDocumentation () {
129113 writer.dokka {
@@ -185,11 +169,11 @@ class DefaultEndpointProviderGenerator(
185169 withConditions(rule.conditions) {
186170 writer.withBlock(" return #T(" , " )" , RuntimeTypes .SmithyClient .Endpoints .Endpoint ) {
187171 writeInline(" #T.parse(" , RuntimeTypes .Core .Net .Url .Url )
188- renderingEndpointUrl = true
189- renderExpression(rule.endpoint.url)
190- renderingEndpointUrl = false
172+ val endpointInfo = renderExpression(rule.endpoint.url)
191173 write(" )," )
192174
175+ val hasAccountIdBasedEndpoint = " accountId" in endpointInfo.params
176+ val hasServiceEndpointOverride = " endpoint" in endpointInfo.params
193177 val needAdditionalEndpointProperties = hasAccountIdBasedEndpoint || hasServiceEndpointOverride
194178
195179 if (rule.endpoint.headers.isNotEmpty()) {
@@ -226,11 +210,9 @@ class DefaultEndpointProviderGenerator(
226210
227211 if (hasAccountIdBasedEndpoint) {
228212 writer.write(" #T to params.accountId" , RuntimeTypes .Core .BusinessMetrics .AccountIdBasedEndpointAccountId )
229- hasAccountIdBasedEndpoint = false
230213 }
231214 if (hasServiceEndpointOverride) {
232215 writer.write(" #T to true" , RuntimeTypes .Core .BusinessMetrics .ServiceEndpointOverride )
233- hasServiceEndpointOverride = false
234216 }
235217 }
236218 }
@@ -253,16 +235,24 @@ class DefaultEndpointProviderGenerator(
253235 }
254236}
255237
238+ data class EndpointInfo (val params : MutableSet <String >) {
239+ companion object {
240+ val Empty = EndpointInfo (params = mutableSetOf ())
241+ }
242+
243+ operator fun plus (other : EndpointInfo ) = EndpointInfo (
244+ params = (this .params + other.params).toMutableSet(),
245+ )
246+ }
247+
256248class ExpressionGenerator (
257249 private val writer : KotlinWriter ,
258250 private val rules : EndpointRuleSet ,
259251 private val functions : Map <String , Symbol >,
260- ) : ExpressionVisitor<Unit>, LiteralVisitor<Unit>, TemplateVisitor<Unit> {
261- override fun visitLiteral (literal : Literal ) {
262- literal.accept(this as LiteralVisitor <Unit >)
263- }
252+ ) : ExpressionVisitor<EndpointInfo?>, LiteralVisitor<EndpointInfo?>, TemplateVisitor<EndpointInfo?> {
253+ override fun visitLiteral (literal : Literal ): EndpointInfo ? = literal.accept(this as LiteralVisitor <EndpointInfo ?>)
264254
265- override fun visitRef (reference : Reference ) {
255+ override fun visitRef (reference : Reference ): EndpointInfo {
266256 val referenceName = reference.name.defaultName()
267257 val isParamReference = isParamRef(reference)
268258
@@ -271,90 +261,112 @@ class ExpressionGenerator(
271261 }
272262 writer.writeInline(referenceName)
273263
274- if (renderingEndpointUrl) {
275- if (isParamReference && referenceName == " accountId" ) hasAccountIdBasedEndpoint = true
276- if (isParamReference && referenceName == " endpoint" ) hasServiceEndpointOverride = true
264+ return if (isParamReference) {
265+ EndpointInfo (params = mutableSetOf (referenceName))
266+ } else {
267+ EndpointInfo .Empty
277268 }
278269 }
279270
280- override fun visitGetAttr (getAttr : GetAttr ) {
281- getAttr.target.accept(this )
271+ override fun visitGetAttr (getAttr : GetAttr ): EndpointInfo ? {
272+ val endpointInfo = getAttr.target.accept(this )
282273 getAttr.path.forEach {
283274 when (it) {
284275 is GetAttr .Part .Key -> writer.writeInline(" ?.#L" , it.key().toString())
285276 is GetAttr .Part .Index -> writer.writeInline(" ?.getOrNull(#L)" , it.index())
286277 else -> throw CodegenException (" unexpected path" )
287278 }
288279 }
280+ return endpointInfo
289281 }
290282
291- override fun visitIsSet (target : Expression ) {
292- target.accept(this )
283+ override fun visitIsSet (target : Expression ): EndpointInfo ? {
284+ val endpointInfo = target.accept(this )
293285 writer.writeInline(" != null" )
286+ return endpointInfo
294287 }
295288
296- override fun visitNot (target : Expression ) {
289+ override fun visitNot (target : Expression ): EndpointInfo ? {
297290 writer.writeInline(" !(" )
298- target.accept(this )
291+ val endpointInfo = target.accept(this )
299292 writer.writeInline(" )" )
293+ return endpointInfo
300294 }
301295
302- override fun visitBoolEquals (left : Expression , right : Expression ) {
303- visitEquals(left, right)
304- }
296+ override fun visitBoolEquals (left : Expression , right : Expression ): EndpointInfo ? = visitEquals(left, right)
305297
306- override fun visitStringEquals (left : Expression , right : Expression ) {
307- visitEquals(left, right)
308- }
298+ override fun visitStringEquals (left : Expression , right : Expression ): EndpointInfo ? = visitEquals(left, right)
309299
310- private fun visitEquals (left : Expression , right : Expression ) {
311- left.accept(this )
300+ private fun visitEquals (left : Expression , right : Expression ): EndpointInfo ? {
301+ val leftEndpointInfo = left.accept(this )
312302 writer.writeInline(" == " )
313- right.accept(this )
303+ val rightEndpointInfo = right.accept(this )
304+
305+ return when {
306+ leftEndpointInfo != null && rightEndpointInfo != null -> leftEndpointInfo + rightEndpointInfo
307+ leftEndpointInfo != null -> leftEndpointInfo
308+ else -> rightEndpointInfo
309+ }
314310 }
315311
316- override fun visitLibraryFunction (fn : FunctionDefinition , args : MutableList <Expression >) {
312+ override fun visitLibraryFunction (fn : FunctionDefinition , args : MutableList <Expression >): EndpointInfo ? {
317313 writer.writeInline(" #T(" , functions.getValue(fn.id))
318- args.forEachIndexed { index, it ->
319- it.accept(this )
314+
315+ val endpointInfo = args.foldIndexed(EndpointInfo .Empty ) { index, acc, curr ->
316+ val currEndpointInfo = curr.accept(this )
320317 if (index < args.lastIndex) {
321318 writer.writeInline(" , " )
322319 }
320+ currEndpointInfo?.let { acc + it } ? : acc
323321 }
324322 writer.writeInline(" )" )
323+ return endpointInfo
325324 }
326325
327- override fun visitInteger (value : Int ) {
326+ override fun visitInteger (value : Int ): EndpointInfo ? {
328327 writer.writeInline(" #L" , value)
328+ return null
329329 }
330330
331- override fun visitString (value : Template ) {
331+ override fun visitString (value : Template ): EndpointInfo ? {
332332 writer.writeInline(" \" " )
333- value.accept(this ).forEach {} // must "consume" the stream to actually generate everything
333+ val endpointInfo = value.accept(this )
334+ .collect(Collectors .toList())
335+ .fold(EndpointInfo .Empty ) { acc, curr ->
336+ curr?.let { acc + it } ? : acc
337+ }
334338 writer.writeInline(" \" " )
339+ return endpointInfo
335340 }
336341
337- override fun visitBoolean (value : Boolean ) {
342+ override fun visitBoolean (value : Boolean ): EndpointInfo ? {
338343 writer.writeInline(" #L" , value)
344+ return null
339345 }
340346
341- override fun visitRecord (value : MutableMap <Identifier , Literal >) {
347+ override fun visitRecord (value : MutableMap <Identifier , Literal >): EndpointInfo ? {
348+ var endpointInfo: EndpointInfo ? = null
342349 writer.withInlineBlock(" #T {" , " }" , RuntimeTypes .Core .Content .buildDocument) {
343- value.entries.forEachIndexed { index, (k, v) ->
350+ endpointInfo = value.entries.foldIndexed( EndpointInfo . Empty ) { index, acc , (k, v) ->
344351 writeInline(" #S to " , k.toString())
345- v.accept(this @ExpressionGenerator as LiteralVisitor <Unit >)
352+ val currInfo = v.accept(this @ExpressionGenerator as LiteralVisitor <EndpointInfo ? >)
346353 if (index < value.size - 1 ) write(" " )
354+ currInfo?.let { acc + it } ? : acc
347355 }
348356 }
357+ return endpointInfo
349358 }
350359
351- override fun visitTuple (value : MutableList <Literal >) {
360+ override fun visitTuple (value : MutableList <Literal >): EndpointInfo ? {
361+ var endpointInfo: EndpointInfo ? = null
352362 writer.withInlineBlock(" listOf(" , " )" ) {
353- value.forEachIndexed { index, it ->
354- it .accept(this @ExpressionGenerator as LiteralVisitor <Unit >)
363+ endpointInfo = value.foldIndexed( EndpointInfo . Empty ) { index, acc, curr ->
364+ val localInfo = curr .accept(this @ExpressionGenerator as LiteralVisitor <EndpointInfo ? >)
355365 if (index < value.size - 1 ) write(" ," ) else writeInline(" ," )
366+ localInfo?.let { acc + it } ? : acc
356367 }
357368 }
369+ return endpointInfo
358370 }
359371
360372 override fun visitStaticTemplate (value : String ) = writeTemplateString(value)
@@ -363,17 +375,19 @@ class ExpressionGenerator(
363375 override fun visitDynamicElement (value : Expression ) = writeTemplateExpression(value)
364376
365377 // no-ops for kotlin codegen
366- override fun startMultipartTemplate () {}
367- override fun finishMultipartTemplate () {}
378+ override fun startMultipartTemplate (): EndpointInfo ? = null
379+ override fun finishMultipartTemplate (): EndpointInfo ? = null
368380
369- private fun writeTemplateString (value : String ) {
381+ private fun writeTemplateString (value : String ): EndpointInfo ? {
370382 writer.writeInline(value.replace(" \" " , " \\\" " ))
383+ return null
371384 }
372385
373- private fun writeTemplateExpression (expr : Expression ) {
386+ private fun writeTemplateExpression (expr : Expression ): EndpointInfo ? {
374387 writer.writeInline(" \$ {" )
375- expr.accept(this )
388+ val endpointInfo = expr.accept(this )
376389 writer.writeInline(" }" )
390+ return endpointInfo
377391 }
378392
379393 private fun isParamRef (ref : Reference ): Boolean = rules.parameters.toList().any { it.name == ref.name }
0 commit comments