|
19 | 19 | from dataclasses import dataclass |
20 | 20 | from pathlib import Path |
21 | 21 | from enum import Enum |
| 22 | +from ..utils.logging import warning |
22 | 23 | from ..utils.primitives import TPrimitive |
23 | 24 | from .status import HTTP_STATUS |
24 | 25 | from ..utils.io import DEFAULT_ENCODING, asWritable |
@@ -139,26 +140,58 @@ def __init__( |
139 | 140 |
|
140 | 141 |
|
141 | 142 | class HTTPBodyIO: |
142 | | - __slots__ = ["reader", "read", "expected", "remaining"] |
| 143 | + __slots__ = ["reader", "read", "expected", "remaining", "existing"] |
143 | 144 | """Represents a body that is loaded from a reader IO.""" |
144 | 145 |
|
145 | | - def __init__(self, reader: "HTTPBodyReader", expected: int | None = None): |
| 146 | + def __init__( |
| 147 | + self, |
| 148 | + reader: "HTTPBodyReader", |
| 149 | + expected: int | None = None, |
| 150 | + existing: bytes | None = None, |
| 151 | + ): |
146 | 152 | self.reader: HTTPBodyReader = reader |
147 | 153 | self.read: int = 0 |
148 | 154 | self.expected: int | None = expected |
149 | 155 | self.remaining: int | None = expected |
| 156 | + self.existing: bytes | None = existing |
150 | 157 |
|
151 | | - async def load( |
| 158 | + async def _read( |
152 | 159 | self, |
153 | 160 | ) -> bytes | None: |
154 | | - """Loads all the data and returns a list of bodies.""" |
155 | | - payload = await self.reader.load() |
156 | | - if payload: |
157 | | - n = len(payload) |
158 | | - self.read += n |
159 | | - if self.remaining is not None: |
| 161 | + """Reads the next available bytes""" |
| 162 | + if self.existing and self.read == 0: |
| 163 | + self.read += len(self.existing) |
| 164 | + return self.existing |
| 165 | + elif self.remaining: |
| 166 | + # FIXME: We should probably have a timeout there |
| 167 | + try: |
| 168 | + payload = await self.reader.load() |
| 169 | + except TimeoutError: |
| 170 | + warning( |
| 171 | + "Request body loading timed out", |
| 172 | + Remaining=self.remaining, |
| 173 | + Read=self.read, |
| 174 | + ) |
| 175 | + return None |
| 176 | + if payload: |
| 177 | + n = len(payload) |
| 178 | + self.read += n |
160 | 179 | self.remaining -= n |
161 | | - return payload |
| 180 | + return payload |
| 181 | + else: |
| 182 | + return None |
| 183 | + |
| 184 | + async def load(self) -> bytes: |
| 185 | + """Fully loads the body.""" |
| 186 | + res = bytearray() |
| 187 | + # FIXME: This would read other requests as well if there is no |
| 188 | + # remaining -- there should be at least a delimiter. |
| 189 | + while True: |
| 190 | + chunk = await self._read() |
| 191 | + if chunk: |
| 192 | + res += chunk |
| 193 | + else: |
| 194 | + return bytes(res) |
162 | 195 |
|
163 | 196 |
|
164 | 197 | class HTTPBodyBlob(NamedTuple): |
@@ -457,10 +490,17 @@ def contentType(self) -> str | None: |
457 | 490 |
|
458 | 491 | @property |
459 | 492 | def body(self) -> HTTPBodyIO | HTTPBodyBlob: |
| 493 | + b = self._body |
460 | 494 | if self._body is None: |
461 | 495 | if not self._reader: |
462 | 496 | raise RuntimeError("Request has no reader, can't read body") |
463 | 497 | self._body = HTTPBodyIO(self._reader) |
| 498 | + elif isinstance(self._body, HTTPBodyBlob) and self._body.remaining: |
| 499 | + if not self._reader: |
| 500 | + raise RuntimeError("Request has no reader, can't read body") |
| 501 | + self._body = HTTPBodyIO( |
| 502 | + self._reader, expected=self._body.remaining, existing=self._body.payload |
| 503 | + ) |
464 | 504 | return self._body |
465 | 505 |
|
466 | 506 | @property |
|
0 commit comments