Skip to content

Commit 741a800

Browse files
committed
minor cython optimizations
clearer type annotations, avoid some unnecessary Python calls
1 parent 52b83df commit 741a800

File tree

1 file changed

+48
-50
lines changed

1 file changed

+48
-50
lines changed

zmq/backend/cython/_zmq.py

Lines changed: 48 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,9 @@
3333
"""
3434
raise ImportError(msg)
3535

36-
import time
3736
import warnings
3837
from threading import Event
38+
from time import monotonic
3939
from weakref import ref
4040

4141
import cython as C
@@ -143,7 +143,13 @@
143143

144144
import zmq
145145
from zmq.constants import SocketOption, _OptType
146-
from zmq.error import InterruptedSystemCall, ZMQError, _check_version
146+
from zmq.error import (
147+
Again,
148+
ContextTerminated,
149+
InterruptedSystemCall,
150+
ZMQError,
151+
_check_version,
152+
)
147153

148154
IPC_PATH_MAX_LEN = get_ipc_path_max_len()
149155

@@ -162,20 +168,12 @@ def _check_rc(rc: C.int, error_without_errno: bint = False) -> C.int:
162168
return 0
163169
if rc == -1: # if rc < -1, it's a bug in libzmq. Should we warn?
164170
if errno == EINTR:
165-
from zmq.error import InterruptedSystemCall
166-
167171
raise InterruptedSystemCall(errno)
168172
elif errno == EAGAIN:
169-
from zmq.error import Again
170-
171173
raise Again(errno)
172174
elif errno == ZMQ_ETERM:
173-
from zmq.error import ContextTerminated
174-
175175
raise ContextTerminated(errno)
176176
else:
177-
from zmq.error import ZMQError
178-
179177
raise ZMQError(errno)
180178
return 0
181179

@@ -272,6 +270,10 @@ def __init__(
272270
if copy_threshold is None:
273271
copy_threshold = zmq.COPY_THRESHOLD
274272

273+
c_copy_threshold: C.size_t = 0
274+
if copy_threshold is not None:
275+
c_copy_threshold = copy_threshold
276+
275277
zmq_msg_ptr: pointer(zmq_msg_t) = address(self.zmq_msg)
276278
# init more as False
277279
self.more = False
@@ -301,13 +303,16 @@ def __init__(
301303
data_len_c = _asbuffer(data, cast(pointer(p_void), address(data_c)))
302304

303305
# copy unspecified, apply copy_threshold
306+
c_copy: bint = True
304307
if copy is None:
305-
if copy_threshold and data_len_c < copy_threshold:
306-
copy = True
308+
if c_copy_threshold and data_len_c < c_copy_threshold:
309+
c_copy = True
307310
else:
308-
copy = False
311+
c_copy = False
312+
else:
313+
c_copy = copy
309314

310-
if copy:
315+
if c_copy:
311316
# copy message data instead of sharing memory
312317
rc = zmq_msg_init_size(zmq_msg_ptr, data_len_c)
313318
_check_rc(rc)
@@ -710,7 +715,7 @@ def closed(self):
710715
"""Whether the socket is closed"""
711716
return _check_closed_deep(self)
712717

713-
def close(self, linger=None):
718+
def close(self, linger: int | None = None):
714719
"""
715720
Close the socket.
716721
@@ -732,7 +737,7 @@ def close(self, linger=None):
732737
if setlinger:
733738
zmq_setsockopt(self.handle, ZMQ_LINGER, address(linger_c), sizeof(int))
734739
rc = zmq_close(self.handle)
735-
if rc < 0 and zmq_errno() != ENOTSOCK:
740+
if rc < 0 and _zmq_errno() != ENOTSOCK:
736741
# ignore ENOTSOCK (closed by Context)
737742
_check_rc(rc)
738743
self._closed = True
@@ -877,7 +882,7 @@ def get(self, option: C.int):
877882

878883
return result
879884

880-
def bind(self, addr):
885+
def bind(self, addr: str):
881886
"""
882887
Bind the socket to an address.
883888
@@ -894,21 +899,13 @@ def bind(self, addr):
894899
encoded to utf-8 first.
895900
"""
896901
rc: C.int
897-
c_addr: p_char
902+
b_addr: bytes = addr.encode('utf-8')
903+
c_addr: p_char = b_addr
898904

899905
_check_closed(self)
900-
addr_b = addr
901-
if isinstance(addr, str):
902-
addr_b = addr.encode('utf-8')
903-
elif isinstance(addr_b, bytes):
904-
addr = addr_b.decode('utf-8')
905-
906-
if not isinstance(addr_b, bytes):
907-
raise TypeError(f'expected str, got: {addr!r}')
908-
c_addr = addr_b
909906
rc = zmq_bind(self.handle, c_addr)
910907
if rc != 0:
911-
if IPC_PATH_MAX_LEN and zmq_errno() == ENAMETOOLONG:
908+
if IPC_PATH_MAX_LEN and _zmq_errno() == ENAMETOOLONG:
912909
path = addr.split('://', 1)[-1]
913910
msg = (
914911
f'ipc path "{path}" is longer than {IPC_PATH_MAX_LEN} '
@@ -917,7 +914,7 @@ def bind(self, addr):
917914
'to check addr length (if it is defined).'
918915
)
919916
raise ZMQError(msg=msg)
920-
elif zmq_errno() == ENOENT:
917+
elif _zmq_errno() == ENOENT:
921918
path = addr.split('://', 1)[-1]
922919
msg = f'No such file or directory for ipc path "{path}".'
923920
raise ZMQError(msg=msg)
@@ -930,7 +927,7 @@ def bind(self, addr):
930927
else:
931928
break
932929

