Skip to content

Commit 54202ae

Browse files
committed
implement Socket.recv_into
wrapper for `zmq_recv` API, matches stdlib `socket.recv_into`
1 parent 5e0cdbc commit 54202ae

File tree

15 files changed

+448
-10
lines changed

15 files changed

+448
-10
lines changed

examples/recv_into/discard.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
"""
2+
use recv_into with an empty buffer to discard unwanted message frames
3+
4+
avoids unnecessary allocations for message frames that won't be used
5+
"""
6+
7+
import logging
8+
import os
9+
import random
10+
import secrets
11+
import time
12+
from pathlib import Path
13+
from tempfile import TemporaryDirectory
14+
from threading import Thread
15+
16+
import zmq
17+
18+
EMPTY = bytearray()
19+
20+
21+
def subscriber(url: str) -> None:
22+
log = logging.getLogger("subscriber")
23+
with zmq.Context() as ctx, ctx.socket(zmq.SUB) as sub:
24+
sub.linger = 0
25+
sub.connect(url)
26+
sub.subscribe(b"")
27+
log.info("Receiving...")
28+
while True:
29+
frame_0 = sub.recv_string()
30+
if frame_0 == "exit":
31+
log.info("Exiting...")
32+
break
33+
elif frame_0 == "large":
34+
discarded_bytes = 0
35+
discarded_frames = 0
36+
while sub.rcvmore:
37+
discarded_bytes += sub.recv_into(EMPTY)
38+
discarded_frames += 1
39+
log.info(
40+
"Discarding large message frames: %s, bytes: %s",
41+
discarded_frames,
42+
discarded_bytes,
43+
)
44+
else:
45+
msg: list = [frame_0]
46+
if sub.rcvmore:
47+
msg.extend(sub.recv_multipart(flags=zmq.DONTWAIT))
48+
log.info("Received %s", msg)
49+
log.info("Done")
50+
51+
52+
def publisher(url) -> None:
53+
log = logging.getLogger("publisher")
54+
choices = ["large", "small"]
55+
with zmq.Context() as ctx, ctx.socket(zmq.PUB) as pub:
56+
pub.linger = 1000
57+
pub.bind(url)
58+
time.sleep(1)
59+
for i in range(10):
60+
kind = random.choice(choices)
61+
frames = [kind.encode()]
62+
if kind == "large":
63+
for _ in range(random.randint(0, 5)):
64+
chunk_size = random.randint(1024, 2048)
65+
chunk = os.urandom(chunk_size)
66+
frames.append(chunk)
67+
else:
68+
for _ in range(random.randint(0, 3)):
69+
chunk_size = random.randint(0, 5)
70+
chunk = secrets.token_hex(chunk_size).encode()
71+
frames.append(chunk)
72+
nbytes = sum(len(chunk) for chunk in frames)
73+
log.info("Sending %s: %s bytes", kind, nbytes)
74+
pub.send_multipart(frames)
75+
time.sleep(0.1)
76+
log.info("Sending exit")
77+
pub.send(b"exit")
78+
log.info("Done")
79+
80+
81+
def main() -> None:
82+
logging.basicConfig(level=logging.INFO)
83+
with TemporaryDirectory() as td:
84+
sock_path = Path(td) / "example.sock"
85+
url = f"ipc://{sock_path}"
86+
s_thread = Thread(
87+
target=subscriber, args=(url,), daemon=True, name="subscriber"
88+
)
89+
s_thread.start()
90+
publisher(url)
91+
s_thread.join(timeout=3)
92+
93+
94+
if __name__ == "__main__":
95+
main()

examples/recv_into/recv_into_array.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
"""
2+
Use recv_into to receive data directly into a numpy array
3+
"""
4+
5+
import numpy as np
6+
import numpy.testing as nt
7+
8+
import zmq
9+
10+
url = "inproc://test"
11+
12+
13+
def main() -> None:
14+
A = (np.random.random((5, 5)) * 255).astype(dtype=np.int64)
15+
B = np.empty_like(A)
16+
assert not (A == B).all()
17+
18+
with (
19+
zmq.Context() as ctx,
20+
ctx.socket(zmq.PUSH) as push,
21+
ctx.socket(zmq.PULL) as pull,
22+
):
23+
push.bind(url)
24+
pull.connect(url)
25+
print("sending:\n", A)
26+
push.send(A)
27+
bytes_received = pull.recv_into(B)
28+
print(f"received {bytes_received} bytes:\n", B)
29+
assert bytes_received == A.nbytes
30+
nt.assert_allclose(A, B)
31+
32+
33+
if __name__ == "__main__":
34+
main()

tests/test_asyncio.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,61 @@ async def test_recv(create_bound_pair):
6464
assert recvd == b"there"
6565

6666

