1212 SpecificShapeDeserializer ,
1313)
1414from smithy_core .schemas import Schema
15+ from smithy_core .shapes import ShapeType
1516from smithy_core .utils import expect_type
1617from smithy_event_stream .aio .interfaces import AsyncEventReceiver
1718
1819from ..events import HEADERS_DICT , Event
1920from ..exceptions import EventError , UnmodeledEventError
20- from . import INITIAL_REQUEST_EVENT_TYPE , INITIAL_RESPONSE_EVENT_TYPE
21- from .traits import EVENT_HEADER_TRAIT , EVENT_PAYLOAD_TRAIT
21+ from . import (
22+ INITIAL_REQUEST_EVENT_TYPE ,
23+ INITIAL_RESPONSE_EVENT_TYPE ,
24+ get_payload_member ,
25+ )
26+ from .traits import EVENT_HEADER_TRAIT
2227
2328INITIAL_MESSAGE_TYPES = (INITIAL_REQUEST_EVENT_TYPE , INITIAL_RESPONSE_EVENT_TYPE )
2429
@@ -69,25 +74,24 @@ def read_struct(
6974 ) -> None :
7075 headers = self ._event .message .headers
7176
72- payload_deserializer = None
73- if self ._event .message .payload :
74- payload_deserializer = self ._payload_codec .create_deserializer (
75- self ._event .message .payload
76- )
77-
78- message_deserializer = EventMessageDeserializer (headers , payload_deserializer )
79-
8077 match headers .get (":message-type" ):
8178 case "event" :
8279 member_name = expect_type (str , headers [":event-type" ])
8380 if member_name in INITIAL_MESSAGE_TYPES :
8481 # If it's an initial message, skip straight to deserialization.
82+ message_deserializer = self ._create_deserializer (schema , headers )
8583 message_deserializer .read_struct (schema , consumer )
8684 else :
87- consumer (schema .members [member_name ], message_deserializer )
85+ member_schema = schema .members [member_name ]
86+ message_deserializer = self ._create_deserializer (
87+ member_schema , headers
88+ )
89+ consumer (member_schema , message_deserializer )
8890 case "exception" :
8991 member_name = expect_type (str , headers [":exception-type" ])
90- consumer (schema .members [member_name ], message_deserializer )
92+ member_schema = schema .members [member_name ]
93+ message_deserializer = self ._create_deserializer (member_schema , headers )
94+ consumer (member_schema , message_deserializer )
9195 case "error" :
9296 # The `application/vnd.amazon.eventstream` format allows for explicitly
9397 # unmodeled exceptions. These exceptions MUST have the `:error-code`
@@ -99,13 +103,49 @@ def read_struct(
99103 case _:
100104 raise EventError (f"Unknown event structure: { self ._event } " )
101105
106+ def _create_deserializer (
107+ self , schema : Schema , headers : HEADERS_DICT
108+ ) -> ShapeDeserializer :
109+ payload_member = get_payload_member (schema )
110+ payload_deserializer = self ._create_payload_deserializer (payload_member )
111+ return EventMessageDeserializer (headers , payload_deserializer , payload_member )
112+
113+ def _create_payload_deserializer (
114+ self , payload_member : Schema | None
115+ ) -> ShapeDeserializer | None :
116+ if not self ._event .message .payload :
117+ return
118+
119+ if payload_member is not None and payload_member .shape_type in (
120+ ShapeType .BLOB ,
121+ ShapeType .STRING ,
122+ ):
123+ return RawPayloadDeserializer (self ._event .message .payload )
124+
125+ return self ._payload_codec .create_deserializer (self ._event .message .payload )
126+
127+
128+ class RawPayloadDeserializer (SpecificShapeDeserializer ):
129+ def __init__ (self , payload : bytes ) -> None :
130+ self ._payload = payload
131+
132+ def read_string (self , schema : Schema ) -> str :
133+ return self ._payload .decode ("utf-8" )
134+
135+ def read_blob (self , schema : Schema ) -> bytes :
136+ return self ._payload
137+
102138
103139class EventMessageDeserializer (SpecificShapeDeserializer ):
104140 def __init__ (
105- self , headers : HEADERS_DICT , payload_deserializer : ShapeDeserializer | None
141+ self ,
142+ headers : HEADERS_DICT ,
143+ payload_deserializer : ShapeDeserializer | None ,
144+ payload_member : Schema | None ,
106145 ) -> None :
107146 self ._headers = headers
108147 self ._payload_deserializer = payload_deserializer
148+ self ._payload_member = payload_member
109149
110150 def read_struct (
111151 self ,
@@ -119,17 +159,11 @@ def read_struct(
119159 consumer (member_schema , headers_deserializer )
120160
121161 if self ._payload_deserializer :
122- if ( payload_member := self ._get_payload_member ( schema )) is not None :
123- consumer (payload_member , self ._payload_deserializer )
162+ if self ._payload_member is not None :
163+ consumer (self . _payload_member , self ._payload_deserializer )
124164 else :
125165 self ._payload_deserializer .read_struct (schema , consumer )
126166
127- def _get_payload_member (self , schema : "Schema" ) -> "Schema | None" :
128- for member in schema .members .values ():
129- if EVENT_PAYLOAD_TRAIT in member .traits :
130- return member
131- return None
132-
133167
134168class EventHeaderDeserializer (SpecificShapeDeserializer ):
135169 def __init__ (self , headers : HEADERS_DICT ) -> None :
0 commit comments