Skip to content

Commit f8fde4d

Browse files
author
盐粒 Yanli
authored
[Feature] Support borrow session for contextual execution (#386)
* support borrow session on client * add doc * run lint
1 parent c560c75 commit f8fde4d

File tree

6 files changed

+162
-131
lines changed

6 files changed

+162
-131
lines changed

docs/2_concurrency.md

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,26 @@ async def concurrent_example():
6262
asyncio.run(concurrent_example())
6363
```
6464

65+
## Contextual Execution
66+
67+
By default, statements run on a random session from the pool. When you need to run several queries on the same session, call `borrow` to obtain and reuse a specific session.
68+
69+
```python
70+
async def contextual_example():
71+
async with await NebulaAsyncClient.connect(
72+
hosts=["127.0.0.1:9669"],
73+
username="root",
74+
password="NebulaGraph01",
75+
session_pool_config=SessionPoolConfig(),
76+
) as client:
77+
print("Connected to the server...")
78+
async with client.borrow() as session:
79+
await session.execute("SESSION SET GRAPH movie")
80+
res = await session.execute("MATCH (v:Movie) RETURN count(v)")
81+
res.print()
82+
```
83+
84+
6585
## Understanding Timeout Values
6686

6787
The client uses three different timeouts that apply at different stages:

src/nebulagraph_python/client/_connection.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,8 @@ class Connection:
9696

9797
# Config
9898
config: ConnectionConfig
99+
# Track which host was successfully connected for session routing
100+
connected: HostAddress | None = field(default=None, init=False)
99101

100102
# Owned Resources
101103
_stub: Optional[graph_pb2_grpc.GraphServiceStub] = field(default=None, init=False)
@@ -152,6 +154,8 @@ def connect(self):
152154
logger.info(
153155
f"Successfully connected to {host_addr.host}:{host_addr.port}."
154156
)
157+
# Remember which host we actually connected to
158+
self.connected = host_addr
155159
return
156160
except Exception as e:
157161
logger.warning(
@@ -174,6 +178,7 @@ def close(self):
174178
self._channel.close()
175179
self._channel = None
176180
self._stub = None
181+
self.connected = None
177182
except Exception:
178183
logger.exception("Failed to close connection")
179184

@@ -303,6 +308,7 @@ class AsyncConnection:
303308
"""
304309

305310
config: ConnectionConfig
311+
connected: HostAddress | None = None
306312
_stub: Optional[graph_pb2_grpc.GraphServiceStub] = field(default=None, init=False)
307313
_channel: Optional[grpc.aio.Channel] = field(
308314
default=None, init=False
@@ -358,6 +364,7 @@ async def connect(self):
358364
logger.info(
359365
f"Successfully connected to {host_addr.host}:{host_addr.port} asynchronously."
360366
)
367+
self.connected = host_addr
361368
return
362369
except Exception as e:
363370
logger.warning(
@@ -380,6 +387,7 @@ async def close(self):
380387
await self._channel.close()
381388
self._channel = None
382389
self._stub = None
390+
self.connected = None
383391
except BaseException:
384392
logger.exception("Failed to close async connection")
385393

src/nebulagraph_python/client/_session.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from nebulagraph_python.error import ExecutingError
2626

2727

28-
@dataclass
28+
@dataclass(kw_only=True, frozen=True)
2929
class SessionConfig:
3030
schema: Optional[str] = None
3131
graph: Optional[str] = None
@@ -47,31 +47,32 @@ class SessionBase:
4747

4848
@dataclass
4949
class Session(SessionBase):
50-
conn: "Connection"
50+
_conn: "Connection"
5151

5252
def execute(
5353
self, statement: str, *, timeout: Optional[float] = None, do_ping: bool = False
5454
):
55-
res = self.conn.execute(
55+
res = self._conn.execute(
5656
self._session, statement, timeout=timeout, do_ping=do_ping
5757
)
5858
# Retry for only one time
5959
if res.status_code == ErrorCode.SESSION_NOT_FOUND.value:
60-
self._session = self.conn.authenticate(
60+
self._session = self._conn.authenticate(
6161
self.username,
6262
self.password,
6363
session_config=self.session_config,
6464
auth_options=self.auth_options,
6565
)
66-
res = self.conn.execute(
66+
res = self._conn.execute(
6767
self._session, statement, timeout=timeout, do_ping=do_ping
6868
)
69+
res.raise_on_error()
6970
return res
7071

71-
def close(self):
72+
def _close(self):
7273
"""Close session"""
7374
try:
74-
self.conn.execute(self._session, "SESSION CLOSE")
75+
self._conn.execute(self._session, "SESSION CLOSE")
7576
except Exception:
7677
logger.exception("Failed to close session")
7778

@@ -84,30 +85,31 @@ def __eq__(self, other):
8485

8586
@dataclass
8687
class AsyncSession(SessionBase):
87-
conn: "AsyncConnection"
88+
_conn: "AsyncConnection"
8889

8990
async def execute(
9091
self, statement: str, *, timeout: Optional[float] = None, do_ping: bool = False
9192
):
92-
res = await self.conn.execute(
93+
res = await self._conn.execute(
9394
self._session, statement, timeout=timeout, do_ping=do_ping
9495
)
9596
# Retry for only one time
9697
if res.status_code == ErrorCode.SESSION_NOT_FOUND.value:
97-
self._session = await self.conn.authenticate(
98+
self._session = await self._conn.authenticate(
9899
self.username,
99100
self.password,
100101
session_config=self.session_config,
101102
auth_options=self.auth_options,
102103
)
103-
res = await self.conn.execute(
104+
res = await self._conn.execute(
104105
self._session, statement, timeout=timeout, do_ping=do_ping
105106
)
107+
res.raise_on_error()
106108
return res
107109

108-
async def close(self):
110+
async def _close(self):
109111
try:
110-
await self.conn.execute(self._session, "SESSION CLOSE")
112+
await self._conn.execute(self._session, "SESSION CLOSE")
111113
except Exception:
112114
logger.exception("Failed to close async session")
113115

src/nebulagraph_python/client/_session_pool.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ async def connect(
9191
except Exception:
9292
# Clean up any sessions that were successfully created
9393
for session in sessions:
94-
await session.close()
94+
await session._close()
9595
raise
9696

9797
def __init__(
@@ -157,20 +157,20 @@ async def borrow(self):
157157
self.busy_sessions_queue.remove(got_session)
158158
self.queue_count.release()
159159

160-
async def close(self):
160+
async def _close(self):
161161
# Acquire all semaphore permits to prevent new borrows
162162
for _ in range(self.config.size):
163163
await self.queue_count.acquire()
164164
async with self.queue_lock:
165165
# Close all free sessions
166166
for session in self.free_sessions_queue:
167-
await session.close()
167+
await session._close()
168168
# Close all busy sessions (if any remain)
169169
for session in self.busy_sessions_queue:
170170
logger.error(
171171
"Busy sessions remain after acquire all semaphore permits, which indicates a bug in the AsyncSessionPool"
172172
)
173-
await session.close()
173+
await session._close()
174174

175175

176176
class SessionPool:
@@ -209,7 +209,7 @@ def connect(
209209
except Exception:
210210
# Clean up any sessions that were successfully created
211211
for session in sessions:
212-
session.close()
212+
session._close()
213213
raise
214214

215215
def __init__(
@@ -273,17 +273,17 @@ def borrow(self):
273273
self.busy_sessions_queue.remove(got_session)
274274
self.queue_count.release()
275275

276-
def close(self):
276+
def _close(self):
277277
# Acquire all semaphore permits to prevent new borrows
278278
for _ in range(self.config.size):
279279
self.queue_count.acquire()
280280
with self.queue_lock:
281281
# Close all free sessions
282282
for session in self.free_sessions_queue:
283-
session.close()
283+
session._close()
284284
# Close all busy sessions (if any remain)
285285
for session in self.busy_sessions_queue:
286286
logger.error(
287287
"Busy sessions remain after acquire all semaphore permits, which indicates a bug in the SessionPool"
288288
)
289-
session.close()
289+
session._close()

src/nebulagraph_python/client/client.py

Lines changed: 35 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
# limitations under the License.
1414

1515
import logging
16+
from collections.abc import AsyncGenerator, Generator
17+
from contextlib import asynccontextmanager, contextmanager
1618
from typing import Any, Dict, List, Literal, Optional, Union
1719

1820
from nebulagraph_python.client._connection import (
@@ -120,7 +122,7 @@ async def connect(
120122
)
121123
else:
122124
self._sessions[host_addr] = AsyncSession(
123-
conn=conn,
125+
_conn=conn,
124126
username=username,
125127
password=password,
126128
session_config=session_config or SessionConfig(),
@@ -134,40 +136,31 @@ async def connect(
134136
async def execute(
135137
self, statement: str, *, timeout: Optional[float] = None, do_ping: bool = False
136138
) -> ResultSet:
139+
async with self.borrow() as session:
140+
return await session.execute(statement, timeout=timeout, do_ping=do_ping)
141+
142+
@asynccontextmanager
143+
async def borrow(self) -> AsyncGenerator[AsyncSession, None]:
137144
if isinstance(self._conn, AsyncConnectionPool):
138-
addr, _conn = await self._conn.next_connection()
145+
addr, conn = await self._conn.next_connection()
139146
else:
140-
addr = self._conn.config.hosts[0]
147+
conn = self._conn
148+
addr = conn.connected
149+
if addr is None:
150+
raise ValueError("Connection not connected")
151+
141152
_session = self._sessions[addr]
142153

143154
if isinstance(_session, AsyncSessionPool):
144155
async with _session.borrow() as session:
145-
return (
146-
await session.execute(statement, timeout=timeout, do_ping=do_ping)
147-
).raise_on_error()
156+
yield session
148157
else:
149-
return (
150-
await _session.execute(statement, timeout=timeout, do_ping=do_ping)
151-
).raise_on_error()
152-
153-
async def ping(self, timeout: Optional[float] = None) -> bool:
154-
try:
155-
res = (
156-
(await self.execute(statement="RETURN 1", timeout=timeout))
157-
.one()
158-
.as_primitive()
159-
)
160-
if not res == {"1": 1}:
161-
raise ValueError(f"Unexpected result from ping: {res}")
162-
return True
163-
except Exception:
164-
logger.exception("Failed to ping NebulaGraph")
165-
return False
158+
yield _session
166159

167160
async def close(self):
168161
"""Close the client connection and session. No Exception will be raised but an error will be logged."""
169162
for session in self._sessions.values():
170-
await session.close()
163+
await session._close()
171164
await self._conn.close()
172165

173166
async def __aenter__(self):
@@ -245,7 +238,7 @@ def __init__(
245238
)
246239
else:
247240
self._sessions[host_addr] = Session(
248-
conn=conn,
241+
_conn=conn,
249242
username=username,
250243
password=password,
251244
session_config=session_config or SessionConfig(),
@@ -258,21 +251,29 @@ def __init__(
258251
def execute(
259252
self, statement: str, *, timeout: Optional[float] = None, do_ping: bool = False
260253
) -> ResultSet:
254+
"""Execute a statement using a borrowed session, raising on errors."""
255+
with self.borrow() as session:
256+
return session.execute(statement, timeout=timeout, do_ping=do_ping)
257+
258+
@contextmanager
259+
def borrow(self) -> Generator[Session, None, None]:
260+
"""Yield a session bound to the selected connection."""
261261
if isinstance(self._conn, ConnectionPool):
262-
addr, _conn = self._conn.next_connection()
262+
addr, conn = self._conn.next_connection()
263263
else:
264-
addr = self._conn.config.hosts[0]
264+
conn = self._conn
265+
addr = conn.connected
266+
if addr is None:
267+
raise ValueError("Connection not connected")
268+
269+
# Route to the correct session (pool or single session)
265270
_session = self._sessions[addr]
266271

267272
if isinstance(_session, SessionPool):
268273
with _session.borrow() as session:
269-
return session.execute(
270-
statement, timeout=timeout, do_ping=do_ping
271-
).raise_on_error()
274+
yield session
272275
else:
273-
return _session.execute(
274-
statement, timeout=timeout, do_ping=do_ping
275-
).raise_on_error()
276+
yield _session
276277

277278
def ping(self, timeout: Optional[float] = None) -> bool:
278279
try:
@@ -291,7 +292,7 @@ def ping(self, timeout: Optional[float] = None) -> bool:
291292
def close(self):
292293
"""Close the client connection and session. No Exception will be raised but an error will be logged."""
293294
for session in self._sessions.values():
294-
session.close()
295+
session._close()
295296
self._conn.close()
296297

297298
def __enter__(self):

0 commit comments

Comments
 (0)