67+
async def test_recv_into(create_bound_pair):
68+
a, b = create_bound_pair()
69+
b.rcvtimeo = 1000
70+
msg = [
71+
b'hello',
72+
b'there world',
73+
b'part 3',
74+
b'rest',
75+
]
76+
a.send_multipart(msg)
77+
78+
# default nbytes: fits in array
79+
buf = bytearray(10)
80+
nbytes = await b.recv_into(buf)
81+
assert nbytes == len(msg[0])
82+
assert buf[:nbytes] == msg[0]
83+
84+
# default nbytes: truncates to sizeof(buf)
85+
buf = bytearray(4)
86+
nbytes = await b.recv_into(buf, flags=zmq.DONTWAIT)
87+
# returned nbytes is the actual received length,
88+
# which indicates truncation
89+
assert nbytes == len(msg[1])
90+
assert buf[:] == msg[1][: len(buf)]
91+
92+
# specify nbytes, truncates
93+
buf = bytearray(10)
94+
nbytes = 4
95+
nbytes_recvd = await b.recv_into(buf, nbytes=nbytes)
96+
assert nbytes_recvd == len(msg[2])
97+
98+
# recv_into empty buffer discards everything
99+
buf = bytearray(10)
100+
view = memoryview(buf)[:0]
101+
assert view.nbytes == 0
102+
nbytes = await b.recv_into(view)
103+
assert nbytes == len(msg[3])
104+
105+
106+
async def test_recv_into_bad(create_bound_pair):
107+
a, b = create_bound_pair()
108+
b.rcvtimeo = 1000
109+
110+
# bad calls
111+
# make sure flags work
112+
with pytest.raises(zmq.Again):
113+
await b.recv_into(bytearray(5), flags=zmq.DONTWAIT)
114+
115+
await a.send(b'msg')
116+
# negative nbytes
117+
buf = bytearray(10)
118+
with pytest.raises(ValueError):
119+
await b.recv_into(buf, nbytes=-1)
120+
121+
67122
@mark.skipif(not hasattr(zmq, "RCVTIMEO"), reason="requires RCVTIMEO")
68123
async def test_recv_timeout(push_pull):
69124
a, b = push_pull

tests/test_socket.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -455,6 +455,79 @@ def test_recv_multipart(self):
455455
for i in range(3):
456456
assert self.recv_multipart(b) == [msg]
457457

458+
def test_recv_into(self):
459+
a, b = self.create_bound_pair()
460+
if not self.green:
461+
b.rcvtimeo = 1000
462+
msg = [
463+
b'hello',
464+
b'there world',
465+
b'part 3',
466+
b'rest',
467+
]
468+
a.send_multipart(msg)
469+
470+
# default nbytes: fits in array
471+
buf = bytearray(10)
472+
nbytes = b.recv_into(buf)
473+
assert nbytes == len(msg[0])
474+
assert buf[:nbytes] == msg[0]
475+
476+
# default nbytes: truncates to sizeof(buf)
477+
buf = bytearray(4)
478+
nbytes = b.recv_into(buf)
479+
# returned nbytes is the actual received length,
480+
# which indicates truncation
481+
assert nbytes == len(msg[1])
482+
assert buf[:] == msg[1][: len(buf)]
483+
484+
# specify nbytes, truncates
485+
buf = bytearray(10)
486+
nbytes = 4
487+
nbytes_recvd = b.recv_into(buf, nbytes=nbytes)
488+
assert nbytes_recvd == len(msg[2])
489+
assert buf[:nbytes] == msg[2][:nbytes]
490+
# didn't recv excess bytes
491+
assert buf[nbytes:] == bytearray(10 - nbytes)
492+
493+
# recv_into empty buffer discards everything
494+
buf = bytearray(10)
495+
view = memoryview(buf)[:0]
496+
assert view.nbytes == 0
497+
nbytes = b.recv_into(view)
498+
assert nbytes == len(msg[3])
499+
assert buf == bytearray(10)
500+
501+
def test_recv_into_bad(self):
502+
a, b = self.create_bound_pair()
503+
b.rcvtimeo = 1000
504+
505+
# bad calls
506+
507+
# negative nbytes
508+
buf = bytearray(10)
509+
with pytest.raises(ValueError):
510+
b.recv_into(buf, nbytes=-1)
511+
# not contiguous
512+
buf = memoryview(bytearray(10))[::2]
513+
with pytest.raises(ValueError):
514+
b.recv_into(buf)
515+
# readonly
516+
buf = memoryview(b"readonly")
517+
with pytest.raises(ValueError):
518+
b.recv_into(buf)
519+
# too big
520+
buf = bytearray(10)
521+
with pytest.raises(ValueError):
522+
b.recv_into(buf, nbytes=11)
523+
# not memory-viewable
524+
with pytest.raises(TypeError):
525+
b.recv_into(pytest)
526+
527+
# make sure flags work
528+
with pytest.raises(zmq.Again):
529+
b.recv_into(bytearray(5), flags=zmq.DONTWAIT)
530+
458531
def test_close_after_destroy(self):
459532
"""s.close() after ctx.destroy() should be fine"""
460533
ctx = self.Context()

tests/zmq_test_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ def create_bound_pair(
134134
s2.setsockopt(zmq.LINGER, 0)
135135
s2.connect(f'{interface}:{port}')
136136
self.sockets.extend([s1, s2])
137+
s2.setsockopt(zmq.LINGER, 0)
137138
return s1, s2
138139

139140
def ping_pong(self, s1, s2, msg):

0 commit comments

Comments
 (0)