Skip to content

Commit 60ee2c0

Browse files
authored
Merge pull request #2035 from minrk/types-are-tedious
more annotations for zmq.asyncio.Socket
2 parents 25168fc + 5f0b954 commit 60ee2c0

File tree

3 files changed

+153
-45
lines changed

3 files changed

+153
-45
lines changed

mypy_tests/test_socket_async.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
from __future__ import annotations
2+
3+
import asyncio
4+
5+
import zmq
6+
import zmq.asyncio
7+
8+
9+
async def main() -> None:
10+
ctx = zmq.asyncio.Context()
11+
12+
# shadow exercise
13+
sync_ctx: zmq.Context = zmq.Context.shadow(ctx)
14+
ctx2: zmq.asyncio.Context = zmq.asyncio.Context.shadow(sync_ctx)
15+
ctx2 = zmq.asyncio.Context(sync_ctx)
16+
17+
url = "tcp://127.0.0.1:5555"
18+
pub = ctx.socket(zmq.PUB)
19+
sub = ctx.socket(zmq.SUB)
20+
pub.bind(url)
21+
sub.connect(url)
22+
sub.subscribe(b"")
23+
await asyncio.sleep(1)
24+
25+
# shadow exercise
26+
sync_sock: zmq.Socket[bytes] = zmq.Socket.shadow(pub)
27+
s2: zmq.asyncio.Socket = zmq.asyncio.Socket(sync_sock)
28+
s2 = zmq.asyncio.Socket.from_socket(sync_sock)
29+
30+
print("sending")
31+
await pub.send(b"plain")
32+
await pub.send(b"plain")
33+
await pub.send_multipart([b"topic", b"Message"])
34+
await pub.send_multipart([b"topic", b"Message"])
35+
await pub.send_string("asdf")
36+
await pub.send_pyobj(123)
37+
await pub.send_json({"a": "5"})
38+
39+
print("receiving")
40+
msg_bytes: bytes = await sub.recv()
41+
msg_frame: zmq.Frame = await sub.recv(copy=False)
42+
msg_list: list[bytes] = await sub.recv_multipart()
43+
msg_frames: list[zmq.Frame] = await sub.recv_multipart(copy=False)
44+
s: str = await sub.recv_string()
45+
obj = await sub.recv_pyobj()
46+
d = await sub.recv_json()
47+
48+
pub.close()
49+
sub.close()
50+
51+
52+
if __name__ == "__main__":
53+
asyncio.run(main())

zmq/_future.py

Lines changed: 8 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,17 @@
99
from collections import deque
1010
from functools import partial
1111
from itertools import chain
12-
from typing import Any, Awaitable, Callable, NamedTuple, TypeVar, cast, overload
12+
from typing import (
13+
Any,
14+
Awaitable,
15+
Callable,
16+
NamedTuple,
17+
TypeVar,
18+
cast,
19+
)
1320

1421
import zmq as _zmq
1522
from zmq import EVENTS, POLLIN, POLLOUT
16-
from zmq._typing import Literal
1723

1824

1925
class _FutureEvent(NamedTuple):
@@ -260,27 +266,6 @@ def get(self, key):
260266

261267
get.__doc__ = _zmq.Socket.get.__doc__
262268

263-
@overload # type: ignore
264-
def recv_multipart(
265-
self, flags: int = 0, *, track: bool = False
266-
) -> Awaitable[list[bytes]]: ...
267-
268-
@overload
269-
def recv_multipart(
270-
self, flags: int = 0, *, copy: Literal[True], track: bool = False
271-
) -> Awaitable[list[bytes]]: ...
272-
273-
@overload
274-
def recv_multipart(
275-
self, flags: int = 0, *, copy: Literal[False], track: bool = False
276-
) -> Awaitable[list[_zmq.Frame]]: # type: ignore
277-
...
278-
279-
@overload
280-
def recv_multipart(
281-
self, flags: int = 0, copy: bool = True, track: bool = False
282-
) -> Awaitable[list[bytes] | list[_zmq.Frame]]: ...
283-
284269
def recv_multipart(
285270
self, flags: int = 0, copy: bool = True, track: bool = False
286271
) -> Awaitable[list[bytes] | list[_zmq.Frame]]:
@@ -292,19 +277,6 @@ def recv_multipart(
292277
'recv_multipart', dict(flags=flags, copy=copy, track=track)
293278
)
294279

295-
@overload # type: ignore
296-
def recv(self, flags: int = 0, *, track: bool = False) -> Awaitable[bytes]: ...
297-
298-
@overload
299-
def recv(
300-
self, flags: int = 0, *, copy: Literal[True], track: bool = False
301-
) -> Awaitable[bytes]: ...
302-
303-
@overload
304-
def recv(
305-
self, flags: int = 0, *, copy: Literal[False], track: bool = False
306-
) -> Awaitable[_zmq.Frame]: ...
307-
308280
def recv( # type: ignore
309281
self, flags: int = 0, copy: bool = True, track: bool = False
310282
) -> Awaitable[bytes | _zmq.Frame]:
@@ -440,15 +412,6 @@ def cancel_poll(future):
440412

