88{.push raises : [].}
99
1010import
11- std/ sequtils,
11+ std/ [ sequtils, sets] ,
1212 chronos,
1313 stew/ [byteutils, leb128, endians2],
1414 chronicles,
@@ -35,17 +35,20 @@ const
3535 talkReqOverhead = getTalkReqOverhead (utpProtocolId)
3636 utpHeaderOverhead = 20
3737 maxUtpPayloadSize = maxDiscv5PacketSize - talkReqOverhead - utpHeaderOverhead
38+ maxPendingTransfersPerPeer = 128
3839
3940type
4041 ContentRequest = object
4142 connectionId: uint16
4243 nodeId: NodeId
44+ contentId: ContentId
4345 content: seq [byte ]
4446 timeout: Moment
4547
4648 ContentOffer = object
4749 connectionId: uint16
4850 nodeId: NodeId
51+ contentIds: seq [ContentId ]
4952 contentKeys: ContentKeysList
5053 timeout: Moment
5154
6972 connectionTimeout: Duration
7073 contentReadTimeout* : Duration
7174 rng: ref HmacDrbgContext
75+ pendingTransfers: TableRef [NodeId , HashSet [ContentId ]]
7276 contentQueue* : AsyncQueue [(Opt [NodeId ], ContentKeysList , seq [seq [byte ]])]
7377
7478 StreamManager * = ref object
7579 transport: UtpDiscv5Protocol
7680 streams: seq [PortalStream ]
7781 rng: ref HmacDrbgContext
7882
83+ proc canAddPendingTransfer (
84+ transfers: TableRef [NodeId , HashSet [ContentId ]],
85+ nodeId: NodeId ,
86+ contentId: ContentId ,
87+ limit: int ,
88+ ): bool =
89+ if not transfers.contains (nodeId):
90+ return true
91+
92+ try :
93+ let contentIds = transfers[nodeId]
94+ (contentIds.len () < limit) and not contentIds.contains (contentId)
95+ except KeyError as e:
96+ raiseAssert (e.msg)
97+
98+ proc addPendingTransfer (
99+ transfers: TableRef [NodeId , HashSet [ContentId ]],
100+ nodeId: NodeId ,
101+ contentId: ContentId ,
102+ ) =
103+ if transfers.contains (nodeId):
104+ try :
105+ transfers[nodeId].incl (contentId)
106+ except KeyError as e:
107+ raiseAssert (e.msg)
108+ else :
109+ var contentIds = initHashSet [ContentId ]()
110+ contentIds.incl (contentId)
111+ transfers[nodeId] = contentIds
112+
113+ proc removePendingTransfer (
114+ transfers: TableRef [NodeId , HashSet [ContentId ]],
115+ nodeId: NodeId ,
116+ contentId: ContentId ,
117+ ) =
118+ doAssert transfers.contains (nodeId)
119+
120+ try :
121+ transfers[nodeId].excl (contentId)
122+
123+ if transfers[nodeId].len () == 0 :
124+ transfers.del (nodeId)
125+ except KeyError as e:
126+ raiseAssert (e.msg)
127+
128+ template canAddPendingTransfer * (
129+ stream: PortalStream , nodeId: NodeId , contentId: ContentId
130+ ): bool =
131+ stream.pendingTransfers.canAddPendingTransfer (
132+ srcId, contentId, maxPendingTransfersPerPeer
133+ )
134+
135+ template addPendingTransfer * (
136+ stream: PortalStream , nodeId: NodeId , contentId: ContentId
137+ ) =
138+ addPendingTransfer (stream.pendingTransfers, nodeId, contentId)
139+
140+ template removePendingTransfer * (
141+ stream: PortalStream , nodeId: NodeId , contentId: ContentId
142+ ) =
143+ removePendingTransfer (stream.pendingTransfers, nodeId, contentId)
144+
79145proc pruneAllowedConnections (stream: PortalStream ) =
80146 # Prune requests and offers that didn't receive a connection request
81147 # before `connectionTimeout`.
82148 let now = Moment .now ()
83- stream.contentRequests.keepIf (
84- proc (x: ContentRequest ): bool =
85- x.timeout > now
86- )
87- stream.contentOffers.keepIf (
88- proc (x: ContentOffer ): bool =
89- x.timeout > now
90- )
149+
150+ for i, request in stream.contentRequests:
151+ if request.timeout <= now:
152+ stream.removePendingTransfer (request.nodeId, request.contentId)
153+ stream.contentRequests.del (i)
154+
155+ for i, offer in stream.contentOffers:
156+ if offer.timeout <= now:
157+ for contentId in offer.contentIds:
158+ stream.removePendingTransfer (offer.nodeId, contentId)
159+ stream.contentOffers.del (i)
91160
92161proc addContentOffer * (
93- stream: PortalStream , nodeId: NodeId , contentKeys: ContentKeysList
162+ stream: PortalStream ,
163+ nodeId: NodeId ,
164+ contentKeys: ContentKeysList ,
165+ contentIds: seq [ContentId ],
94166): Bytes2 =
95167 stream.pruneAllowedConnections ()
96168
@@ -107,6 +179,7 @@ proc addContentOffer*(
107179 let contentOffer = ContentOffer (
108180 connectionId: id,
109181 nodeId: nodeId,
182+ contentIds: contentIds,
110183 contentKeys: contentKeys,
111184 timeout: Moment .now () + stream.connectionTimeout,
112185 )
@@ -115,7 +188,7 @@ proc addContentOffer*(
115188 return connectionId
116189
117190proc addContentRequest * (
118- stream: PortalStream , nodeId: NodeId , content: seq [byte ]
191+ stream: PortalStream , nodeId: NodeId , contentId: ContentId , content: seq [byte ]
119192): Bytes2 =
120193 stream.pruneAllowedConnections ()
121194
@@ -129,6 +202,7 @@ proc addContentRequest*(
129202 let contentRequest = ContentRequest (
130203 connectionId: id,
131204 nodeId: nodeId,
205+ contentId: contentId,
132206 content: content,
133207 timeout: Moment .now () + stream.connectionTimeout,
134208 )
@@ -285,6 +359,7 @@ proc new(
285359 transport: transport,
286360 connectionTimeout: connectionTimeout,
287361 contentReadTimeout: contentReadTimeout,
362+ pendingTransfers: newTable [NodeId , HashSet [ContentId ]](),
288363 contentQueue: contentQueue,
289364 rng: rng,
290365 )
@@ -317,13 +392,18 @@ proc handleIncomingConnection(
317392 if request.connectionId == socket.connectionId and
318393 request.nodeId == socket.remoteAddress.nodeId:
319394 let fut = socket.writeContentRequest (stream, request)
395+
396+ stream.removePendingTransfer (request.nodeId, request.contentId)
320397 stream.contentRequests.del (i)
321398 return noCancel (fut)
322399
323400 for i, offer in stream.contentOffers:
324401 if offer.connectionId == socket.connectionId and
325402 offer.nodeId == socket.remoteAddress.nodeId:
326403 let fut = socket.readContentOffer (stream, offer)
404+
405+ for contentId in offer.contentIds:
406+ stream.removePendingTransfer (offer.nodeId, contentId)
327407 stream.contentOffers.del (i)
328408 return noCancel (fut)
329409
0 commit comments