Skip to content

Commit 97709d7

Browse files
lauzadis0marperez
andauthored
fix: remove global booleans previously used for emitting business metrics (#1104)
Co-authored-by: 0marperez <[email protected]>
1 parent 72bc814 commit 97709d7

File tree

3 files changed

+176
-68
lines changed

3 files changed

+176
-68
lines changed

codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/endpoints/DefaultEndpointProviderGenerator.kt

Lines changed: 79 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import software.amazon.smithy.rulesengine.language.syntax.rule.EndpointRule
2424
import software.amazon.smithy.rulesengine.language.syntax.rule.ErrorRule
2525
import software.amazon.smithy.rulesengine.language.syntax.rule.Rule
2626
import software.amazon.smithy.rulesengine.language.syntax.rule.TreeRule
27+
import java.util.stream.Collectors
2728

2829
/**
2930
* The core set of standard library functions available to the rules language.
@@ -49,24 +50,9 @@ typealias EndpointPropertyRenderer = (KotlinWriter, Expression, ExpressionRender
4950
* An expression renderer generates code for an endpoint expression construct.
5051
*/
5152
fun interface ExpressionRenderer {
52-
fun renderExpression(expr: Expression)
53+
fun renderExpression(expr: Expression): EndpointInfo
5354
}
5455

55-
/**
56-
* Will be toggled to true if it is determined an endpoint is account ID based then to false again
57-
*/
58-
private var hasAccountIdBasedEndpoint = false
59-
60-
/**
61-
* Will be toggled to true if determined an endpoint comes from a service endpoint override then to false again
62-
*/
63-
private var hasServiceEndpointOverride = false
64-
65-
/**
66-
* Will be toggled to true when rendering an endpoint URL then to false again
67-
*/
68-
private var renderingEndpointUrl = false
69-
7056
/**
7157
* Renders the default endpoint provider based on the provided rule set.
7258
*/
@@ -121,9 +107,7 @@ class DefaultEndpointProviderGenerator(
121107
}
122108
}
123109

124-
override fun renderExpression(expr: Expression) {
125-
expr.accept(expressionGenerator)
126-
}
110+
override fun renderExpression(expr: Expression): EndpointInfo = expr.accept(expressionGenerator) ?: EndpointInfo.Empty
127111

