Skip to content

Commit 6e4364d

Browse files
authored
fix: reconnects for websocket connection closed ok and ws-api (#1655)
1 parent 9b9c9e9 commit 6e4364d

File tree

4 files changed

+259
-9
lines changed

4 files changed

+259
-9
lines changed

binance/ws/keepalive_websocket.py

Lines changed: 40 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,16 +31,23 @@ def __init__(
3131
self._timer = None
3232
self._subscription_id = None
3333
self._listen_key = None # Used for non spot stream types
34+
self._uses_ws_api_subscription = False # True when using ws_api
3435

3536
async def __aexit__(self, *args, **kwargs):
36-
if not self._path:
37-
return
3837
if self._timer:
3938
self._timer.cancel()
4039
self._timer = None
4140
# Clean up subscription if it exists
4241
if self._subscription_id is not None:
42+
# Unregister the queue from ws_api before unsubscribing
43+
if hasattr(self._client, 'ws_api') and self._client.ws_api:
44+
self._client.ws_api.unregister_subscription_queue(self._subscription_id)
4345
await self._unsubscribe_from_user_data_stream()
46+
if self._uses_ws_api_subscription:
47+
# For ws_api subscriptions, we don't manage the connection
48+
return
49+
if not self._path:
50+
return
4451
await super().__aexit__(*args, **kwargs)
4552

4653
def _build_path(self):
@@ -51,16 +58,43 @@ def _build_path(self):
5158

5259
async def _before_connect(self):
5360
if self._keepalive_type == "user":
61+
# Subscribe via ws_api and register our own queue for events
5462
self._subscription_id = await self._subscribe_to_user_data_stream()
55-
# Reuse the ws_api connection that's already established
56-
self.ws = self._client.ws_api.ws
57-
self.ws_state = self._client.ws_api.ws_state
58-
self._queue = self._client.ws_api._queue
63+
self._uses_ws_api_subscription = True
64+
# Register our queue with ws_api so events get routed to us
65+
self._client.ws_api.register_subscription_queue(self._subscription_id, self._queue)
66+
self._path = f"user_subscription:{self._subscription_id}"
5967
return
6068
if not self._listen_key:
6169
self._listen_key = await self._get_listen_key()
6270
self._build_path()
6371

72+
async def connect(self):
73+
"""Override connect to handle ws_api subscriptions differently."""
74+
if self._keepalive_type == "user":
75+
# For user sockets using ws_api subscription:
76+
# - Subscribe via ws_api (done in _before_connect)
77+
# - Don't create our own websocket connection
78+
# - Don't start a read loop (ws_api handles reading)
79+
await self._before_connect()
80+
await self._after_connect()
81+
return
82+
# For other keepalive types, use normal connection logic
83+
await super().connect()
84+
85+
async def recv(self):
86+
"""Override recv to work without a read loop for ws_api subscriptions."""
87+
if self._uses_ws_api_subscription:
88+
# For ws_api subscriptions, just read from queue
89+
res = None
90+
while not res:
91+
try:
92+
res = await asyncio.wait_for(self._queue.get(), timeout=self.TIMEOUT)
93+
except asyncio.TimeoutError:
94+
self._log.debug(f"no message in {self.TIMEOUT} seconds")
95+
return res
96+
return await super().recv()
97+
6498
async def _after_connect(self):
6599
if self._timer is None:
66100
self._start_socket_timer()

binance/ws/reconnecting_websocket.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@
1515
pass
1616

1717
try:
18-
from websockets.exceptions import ConnectionClosedError # type: ignore
18+
from websockets.exceptions import ConnectionClosedError, ConnectionClosedOK # type: ignore
1919
except ImportError:
20-
from websockets import ConnectionClosedError # type: ignore
20+
from websockets import ConnectionClosedError, ConnectionClosedOK # type: ignore
2121

2222

2323
Proxy = None
@@ -226,6 +226,7 @@ async def _read_loop(self):
226226
asyncio.IncompleteReadError,
227227
gaierror,
228228
ConnectionClosedError,
229+
ConnectionClosedOK,
229230
BinanceWebsocketClosed,
230231
) as e:
231232
# reports errors and continue loop

binance/ws/websocket_api.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,18 @@ def __init__(self, url: str, tld: str = "com", testnet: bool = False, https_prox
1414
self._testnet = testnet
1515
self._responses: Dict[str, asyncio.Future] = {}
1616
self._connection_lock: Optional[asyncio.Lock] = None
17+
# Subscription queues for routing user data stream events
18+
self._subscription_queues: Dict[str, asyncio.Queue] = {}
1719
super().__init__(url=url, prefix="", path="", is_binary=False, https_proxy=https_proxy)
1820

21+
def register_subscription_queue(self, subscription_id: str, queue: asyncio.Queue) -> None:
22+
"""Register a queue to receive events for a specific subscription."""
23+
self._subscription_queues[subscription_id] = queue
24+
25+
def unregister_subscription_queue(self, subscription_id: str) -> None:
26+
"""Unregister a subscription queue."""
27+
self._subscription_queues.pop(subscription_id, None)
28+
1929
@property
2030
def connection_lock(self) -> asyncio.Lock:
2131
if self._connection_lock is None:
@@ -33,7 +43,21 @@ def _handle_message(self, msg):
3343
# Check if this is a subscription event (user data stream, etc.)
3444
# These have 'subscriptionId' and 'event' fields instead of 'id'
3545
if "subscriptionId" in parsed_msg and "event" in parsed_msg:
36-
return parsed_msg["event"]
46+
subscription_id = parsed_msg["subscriptionId"]
47+
event = parsed_msg["event"]
48+
# Route to the registered subscription queue if one exists
49+
if subscription_id in self._subscription_queues:
50+
queue = self._subscription_queues[subscription_id]
51+
try:
52+
queue.put_nowait(event)
53+
except asyncio.QueueFull:
54+
self._log.error(f"Subscription queue full for {subscription_id}, dropping event")
55+
except Exception as e:
56+
self._log.error(f"Error putting event in subscription queue for {subscription_id}: {e}")
57+
return None # Don't put in main queue
58+
else:
59+
# No registered queue, return event for main queue (backward compat)
60+
return event
3761

3862
req_id, exception = None, None
3963
if "id" in parsed_msg:
Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
"""
2+
Integration tests for user socket with ws_api subscription routing.
3+
4+
These tests verify that the user socket correctly:
5+
1. Uses ws_api for subscription (not creating its own connection)
6+
2. Has its own queue for receiving events (not sharing ws_api's queue)
7+
3. Does not start its own read loop (ws_api handles reading)
8+
4. Properly cleans up subscriptions on exit
9+
10+
Requirements:
11+
- Binance testnet API credentials (configured in conftest.py)
12+
- Network connectivity to testnet
13+
14+
Run with: pytest tests/test_user_socket_integration.py -v
15+
"""
16+
import asyncio
17+
import pytest
18+
import pytest_asyncio
19+
20+
from binance import BinanceSocketManager
21+
22+
23+
@pytest_asyncio.fixture
24+
async def socket_manager(clientAsync):
25+
"""Create a BinanceSocketManager using the clientAsync fixture from conftest."""
26+
return BinanceSocketManager(clientAsync)
27+
28+
29+
class TestUserSocketArchitecture:
30+
"""Tests verifying the user socket architecture is correct."""
31+
32+
@pytest.mark.asyncio
33+
async def test_user_socket_has_separate_queue(self, clientAsync, socket_manager):
34+
"""User socket should have its own queue, not share ws_api's queue."""
35+
user_socket = socket_manager.user_socket()
36+
37+
async with user_socket:
38+
# Queues should be different objects
39+
assert user_socket._queue is not clientAsync.ws_api._queue, \
40+
"user_socket should have its own queue, not share ws_api's queue"
41+
42+
@pytest.mark.asyncio
43+
async def test_user_socket_uses_ws_api_subscription(self, clientAsync, socket_manager):
44+
"""User socket should use ws_api subscription mechanism."""
45+
user_socket = socket_manager.user_socket()
46+
47+
async with user_socket:
48+
# Should be marked as using ws_api subscription
49+
assert user_socket._uses_ws_api_subscription is True, \
50+
"user_socket should be marked as using ws_api subscription"
51+
52+
# Should have a subscription ID
53+
assert user_socket._subscription_id is not None, \
54+
"user_socket should have a subscription ID"
55+
56+
@pytest.mark.asyncio
57+
async def test_user_socket_no_read_loop(self, clientAsync, socket_manager):
58+
"""User socket should NOT have its own read loop (ws_api handles reading)."""
59+
user_socket = socket_manager.user_socket()
60+
61+
async with user_socket:
62+
# user_socket should not have started its own read loop
63+
assert user_socket._handle_read_loop is None, \
64+
"user_socket should not have its own read loop"
65+
66+
# ws_api should have a read loop
67+
assert clientAsync.ws_api._handle_read_loop is not None, \
68+
"ws_api should have a read loop"
69+
70+
@pytest.mark.asyncio
71+
async def test_user_socket_queue_registered_with_ws_api(self, clientAsync, socket_manager):
72+
"""User socket's queue should be registered with ws_api for event routing."""
73+
user_socket = socket_manager.user_socket()
74+
75+
async with user_socket:
76+
sub_id = user_socket._subscription_id
77+
78+
# Subscription should be registered in ws_api
79+
assert sub_id in clientAsync.ws_api._subscription_queues, \
80+
"Subscription should be registered with ws_api"
81+
82+
# Registered queue should be user_socket's queue
83+
registered_queue = clientAsync.ws_api._subscription_queues[sub_id]
84+
assert registered_queue is user_socket._queue, \
85+
"Registered queue should be user_socket's queue"
86+
87+
@pytest.mark.asyncio
88+
async def test_user_socket_cleanup_on_exit(self, clientAsync, socket_manager):
89+
"""User socket should unregister from ws_api on exit."""
90+
user_socket = socket_manager.user_socket()
91+
92+
async with user_socket:
93+
sub_id = user_socket._subscription_id
94+
# Verify it's registered while connected
95+
assert sub_id in clientAsync.ws_api._subscription_queues
96+
97+
# After exit, subscription should be unregistered
98+
assert sub_id not in clientAsync.ws_api._subscription_queues, \
99+
"Subscription should be unregistered after exit"
100+
101+
102+
class TestUserSocketFunctionality:
103+
"""Tests verifying user socket functionality works correctly."""
104+
105+
@pytest.mark.asyncio
106+
async def test_user_socket_recv_timeout(self, clientAsync, socket_manager):
107+
"""User socket recv() should timeout gracefully when no events."""
108+
user_socket = socket_manager.user_socket()
109+
110+
async with user_socket:
111+
# recv() should timeout without errors (no events on quiet account)
112+
with pytest.raises(asyncio.TimeoutError):
113+
await asyncio.wait_for(user_socket.recv(), timeout=2)
114+
115+
@pytest.mark.asyncio
116+
async def test_user_socket_context_manager(self, clientAsync, socket_manager):
117+
"""User socket should work as async context manager."""
118+
user_socket = socket_manager.user_socket()
119+
120+
# Should not be connected initially
121+
assert user_socket._subscription_id is None
122+
123+
async with user_socket:
124+
# Should be connected inside context
125+
assert user_socket._subscription_id is not None
126+
assert user_socket._uses_ws_api_subscription is True
127+
128+
# Subscription ID is cleared after unsubscribe
129+
assert user_socket._subscription_id is None
130+
131+
132+
class TestNonUserSockets:
133+
"""Tests verifying other socket types still work normally."""
134+
135+
@pytest.mark.asyncio
136+
async def test_margin_socket_not_using_ws_api_subscription(self, clientAsync, socket_manager):
137+
"""Non-user KeepAliveWebsockets (like margin socket) should not use ws_api subscription."""
138+
# margin_socket is a KeepAliveWebsocket with keepalive_type="margin"
139+
# Create it but don't connect - just check the flag
140+
margin_socket = socket_manager.margin_socket()
141+
142+
# Before connecting, the flag should be False (default)
143+
assert margin_socket._uses_ws_api_subscription is False, \
144+
"Margin socket should not use ws_api subscription"
145+
146+
# The _keepalive_type should be "margin", not "user"
147+
assert margin_socket._keepalive_type == "margin"
148+
149+
150+
class TestWsApiSubscriptionRouting:
151+
"""Tests verifying ws_api correctly routes subscription events."""
152+
153+
@pytest.mark.asyncio
154+
async def test_ws_api_has_subscription_queues(self, clientAsync):
155+
"""ws_api should have subscription queues dict."""
156+
# Ensure ws_api is initialized
157+
await clientAsync.ws_api._ensure_ws_connection()
158+
159+
assert hasattr(clientAsync.ws_api, '_subscription_queues'), \
160+
"ws_api should have _subscription_queues attribute"
161+
assert isinstance(clientAsync.ws_api._subscription_queues, dict), \
162+
"_subscription_queues should be a dict"
163+
164+
@pytest.mark.asyncio
165+
async def test_ws_api_register_unregister_queue(self, clientAsync):
166+
"""ws_api should be able to register and unregister queues."""
167+
await clientAsync.ws_api._ensure_ws_connection()
168+
169+
test_queue = asyncio.Queue()
170+
test_sub_id = "test_subscription_123"
171+
172+
# Register
173+
clientAsync.ws_api.register_subscription_queue(test_sub_id, test_queue)
174+
assert test_sub_id in clientAsync.ws_api._subscription_queues
175+
assert clientAsync.ws_api._subscription_queues[test_sub_id] is test_queue
176+
177+
# Unregister
178+
clientAsync.ws_api.unregister_subscription_queue(test_sub_id)
179+
assert test_sub_id not in clientAsync.ws_api._subscription_queues
180+
181+
@pytest.mark.asyncio
182+
async def test_ws_api_unregister_nonexistent_is_safe(self, clientAsync):
183+
"""Unregistering a non-existent subscription should not raise."""
184+
await clientAsync.ws_api._ensure_ws_connection()
185+
186+
# Should not raise
187+
clientAsync.ws_api.unregister_subscription_queue("nonexistent_sub_id")
188+
189+
190+
if __name__ == "__main__":
191+
pytest.main([__file__, "-v"])

0 commit comments

Comments
 (0)