Skip to content

Commit b11f126

Browse files
minrksteamraven
andauthored
[experimental] support socket.FD on draft thread-safe sockets (#2103)
gated by warning, may be removed if it causes problems will have no effect if/when getsockopt(ZMQ_FD) is implemented for threadsafe sockets --------- Co-authored-by: Matthew Hawn <[email protected]>
1 parent 4765057 commit b11f126

File tree

9 files changed

+209
-14
lines changed

9 files changed

+209
-14
lines changed

tests/test_asyncio.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -405,6 +405,33 @@ async def test_poll_leak():
405405
assert len(s._recv_futures) == 0
406406

407407

408+
async def test_draft_asyncio():
409+
if not zmq.has("draft"):
410+
pytest.skip("draft API")
411+
if zmq.zmq_version_info() < (4, 3, 2):
412+
pytest.skip("requires libzmq 4.3.2 for zmq_poller_fd")
413+
with zmq.asyncio.Context() as ctx, ctx.socket(zmq.CLIENT) as client, ctx.socket(
414+
zmq.SERVER
415+
) as server:
416+
server.bind_to_random_port("tcp://127.0.0.1")
417+
client.connect(server.last_endpoint)
418+
server.rcvtimeo = client.rcvtimeo = 100
419+
with pytest.raises(zmq.Again):
420+
await server.recv()
421+
with pytest.raises(zmq.Again):
422+
await client.recv()
423+
server.rcvtimeo = client.rcvtimeo = server.sndtimeo = client.sndtimeo = 3000
424+
recv_future = asyncio.ensure_future(server.recv(copy=False))
425+
assert not recv_future.done()
426+
await client.send(b'request')
427+
msg = await recv_future
428+
recv_future = asyncio.ensure_future(client.recv())
429+
assert not recv_future.done()
430+
await server.send(msg)
431+
response = await recv_future
432+
assert response == b'request'
433+
434+
408435
class ProcessForTeardownTest(Process):
409436
def run(self):
410437
"""Leave context, socket and event loop upon implicit disposal"""

tests/test_draft.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,13 @@
66
import pytest
77

88
import zmq
9-
from zmq_test_utils import BaseZMQTestCase
9+
from zmq_test_utils import BaseZMQTestCase, skip_pypy
1010

11+
pytestmark = pytest.mark.skipif(not zmq.DRAFT_API, reason="draft api unavailable")
1112

12-
class TestDraftSockets(BaseZMQTestCase):
13-
def setUp(self):
14-
if not zmq.DRAFT_API:
15-
pytest.skip("draft api unavailable")
16-
super().setUp()
1713

14+
class TestDraftSockets(BaseZMQTestCase):
15+
@skip_pypy
1816
def test_client_server(self):
1917
client, server = self.create_bound_pair(zmq.CLIENT, zmq.SERVER)
2018
client.send(b'request')
@@ -24,6 +22,15 @@ def test_client_server(self):
2422
reply = self.recv(client)
2523
assert reply == b'reply'
2624

25+
def test_client_server_frame(self):
26+
client, server = self.create_bound_pair(zmq.CLIENT, zmq.SERVER)
27+
client.send(b'request')
28+
msg = self.recv(server, copy=False)
29+
server.send(msg)
30+
reply = self.recv(client)
31+
assert reply == b'request'
32+
33+
@skip_pypy
2734
def test_radio_dish(self):
2835
dish, radio = self.create_bound_pair(zmq.DISH, zmq.RADIO)
2936
dish.rcvtimeo = 250
@@ -45,3 +52,13 @@ def test_radio_dish(self):
4552
received_count += 1
4653
# assert that we got *something*
4754
assert len(received.intersection(sent)) >= 5
55+
56+
57+
def test_draft_fd():
58+
if zmq.zmq_version_info() < (4, 3, 2):
59+
pytest.skip("requires libzmq 4.3.2 for zmq_poller_fd")
60+
with zmq.Context() as ctx, ctx.socket(zmq.SERVER) as s:
61+
fd = s.FD
62+
assert isinstance(fd, int)
63+
fd_2 = s.get(zmq.FD)
64+
assert fd_2 == fd

zmq/backend/cffi/_cdefs.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,12 @@ typedef struct
6868

6969
int zmq_poll(zmq_pollitem_t *items, int nitems, long timeout);
7070

71+
// draft poller
72+
void *zmq_poller_new ();
73+
int zmq_poller_destroy (void **poller_p_);
74+
int zmq_poller_add (void *poller_, void *socket_, void *user_data_, short events_);
75+
int zmq_poller_fd (void *poller_, ZMQ_FD_T *fd_);
76+
7177
// miscellany
7278
void * memcpy(void *restrict s1, const void *restrict s2, size_t n);
7379
void * malloc(size_t sz);

zmq/backend/cffi/socket.py

Lines changed: 58 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,11 @@
44
# Distributed under the terms of the Modified BSD License.
55

66
import errno as errno_mod
7+
import warnings
78

89
import zmq
910
from zmq.constants import SocketOption, _OptType
10-
from zmq.error import ZMQError, _check_rc
11+
from zmq.error import ZMQError, _check_rc, _check_version
1112

1213
from ._cffi import ffi
1314
from ._cffi import lib as C
@@ -49,7 +50,8 @@ def value_binary_data(val, length):
4950
return ffi.new(f'char[{length + 1:d}]', val), ffi.sizeof('char') * length
5051

5152

52-
ZMQ_FD_64BIT = ffi.sizeof('ZMQ_FD_T') == 8
53+
_fd_size = ffi.sizeof('ZMQ_FD_T')
54+
ZMQ_FD_64BIT = _fd_size == 8
5355

5456
IPC_PATH_MAX_LEN = C.get_ipc_path_max_len()
5557

@@ -100,6 +102,8 @@ class Socket:
100102
_closed = None
101103
_ref = None
102104
_shadow = False
105+
_draft_poller = None
106+
_draft_poller_ptr = None
103107
copy_threshold = 0
104108

105109
def __init__(self, context=None, socket_type=None, shadow=0, copy_threshold=None):
@@ -108,6 +112,7 @@ def __init__(self, context=None, socket_type=None, shadow=0, copy_threshold=None
108112
self.copy_threshold = copy_threshold
109113

110114
self.context = context
115+
self._draft_poller = self._draft_poller_ptr = None
111116
if shadow:
112117
self._zmq_socket = ffi.cast("void *", shadow)
113118
self._shadow = True
@@ -152,6 +157,10 @@ def closed(self):
152157
def close(self, linger=None):
153158
rc = 0
154159
if not self._closed and hasattr(self, '_zmq_socket'):
160+
if self._draft_poller_ptr is not None:
161+
rc = C.zmq_poller_destroy(self._draft_poller_ptr)
162+
self._draft_poller = self._draft_poller_ptr = None
163+
155164
if self._zmq_socket is not None:
156165
if linger is not None:
157166
self.set(zmq.LINGER, linger)
@@ -242,11 +251,55 @@ def get(self, option):
242251
else:
243252
opt_type = option._opt_type
244253

254+
if option == zmq.FD and self._draft_poller is not None:
255+
c_value_pointer, _ = new_pointer_from_opt(option)
256+
C.zmq_poller_fd(self._draft_poller, ffi.cast('void*', c_value_pointer))
257+
return int(c_value_pointer[0])
258+
245259
c_value_pointer, c_sizet_pointer = new_pointer_from_opt(option, length=255)
246260

247-
_retry_sys_call(
248-
C.zmq_getsockopt, self._zmq_socket, option, c_value_pointer, c_sizet_pointer
249-
)
261+
try:
262+
_retry_sys_call(
263+
C.zmq_getsockopt,
264+
self._zmq_socket,
265+
option,
266+
c_value_pointer,
267+
c_sizet_pointer,
268+
)
269+
except ZMQError as e:
270+
if (
271+
option == SocketOption.FD
272+
and e.errno == zmq.Errno.EINVAL
273+
and self.get(SocketOption.THREAD_SAFE)
274+
):
275+
_check_version((4, 3, 2), "draft socket FD support via zmq_poller_fd")
276+
if not zmq.has('draft'):
277+
raise RuntimeError("libzmq must be built with draft support")
278+
warnings.warn(zmq.error.DraftFDWarning(), stacklevel=2)
279+
280+
# create a poller and retrieve its fd
281+
self._draft_poller_ptr = ffi.new("void*[1]")
282+
self._draft_poller_ptr[0] = self._draft_poller = C.zmq_poller_new()
283+
if self._draft_poller == ffi.NULL:
284+
# failed (why?), raise original error
285+
self._draft_poller_ptr = self._draft_poller = None
286+
raise
287+
# register self with poller
288+
rc = C.zmq_poller_add(
289+
self._draft_poller,
290+
self._zmq_socket,
291+
ffi.NULL,
292+
zmq.POLLIN | zmq.POLLOUT,
293+
)
294+
_check_rc(rc)
295+
# use poller fd as proxy for ours
296+
rc = C.zmq_poller_fd(
297+
self._draft_poller, ffi.cast('void *', c_value_pointer)
298+
)
299+
_check_rc(rc)
300+
return int(c_value_pointer[0])
301+
else:
302+
raise
250303

251304
sz = c_sizet_pointer[0]
252305
v = value_from_opt_pointer(option, c_value_pointer, sz)

zmq/backend/cython/_zmq.pxd

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ cdef class Socket:
4444
cdef public bint _closed # bool property for a closed socket.
4545
cdef public int copy_threshold # threshold below which pyzmq will always copy messages
4646
cdef int _pid # the pid of the process which created me (for fork safety)
47+
cdef void *_draft_poller # The C handle for the zmq poller for draft socket zmq.FD support
4748

4849
# cpdef methods for direct-cython access:
4950
cpdef object send(self, data, int flags=*, bint copy=*, bint track=*)

zmq/backend/cython/_zmq.py

Lines changed: 59 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,13 +88,16 @@
8888
ZMQ_ENOTSOCK,
8989
ZMQ_ETERM,
9090
ZMQ_EVENT_ALL,
91+
ZMQ_FD,
9192
ZMQ_IDENTITY,
9293
ZMQ_IO_THREADS,
9394
ZMQ_LINGER,
9495
ZMQ_POLLIN,
96+
ZMQ_POLLOUT,
9597
ZMQ_RCVMORE,
9698
ZMQ_ROUTER,
9799
ZMQ_SNDMORE,
100+
ZMQ_THREAD_SAFE,
98101
ZMQ_TYPE,
99102
_zmq_version,
100103
fd_t,
@@ -131,6 +134,10 @@
131134
zmq_msg_set_routing_id,
132135
zmq_msg_size,
133136
zmq_msg_t,
137+
zmq_poller_add,
138+
zmq_poller_destroy,
139+
zmq_poller_fd,
140+
zmq_poller_new,
134141
zmq_pollitem_t,
135142
zmq_proxy,
136143
zmq_proxy_steerable,
@@ -700,6 +707,7 @@ def __init__(
700707
):
701708
# pre-init
702709
self.handle = NULL
710+
self._draft_poller = NULL
703711
self._pid = 0
704712
self._shadow = False
705713
self.context = None
@@ -756,6 +764,12 @@ def close(self, linger: int | None = None):
756764
if self.handle != NULL and not self._closed and getpid() == self._pid:
757765
if setlinger:
758766
zmq_setsockopt(self.handle, ZMQ_LINGER, address(linger_c), sizeof(int))
767+
768+
# teardown draft poller
769+
if self._draft_poller != NULL:
770+
zmq_poller_destroy(address(self._draft_poller))
771+
self._draft_poller = NULL
772+
759773
rc = zmq_close(self.handle)
760774
if rc < 0 and _zmq_errno() != ENOTSOCK:
761775
# ignore ENOTSOCK (closed by Context)
@@ -834,6 +848,10 @@ def get(self, option: C.int):
834848
835849
See the 0MQ API documentation for details on specific options.
836850
851+
.. versionchanged:: 27
852+
Added experimental support for ZMQ_FD for draft sockets via `zmq_poller_fd`.
853+
Requires libzmq >=4.3.2 built with draft support.
854+
837855
Parameters
838856
----------
839857
option : int
@@ -882,11 +900,49 @@ def get(self, option: C.int):
882900
self.handle, option, cast(p_void, address(optval_int64_c)), address(sz)
883901
)
884902
result = optval_int64_c
903+
elif option == ZMQ_FD and self._draft_poller != NULL:
904+
# draft sockets use FD of a draft zmq_poller as proxy
905+
rc = zmq_poller_fd(self._draft_poller, address(optval_fd_c))
906+
_check_rc(rc)
907+
result = optval_fd_c
885908
elif opt_type == _OptType.fd:
886909
sz = sizeof(fd_t)
887-
_getsockopt(
888-
self.handle, option, cast(p_void, address(optval_fd_c)), address(sz)
889-
)
910+
try:
911+
_getsockopt(
912+
self.handle, option, cast(p_void, address(optval_fd_c)), address(sz)
913+
)
914+
except ZMQError as e:
915+
# threadsafe sockets don't support ZMQ_FD (yet!)
916+
# fallback on zmq_poller_fd as proxy with the same behavior
917+
# until libzmq fixes this.
918+
# if upstream fixes it, this branch will never be taken
919+
if (
920+
option == ZMQ_FD
921+
and e.errno == zmq.Errno.EINVAL
922+
and self.get(ZMQ_THREAD_SAFE)
923+
):
924+
_check_version(
925+
(4, 3, 2), "draft socket FD support via zmq_poller_fd"
926+
)
927+
if not zmq.has('draft'):
928+
raise RuntimeError("libzmq must be built with draft support")
929+
warnings.warn(zmq.error.DraftFDWarning(), stacklevel=2)
930+
931+
# create a poller and retrieve its fd
932+
self._draft_poller = zmq_poller_new()
933+
if self._draft_poller == NULL:
934+
# failed (why?), raise original error
935+
raise
936+
# register self with poller
937+
rc = zmq_poller_add(
938+
self._draft_poller, self.handle, NULL, ZMQ_POLLIN | ZMQ_POLLOUT
939+
)
940+
_check_rc(rc)
941+
# use poller fd as proxy for ours
942+
rc = zmq_poller_fd(self._draft_poller, address(optval_fd_c))
943+
_check_rc(rc)
944+
else:
945+
raise
890946
result = optval_fd_c
891947
else:
892948
# default is to assume int, which is what most new sockopts will be

zmq/backend/cython/libzmq.pxd

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,3 +119,10 @@ cdef extern from "zmq.h" nogil:
119119
uint32_t zmq_msg_routing_id(zmq_msg_t *msg)
120120
int zmq_msg_set_group(zmq_msg_t *msg, const char *group)
121121
const char *zmq_msg_group(zmq_msg_t *msg)
122+
123+
void *zmq_poller_new ()
124+
int zmq_poller_destroy (void **poller_p_)
125+
int zmq_poller_add (void *poller_, void *socket_, void *user_data_, short events_)
126+
int zmq_poller_modify (void *poller_, void *socket_, short events_)
127+
int zmq_poller_remove (void *poller_, void *socket_)
128+
int zmq_poller_fd (void *poller_, fd_t *fd_)

zmq/error.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,22 @@
77
from errno import EINTR
88

99

10+
class DraftFDWarning(RuntimeWarning):
11+
"""Warning for using experimental FD on draft sockets.
12+
13+
.. versionadded:: 27
14+
"""
15+
16+
def __init__(self, msg=""):
17+
if not msg:
18+
msg = (
19+
"pyzmq's back-fill socket.FD support on thread-safe sockets is experimental, and may be removed."
20+
" This warning will go away automatically if/when libzmq implements socket.FD on thread-safe sockets."
21+
" You can suppress this warning with `warnings.simplefilter('ignore', zmq.error.DraftFDWarning)"
22+
)
23+
super().__init__(msg)
24+
25+
1026
class ZMQBaseError(Exception):
1127
"""Base exception class for 0MQ errors in Python."""
1228

@@ -201,6 +217,7 @@ def _check_version(
201217

202218

203219
__all__ = [
220+
"DraftFDWarning",
204221
"ZMQBaseError",
205222
"ZMQBindError",
206223
"ZMQError",

zmq/utils/zmq_compat.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,9 @@
6060
#if ZMQ_VERSION >= 40200
6161
#define PYZMQ_DRAFT_42
6262
#endif
63+
#if ZMQ_VERSION >= 40302
64+
#define PYZMQ_DRAFT_432
65+
#endif
6366
#endif
6467
#ifndef PYZMQ_DRAFT_42
6568
#define zmq_join(s, group) _missing
@@ -68,6 +71,14 @@
6871
#define zmq_msg_routing_id(msg) 0
6972
#define zmq_msg_set_group(msg, group) _missing
7073
#define zmq_msg_group(msg) NULL
74+
#define zmq_poller_new() NULL
75+
#define zmq_poller_destroy(poller_p) _missing
76+
#define zmq_poller_add(poller, socket, userdata, events) _missing
77+
#define zmq_poller_modify(poller, socket, events) _missing
78+
#define zmq_poller_remove(poller, socket) _missing
79+
#endif
80+
#ifndef PYZMQ_DRAFT_432
81+
#define zmq_poller_fd(poller, fd) _missing
7182
#endif
7283

7384
#if ZMQ_VERSION >= 40100

0 commit comments

Comments
 (0)