Skip to content

Commit 5dfbbc4

Browse files
Implement __contains__ for Schema
1 parent 1753210 commit 5dfbbc4

File tree

5 files changed

+53
-12
lines changed

5 files changed

+53
-12
lines changed

packages/aws-event-stream/src/aws_event_stream/_private/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,6 @@
1010

1111
def get_payload_member(schema: Schema) -> Schema | None:
1212
for member in schema.members.values():
13-
if EventPayloadTrait.id in member.traits:
13+
if EventPayloadTrait in member:
1414
return member
1515
return None

packages/aws-event-stream/src/aws_event_stream/_private/deserializers.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -158,10 +158,7 @@ def read_struct(
158158
headers_deserializer = EventHeaderDeserializer(self._headers)
159159
for key in self._headers.keys():
160160
member_schema = schema.members.get(key)
161-
if (
162-
member_schema is not None
163-
and EventHeaderTrait.id in member_schema.traits
164-
):
161+
if member_schema is not None and EventHeaderTrait in member_schema:
165162
consumer(member_schema, headers_deserializer)
166163

167164
if self._payload_deserializer:

packages/aws-event-stream/src/aws_event_stream/_private/serializers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def begin_struct(self, schema: "Schema") -> Iterator[ShapeSerializer]:
102102

103103
headers_encoder = EventHeaderEncoder()
104104

105-
if ErrorTrait.id in schema.traits:
105+
if ErrorTrait in schema:
106106
headers_encoder.encode_string(":message-type", "exception")
107107
headers_encoder.encode_string(
108108
":exception-type", schema.expect_member_name()
@@ -214,7 +214,7 @@ def __init__(
214214
self._payload_struct_serializer = payload_struct_serializer
215215

216216
def before(self, schema: "Schema") -> ShapeSerializer:
217-
if EventHeaderTrait.id in schema.traits:
217+
if EventHeaderTrait in schema:
218218
return self._header_serializer
219219
return self._payload_struct_serializer
220220

packages/smithy-core/src/smithy_core/schemas.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,11 @@
22
# SPDX-License-Identifier: Apache-2.0
33
from collections.abc import Mapping
44
from dataclasses import dataclass, field, replace
5-
from typing import TYPE_CHECKING, NotRequired, Required, Self, TypedDict, overload
5+
from typing import NotRequired, Required, Self, TypedDict, overload, Any
66

77
from .exceptions import ExpectationNotMetException, SmithyException
88
from .shapes import ShapeID, ShapeType
9-
10-
if TYPE_CHECKING:
11-
from .traits import Trait, DynamicTrait
9+
from .traits import Trait, DynamicTrait
1210

1311

1412
@dataclass(kw_only=True, frozen=True, init=False)
@@ -164,6 +162,24 @@ def expect_trait(self, t: "type[Trait] | ShapeID") -> "Trait | DynamicTrait":
164162
id = t if isinstance(t, ShapeID) else t.id
165163
return self.traits[id]
166164

165+
def __contains__(self, item: Any):
166+
"""Returns whether the schema has the given member or trait."""
167+
match item:
168+
case type():
169+
if issubclass(item, Trait):
170+
return item.id in self.traits
171+
return False
172+
case ShapeID():
173+
if (member := item.member) is not None:
174+
if self.id.with_member(member) == item:
175+
return member in self.members
176+
return False
177+
return item in self.traits
178+
case str():
179+
return item in self.members
180+
case _:
181+
return False
182+
167183
@classmethod
168184
def collection(
169185
cls,

packages/smithy-core/tests/unit/test_schemas.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,16 @@
22

33
import pytest
44

5+
from typing import Any
6+
57
from smithy_core.exceptions import ExpectationNotMetException
68
from smithy_core.schemas import Schema
79
from smithy_core.shapes import ShapeID, ShapeType
8-
from smithy_core.traits import InternalTrait, DynamicTrait, SensitiveTrait
10+
from smithy_core.traits import (
11+
InternalTrait,
12+
DynamicTrait,
13+
SensitiveTrait,
14+
)
915

1016
ID: ShapeID = ShapeID("ns.foo#bar")
1117
STRING = Schema(id=ShapeID("smithy.api#String"), shape_type=ShapeType.STRING)
@@ -143,3 +149,25 @@ def test_member_constructor_asserts_target_is_not_member():
143149
)
144150
with pytest.raises(ExpectationNotMetException):
145151
Schema.member(id=ShapeID("smithy.example#Foo$bar"), target=target, index=0)
152+
153+
154+
@pytest.mark.parametrize(
155+
"item, contains",
156+
[
157+
(SensitiveTrait, True),
158+
(SensitiveTrait.id, True),
159+
(InternalTrait, False),
160+
(InternalTrait.id, False),
161+
("baz", True),
162+
(ID.with_member("baz"), True),
163+
(ID, False),
164+
],
165+
)
166+
def test_contains(item: Any, contains: bool):
167+
schema = Schema.collection(
168+
id=ID,
169+
members={"baz": {"target": STRING, "index": 0}},
170+
traits=[SensitiveTrait()],
171+
)
172+
173+
assert (item in schema) == contains

0 commit comments

Comments
 (0)