1313# limitations under the License.
1414
1515import logging
16+ from collections .abc import AsyncGenerator , Generator
17+ from contextlib import asynccontextmanager , contextmanager
1618from typing import Any , Dict , List , Literal , Optional , Union
1719
1820from nebulagraph_python .client ._connection import (
@@ -120,7 +122,7 @@ async def connect(
120122 )
121123 else :
122124 self ._sessions [host_addr ] = AsyncSession (
123- conn = conn ,
125+ _conn = conn ,
124126 username = username ,
125127 password = password ,
126128 session_config = session_config or SessionConfig (),
@@ -134,40 +136,31 @@ async def connect(
134136 async def execute (
135137 self , statement : str , * , timeout : Optional [float ] = None , do_ping : bool = False
136138 ) -> ResultSet :
139+ async with self .borrow () as session :
140+ return await session .execute (statement , timeout = timeout , do_ping = do_ping )
141+
142+ @asynccontextmanager
143+ async def borrow (self ) -> AsyncGenerator [AsyncSession , None ]:
137144 if isinstance (self ._conn , AsyncConnectionPool ):
138- addr , _conn = await self ._conn .next_connection ()
145+ addr , conn = await self ._conn .next_connection ()
139146 else :
140- addr = self ._conn .config .hosts [0 ]
147+ conn = self ._conn
148+ addr = conn .connected
149+ if addr is None :
150+ raise ValueError ("Connection not connected" )
151+
141152 _session = self ._sessions [addr ]
142153
143154 if isinstance (_session , AsyncSessionPool ):
144155 async with _session .borrow () as session :
145- return (
146- await session .execute (statement , timeout = timeout , do_ping = do_ping )
147- ).raise_on_error ()
156+ yield session
148157 else :
149- return (
150- await _session .execute (statement , timeout = timeout , do_ping = do_ping )
151- ).raise_on_error ()
152-
153- async def ping (self , timeout : Optional [float ] = None ) -> bool :
154- try :
155- res = (
156- (await self .execute (statement = "RETURN 1" , timeout = timeout ))
157- .one ()
158- .as_primitive ()
159- )
160- if not res == {"1" : 1 }:
161- raise ValueError (f"Unexpected result from ping: { res } " )
162- return True
163- except Exception :
164- logger .exception ("Failed to ping NebulaGraph" )
165- return False
158+ yield _session
166159
167160 async def close (self ):
168161 """Close the client connection and session. No Exception will be raised but an error will be logged."""
169162 for session in self ._sessions .values ():
170- await session .close ()
163+ await session ._close ()
171164 await self ._conn .close ()
172165
173166 async def __aenter__ (self ):
@@ -245,7 +238,7 @@ def __init__(
245238 )
246239 else :
247240 self ._sessions [host_addr ] = Session (
248- conn = conn ,
241+ _conn = conn ,
249242 username = username ,
250243 password = password ,
251244 session_config = session_config or SessionConfig (),
@@ -258,21 +251,29 @@ def __init__(
258251 def execute (
259252 self , statement : str , * , timeout : Optional [float ] = None , do_ping : bool = False
260253 ) -> ResultSet :
254+ """Execute a statement using a borrowed session, raising on errors."""
255+ with self .borrow () as session :
256+ return session .execute (statement , timeout = timeout , do_ping = do_ping )
257+
258+ @contextmanager
259+ def borrow (self ) -> Generator [Session , None , None ]:
260+ """Yield a session bound to the selected connection."""
261261 if isinstance (self ._conn , ConnectionPool ):
262- addr , _conn = self ._conn .next_connection ()
262+ addr , conn = self ._conn .next_connection ()
263263 else :
264- addr = self ._conn .config .hosts [0 ]
264+ conn = self ._conn
265+ addr = conn .connected
266+ if addr is None :
267+ raise ValueError ("Connection not connected" )
268+
269+ # Route to the correct session (pool or single session)
265270 _session = self ._sessions [addr ]
266271
267272 if isinstance (_session , SessionPool ):
268273 with _session .borrow () as session :
269- return session .execute (
270- statement , timeout = timeout , do_ping = do_ping
271- ).raise_on_error ()
274+ yield session
272275 else :
273- return _session .execute (
274- statement , timeout = timeout , do_ping = do_ping
275- ).raise_on_error ()
276+ yield _session
276277
277278 def ping (self , timeout : Optional [float ] = None ) -> bool :
278279 try :
@@ -291,7 +292,7 @@ def ping(self, timeout: Optional[float] = None) -> bool:
291292 def close (self ):
292293 """Close the client connection and session. No Exception will be raised but an error will be logged."""
293294 for session in self ._sessions .values ():
294- session .close ()
295+ session ._close ()
295296 self ._conn .close ()
296297
297298 def __enter__ (self ):
0 commit comments