Skip to content

Commit 043248b

Browse files
committed
Delinted & PR review
1 parent 4531855 commit 043248b

18 files changed

+205
-178
lines changed

p2p/peer.py

Lines changed: 33 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
Dict,
2121
Iterator,
2222
List,
23+
NamedTuple,
2324
Set,
2425
TYPE_CHECKING,
2526
Tuple,
@@ -209,7 +210,7 @@ async def send_sub_proto_handshake(self) -> None:
209210

210211
@abstractmethod
211212
async def process_sub_proto_handshake(
212-
self, cmd: protocol.Command, msg: protocol._DecodedMsgType) -> None:
213+
self, cmd: protocol.Command, msg: protocol.PayloadType) -> None:
213214
raise NotImplementedError("Must be implemented by subclasses")
214215

215216
@contextlib.contextmanager
@@ -365,7 +366,7 @@ async def _run(self) -> None:
365366
self.logger.debug("%s disconnected: %s", self, e)
366367
return
367368

368-
async def read_msg(self) -> Tuple[protocol.Command, protocol._DecodedMsgType]:
369+
async def read_msg(self) -> Tuple[protocol.Command, protocol.PayloadType]:
369370
header_data = await self.read(HEADER_LEN + MAC_LEN)
370371
header = self.decrypt_header(header_data)
371372
frame_size = self.get_frame_size(header)
@@ -392,7 +393,7 @@ async def read_msg(self) -> Tuple[protocol.Command, protocol._DecodedMsgType]:
392393
self.received_msgs[cmd] += 1
393394
return cmd, decoded_msg
394395

395-
def handle_p2p_msg(self, cmd: protocol.Command, msg: protocol._DecodedMsgType) -> None:
396+
def handle_p2p_msg(self, cmd: protocol.Command, msg: protocol.PayloadType) -> None:
396397
"""Handle the base protocol (P2P) messages."""
397398
if isinstance(cmd, Disconnect):
398399
msg = cast(Dict[str, Any], msg)
@@ -406,12 +407,12 @@ def handle_p2p_msg(self, cmd: protocol.Command, msg: protocol._DecodedMsgType) -
406407
else:
407408
raise UnexpectedMessage("Unexpected msg: {} ({})".format(cmd, msg))
408409

409-
def handle_sub_proto_msg(self, cmd: protocol.Command, msg: protocol._DecodedMsgType) -> None:
410+
def handle_sub_proto_msg(self, cmd: protocol.Command, msg: protocol.PayloadType) -> None:
410411
cmd_type = type(cmd)
411412

412413
if self._subscribers:
413414
was_added = tuple(
414-
subscriber.add_msg((self, cmd, msg))
415+
subscriber.add_msg(PeerMessage(self, cmd, msg))
415416
for subscriber
416417
in self._subscribers
417418
)
@@ -424,14 +425,14 @@ def handle_sub_proto_msg(self, cmd: protocol.Command, msg: protocol._DecodedMsgT
424425
else:
425426
self.logger.warn("Peer %s has no subscribers, discarding %s msg", self, cmd)
426427

427-
def process_msg(self, cmd: protocol.Command, msg: protocol._DecodedMsgType) -> None:
428+
def process_msg(self, cmd: protocol.Command, msg: protocol.PayloadType) -> None:
428429
if cmd.is_base_protocol:
429430
self.handle_p2p_msg(cmd, msg)
430431
else:
431432
self.handle_sub_proto_msg(cmd, msg)
432433

433434
async def process_p2p_handshake(
434-
self, cmd: protocol.Command, msg: protocol._DecodedMsgType) -> None:
435+
self, cmd: protocol.Command, msg: protocol.PayloadType) -> None:
435436
msg = cast(Dict[str, Any], msg)
436437
if not isinstance(cmd, Hello):
437438
await self.disconnect(DisconnectReason.bad_protocol)
@@ -563,8 +564,14 @@ def __hash__(self) -> int:
563564
return hash(self.remote)
564565

565566

567+
class PeerMessage(NamedTuple):
568+
peer: BasePeer
569+
command: protocol.Command
570+
payload: protocol.PayloadType
571+
572+
566573
class PeerSubscriber(ABC):
567-
_msg_queue: 'asyncio.Queue[PEER_MSG_TYPE]' = None
574+
_msg_queue: 'asyncio.Queue[PeerMessage]' = None
568575

569576
@property
570577
@abstractmethod
@@ -609,7 +616,7 @@ def deregister_peer(self, peer: BasePeer) -> None:
609616
pass
610617

