Skip to content

Commit 9c5053c

Browse files
committed
checkpoint 2
1 parent cb42102 commit 9c5053c

File tree

9 files changed

+190
-11
lines changed

9 files changed

+190
-11
lines changed

codegen/core/src/main/java/software/amazon/smithy/python/codegen/ClientGenerator.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import software.amazon.smithy.python.codegen.sections.SendRequestSection;
2626
import software.amazon.smithy.python.codegen.sections.SignRequestSection;
2727
import software.amazon.smithy.python.codegen.writer.PythonWriter;
28+
import software.amazon.smithy.utils.CaseUtils;
2829
import software.amazon.smithy.utils.SmithyInternalApi;
2930

3031
/**
@@ -705,6 +706,7 @@ private void writeDefaultPlugins(PythonWriter writer, Collection<SymbolReference
705706
*/
706707
private void generateOperation(PythonWriter writer, OperationShape operation) {
707708
var operationSymbol = context.symbolProvider().toSymbol(operation);
709+
var operationMethodSymbol = operationSymbol.expectProperty(SymbolProperties.OPERATION_METHOD);
708710
var pluginSymbol = CodegenUtils.getPluginSymbol(context.settings());
709711

710712
var input = context.model().expectShape(operation.getInputShape());
@@ -715,7 +717,7 @@ private void generateOperation(PythonWriter writer, OperationShape operation) {
715717

716718
writer.openBlock("async def $L(self, input: $T, plugins: list[$T] | None = None) -> $T:",
717719
"",
718-
operationSymbol.getName(),
720+
operationMethodSymbol.getName(),
719721
inputSymbol,
720722
pluginSymbol,
721723
outputSymbol,

codegen/core/src/main/java/software/amazon/smithy/python/codegen/DirectedPythonClientCodegen.java

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import software.amazon.smithy.codegen.core.directed.GenerateIntEnumDirective;
2222
import software.amazon.smithy.codegen.core.directed.GenerateListDirective;
2323
import software.amazon.smithy.codegen.core.directed.GenerateMapDirective;
24+
import software.amazon.smithy.codegen.core.directed.GenerateOperationDirective;
2425
import software.amazon.smithy.codegen.core.directed.GenerateServiceDirective;
2526
import software.amazon.smithy.codegen.core.directed.GenerateStructureDirective;
2627
import software.amazon.smithy.codegen.core.directed.GenerateUnionDirective;
@@ -35,6 +36,7 @@
3536
import software.amazon.smithy.python.codegen.generators.IntEnumGenerator;
3637
import software.amazon.smithy.python.codegen.generators.ListGenerator;
3738
import software.amazon.smithy.python.codegen.generators.MapGenerator;
39+
import software.amazon.smithy.python.codegen.generators.OperationGenerator;
3840
import software.amazon.smithy.python.codegen.generators.ProtocolGenerator;
3941
import software.amazon.smithy.python.codegen.generators.SchemaGenerator;
4042
import software.amazon.smithy.python.codegen.generators.ServiceErrorGenerator;
@@ -133,6 +135,19 @@ public void generateService(GenerateServiceDirective<GenerationContext, PythonSe
133135
protocolGenerator.generateProtocolTests(directive.context());
134136
}
135137

138+
@Override
139+
public void generateOperation(GenerateOperationDirective<GenerationContext, PythonSettings> directive) {
140+
DirectedCodegen.super.generateOperation(directive);
141+
142+
directive.context().writerDelegator().useShapeWriter(directive.shape(), writer -> {
143+
OperationGenerator generator = new OperationGenerator(
144+
directive.context(),
145+
writer,
146+
directive.shape());
147+
generator.run();
148+
});
149+
}
150+
136151
@Override
137152
public void generateStructure(GenerateStructureDirective<GenerationContext, PythonSettings> directive) {
138153
directive.context().writerDelegator().useShapeWriter(directive.shape(), writer -> {

codegen/core/src/main/java/software/amazon/smithy/python/codegen/PythonSymbolProvider.java

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -276,8 +276,14 @@ public Symbol bigDecimalShape(BigDecimalShape shape) {
276276
public Symbol operationShape(OperationShape shape) {
277277
// Operation names are escaped like members because ultimately they're
278278
// properties on an object too.
279-
var name = escaper.escapeMemberName(CaseUtils.toSnakeCase(shape.getId().getName(service)));
280-
return createGeneratedSymbolBuilder(shape, name, "client").build();
279+
var methodName = escaper.escapeMemberName(CaseUtils.toSnakeCase(shape.getId().getName(service)));
280+
var methodSymbol = createGeneratedSymbolBuilder(shape, methodName, "client").build();
281+
282+
// We add a symbol for the method in the client as a property, whereas the actual
283+
// operation symbol points to the generated type for it
284+
return createGeneratedSymbolBuilder(shape, getDefaultShapeName(shape), SHAPES_FILE)
285+
.putProperty(SymbolProperties.OPERATION_METHOD, methodSymbol)
286+
.build();
281287
}
282288

283289
@Override

codegen/core/src/main/java/software/amazon/smithy/python/codegen/SymbolProperties.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,5 +73,11 @@ public final class SymbolProperties {
7373
*/
7474
public static final Property<Symbol> DESERIALIZER = Property.named("deserializer");
7575

76+
/**
77+
* Contains a symbol pointing to an operation shape's method in the client. This is
78+
* only used for operations.
79+
*/
80+
public static final Property<Symbol> OPERATION_METHOD = Property.named("operationMethod");
81+
7682
private SymbolProperties() {}
7783
}
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
/*
2+
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
package software.amazon.smithy.python.codegen.generators;
6+
7+
import software.amazon.smithy.codegen.core.Symbol;
8+
import software.amazon.smithy.codegen.core.SymbolProvider;
9+
import software.amazon.smithy.model.Model;
10+
import software.amazon.smithy.model.knowledge.ServiceIndex;
11+
import software.amazon.smithy.model.shapes.OperationShape;
12+
import software.amazon.smithy.model.shapes.ServiceShape;
13+
import software.amazon.smithy.model.shapes.ShapeId;
14+
import software.amazon.smithy.python.codegen.CodegenUtils;
15+
import software.amazon.smithy.python.codegen.GenerationContext;
16+
import software.amazon.smithy.python.codegen.SymbolProperties;
17+
import software.amazon.smithy.python.codegen.writer.PythonWriter;
18+
import software.amazon.smithy.utils.SmithyInternalApi;
19+
20+
import java.util.List;
21+
import java.util.logging.Logger;
22+
23+
@SmithyInternalApi
24+
public final class OperationGenerator implements Runnable {
25+
private static final Logger LOGGER = Logger.getLogger(OperationGenerator.class.getName());
26+
27+
private final GenerationContext context;
28+
private final PythonWriter writer;
29+
private final OperationShape shape;
30+
private final SymbolProvider symbolProvider;
31+
private final Model model;
32+
33+
34+
public OperationGenerator(GenerationContext context, PythonWriter writer, OperationShape shape) {
35+
this.context = context;
36+
this.writer = writer;
37+
this.shape = shape;
38+
this.symbolProvider = context.symbolProvider();
39+
this.model = context.model();
40+
}
41+
42+
@Override
43+
public void run() {
44+
var opSymbol = symbolProvider.toSymbol(shape);
45+
var inSymbol = symbolProvider.toSymbol(model.expectShape(shape.getInputShape()));
46+
var outSymbol = symbolProvider.toSymbol(model.expectShape(shape.getOutputShape()));
47+
48+
writer.addStdlibImport("dataclasses", "dataclass");
49+
writer.addImport("smithy_core.schemas", "APIOperation");
50+
writer.addImport("smithy_core.type_registry", "TypeRegistry");
51+
52+
writer.write("""
53+
@dataclass(kw_only=True, frozen=True)
54+
class $1L(APIOperation["$2T", "$3T"]):
55+
input = $2T
56+
output = $3T
57+
schema = $4T
58+
input_schema = $5T
59+
output_schema = $6T
60+
error_registry = TypeRegistry({
61+
$7C
62+
})
63+
effective_auth_schemes = [
64+
$8C
65+
]
66+
""",
67+
opSymbol.getName(),
68+
inSymbol,
69+
outSymbol,
70+
opSymbol.expectProperty(SymbolProperties.SCHEMA),
71+
inSymbol.expectProperty(SymbolProperties.SCHEMA),
72+
outSymbol.expectProperty(SymbolProperties.SCHEMA),
73+
writer.consumer(this::writeErrorTypeRegistry),
74+
writer.consumer(this::writeAuthSchemes)
75+
// TODO: Docs? Maybe not necessary on the operation type itself
76+
// TODO: Singleton?
77+
);
78+
}
79+
80+
private void writeErrorTypeRegistry(PythonWriter writer) {
81+
List<ShapeId> errors = shape.getErrors();
82+
if (!errors.isEmpty()) {
83+
writer.addImport("smithy_core.shapes", "ShapeID");
84+
}
85+
for (var error : errors) {
86+
var errSymbol = symbolProvider.toSymbol(model.expectShape(error));
87+
writer.write("ShapeID($S): $T,", error, errSymbol);
88+
}
89+
}
90+
91+
private void writeAuthSchemes(PythonWriter writer) {
92+
var authSchemes = ServiceIndex.of(model).getEffectiveAuthSchemes(context.settings().service(), shape.getId(),
93+
ServiceIndex.AuthSchemeMode.NO_AUTH_AWARE);
94+
95+
if (!authSchemes.isEmpty()) {
96+
writer.addImport("smithy_core.shapes", "ShapeID");
97+
}
98+
99+
for(var authSchemeId : authSchemes.keySet()) {
100+
writer.write("ShapeID($S)", authSchemeId);
101+
}
102+
103+
}
104+
}

packages/smithy-core/src/smithy_core/documents.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -146,8 +146,7 @@ def shape_type(self) -> ShapeType:
146146
@property
147147
def discriminator(self) -> ShapeID:
148148
"""The shape ID that corresponds to the contents of the document."""
149-
# TODO: custom exception?
150-
raise NotImplementedError(f"{self} document has no discriminator.")
149+
return self._schema.id
151150

152151
def is_none(self) -> bool:
153152
"""Indicates whether the document contains a null value."""

packages/smithy-core/src/smithy_core/schemas.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from .exceptions import ExpectationNotMetException, SmithyException
88
from .shapes import ShapeID, ShapeType
99
from .traits import Trait, DynamicTrait, IdempotencyTokenTrait, StreamingTrait
10-
10+
from .type_registry import TypeRegistry
1111

1212
if TYPE_CHECKING:
1313
from .serializers import SerializeableShape
@@ -289,8 +289,7 @@ class APIOperation[I: "SerializeableShape", O: "DeserializeableShape"]:
289289
output_schema: Schema
290290
"""The schema of the operation's output shape."""
291291

292-
# TODO: Add a type registry for errors
293-
error_registry: Any
292+
error_registry: TypeRegistry
294293
"""A TypeRegistry used to create errors."""
295294

296295
effective_auth_schemes: Sequence[ShapeID]

packages/smithy-core/src/smithy_core/type_registry.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,18 +13,18 @@
1313
class TypeRegistry:
1414
def __init__(
1515
self,
16-
types: dict[ShapeID, DeserializeableShape],
16+
types: dict[ShapeID, type[DeserializeableShape]],
1717
sub_registry: "TypeRegistry | None" = None,
1818
):
1919
self._types = types
2020
self._sub_registry = sub_registry
2121

2222
def get(self, shape: ShapeID) -> type[DeserializeableShape]:
2323
if shape in self._types:
24-
return type(self._types[shape])
24+
return self._types[shape]
2525
if self._sub_registry is not None:
2626
return self._sub_registry.get(shape)
27-
raise ValueError(f"Unknown shape: {shape}") # TODO: real exception?
27+
raise KeyError(f"Unknown shape: {shape}")
2828

2929
def deserialize(self, document: Document) -> DeserializeableShape:
3030
return document.as_shape(self.get(document.discriminator))
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
from smithy_core.deserializers import DeserializeableShape, ShapeDeserializer
2+
from smithy_core.documents import Document
3+
from smithy_core.schemas import Schema
4+
from smithy_core.shapes import ShapeID, ShapeType
5+
from smithy_core.type_registry import TypeRegistry
6+
import pytest
7+
8+
9+
class TestTypeRegistry:
10+
def test_get(self):
11+
registry = TypeRegistry({ShapeID("com.example#Test"): TestShape})
12+
13+
result = registry.get(ShapeID("com.example#Test"))
14+
15+
assert result == TestShape
16+
17+
def test_get_sub_registry(self):
18+
sub_registry = TypeRegistry({ShapeID("com.example#Test"): TestShape})
19+
registry = TypeRegistry({}, sub_registry)
20+
21+
result = registry.get(ShapeID("com.example#Test"))
22+
23+
assert result == TestShape
24+
25+
def test_get_no_match(self):
26+
registry = TypeRegistry({ShapeID("com.example#Test"): TestShape})
27+
28+
with pytest.raises(KeyError, match="Unknown shape: com.example#Test2"):
29+
registry.get(ShapeID("com.example#Test2"))
30+
31+
def test_deserialize(self):
32+
shape_id = ShapeID("com.example#Test")
33+
registry = TypeRegistry({shape_id: TestShape})
34+
35+
result = registry.deserialize(Document("abc123", schema=TestShape.schema))
36+
37+
assert isinstance(result, TestShape) and result.value == "abc123"
38+
39+
40+
class TestShape(DeserializeableShape):
41+
schema = Schema(id=ShapeID("com.example#Test"), shape_type=ShapeType.STRING)
42+
43+
def __init__(self, value: str):
44+
self.value = value
45+
46+
@classmethod
47+
def deserialize(cls, deserializer: ShapeDeserializer) -> "TestShape":
48+
return TestShape(deserializer.read_string(schema=TestShape.schema))

0 commit comments

Comments
 (0)