Skip to content

Commit d30ef74

Browse files
eulersIDcrisisAaron Gibson
andauthored
Permit streaming_callback of AsyncHTTPClient to be a coroutine. (#3471)
Co-authored-by: Aaron Gibson <[email protected]>
1 parent eb64d13 commit d30ef74

File tree

5 files changed

+56
-6
lines changed

5 files changed

+56
-6
lines changed

tornado/curl_httpclient.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,10 @@
2222
import re
2323
import threading
2424
import time
25+
import inspect
2526
from io import BytesIO
2627

28+
from tornado import gen
2729
from tornado import httputil
2830
from tornado import ioloop
2931

@@ -368,6 +370,13 @@ def _curl_setup_request(
368370
)
369371
if request.streaming_callback:
370372

373+
if gen.is_coroutine_function(
374+
request.streaming_callback
375+
) or inspect.iscoroutinefunction(request.streaming_callback):
376+
raise TypeError(
377+
"'CurlAsyncHTTPClient' does not support async streaming_callbacks."
378+
)
379+
371380
def write_function(b: Union[bytes, bytearray]) -> int:
372381
assert request.streaming_callback is not None
373382
self.io_loop.add_callback(request.streaming_callback, b)

tornado/httpclient.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@
5353
from tornado.ioloop import IOLoop
5454
from tornado.util import Configurable
5555

56-
from typing import Type, Any, Union, Dict, Callable, Optional, cast
56+
from typing import Type, Any, Union, Dict, Callable, Optional, Awaitable, cast
5757

5858

5959
class HTTPClient:
@@ -372,7 +372,9 @@ def __init__(
372372
user_agent: Optional[str] = None,
373373
use_gzip: Optional[bool] = None,
374374
network_interface: Optional[str] = None,
375-
streaming_callback: Optional[Callable[[bytes], None]] = None,
375+
streaming_callback: Optional[
376+
Callable[[bytes], Optional[Awaitable[None]]]
377+
] = None,
376378
header_callback: Optional[Callable[[str], None]] = None,
377379
prepare_curl_callback: Optional[Callable[[Any], None]] = None,
378380
proxy_host: Optional[str] = None,

tornado/simple_httpclient.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
from io import BytesIO
3434
import urllib.parse
3535

36-
from typing import Dict, Any, Callable, Optional, Type, Union
36+
from typing import Dict, Any, Callable, Optional, Type, Union, Awaitable
3737
from types import TracebackType
3838
import typing
3939

@@ -687,14 +687,15 @@ def finish(self) -> None:
687687
def _on_end_request(self) -> None:
688688
self.stream.close()
689689

690-
def data_received(self, chunk: bytes) -> None:
690+
def data_received(self, chunk: bytes) -> Optional[Awaitable[None]]:
691691
if self._should_follow_redirect():
692692
# We're going to follow a redirect so just discard the body.
693-
return
693+
return None
694694
if self.request.streaming_callback is not None:
695-
self.request.streaming_callback(chunk)
695+
return self.request.streaming_callback(chunk)
696696
else:
697697
self.chunks.append(chunk)
698+
return None
698699

699700

700701
if __name__ == "__main__":

tornado/test/curl_httpclient_test.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from tornado.testing import AsyncHTTPTestCase
66
from tornado.test import httpclient_test
77
from tornado.web import Application, RequestHandler
8+
from tornado import gen
89

910

1011
try:
@@ -123,3 +124,19 @@ def test_digest_auth_non_ascii(self):
123124
auth_password="barユ£",
124125
)
125126
self.assertEqual(response.body, b"ok")
127+
128+
def test_streaming_callback_not_permitted(self):
129+
@gen.coroutine
130+
def _recv_chunk(chunk):
131+
yield gen.moment
132+
133+
with self.assertRaises(TypeError):
134+
self.fetch("/digest", streaming_callback=_recv_chunk)
135+
136+
import asyncio
137+
138+
async def _async_recv_chunk(chunk):
139+
await asyncio.sleep(0)
140+
141+
with self.assertRaises(TypeError):
142+
self.fetch("/digest", streaming_callback=_async_recv_chunk)

tornado/test/simple_httpclient_test.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -539,6 +539,27 @@ def test_streaming_follow_redirects(self):
539539
num_start_lines = len([h for h in headers if h.startswith("HTTP/")])
540540
self.assertEqual(num_start_lines, 1)
541541

542+
def test_streaming_callback_coroutine(self: typing.Any):
543+
headers = [] # type: typing.List[str]
544+
chunk_bytes = [] # type: typing.List[bytes]
545+
546+
import asyncio
547+
548+
async def _put_chunk(chunk):
549+
await asyncio.sleep(0)
550+
chunk_bytes.append(chunk)
551+
552+
self.fetch(
553+
"/chunk",
554+
header_callback=headers.append,
555+
streaming_callback=_put_chunk,
556+
)
557+
chunks = list(map(to_unicode, chunk_bytes))
558+
self.assertEqual("".join(chunks), "asdfqwer")
559+
# Make sure we only got one set of headers.
560+
num_start_lines = len([h for h in headers if h.startswith("HTTP/")])
561+
self.assertEqual(num_start_lines, 1)
562+
542563

543564
class SimpleHTTPClientTestCase(AsyncHTTPTestCase, SimpleHTTPClientTestMixin):
544565
def setUp(self):

0 commit comments

Comments
 (0)