611618
@property
612-
def msg_queue(self) -> 'asyncio.Queue[PEER_MSG_TYPE]':
619+
def msg_queue(self) -> 'asyncio.Queue[PeerMessage]':
613620
if self._msg_queue is None:
614621
self._msg_queue = asyncio.Queue(maxsize=self.msg_queue_maxsize)
615622
return self._msg_queue
@@ -618,26 +625,29 @@ def msg_queue(self) -> 'asyncio.Queue[PEER_MSG_TYPE]':
618625
def queue_size(self) -> int:
619626
return self.msg_queue.qsize()
620627

621-
def add_msg(self, msg: 'PEER_MSG_TYPE') -> bool:
628+
def add_msg(self, msg: PeerMessage) -> bool:
622629
peer, cmd, _ = msg
623630

624631
if not self.is_subscription_command(type(cmd)):
625-
self.logger.trace( # type: ignore
626-
"Discarding %s msg from %s; not subscribed to msg type; "
627-
"subscriptions: %s",
628-
cmd, peer, self.subscription_msg_types,
629-
)
632+
if hasattr(self, 'logger'):
633+
self.logger.trace( # type: ignore
634+
"Discarding %s msg from %s; not subscribed to msg type; "
635+
"subscriptions: %s",
636+
cmd, peer, self.subscription_msg_types,
637+
)
630638
return False
631639

632640
try:
633-
self.logger.trace( # type: ignore
634-
"Adding %s msg from %s to queue; queue_size=%d", cmd, peer, self.queue_size)
641+
if hasattr(self, 'logger'):
642+
self.logger.trace( # type: ignore
643+
"Adding %s msg from %s to queue; queue_size=%d", cmd, peer, self.queue_size)
635644
self.msg_queue.put_nowait(msg)
636645
return True
637646
except asyncio.queues.QueueFull:
638-
self.logger.warn( # type: ignore
639-
"%s msg queue is full; discarding %s msg from %s",
640-
self.__class__.__name__, cmd, peer)
647+
if hasattr(self, 'logger'):
648+
self.logger.warn( # type: ignore
649+
"%s msg queue is full; discarding %s msg from %s",
650+
self.__class__.__name__, cmd, peer)
641651
return False
642652

643653
@contextlib.contextmanager
@@ -663,7 +673,7 @@ class MsgBuffer(PeerSubscriber):
663673
subscription_msg_types = {protocol.Command}
664674

665675
@to_tuple
666-
def get_messages(self) -> Iterator['PEER_MSG_TYPE']:
676+
def get_messages(self) -> Iterator[PeerMessage]:
667677
while not self.msg_queue.empty():
668678
yield self.msg_queue.get_nowait()
669679

@@ -740,7 +750,7 @@ async def start_peer(self, peer: BasePeer) -> None:
740750

741751
def _add_peer(self,
742752
peer: BasePeer,
743-
msgs: Tuple[Tuple[protocol.Command, protocol._DecodedMsgType], ...]) -> None:
753+
msgs: Tuple[Tuple[protocol.Command, protocol.PayloadType], ...]) -> None:
744754
"""Add the given peer to the pool.
745755
746756
Appart from adding it to our list of connected nodes and adding each of our subscriber's
@@ -753,7 +763,7 @@ def _add_peer(self,
753763
subscriber.register_peer(peer)
754764
peer.add_subscriber(subscriber)
755765
for cmd, msg in msgs:
756-
subscriber.add_msg((peer, cmd, msg))
766+
subscriber.add_msg(PeerMessage(peer, cmd, msg))
757767

758768
async def _run(self) -> None:
759769
# FIXME: PeerPool should probably no longer be a BaseService, but for now we're keeping it
@@ -1006,9 +1016,6 @@ def __init__(self,
10061016
self.genesis_hash = genesis_hash
10071017

10081018

1009-
PEER_MSG_TYPE = Tuple[BasePeer, protocol.Command, protocol._DecodedMsgType]
1010-
1011-
10121019
def _test() -> None:
10131020
"""
10141021
Create a Peer instance connected to a local geth instance and log messages exchanged with it.

p2p/protocol.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -23,19 +23,23 @@
2323
)
2424
from p2p.utils import get_devp2p_cmd_id
2525

26-
2726
# Workaround for import cycles caused by type annotations:
2827
# http://mypy.readthedocs.io/en/latest/common_issues.html#import-cycles
2928
if TYPE_CHECKING:
3029
from p2p.peer import ChainInfo, BasePeer # noqa: F401
3130

