Skip to content

Commit da9a952

Browse files
polybassagpotter2
andauthored
Fix threaded sendrecv (#4538)
* Restore sndrcv behaviour from before 53afe84 * Fix possible race condition of sndrcv * Use much better timeout for threading * Reduce abuse on public servers * fix doip unit tests * add testcase * fix test case * fix unit tests * fix unit tests * fix unit tests * fix unit tests --------- Co-authored-by: gpotter2 <[email protected]>
1 parent 19eeafe commit da9a952

File tree

6 files changed

+125
-35
lines changed

6 files changed

+125
-35
lines changed

scapy/sendrecv.py

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ class debug:
9292
Automatically enabled when a generator is passed as the packet
9393
:param _flood:
9494
:param threaded: if True, packets are sent in a thread and received in another.
95-
defaults to False.
95+
Defaults to True.
9696
:param session: a flow decoder used to handle stream of packets
9797
:param chainEX: if True, exceptions during send will be forwarded
9898
:param stop_filter: Python function applied to each packet to determine if
@@ -128,7 +128,7 @@ def __init__(self,
128128
rcv_pks=None, # type: Optional[SuperSocket]
129129
prebuild=False, # type: bool
130130
_flood=None, # type: Optional[_FloodGenerator]
131-
threaded=False, # type: bool
131+
threaded=True, # type: bool
132132
session=None, # type: Optional[_GlobSessionType]
133133
chainEX=False, # type: bool
134134
stop_filter=None # type: Optional[Callable[[Packet], bool]]
@@ -158,7 +158,7 @@ def __init__(self,
158158
self.noans = 0
159159
self._flood = _flood
160160
self.threaded = threaded
161-
self.breakout = False
161+
self.breakout = Event()
162162
# Instantiate packet holders
163163
if prebuild and not self._flood:
164164
self.tobesent = list(pkt) # type: _PacketIterable
@@ -174,6 +174,7 @@ def __init__(self,
174174
self.timeout = None
175175

176176
while retry >= 0:
177+
self.breakout.clear()
177178
self.hsent = {} # type: Dict[bytes, List[Packet]]
178179

179180
if threaded or self._flood:
@@ -190,7 +191,7 @@ def __init__(self,
190191
except KeyboardInterrupt as ex:
191192
interrupted = ex
192193

193-
self.breakout = True
194+
self.breakout.set()
194195

195196
# Ended. Let's close gracefully
196197
if self._flood:
@@ -251,28 +252,33 @@ def results(self):
251252
# type: () -> Tuple[SndRcvList, PacketList]
252253
return self.ans_result, self.unans_result
253254

255+
def _stop_sniffer_if_done(self) -> None:
256+
"""Close the sniffer if all expected answers have been received"""
257+
if self._send_done and self.noans >= self.notans and not self.multi:
258+
if self.sniffer and self.sniffer.running:
259+
self.sniffer.stop(join=False)
260+
254261
def _sndrcv_snd(self):
255262
# type: () -> None
256263
"""Function used in the sending thread of sndrcv()"""
257264
i = 0
258265
p = None
259266
try:
260267
if self.verbose:
261-
print("Begin emission:")
268+
os.write(1, b"Begin emission\n")
262269
for p in self.tobesent:
263270
# Populate the dictionary of _sndrcv_rcv
264271
# _sndrcv_rcv won't miss the answer of a packet that
265272
# has not been sent
266273
self.hsent.setdefault(p.hashret(), []).append(p)
267274
# Send packet
268275
self.pks.send(p)
269-
if self.inter:
270-
time.sleep(self.inter)
271-
if self.breakout:
276+
time.sleep(self.inter)
277+
if self.breakout.is_set():
272278
break
273279
i += 1
274280
if self.verbose:
275-
print("Finished sending %i packets." % i)
281+
os.write(1, b"\nFinished sending %i packets\n" % i)
276282
except SystemExit:
277283
pass
278284
except Exception:
@@ -291,13 +297,10 @@ def _sndrcv_snd(self):
291297
elif not self._send_done:
292298
self.notans = i
293299
self._send_done = True
294-
# In threaded mode, timeout.
295-
if self.threaded and self.timeout is not None and not self.breakout:
296-
t = time.monotonic() + self.timeout
297-
while time.monotonic() < t:
298-
if self.breakout:
299-
break
300-
time.sleep(0.1)
300+
self._stop_sniffer_if_done()
301+
# In threaded mode, timeout
302+
if self.threaded and self.timeout is not None and not self.breakout.is_set():
303+
self.breakout.wait(timeout=self.timeout)
301304
if self.sniffer and self.sniffer.running:
302305
self.sniffer.stop()
303306

@@ -324,9 +327,7 @@ def _process_packet(self, r):
324327
self.noans += 1
325328
sentpkt._answered = 1
326329
break
327-
if self._send_done and self.noans >= self.notans and not self.multi:
328-
if self.sniffer and self.sniffer.running:
329-
self.sniffer.stop(join=False)
330+
self._stop_sniffer_if_done()
330331
if not ok:
331332
if self.verbose > 1:
332333
os.write(1, b".")
@@ -342,7 +343,7 @@ def _sndrcv_rcv(self, callback):
342343
self.sniffer = AsyncSniffer()
343344
self.sniffer._run(
344345
prn=self._process_packet,
345-
timeout=None if self.threaded else self.timeout,
346+
timeout=None if self.threaded and not self._flood else self.timeout,
346347
store=False,
347348
opened_socket=self.rcv_pks,
348349
session=self.session,

test/contrib/automotive/doip.uts

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -416,6 +416,7 @@ import tempfile
416416
= Test DoIPSocket
417417

418418
server_up = threading.Event()
419+
sniff_up = threading.Event()
419420
def server():
420421
buffer = b'\x02\xfd\x80\x02\x00\x00\x00\x05\x00\x00\x00\x00\x00\x02\xfd\x80\x01\x00\x00\x00\n\x10\x10\x0e\x80P\x03\x002\x01\xf4'
421422
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
@@ -426,6 +427,7 @@ def server():
426427
sock.listen(1)
427428
server_up.set()
428429
connection, address = sock.accept()
430+
sniff_up.wait(timeout=1)
429431
connection.send(buffer)
430432
connection.close()
431433
finally:
@@ -437,7 +439,7 @@ server_thread.start()
437439
server_up.wait(timeout=1)
438440
sock = DoIPSocket(activate_routing=False)
439441

440-
pkts = sock.sniff(timeout=1, count=2)
442+
pkts = sock.sniff(timeout=1, count=2, started_callback=sniff_up.set)
441443
server_thread.join(timeout=1)
442444
assert len(pkts) == 2
443445

@@ -446,6 +448,7 @@ assert len(pkts) == 2
446448
~ linux
447449

448450
server_up = threading.Event()
451+
sniff_up = threading.Event()
449452
def server():
450453
buffer = b'\x02\xfd\x80\x02\x00\x00\x00\x05\x00\x00\x00\x00\x00\x02\xfd\x80\x01\x00\x00\x00\n\x10\x10\x0e\x80P\x03\x002\x01\xf4'
451454
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
@@ -456,6 +459,7 @@ def server():
456459
sock.listen(1)
457460
server_up.set()
458461
connection, address = sock.accept()
462+
sniff_up.wait(timeout=1)
459463
for i in range(len(buffer)):
460464
connection.send(buffer[i:i+1])
461465
time.sleep(0.01)
@@ -469,13 +473,14 @@ server_thread.start()
469473
server_up.wait(timeout=1)
470474
sock = DoIPSocket(activate_routing=False)
471475

472-
pkts = sock.sniff(timeout=1, count=2)
476+
pkts = sock.sniff(timeout=1, count=2, started_callback=sniff_up.set)
473477
server_thread.join(timeout=1)
474478
assert len(pkts) == 2
475479

476480
= Test DoIPSocket 3
477481

478482
server_up = threading.Event()
483+
sniff_up = threading.Event()
479484
def server():
480485
buffer = b'\x02\xfd\x80\x02\x00\x00\x00\x05\x00\x00\x00\x00\x00\x02\xfd\x80\x01\x00\x00\x00\n\x10\x10\x0e\x80P\x03\x002\x01\xf4'
481486
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
@@ -486,6 +491,7 @@ def server():
486491
sock.listen(1)
487492
server_up.set()
488493
connection, address = sock.accept()
494+
sniff_up.wait(timeout=1)
489495
while buffer:
490496
randlen = random.randint(0, len(buffer))
491497
connection.send(buffer[:randlen])
@@ -501,14 +507,15 @@ server_thread.start()
501507
server_up.wait(timeout=1)
502508
sock = DoIPSocket(activate_routing=False)
503509

504-
pkts = sock.sniff(timeout=1, count=2)
510+
pkts = sock.sniff(timeout=1, count=2, started_callback=sniff_up.set)
505511
server_thread.join(timeout=1)
506512
assert len(pkts) == 2
507513

508514

509515
= Test DoIPSocket6
510516

511517
server_up = threading.Event()
518+
sniff_up = threading.Event()
512519
def server():
513520
buffer = b'\x02\xfd\x80\x02\x00\x00\x00\x05\x00\x00\x00\x00\x00\x02\xfd\x80\x01\x00\x00\x00\n\x10\x10\x0e\x80P\x03\x002\x01\xf4'
514521
sock = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
@@ -519,6 +526,7 @@ def server():
519526
sock.listen(1)
520527
server_up.set()
521528
connection, address = sock.accept()
529+
sniff_up.wait(timeout=1)
522530
connection.send(buffer)
523531
connection.close()
524532
finally:
@@ -530,7 +538,7 @@ server_thread.start()
530538
server_up.wait(timeout=1)
531539
sock = DoIPSocket(ip="::1", activate_routing=False)
532540

533-
pkts = sock.sniff(timeout=1, count=2)
541+
pkts = sock.sniff(timeout=1, count=2, started_callback=sniff_up.set)
534542
server_thread.join(timeout=1)
535543
assert len(pkts) == 2
536544

@@ -604,6 +612,7 @@ def _load_certificate_chain(context) -> None:
604612

605613

606614
server_up = threading.Event()
615+
sniff_up = threading.Event()
607616
def server():
608617
context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
609618
_load_certificate_chain(context)
@@ -619,6 +628,7 @@ def server():
619628
ssock.listen(1)
620629
server_up.set()
621630
connection, address = ssock.accept()
631+
sniff_up.wait(timeout=1)
622632
connection.send(buffer)
623633
connection.close()
624634
finally:
@@ -633,14 +643,15 @@ context.check_hostname = False
633643
context.verify_mode = ssl.CERT_NONE
634644
sock = DoIPSocket(activate_routing=False, force_tls=True, context=context)
635645

636-
pkts = sock.sniff(timeout=1, count=2)
646+
pkts = sock.sniff(timeout=1, count=2, started_callback=sniff_up.set)
637647
server_thread.join(timeout=1)
638648
assert len(pkts) == 2
639649

640650
= Test DoIPSslSocket6
641651
~ broken_windows
642652

643653
server_up = threading.Event()
654+
sniff_up = threading.Event()
644655
def server():
645656
context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
646657
_load_certificate_chain(context)
@@ -656,6 +667,7 @@ def server():
656667
ssock.listen(1)
657668
server_up.set()
658669
connection, address = ssock.accept()
670+
sniff_up.wait(timeout=1)
659671
connection.send(buffer)
660672
connection.close()
661673
finally:
@@ -670,14 +682,15 @@ context.check_hostname = False
670682
context.verify_mode = ssl.CERT_NONE
671683
sock = DoIPSocket(ip="::1", activate_routing=False, force_tls=True, context=context)
672684

673-
pkts = sock.sniff(timeout=1, count=2)
685+
pkts = sock.sniff(timeout=1, count=2, started_callback=sniff_up.set)
674686
server_thread.join(timeout=1)
675687
assert len(pkts) == 2
676688

677689
= Test UDS_DoIPSslSocket6
678690
~ broken_windows
679691

680692
server_up = threading.Event()
693+
sniff_up = threading.Event()
681694
def server():
682695
context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
683696
_load_certificate_chain(context)
@@ -693,6 +706,7 @@ def server():
693706
ssock.listen(1)
694707
server_up.set()
695708
connection, address = ssock.accept()
709+
sniff_up.wait(timeout=1)
696710
connection.send(buffer)
697711
connection.close()
698712
finally:
@@ -707,15 +721,16 @@ context.check_hostname = False
707721
context.verify_mode = ssl.CERT_NONE
708722
sock = UDS_DoIPSocket(ip="::1", activate_routing=False, force_tls=True, context=context)
709723

710-
pkts = sock.sniff(timeout=1, count=2)
724+
pkts = sock.sniff(timeout=1, count=2, started_callback=sniff_up.set)
711725
server_thread.join(timeout=1)
712726
assert len(pkts) == 2
713727

714728
= Test UDS_DualDoIPSslSocket6
715-
~ broken_windows
729+
~ broken_windows not_pypy
716730

717731
server_tcp_up = threading.Event()
718732
server_tls_up = threading.Event()
733+
sniff_up = threading.Event()
719734
def server_tls():
720735
context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
721736
_load_certificate_chain(context)
@@ -732,6 +747,7 @@ def server_tls():
732747
ssock.listen(1)
733748
server_tls_up.set()
734749
connection, address = ssock.accept()
750+
sniff_up.wait(timeout=1)
735751
connection.send(buffer)
736752
connection.close()
737753
finally:
@@ -748,7 +764,7 @@ def server_tcp():
748764
server_tcp_up.set()
749765
connection, address = sock.accept()
750766
connection.send(buffer)
751-
connection.shutdown()
767+
connection.shutdown(socket.SHUT_RDWR)
752768
connection.close()
753769
finally:
754770
sock.close()
@@ -767,7 +783,7 @@ context.verify_mode = ssl.CERT_NONE
767783

768784
sock = UDS_DoIPSocket(ip="::1", context=context)
769785

770-
pkts = sock.sniff(timeout=1, count=2)
786+
pkts = sock.sniff(timeout=1, count=2, started_callback=sniff_up.set)
771787
server_tcp_thread.join(timeout=1)
772788
server_tls_thread.join(timeout=1)
773789
assert len(pkts) == 2

test/contrib/automotive/scanner/enumerator.uts

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,19 @@ class MockISOTPSocket(SuperSocket):
219219
return len(sx)
220220
@staticmethod
221221
def select(sockets, remain=None):
222+
time.sleep(0)
222223
return sockets
224+
def sr(self, *args, **kargs):
225+
from scapy import sendrecv
226+
return sendrecv.sndrcv(self, *args, threaded=False, **kargs)
227+
def sr1(self, *args, **kargs):
228+
from scapy import sendrecv
229+
ans = sendrecv.sndrcv(self, *args, threaded=False, **kargs)[0] # type: SndRcvList
230+
if len(ans) > 0:
231+
pkt = ans[0][1] # type: Packet
232+
return pkt
233+
else:
234+
return None
223235

224236
sock = MockISOTPSocket()
225237
sock.rcvd_queue.put(b"\x41")

test/regression.uts

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1832,7 +1832,7 @@ sck = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
18321832
ssck = StreamSocket(sck)
18331833

18341834
try:
1835-
r = ssck.sr1(ICMP(type='echo-request'), timeout=0.1, chainEX=True)
1835+
r = ssck.sr1(ICMP(type='echo-request'), timeout=0.1, chainEX=True, threaded=False)
18361836
assert False
18371837
except Exception:
18381838
assert True
@@ -2132,7 +2132,7 @@ retry_test(_test)
21322132
~ netaccess needs_root IP ICMP
21332133
def _test():
21342134
packet = IP(dst="8.8.8.8")/ICMP()
2135-
r = srflood(packet, timeout=2)
2135+
r = srflood(packet, timeout=0.5)
21362136
assert packet.sent_time is not None
21372137

21382138
retry_test(_test)
@@ -2142,7 +2142,7 @@ retry_test(_test)
21422142
def _test():
21432143
packet1 = IP(dst="8.8.8.8")/ICMP()
21442144
packet2 = IP(dst="8.8.4.4")/ICMP()
2145-
r = srflood([packet1, packet2], timeout=2)
2145+
r = srflood([packet1, packet2], timeout=0.5)
21462146
assert packet1.sent_time is not None
21472147
assert packet2.sent_time is not None
21482148

0 commit comments

Comments
 (0)