@@ -52,7 +52,7 @@ class EventType(Enum):
5252 retry : Optional [int ]
5353
5454 def __str__ (self ) -> str :
55- if self .event == "output" :
55+ if self .event == ServerSentEvent . EventType . OUTPUT :
5656 return self .data
5757
5858 return ""
@@ -114,7 +114,7 @@ def decode(self, line: str) -> Optional[ServerSentEvent]:
114114 return None
115115
116116 fieldname , _ , value = line .partition (":" )
117- value = value .lstrip ( )
117+ value = value .removeprefix ( " " )
118118
119119 if fieldname == "event" :
120120 if event := ServerSentEvent .EventType (value ):
@@ -138,26 +138,28 @@ def __iter__(self) -> Iterator[ServerSentEvent]:
138138 line = line .rstrip ("\n " )
139139 sse = decoder .decode (line )
140140 if sse is not None :
141- if sse .event == "done" :
142- return
143- if sse .event == "error" :
141+ if sse .event == ServerSentEvent .EventType .ERROR :
144142 raise RuntimeError (sse .data )
145143
146144 yield sse
147145
146+ if sse .event == ServerSentEvent .EventType .DONE :
147+ return
148+
148149 async def __aiter__ (self ) -> AsyncIterator [ServerSentEvent ]:
149150 decoder = EventSource .Decoder ()
150151 async for line in self .response .aiter_lines ():
151152 line = line .rstrip ("\n " )
152153 sse = decoder .decode (line )
153154 if sse is not None :
154- if sse .event == "done" :
155- return
156- if sse .event == "error" :
155+ if sse .event == ServerSentEvent .EventType .ERROR :
157156 raise RuntimeError (sse .data )
158157
159158 yield sse
160159
160+ if sse .event == ServerSentEvent .EventType .DONE :
161+ return
162+
161163
162164def stream (
163165 client : "Client" ,
0 commit comments