32-
33-
_DecodedMsgType = Union[
31+
PayloadType = Union[
3432
Dict[str, Any],
3533
List[rlp.Serializable],
3634
Tuple[rlp.Serializable, ...],
3735
]
3836

37+
# A payload to be delivered with a request
38+
TRequestPayload = TypeVar('TRequestPayload', bound=PayloadType, covariant=True)
39+
40+
# for backwards compatibility for internal references in p2p:
41+
_DecodedMsgType = PayloadType
42+
3943

4044
class Command:
4145
_cmd_id: int = None
@@ -63,7 +67,7 @@ def is_base_protocol(self) -> bool:
6367
def __str__(self) -> str:
6468
return "{} (cmd_id={})".format(self.__class__.__name__, self.cmd_id)
6569

66-
def encode_payload(self, data: Union[_DecodedMsgType, sedes.CountableList]) -> bytes:
70+
def encode_payload(self, data: Union[PayloadType, sedes.CountableList]) -> bytes:
6771
if isinstance(data, dict): # convert dict to ordered list
6872
if not isinstance(self.structure, list):
6973
raise ValueError("Command.structure must be a list when data is a dict")
@@ -79,7 +83,7 @@ def encode_payload(self, data: Union[_DecodedMsgType, sedes.CountableList]) -> b
7983
encoder = sedes.List([type_ for _, type_ in self.structure])
8084
return rlp.encode(data, sedes=encoder)
8185

82-
def decode_payload(self, rlp_data: bytes) -> _DecodedMsgType:
86+
def decode_payload(self, rlp_data: bytes) -> PayloadType:
8387
if isinstance(self.structure, sedes.CountableList):
8488
decoder = self.structure
8589
else:
@@ -100,13 +104,13 @@ def decode_payload(self, rlp_data: bytes) -> _DecodedMsgType:
100104
in zip(self.structure, data)
101105
}
102106

103-
def decode(self, data: bytes) -> _DecodedMsgType:
107+
def decode(self, data: bytes) -> PayloadType:
104108
packet_type = get_devp2p_cmd_id(data)
105109
if packet_type != self.cmd_id:
106110
raise MalformedMessage("Wrong packet type: {}".format(packet_type))
107111
return self.decode_payload(data[1:])
108112

109-
def encode(self, data: _DecodedMsgType) -> Tuple[bytes, bytes]:
113+
def encode(self, data: PayloadType) -> Tuple[bytes, bytes]:
110114
payload = self.encode_payload(data)
111115
enc_cmd_id = rlp.encode(self.cmd_id, sedes=rlp.sedes.big_endian_int)
112116
frame_size = len(enc_cmd_id) + len(payload)
@@ -126,15 +130,12 @@ def encode(self, data: _DecodedMsgType) -> Tuple[bytes, bytes]:
126130
return header, body
127131

128132

129-
TCommandPayload = TypeVar('TCommandPayload', bound=_DecodedMsgType)
130-
131-
132-
class BaseRequest(ABC, Generic[TCommandPayload]):
133+
class BaseRequest(ABC, Generic[TRequestPayload]):
133134
"""
134135
Must define command_payload during init. This is the data that will
135136
be sent to the peer with the request command.
136137
"""
137-
command_payload: TCommandPayload
138+
command_payload: TRequestPayload
138139

139140
@property
140141
@abstractmethod
@@ -165,7 +166,7 @@ def __init__(self, peer: 'BasePeer', cmd_id_offset: int) -> None:
165166
def send(self, header: bytes, body: bytes) -> None:
166167
self.peer.send(header, body)
167168

168-
def send_request(self, request: BaseRequest[_DecodedMsgType]) -> None:
169+
def send_request(self, request: BaseRequest[PayloadType]) -> None:
169170
command = self.cmd_by_type[request.cmd_type]
170171
header, body = command.encode(request.command_payload)
171172
self.send(header, body)

tests/trinity/core/p2p-proto/test_block_bodies_request_object.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -76,29 +76,29 @@ def mk_headers(*counts):
7676
def test_block_bodies_request_empty_response_is_valid():
7777
headers_bundle = mk_headers((2, 3), (8, 4), (0, 1), (0, 0))
7878
headers, _, _, _, _ = zip(*headers_bundle)
79-
request = GetBlockBodiesValidator(headers)
80-
request.validate_result(tuple())
79+
validator = GetBlockBodiesValidator(headers)
80+
validator.validate_result(tuple())
8181