128112
private fun renderDocumentation() {
129113
writer.dokka {
@@ -185,11 +169,11 @@ class DefaultEndpointProviderGenerator(
185169
withConditions(rule.conditions) {
186170
writer.withBlock("return #T(", ")", RuntimeTypes.SmithyClient.Endpoints.Endpoint) {
187171
writeInline("#T.parse(", RuntimeTypes.Core.Net.Url.Url)
188-
renderingEndpointUrl = true
189-
renderExpression(rule.endpoint.url)
190-
renderingEndpointUrl = false
172+
val endpointInfo = renderExpression(rule.endpoint.url)
191173
write("),")
192174

175+
val hasAccountIdBasedEndpoint = "accountId" in endpointInfo.params
176+
val hasServiceEndpointOverride = "endpoint" in endpointInfo.params
193177
val needAdditionalEndpointProperties = hasAccountIdBasedEndpoint || hasServiceEndpointOverride
194178

195179
if (rule.endpoint.headers.isNotEmpty()) {
@@ -226,11 +210,9 @@ class DefaultEndpointProviderGenerator(
226210

227211
if (hasAccountIdBasedEndpoint) {
228212
writer.write("#T to params.accountId", RuntimeTypes.Core.BusinessMetrics.AccountIdBasedEndpointAccountId)
229-
hasAccountIdBasedEndpoint = false
230213
}
231214
if (hasServiceEndpointOverride) {
232215
writer.write("#T to true", RuntimeTypes.Core.BusinessMetrics.ServiceEndpointOverride)
233-
hasServiceEndpointOverride = false
234216
}
235217
}
236218
}
@@ -253,16 +235,24 @@ class DefaultEndpointProviderGenerator(
253235
}
254236
}
255237

238+
data class EndpointInfo(val params: MutableSet<String>) {
239+
companion object {
240+
val Empty = EndpointInfo(params = mutableSetOf())
241+
}
242+
243+
operator fun plus(other: EndpointInfo) = EndpointInfo(
244+
params = (this.params + other.params).toMutableSet(),
245+
)
246+
}
247+
256248
class ExpressionGenerator(
257249
private val writer: KotlinWriter,
258250
private val rules: EndpointRuleSet,
259251
private val functions: Map<String, Symbol>,
260-
) : ExpressionVisitor<Unit>, LiteralVisitor<Unit>, TemplateVisitor<Unit> {
261-
override fun visitLiteral(literal: Literal) {
262-
literal.accept(this as LiteralVisitor<Unit>)
263-
}
252+
) : ExpressionVisitor<EndpointInfo?>, LiteralVisitor<EndpointInfo?>, TemplateVisitor<EndpointInfo?> {
253+
override fun visitLiteral(literal: Literal): EndpointInfo? = literal.accept(this as LiteralVisitor<EndpointInfo?>)
264254

265-
override fun visitRef(reference: Reference) {
255+
override fun visitRef(reference: Reference): EndpointInfo {
266256
val referenceName = reference.name.defaultName()
267257
val isParamReference = isParamRef(reference)
268258

@@ -271,90 +261,112 @@ class ExpressionGenerator(
271261
}
272262
writer.writeInline(referenceName)
273263

274-
if (renderingEndpointUrl) {
275-
if (isParamReference && referenceName == "accountId") hasAccountIdBasedEndpoint = true
276-
if (isParamReference && referenceName == "endpoint") hasServiceEndpointOverride = true
264+
return if (isParamReference) {
265+
EndpointInfo(params = mutableSetOf(referenceName))
266+
} else {
267+
EndpointInfo.Empty
277268
}
278269
}
279270

280-
override fun visitGetAttr(getAttr: GetAttr) {
281-
getAttr.target.accept(this)
271+
override fun visitGetAttr(getAttr: GetAttr): EndpointInfo? {
272+
val endpointInfo = getAttr.target.accept(this)
282273
getAttr.path.forEach {
283274
when (it) {
284275
is GetAttr.Part.Key -> writer.writeInline("?.#L", it.key().toString())
285276
is GetAttr.Part.Index -> writer.writeInline("?.getOrNull(#L)", it.index())
286277
else -> throw CodegenException("unexpected path")
287278
}
288279
}
280+
return endpointInfo
289281
}
290282

291-
override fun visitIsSet(target: Expression) {
292-
target.accept(this)
283+
override fun visitIsSet(target: Expression): EndpointInfo? {
284+
val endpointInfo = target.accept(this)
293285
writer.writeInline(" != null")
286+
return endpointInfo
294287
}
295288

296-
override fun visitNot(target: Expression) {
289+
override fun visitNot(target: Expression): EndpointInfo? {
297290
writer.writeInline("!(")
298-
target.accept(this)
291+
val endpointInfo = target.accept(this)
299292
writer.writeInline(")")
293+
return endpointInfo
300294
}
301295

302-
override fun visitBoolEquals(left: Expression, right: Expression) {
303-
visitEquals(left, right)
304-
}
296+
override fun visitBoolEquals(left: Expression, right: Expression): EndpointInfo? = visitEquals(left, right)
305297

306-
override fun visitStringEquals(left: Expression, right: Expression) {
307-
visitEquals(left, right)
308-
}
298+
override fun visitStringEquals(left: Expression, right: Expression): EndpointInfo? = visitEquals(left, right)
309299

310-
private fun visitEquals(left: Expression, right: Expression) {
311-
left.accept(this)
300+
private fun visitEquals(left: Expression, right: Expression): EndpointInfo? {
301+
val leftEndpointInfo = left.accept(this)
312302
writer.writeInline(" == ")
313-
right.accept(this)
303+
val rightEndpointInfo = right.accept(this)
304+
305+
return when {
306+
leftEndpointInfo != null && rightEndpointInfo != null -> leftEndpointInfo + rightEndpointInfo
307+
leftEndpointInfo != null -> leftEndpointInfo
308+
else -> rightEndpointInfo
309+
}
314310
}
315311

316-
override fun visitLibraryFunction(fn: FunctionDefinition, args: MutableList<Expression>) {
312+
override fun visitLibraryFunction(fn: FunctionDefinition, args: MutableList<Expression>): EndpointInfo? {
317313
writer.writeInline("#T(", functions.getValue(fn.id))
318-
args.forEachIndexed { index, it ->
319-
it.accept(this)
314+
315+
val endpointInfo = args.foldIndexed(EndpointInfo.Empty) { index, acc, curr ->
316+
val currEndpointInfo = curr.accept(this)
320317
if (index < args.lastIndex) {
321318
writer.writeInline(", ")
322319
}
320+
currEndpointInfo?.let { acc + it } ?: acc
323321
}
324322
writer.writeInline(")")
323+
return endpointInfo
325324
}
326325

327-
override fun visitInteger(value: Int) {
326+
override fun visitInteger(value: Int): EndpointInfo? {
328327
writer.writeInline("#L", value)
328+
return null
329329
}
330330

331-
override fun visitString(value: Template) {
331+
override fun visitString(value: Template): EndpointInfo? {
332332
writer.writeInline("\"")
333-
value.accept(this).forEach {} // must "consume" the stream to actually generate everything
333+
val endpointInfo = value.accept(this)
334+
.collect(Collectors.toList())
335+
.fold(EndpointInfo.Empty) { acc, curr ->
336+
curr?.let { acc + it } ?: acc
337+
}
334338
writer.writeInline("\"")
339+
return endpointInfo
335340
}
336341

337-
override fun visitBoolean(value: Boolean) {
342+
override fun visitBoolean(value: Boolean): EndpointInfo? {
338343
writer.writeInline("#L", value)
344+
return null
339345
}
340346

341-
override fun visitRecord(value: MutableMap<Identifier, Literal>) {
347+
override fun visitRecord(value: MutableMap<Identifier, Literal>): EndpointInfo? {
348+
var endpointInfo: EndpointInfo? = null
342349
writer.withInlineBlock("#T {", "}", RuntimeTypes.Core.Content.buildDocument) {
343-
value.entries.forEachIndexed { index, (k, v) ->
350+
endpointInfo = value.entries.foldIndexed(EndpointInfo.Empty) { index, acc, (k, v) ->
344351
writeInline("#S to ", k.toString())
345-
v.accept(this@ExpressionGenerator as LiteralVisitor<Unit>)
352+
val currInfo = v.accept(this@ExpressionGenerator as LiteralVisitor<EndpointInfo?>)
346353
if (index < value.size - 1) write("")
354+
currInfo?.let { acc + it } ?: acc
347355
}
348356
}
357+
return endpointInfo
349358
}
350359

351-
override fun visitTuple(value: MutableList<Literal>) {
360+
override fun visitTuple(value: MutableList<Literal>): EndpointInfo? {
361+
var endpointInfo: EndpointInfo? = null
352362
writer.withInlineBlock("listOf(", ")") {
353-
value.forEachIndexed { index, it ->
354-
it.accept(this@ExpressionGenerator as LiteralVisitor<Unit>)
363+
endpointInfo = value.foldIndexed(EndpointInfo.Empty) { index, acc, curr ->
364+
val localInfo = curr.accept(this@ExpressionGenerator as LiteralVisitor<EndpointInfo?>)
355365
if (index < value.size - 1) write(",") else writeInline(",")
366+
localInfo?.let { acc + it } ?: acc
356367
}
357368
}
369+
return endpointInfo
358370
}
359371

360372
override fun visitStaticTemplate(value: String) = writeTemplateString(value)
@@ -363,17 +375,19 @@ class ExpressionGenerator(
363375
override fun visitDynamicElement(value: Expression) = writeTemplateExpression(value)
364376

365377
// no-ops for kotlin codegen
366-
override fun startMultipartTemplate() {}
367-
override fun finishMultipartTemplate() {}
378+
override fun startMultipartTemplate(): EndpointInfo? = null
379+
override fun finishMultipartTemplate(): EndpointInfo? = null
368380

369-
private fun writeTemplateString(value: String) {
381+
private fun writeTemplateString(value: String): EndpointInfo? {
370382
writer.writeInline(value.replace("\"", "\\\""))
383+
return null
371384
}
372385

373-
private fun writeTemplateExpression(expr: Expression) {
386+
private fun writeTemplateExpression(expr: Expression): EndpointInfo? {
374387
writer.writeInline("\${")
375-
expr.accept(this)
388+
val endpointInfo = expr.accept(this)
376389
writer.writeInline("}")
390+
return endpointInfo
377391
}
378392

379393
private fun isParamRef(ref: Reference): Boolean = rules.parameters.toList().any { it.name == ref.name }

codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/endpoints/DefaultEndpointProviderTestGenerator.kt

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -101,9 +101,7 @@ class DefaultEndpointProviderTestGenerator(
101101
}
102102
}
103103

104-
override fun renderExpression(expr: Expression) {
105-
expr.accept(expressionGenerator)
106-
}
104+
override fun renderExpression(expr: Expression): EndpointInfo = expr.accept(expressionGenerator) ?: EndpointInfo.Empty
107105

108106
private fun renderTestCase(index: Int, case: EndpointTestCase) {
109107
case.documentation.ifPresent {

codegen/smithy-kotlin-codegen/src/test/kotlin/software/amazon/smithy/kotlin/codegen/rendering/endpoints/DefaultEndpointProviderGeneratorTest.kt

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,14 @@ class DefaultEndpointProviderGeneratorTest {
3131
"QuxName": {
3232
"type": "string",
3333
"required": false
34+
},
35+
"accountId": {
36+
"type": "string",
37+
"required": false
38+
},
39+
"endpoint": {
40+
"type": "string",
41+
"required": false
3442
}
3543
},
3644
"rules": [
@@ -118,6 +126,57 @@ class DefaultEndpointProviderGeneratorTest {
118126
"fooheader": ["barheader"]
119127
}
120128
}
129+
},
130+
{
131+
"documentation": "account id based endpoint and service endpoint override",
132+
"type": "endpoint",
133+
"conditions": [
134+
{
135+
"fn": "isSet",
136+
"argv": [
137+
{"ref": "accountId"}
138+
]
139+
},
140+
{
141+
"fn": "isSet",
142+
"argv": [
143+
{"ref": "endpoint"}
144+
]
145+
}
146+
],
147+
"endpoint": {
148+
"url": "https://{accountId}.{endpoint}"
149+
}
150+
},
151+
{
152+
"documentation": "service endpoint override",
153+
"type": "endpoint",
154+
"conditions": [
155+
{
156+
"fn": "isSet",
157+
"argv": [
158+
{"ref": "endpoint"}
159+
]
160+
}
161+
],
162+
"endpoint": {
163+
"url": "https://{endpoint}"
164+
}
165+
},
166+
{
167+
"documentation": "account id based endpoint",
168+
"type": "endpoint",
169+
"conditions": [
170+
{
171+
"fn": "isSet",
172+
"argv": [
173+
{"ref": "accountId"}
174+
]
175+
}
176+
],
177+
"endpoint": {
178+
"url": "https://{accountId}"
179+
}
121180
}
122181
]
123182
}
@@ -216,4 +275,41 @@ class DefaultEndpointProviderGeneratorTest {
216275
""".formatForTest(indent = " ")
217276
generatedClass.shouldContainOnlyOnceWithDiff(expected)
218277
}
278+
279+
@Test
280+
fun testBusinessMetrics() {
281+
val moneySign = "$"
282+
283+
val accountIdAndEndpoint = """
284+
return Endpoint(
285+
Url.parse("https://$moneySign{params.accountId}.$moneySign{params.endpoint}"),
286+
attributes = attributesOf {
287+
AccountIdBasedEndpointAccountId to params.accountId
288+
ServiceEndpointOverride to true
289+
},
290+
)
291+
"""
292+
293+
val accountId = """
294+
return Endpoint(
295+
Url.parse("https://$moneySign{params.accountId}"),
296+
attributes = attributesOf {
297+
AccountIdBasedEndpointAccountId to params.accountId
298+
},
299+
)
300+
"""
301+
302+
val endpoint = """
303+
return Endpoint(
304+
Url.parse("https://$moneySign{params.endpoint}"),
305+
attributes = attributesOf {
306+
ServiceEndpointOverride to true
307+
},
308+
)
309+
"""
310+
311+
generatedClass.shouldContainOnlyOnceWithDiff(accountIdAndEndpoint)
312+
generatedClass.shouldContainOnlyOnceWithDiff(accountId)
313+
generatedClass.shouldContainOnlyOnceWithDiff(endpoint)
314+
}
219315
}

0 commit comments

Comments
 (0)