Skip to content

Commit 2da6ea7

Browse files
Remove irrelevant Lock
1 parent 58e6eab commit 2da6ea7

File tree

4 files changed

+38
-59
lines changed

4 files changed

+38
-59
lines changed

src/replit_river/client_session.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ async def _handle_messages_from_ws(self) -> None:
132132
logger.debug(f"{self._transport_id} got a message %r", msg)
133133

134134
# Update bookkeeping
135-
match await self._seq_manager.check_seq_and_update(msg):
135+
match self._seq_manager.check_seq_and_update(msg):
136136
case IgnoreMessage():
137137
continue
138138
case None:

src/replit_river/seq_manager.py

Lines changed: 30 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import asyncio
21
import logging
32
from dataclasses import dataclass
43

@@ -7,12 +6,6 @@
76
logger = logging.getLogger(__name__)
87

98

10-
class IgnoreMessageException(Exception):
11-
"""Exception to ignore a transport message, but good to continue."""
12-
13-
pass
14-
15-
169
class InvalidMessageException(Exception):
1710
"""Error processing a transport message, should raise a exception."""
1811

@@ -46,53 +39,40 @@ class SeqManager:
4639
def __init__(
4740
self,
4841
) -> None:
49-
self._seq_lock = asyncio.Lock()
5042
self.seq = 0
51-
self._ack_lock = asyncio.Lock()
5243
self.ack = 0
5344
self.receiver_ack = 0
5445

55-
async def get_seq_and_increment(self) -> int:
46+
def get_seq_and_increment(self) -> int:
5647
"""Get the current sequence number and increment it.
5748
This removes one lock acquire than get_seq and increment_seq separately.
5849
"""
59-
async with self._seq_lock:
60-
current_value = self.seq
61-
self.seq += 1
62-
return current_value
63-
64-
async def increment_seq(self) -> int:
65-
async with self._seq_lock:
66-
self.seq += 1
67-
return self.seq
68-
69-
async def get_seq(self) -> int:
70-
async with self._seq_lock:
71-
return self.seq
72-
73-
async def get_ack(self) -> int:
74-
async with self._ack_lock:
75-
return self.ack
76-
77-
async def check_seq_and_update(self, msg: TransportMessage) -> IgnoreMessage | None:
78-
async with self._ack_lock:
79-
if msg.seq != self.ack:
80-
if msg.seq < self.ack:
81-
return IgnoreMessage()
82-
else:
83-
logger.warn(
84-
f"Out of order message received got {msg.seq} expected "
85-
f"{self.ack}"
86-
)
87-
88-
raise OutOfOrderMessageException(
89-
f"Out of order message received got {msg.seq} expected "
90-
f"{self.ack}"
91-
)
92-
self.receiver_ack = msg.ack
93-
await self._set_ack(msg.seq + 1)
94-
95-
async def _set_ack(self, new_ack: int) -> int:
96-
async with self._ack_lock:
97-
self.ack = new_ack
98-
return self.ack
50+
current_value = self.seq
51+
self.seq += 1
52+
return current_value
53+
54+
def increment_seq(self) -> int:
55+
self.seq += 1
56+
return self.seq
57+
58+
def get_seq(self) -> int:
59+
return self.seq
60+
61+
def get_ack(self) -> int:
62+
return self.ack
63+
64+
def check_seq_and_update(self, msg: TransportMessage) -> IgnoreMessage | None:
65+
if msg.seq != self.ack:
66+
if msg.seq < self.ack:
67+
return IgnoreMessage()
68+
else:
69+
logger.warn(
70+
f"Out of order message received got {msg.seq} expected {self.ack}"
71+
)
72+
73+
raise OutOfOrderMessageException(
74+
f"Out of order message received got {msg.seq} expected {self.ack}"
75+
)
76+
self.receiver_ack = msg.ack
77+
self.ack = msg.seq + 1
78+
return None

src/replit_river/server_session.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -170,8 +170,7 @@ async def _handle_messages_from_ws(self, tg: asyncio.TaskGroup) -> None:
170170
if isinstance(_stream, IgnoreMessage):
171171
continue
172172
if not stream:
173-
async with self._stream_lock:
174-
self._streams[msg.streamId] = _stream
173+
self._streams[msg.streamId] = _stream
175174
stream = _stream
176175

177176
if msg.controlFlags & STREAM_CLOSED_BIT != 0:
@@ -193,7 +192,7 @@ async def _handle_messages_from_ws(self, tg: asyncio.TaskGroup) -> None:
193192
async def _open_stream_and_call_handler(
194193
self,
195194
msg: TransportMessage,
196-
tg: asyncio.TaskGroup | None,
195+
tg: asyncio.TaskGroup,
197196
) -> Channel | IgnoreMessage:
198197
if not msg.serviceName or not msg.procedureName:
199198
logger.warning(

src/replit_river/session.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -225,18 +225,18 @@ async def _send_transport_message(
225225

226226
async def get_next_expected_seq(self) -> int:
227227
"""Get the next expected sequence number from the server."""
228-
return await self._seq_manager.get_ack()
228+
return self._seq_manager.get_ack()
229229

230230
async def get_next_sent_seq(self) -> int:
231231
"""Get the next sequence number that the client will send."""
232232
nextMessage = await self._buffer.peek()
233233
if nextMessage:
234234
return nextMessage.seq
235-
return await self._seq_manager.get_seq()
235+
return self._seq_manager.get_seq()
236236

237237
async def get_next_expected_ack(self) -> int:
238238
"""Get the next expected ack that the client expects."""
239-
return await self._seq_manager.get_seq()
239+
return self._seq_manager.get_seq()
240240

241241
async def send_message(
242242
self,
@@ -256,8 +256,8 @@ async def send_message(
256256
id=nanoid.generate(),
257257
from_=self._transport_id, # type: ignore
258258
to=self._to_id,
259-
seq=await self._seq_manager.get_seq_and_increment(),
260-
ack=await self._seq_manager.get_ack(),
259+
seq=self._seq_manager.get_seq_and_increment(),
260+
ack=self._seq_manager.get_ack(),
261261
controlFlags=control_flags,
262262
payload=payload,
263263
serviceName=service_name,

0 commit comments

Comments
 (0)