8282

8383
def test_block_bodies_request_valid_with_full_response():
8484
headers_bundle = mk_headers((2, 3), (8, 4), (0, 1), (0, 0))
8585
headers, bodies, transactions_roots, trie_data_dicts, uncles_hashes = zip(*headers_bundle)
8686
transactions_bundles = tuple(zip(transactions_roots, trie_data_dicts))
8787
bodies_bundle = tuple(zip(bodies, transactions_bundles, uncles_hashes))
88-
request = GetBlockBodiesValidator(headers)
89-
request.validate_result(bodies_bundle)
88+
validator = GetBlockBodiesValidator(headers)
89+
validator.validate_result(bodies_bundle)
9090

9191

9292
def test_block_bodies_request_valid_with_partial_response():
9393
headers_bundle = mk_headers((2, 3), (8, 4), (0, 1), (0, 0))
9494
headers, bodies, transactions_roots, trie_data_dicts, uncles_hashes = zip(*headers_bundle)
9595
transactions_bundles = tuple(zip(transactions_roots, trie_data_dicts))
9696
bodies_bundle = tuple(zip(bodies, transactions_bundles, uncles_hashes))
97-
request = GetBlockBodiesValidator(headers)
97+
validator = GetBlockBodiesValidator(headers)
9898

99-
request.validate_result(bodies_bundle[:2])
100-
request.validate_result(bodies_bundle[2:])
101-
request.validate_result((bodies_bundle[0], bodies_bundle[2], bodies_bundle[3]))
99+
validator.validate_result(bodies_bundle[:2])
100+
validator.validate_result(bodies_bundle[2:])
101+
validator.validate_result((bodies_bundle[0], bodies_bundle[2], bodies_bundle[3]))
102102

103103

104104
def test_block_bodies_request_with_fully_invalid_response():
@@ -112,17 +112,17 @@ def test_block_bodies_request_with_fully_invalid_response():
112112
w_transactions_bundles = tuple(zip(w_transactions_roots, w_trie_data_dicts))
113113
w_bodies_bundle = tuple(zip(w_bodies, w_transactions_bundles, w_uncles_hashes))
114114

115-
request = GetBlockBodiesValidator(headers)
115+
validator = GetBlockBodiesValidator(headers)
116116
with pytest.raises(ValidationError):
117-
request.validate_result(w_bodies_bundle)
117+
validator.validate_result(w_bodies_bundle)
118118

119119

120120
def test_block_bodies_request_with_extra_unrequested_bodies():
121121
headers_bundle = mk_headers((2, 3), (8, 4), (0, 1), (0, 0))
122122
headers, bodies, transactions_roots, trie_data_dicts, uncles_hashes = zip(*headers_bundle)
123123
transactions_bundles = tuple(zip(transactions_roots, trie_data_dicts))
124124
bodies_bundle = tuple(zip(bodies, transactions_bundles, uncles_hashes))
125-
request = GetBlockBodiesValidator(headers)
125+
validator = GetBlockBodiesValidator(headers)
126126

127127
wrong_headers_bundle = mk_headers((3, 2), (4, 8), (1, 0), (0, 0))
128128
w_headers, w_bodies, w_transactions_roots, w_trie_data_dicts, w_uncles_hashes = zip(
@@ -131,6 +131,6 @@ def test_block_bodies_request_with_extra_unrequested_bodies():
131131
w_transactions_bundles = tuple(zip(w_transactions_roots, w_trie_data_dicts))
132132
w_bodies_bundle = tuple(zip(w_bodies, w_transactions_bundles, w_uncles_hashes))
133133

134-
request = GetBlockBodiesValidator(headers)
134+
validator = GetBlockBodiesValidator(headers)
135135
with pytest.raises(ValidationError):
136-
request.validate_result(bodies_bundle + w_bodies_bundle)
136+
validator.validate_result(bodies_bundle + w_bodies_bundle)

tests/trinity/core/p2p-proto/test_headers_request_validator.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -129,10 +129,10 @@ def test_header_request_sequence_matching(
129129
params,
130130
sequence,
131131
is_match):
132-
request = BlockHeadersValidator(*params)
132+
validator = BlockHeadersValidator(*params)
133133

134134
if is_match:
135-
request._validate_sequence(sequence)
135+
validator._validate_sequence(sequence)
136136
else:
137137
with pytest.raises(ValidationError):
138-
request._validate_sequence(sequence)
138+
validator._validate_sequence(sequence)

0 commit comments

Comments
 (0)