441413
return future
442414

443-
# overrides only necessary for updated types
444-
def recv_string(self, *args, **kwargs) -> Awaitable[str]: # type: ignore
445-
return super().recv_string(*args, **kwargs) # type: ignore
446-
447-
def send_string( # type: ignore
448-
self, s: str, flags: int = 0, encoding: str = 'utf-8'
449-
) -> Awaitable[None]:
450-
return super().send_string(s, flags=flags, encoding=encoding) # type: ignore
451-
452415
def _add_timeout(self, future, timeout):
453416
"""Add a timeout for a send or recv Future"""
454417

zmq/_future.pyi

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
"""type annotations for async sockets"""
2+
3+
from __future__ import annotations
4+
5+
from asyncio import Future
6+
from pickle import DEFAULT_PROTOCOL
7+
from typing import Any, Awaitable, Literal, Sequence, TypeVar, overload
8+
9+
import zmq as _zmq
10+
11+
class _AsyncPoller(_zmq.Poller):
12+
_socket_class: type[_AsyncSocket]
13+
14+
def poll(self, timeout=-1) -> Awaitable[list[tuple[Any, int]]]: ... # type: ignore
15+
16+
T = TypeVar("T", bound="_AsyncSocket")
17+
18+
class _AsyncSocket(_zmq.Socket[Future]):
19+
@classmethod
20+
def from_socket(cls: type[T], socket: _zmq.Socket, io_loop: Any = None) -> T: ...
21+
def send( # type: ignore
22+
self,
23+
data: Any,
24+
flags: int = 0,
25+
copy: bool = True,
26+
track: bool = False,
27+
routing_id: int | None = None,
28+
group: str | None = None,
29+
) -> Awaitable[_zmq.MessageTracker | None]: ...
30+
@overload # type: ignore
31+
def recv(self, flags: int = 0, *, track: bool = False) -> Awaitable[bytes]: ...
32+
@overload
33+
def recv(
34+
self, flags: int = 0, *, copy: Literal[True], track: bool = False
35+
) -> Awaitable[bytes]: ...
36+
@overload
37+
def recv(
38+
self, flags: int = 0, *, copy: Literal[False], track: bool = False
39+
) -> Awaitable[_zmq.Frame]: ...
40+
@overload
41+
def recv(
42+
self, flags: int = 0, copy: bool = True, track: bool = False
43+
) -> Awaitable[bytes | _zmq.Frame]: ...
44+
def send_multipart( # type: ignore
45+
self,
46+
msg_parts: Sequence,
47+
flags: int = 0,
48+
copy: bool = True,
49+
track: bool = False,
50+
routing_id: int | None = None,
51+
group: str | None = None,
52+
) -> Awaitable[_zmq.MessageTracker | None]: ...
53+
@overload # type: ignore
54+
def recv_multipart(
55+
self, flags: int = 0, *, track: bool = False
56+
) -> Awaitable[list[bytes]]: ...
57+
@overload
58+
def recv_multipart(
59+
self, flags: int = 0, *, copy: Literal[True], track: bool = False
60+
) -> Awaitable[list[bytes]]: ...
61+
@overload
62+
def recv_multipart(
63+
self, flags: int = 0, *, copy: Literal[False], track: bool = False
64+
) -> Awaitable[list[_zmq.Frame]]: ...
65+
@overload
66+
def recv_multipart(
67+
self, flags: int = 0, copy: bool = True, track: bool = False
68+
) -> Awaitable[list[bytes] | list[_zmq.Frame]]: ...
69+
70+
# serialization wrappers
71+
72+
def send_string( # type: ignore
73+
self,
74+
u: str,
75+
flags: int = 0,
76+
copy: bool = True,
77+
*,
78+
encoding: str = 'utf-8',
79+
**kwargs,
80+
) -> Awaitable[_zmq.Frame | None]: ...
81+
def recv_string( # type: ignore
82+
self, flags: int = 0, encoding: str = 'utf-8'
83+
) -> Awaitable[str]: ...
84+
def send_pyobj( # type: ignore
85+
self, obj: Any, flags: int = 0, protocol: int = DEFAULT_PROTOCOL, **kwargs
86+
) -> Awaitable[_zmq.Frame | None]: ...
87+
def recv_pyobj(self, flags: int = 0) -> Awaitable[Any]: ... # type: ignore
88+
def send_json( # type: ignore
89+
self, obj: Any, flags: int = 0, **kwargs
90+
) -> Awaitable[_zmq.Frame | None]: ...
91+
def recv_json(self, flags: int = 0, **kwargs) -> Awaitable[Any]: ... # type: ignore
92+
def poll(self, timeout=-1) -> Awaitable[list[tuple[Any, int]]]: ... # type: ignore

0 commit comments

Comments
 (0)