933-
def connect(self, addr):
930+
def connect(self, addr: str) -> None:
934931
"""
935932
Connect to a remote 0MQ socket.
936933
@@ -943,14 +940,10 @@ def connect(self, addr):
943940
encoded to utf-8 first.
944941
"""
945942
rc: C.int
946-
c_addr: p_char
943+
b_addr: bytes = addr.encode('utf-8')
944+
c_addr: p_char = b_addr
947945

948946
_check_closed(self)
949-
if isinstance(addr, str):
950-
addr = addr.encode('utf-8')
951-
if not isinstance(addr, bytes):
952-
raise TypeError(f'expected str, got: {addr!r}')
953-
c_addr = addr
954947

955948
while True:
956949
try:
@@ -1061,7 +1054,7 @@ def monitor(self, addr, events: C.int = ZMQ_EVENT_ALL):
10611054

10621055
_check_rc(zmq_socket_monitor(self.handle, c_addr, events))
10631056

1064-
def join(self, group):
1057+
def join(self, group: str | bytes):
10651058
"""
10661059
Join a RADIO-DISH group
10671060
@@ -1076,7 +1069,8 @@ def join(self, group):
10761069
raise RuntimeError("libzmq must be built with draft support")
10771070
if isinstance(group, str):
10781071
group = group.encode('utf8')
1079-
rc: C.int = zmq_join(self.handle, group)
1072+
c_group: bytes = group
1073+
rc: C.int = zmq_join(self.handle, c_group)
10801074
_check_rc(rc)
10811075

10821076
def leave(self, group):
@@ -1152,8 +1146,10 @@ def send(self, data, flags=0, copy: bint = True, track: bint = False):
11521146
else:
11531147
if self.copy_threshold:
11541148
buf = memoryview(data)
1149+
nbytes: C.int = buf.nbytes
1150+
copy_threshold: C.int = self.copy_threshold
11551151
# always copy messages smaller than copy_threshold
1156-
if buf.nbytes < self.copy_threshold:
1152+
if nbytes < copy_threshold:
11571153
_send_copy(self.handle, buf, flags)
11581154
return zmq._FINISHED_TRACKER
11591155
msg = Frame(data, track=track, copy_threshold=self.copy_threshold)
@@ -1310,7 +1306,7 @@ def _check_closed_deep(s: Socket) -> bint:
13101306
s.handle, ZMQ_TYPE, cast(p_void, address(stype)), address(sz)
13111307
)
13121308
if rc < 0:
1313-
errno = zmq_errno()
1309+
errno = _zmq_errno()
13141310
if errno == ENOTSOCK:
13151311
s._closed = True
13161312
return True
@@ -1465,7 +1461,7 @@ def _setsockopt(handle: p_void, option: C.int, optval: p_void, sz: size_t):
14651461
# General utility functions
14661462

14671463

1468-
def zmq_errno():
1464+
def zmq_errno() -> C.int:
14691465
"""Return the integer errno of the most recent zmq error."""
14701466
return _zmq_errno()
14711467

@@ -1487,17 +1483,14 @@ def zmq_version_info() -> tuple[int, int, int]:
14871483
return (major, minor, patch)
14881484

14891485

1490-
def has(capability) -> bool:
1486+
def has(capability: str) -> bool:
14911487
"""Check for zmq capability by name (e.g. 'ipc', 'curve')
14921488
14931489
.. versionadded:: libzmq-4.1
14941490
.. versionadded:: 14.1
14951491
"""
14961492
_check_version((4, 1), 'zmq.has')
1497-
ccap: bytes
1498-
if isinstance(capability, str):
1499-
capability = capability.encode('utf8')
1500-
ccap = capability
1493+
ccap: bytes = capability.encode('utf8')
15011494
return bool(zmq_has(ccap))
15021495

15031496

@@ -1578,6 +1571,8 @@ def zmq_poll(sockets, timeout: C.int = -1):
15781571
"""
15791572
rc: C.int
15801573
i: C.int
1574+
fileno: fd_t
1575+
events: C.int
15811576
pollitems: pointer(zmq_pollitem_t) = NULL
15821577
nsockets: C.int = len(sockets)
15831578

@@ -1596,8 +1591,9 @@ def zmq_poll(sockets, timeout: C.int = -1):
15961591
pollitems[i].events = events
15971592
pollitems[i].revents = 0
15981593
elif isinstance(s, int):
1594+
fileno = s
15991595
pollitems[i].socket = NULL
1600-
pollitems[i].fd = s
1596+
pollitems[i].fd = fileno
16011597
pollitems[i].events = events
16021598
pollitems[i].revents = 0
16031599
elif hasattr(s, 'fileno'):
@@ -1618,17 +1614,19 @@ def zmq_poll(sockets, timeout: C.int = -1):
16181614
f"a fileno() method: {s!r}"
16191615
)
16201616

1621-
ms_passed: int = 0
1617+
ms_passed: C.int = 0
1618+
tic: C.int
16221619
try:
16231620
while True:
1624-
start = time.monotonic()
1621+
start: C.int = monotonic()
16251622
with nogil:
16261623
rc = zmq_poll_c(pollitems, nsockets, timeout)
16271624
try:
16281625
_check_rc(rc)
16291626
except InterruptedSystemCall:
16301627
if timeout > 0:
1631-
ms_passed = int(1000 * (time.monotonic() - start))
1628+
tic = monotonic()
1629+
ms_passed = int(1000 * (tic - start))
16321630
if ms_passed < 0:
16331631
# don't allow negative ms_passed,
16341632
# which can happen on old Python versions without time.monotonic.

0 commit comments

Comments
 (0)