diff --git a/docs/2_concurrency.md b/docs/2_concurrency.md index d9e48f4f..eb063b36 100644 --- a/docs/2_concurrency.md +++ b/docs/2_concurrency.md @@ -62,6 +62,26 @@ async def concurrent_example(): asyncio.run(concurrent_example()) ``` +## Contextual Execution + +By default, statements run on a random session from the pool. When you need to run several queries on the same session, call `borrow` to obtain and reuse a specific session. + +```python +async def contextual_example(): + async with await NebulaAsyncClient.connect( + hosts=["127.0.0.1:9669"], + username="root", + password="NebulaGraph01", + session_pool_config=SessionPoolConfig(), + ) as client: + print("Connected to the server...") + async with client.borrow() as session: + await session.execute("SESSION SET GRAPH movie") + res = await session.execute("MATCH (v:Movie) RETURN count(v)") + res.print() +``` + + ## Understanding Timeout Values The client uses three different timeouts that apply at different stages: diff --git a/src/nebulagraph_python/client/_connection.py b/src/nebulagraph_python/client/_connection.py index 3dc320dc..a2f02630 100644 --- a/src/nebulagraph_python/client/_connection.py +++ b/src/nebulagraph_python/client/_connection.py @@ -96,6 +96,8 @@ class Connection: # Config config: ConnectionConfig + # Track which host was successfully connected for session routing + connected: HostAddress | None = field(default=None, init=False) # Owned Resources _stub: Optional[graph_pb2_grpc.GraphServiceStub] = field(default=None, init=False) @@ -152,6 +154,8 @@ def connect(self): logger.info( f"Successfully connected to {host_addr.host}:{host_addr.port}." ) + # Remember which host we actually connected to + self.connected = host_addr return except Exception as e: logger.warning( @@ -174,6 +178,7 @@ def close(self): self._channel.close() self._channel = None self._stub = None + self.connected = None except Exception: logger.exception("Failed to close connection") @@ -303,6 +308,7 @@ class AsyncConnection: """ config: ConnectionConfig + connected: HostAddress | None = None _stub: Optional[graph_pb2_grpc.GraphServiceStub] = field(default=None, init=False) _channel: Optional[grpc.aio.Channel] = field( default=None, init=False @@ -358,6 +364,7 @@ async def connect(self): logger.info( f"Successfully connected to {host_addr.host}:{host_addr.port} asynchronously." ) + self.connected = host_addr return except Exception as e: logger.warning( @@ -380,6 +387,7 @@ async def close(self): await self._channel.close() self._channel = None self._stub = None + self.connected = None except BaseException: logger.exception("Failed to close async connection") diff --git a/src/nebulagraph_python/client/_session.py b/src/nebulagraph_python/client/_session.py index 09184d50..83bd322b 100644 --- a/src/nebulagraph_python/client/_session.py +++ b/src/nebulagraph_python/client/_session.py @@ -25,7 +25,7 @@ from nebulagraph_python.error import ExecutingError -@dataclass +@dataclass(kw_only=True, frozen=True) class SessionConfig: schema: Optional[str] = None graph: Optional[str] = None @@ -47,31 +47,32 @@ class SessionBase: @dataclass class Session(SessionBase): - conn: "Connection" + _conn: "Connection" def execute( self, statement: str, *, timeout: Optional[float] = None, do_ping: bool = False ): - res = self.conn.execute( + res = self._conn.execute( self._session, statement, timeout=timeout, do_ping=do_ping ) # Retry for only one time if res.status_code == ErrorCode.SESSION_NOT_FOUND.value: - self._session = self.conn.authenticate( + self._session = self._conn.authenticate( self.username, self.password, session_config=self.session_config, auth_options=self.auth_options, ) - res = self.conn.execute( + res = self._conn.execute( self._session, statement, timeout=timeout, do_ping=do_ping ) + res.raise_on_error() return res - def close(self): + def _close(self): """Close session""" try: - self.conn.execute(self._session, "SESSION CLOSE") + self._conn.execute(self._session, "SESSION CLOSE") except Exception: logger.exception("Failed to close session") @@ -84,30 +85,31 @@ def __eq__(self, other): @dataclass class AsyncSession(SessionBase): - conn: "AsyncConnection" + _conn: "AsyncConnection" async def execute( self, statement: str, *, timeout: Optional[float] = None, do_ping: bool = False ): - res = await self.conn.execute( + res = await self._conn.execute( self._session, statement, timeout=timeout, do_ping=do_ping ) # Retry for only one time if res.status_code == ErrorCode.SESSION_NOT_FOUND.value: - self._session = await self.conn.authenticate( + self._session = await self._conn.authenticate( self.username, self.password, session_config=self.session_config, auth_options=self.auth_options, ) - res = await self.conn.execute( + res = await self._conn.execute( self._session, statement, timeout=timeout, do_ping=do_ping ) + res.raise_on_error() return res - async def close(self): + async def _close(self): try: - await self.conn.execute(self._session, "SESSION CLOSE") + await self._conn.execute(self._session, "SESSION CLOSE") except Exception: logger.exception("Failed to close async session") diff --git a/src/nebulagraph_python/client/_session_pool.py b/src/nebulagraph_python/client/_session_pool.py index d08f930b..2a085a33 100644 --- a/src/nebulagraph_python/client/_session_pool.py +++ b/src/nebulagraph_python/client/_session_pool.py @@ -91,7 +91,7 @@ async def connect( except Exception: # Clean up any sessions that were successfully created for session in sessions: - await session.close() + await session._close() raise def __init__( @@ -157,20 +157,20 @@ async def borrow(self): self.busy_sessions_queue.remove(got_session) self.queue_count.release() - async def close(self): + async def _close(self): # Acquire all semaphore permits to prevent new borrows for _ in range(self.config.size): await self.queue_count.acquire() async with self.queue_lock: # Close all free sessions for session in self.free_sessions_queue: - await session.close() + await session._close() # Close all busy sessions (if any remain) for session in self.busy_sessions_queue: logger.error( "Busy sessions remain after acquire all semaphore permits, which indicates a bug in the AsyncSessionPool" ) - await session.close() + await session._close() class SessionPool: @@ -209,7 +209,7 @@ def connect( except Exception: # Clean up any sessions that were successfully created for session in sessions: - session.close() + session._close() raise def __init__( @@ -273,17 +273,17 @@ def borrow(self): self.busy_sessions_queue.remove(got_session) self.queue_count.release() - def close(self): + def _close(self): # Acquire all semaphore permits to prevent new borrows for _ in range(self.config.size): self.queue_count.acquire() with self.queue_lock: # Close all free sessions for session in self.free_sessions_queue: - session.close() + session._close() # Close all busy sessions (if any remain) for session in self.busy_sessions_queue: logger.error( "Busy sessions remain after acquire all semaphore permits, which indicates a bug in the SessionPool" ) - session.close() + session._close() diff --git a/src/nebulagraph_python/client/client.py b/src/nebulagraph_python/client/client.py index 57be1d7b..1b57fa81 100644 --- a/src/nebulagraph_python/client/client.py +++ b/src/nebulagraph_python/client/client.py @@ -13,6 +13,8 @@ # limitations under the License. import logging +from collections.abc import AsyncGenerator, Generator +from contextlib import asynccontextmanager, contextmanager from typing import Any, Dict, List, Literal, Optional, Union from nebulagraph_python.client._connection import ( @@ -120,7 +122,7 @@ async def connect( ) else: self._sessions[host_addr] = AsyncSession( - conn=conn, + _conn=conn, username=username, password=password, session_config=session_config or SessionConfig(), @@ -134,40 +136,31 @@ async def connect( async def execute( self, statement: str, *, timeout: Optional[float] = None, do_ping: bool = False ) -> ResultSet: + async with self.borrow() as session: + return await session.execute(statement, timeout=timeout, do_ping=do_ping) + + @asynccontextmanager + async def borrow(self) -> AsyncGenerator[AsyncSession, None]: if isinstance(self._conn, AsyncConnectionPool): - addr, _conn = await self._conn.next_connection() + addr, conn = await self._conn.next_connection() else: - addr = self._conn.config.hosts[0] + conn = self._conn + addr = conn.connected + if addr is None: + raise ValueError("Connection not connected") + _session = self._sessions[addr] if isinstance(_session, AsyncSessionPool): async with _session.borrow() as session: - return ( - await session.execute(statement, timeout=timeout, do_ping=do_ping) - ).raise_on_error() + yield session else: - return ( - await _session.execute(statement, timeout=timeout, do_ping=do_ping) - ).raise_on_error() - - async def ping(self, timeout: Optional[float] = None) -> bool: - try: - res = ( - (await self.execute(statement="RETURN 1", timeout=timeout)) - .one() - .as_primitive() - ) - if not res == {"1": 1}: - raise ValueError(f"Unexpected result from ping: {res}") - return True - except Exception: - logger.exception("Failed to ping NebulaGraph") - return False + yield _session async def close(self): """Close the client connection and session. No Exception will be raised but an error will be logged.""" for session in self._sessions.values(): - await session.close() + await session._close() await self._conn.close() async def __aenter__(self): @@ -245,7 +238,7 @@ def __init__( ) else: self._sessions[host_addr] = Session( - conn=conn, + _conn=conn, username=username, password=password, session_config=session_config or SessionConfig(), @@ -258,21 +251,29 @@ def __init__( def execute( self, statement: str, *, timeout: Optional[float] = None, do_ping: bool = False ) -> ResultSet: + """Execute a statement using a borrowed session, raising on errors.""" + with self.borrow() as session: + return session.execute(statement, timeout=timeout, do_ping=do_ping) + + @contextmanager + def borrow(self) -> Generator[Session, None, None]: + """Yield a session bound to the selected connection.""" if isinstance(self._conn, ConnectionPool): - addr, _conn = self._conn.next_connection() + addr, conn = self._conn.next_connection() else: - addr = self._conn.config.hosts[0] + conn = self._conn + addr = conn.connected + if addr is None: + raise ValueError("Connection not connected") + + # Route to the correct session (pool or single session) _session = self._sessions[addr] if isinstance(_session, SessionPool): with _session.borrow() as session: - return session.execute( - statement, timeout=timeout, do_ping=do_ping - ).raise_on_error() + yield session else: - return _session.execute( - statement, timeout=timeout, do_ping=do_ping - ).raise_on_error() + yield _session def ping(self, timeout: Optional[float] = None) -> bool: try: @@ -291,7 +292,7 @@ def ping(self, timeout: Optional[float] = None) -> bool: def close(self): """Close the client connection and session. No Exception will be raised but an error will be logged.""" for session in self._sessions.values(): - session.close() + session._close() self._conn.close() def __enter__(self): diff --git a/tests/test_session_pool.py b/tests/test_session_pool.py index 7cc5d38f..05b40872 100644 --- a/tests/test_session_pool.py +++ b/tests/test_session_pool.py @@ -26,9 +26,9 @@ def test_init_basic(self): """Test basic initialization""" mock_conn = Mock() sessions = { - Session(conn=mock_conn, username="user1", password="pass1", session_config=None, auth_options=None), - Session(conn=mock_conn, username="user2", password="pass2", session_config=None, auth_options=None), - Session(conn=mock_conn, username="user3", password="pass3", session_config=None, auth_options=None), + Session(_conn=mock_conn, username="user1", password="pass1", session_config=None, auth_options=None), + Session(_conn=mock_conn, username="user2", password="pass2", session_config=None, auth_options=None), + Session(_conn=mock_conn, username="user3", password="pass3", session_config=None, auth_options=None), } config = SessionPoolConfig(size=3) pool = SessionPool(copy(sessions), config) @@ -42,8 +42,8 @@ def test_init_with_config(self): """Test initialization with custom config""" mock_conn = Mock() sessions = { - Session(conn=mock_conn, username="user1", password="pass1", session_config=None, auth_options=None), - Session(conn=mock_conn, username="user2", password="pass2", session_config=None, auth_options=None), + Session(_conn=mock_conn, username="user1", password="pass1", session_config=None, auth_options=None), + Session(_conn=mock_conn, username="user2", password="pass2", session_config=None, auth_options=None), } config = SessionPoolConfig(size=2, wait_timeout=10.0) pool = SessionPool(copy(sessions), config) @@ -55,9 +55,9 @@ def test_init_with_all_config_params(self): """Test initialization with all configuration parameters""" mock_conn = Mock() sessions = { - Session(conn=mock_conn, username="user1", password="pass1", session_config=None, auth_options=None), - Session(conn=mock_conn, username="user2", password="pass2", session_config=None, auth_options=None), - Session(conn=mock_conn, username="user3", password="pass3", session_config=None, auth_options=None), + Session(_conn=mock_conn, username="user1", password="pass1", session_config=None, auth_options=None), + Session(_conn=mock_conn, username="user2", password="pass2", session_config=None, auth_options=None), + Session(_conn=mock_conn, username="user3", password="pass3", session_config=None, auth_options=None), } config = SessionPoolConfig( size=3, @@ -73,9 +73,9 @@ def test_borrow_single_session(self): """Test borrowing a single session""" mock_conn = Mock() sessions = { - Session(conn=mock_conn, username="user1", password="pass1", session_config=None, auth_options=None), - Session(conn=mock_conn, username="user2", password="pass2", session_config=None, auth_options=None), - Session(conn=mock_conn, username="user3", password="pass3", session_config=None, auth_options=None), + Session(_conn=mock_conn, username="user1", password="pass1", session_config=None, auth_options=None), + Session(_conn=mock_conn, username="user2", password="pass2", session_config=None, auth_options=None), + Session(_conn=mock_conn, username="user3", password="pass3", session_config=None, auth_options=None), } pool = SessionPool(copy(sessions), SessionPoolConfig(size=3)) @@ -95,8 +95,8 @@ def test_borrow_all_sessions(self): """Test borrowing all available sessions""" mock_conn = Mock() sessions = { - Session(conn=mock_conn, username="user1", password="pass1", session_config=None, auth_options=None), - Session(conn=mock_conn, username="user2", password="pass2", session_config=None, auth_options=None), + Session(_conn=mock_conn, username="user1", password="pass1", session_config=None, auth_options=None), + Session(_conn=mock_conn, username="user2", password="pass2", session_config=None, auth_options=None), } pool = SessionPool(copy(sessions), SessionPoolConfig(size=2)) @@ -110,7 +110,7 @@ def test_borrow_timeout_exceeded(self): """Test borrowing when timeout is exceeded""" mock_conn = Mock() sessions = { - Session(conn=mock_conn, username="user1", password="pass1", session_config=None, auth_options=None), + Session(_conn=mock_conn, username="user1", password="pass1", session_config=None, auth_options=None), } config = SessionPoolConfig(size=1, wait_timeout=0.2) pool = SessionPool(copy(sessions), config) @@ -125,7 +125,7 @@ def test_borrow_infinite_wait_with_release(self): """Test borrowing with infinite wait that succeeds when session becomes available""" mock_conn = Mock() sessions = { - Session(conn=mock_conn, username="user1", password="pass1", session_config=None, auth_options=None), + Session(_conn=mock_conn, username="user1", password="pass1", session_config=None, auth_options=None), } config = SessionPoolConfig(size=1, wait_timeout=None) pool = SessionPool(copy(sessions), config) @@ -158,7 +158,7 @@ def test_concurrent_borrowing(self): """Test concurrent borrowing from multiple threads""" mock_conn = Mock() sessions = { - Session(conn=mock_conn, username=f"user{i}", password=f"pass{i}", session_config=None, auth_options=None) + Session(_conn=mock_conn, username=f"user{i}", password=f"pass{i}", session_config=None, auth_options=None) for i in range(5) } config = SessionPoolConfig(size=5) @@ -196,8 +196,8 @@ def test_semaphore_consistency(self): """Test that semaphore behavior stays consistent with actual session availability""" mock_conn = Mock() sessions = { - Session(conn=mock_conn, username="user1", password="pass1", session_config=None, auth_options=None), - Session(conn=mock_conn, username="user2", password="pass2", session_config=None, auth_options=None), + Session(_conn=mock_conn, username="user1", password="pass1", session_config=None, auth_options=None), + Session(_conn=mock_conn, username="user2", password="pass2", session_config=None, auth_options=None), } config = SessionPoolConfig(size=2) pool = SessionPool(copy(sessions), config) @@ -225,49 +225,49 @@ def test_close_all_free_sessions(self): """Test closing pool with all sessions free""" mock_conn = Mock() sessions = { - Session(conn=mock_conn, username="user1", password="pass1", session_config=None, auth_options=None), - Session(conn=mock_conn, username="user2", password="pass2", session_config=None, auth_options=None), - Session(conn=mock_conn, username="user3", password="pass3", session_config=None, auth_options=None), + Session(_conn=mock_conn, username="user1", password="pass1", session_config=None, auth_options=None), + Session(_conn=mock_conn, username="user2", password="pass2", session_config=None, auth_options=None), + Session(_conn=mock_conn, username="user3", password="pass3", session_config=None, auth_options=None), } config = SessionPoolConfig(size=3) pool = SessionPool(copy(sessions), config) # Mock the close_session method for all sessions for session in sessions: - session.close = Mock() + session._close = Mock() - pool.close() + pool._close() # Should close all sessions for session in sessions: - session.close.assert_called_once() + session._close.assert_called_once() @patch('nebulagraph_python.client._session_pool.logger') def test_close_with_busy_sessions(self, mock_logger): """Test closing pool with some busy sessions""" mock_conn = Mock() sessions = { - Session(conn=mock_conn, username="user1", password="pass1", session_config=None, auth_options=None), - Session(conn=mock_conn, username="user2", password="pass2", session_config=None, auth_options=None), - Session(conn=mock_conn, username="user3", password="pass3", session_config=None, auth_options=None), + Session(_conn=mock_conn, username="user1", password="pass1", session_config=None, auth_options=None), + Session(_conn=mock_conn, username="user2", password="pass2", session_config=None, auth_options=None), + Session(_conn=mock_conn, username="user3", password="pass3", session_config=None, auth_options=None), } config = SessionPoolConfig(size=3) pool = SessionPool(copy(sessions), config) # Mock the close_session method for all sessions for session in sessions: - session.close = Mock() + session._close = Mock() # Manually move a session to busy state busy_session = list(sessions)[1] # Get the second session pool.free_sessions_queue.remove(busy_session) pool.busy_sessions_queue.add(busy_session) - pool.close() + pool._close() # Should close all sessions for session in sessions: - session.close.assert_called_once() + session._close.assert_called_once() # Should log error about busy sessions mock_logger.error.assert_called_once() assert "Busy sessions remain" in mock_logger.error.call_args[0][0] @@ -337,8 +337,8 @@ def test_multiple_borrow_release_cycles(self): """Test multiple borrow-release cycles work correctly""" mock_conn = Mock() sessions = { - Session(conn=mock_conn, username="user1", password="pass1", session_config=None, auth_options=None), - Session(conn=mock_conn, username="user2", password="pass2", session_config=None, auth_options=None), + Session(_conn=mock_conn, username="user1", password="pass1", session_config=None, auth_options=None), + Session(_conn=mock_conn, username="user2", password="pass2", session_config=None, auth_options=None), } config = SessionPoolConfig(size=2) pool = SessionPool(copy(sessions), config) @@ -370,9 +370,9 @@ async def test_init_basic(self): """Test basic initialization""" mock_conn = AsyncMock() sessions = { - AsyncSession(conn=mock_conn, username="user1", password="pass1", session_config=None, auth_options=None), - AsyncSession(conn=mock_conn, username="user2", password="pass2", session_config=None, auth_options=None), - AsyncSession(conn=mock_conn, username="user3", password="pass3", session_config=None, auth_options=None), + AsyncSession(_conn=mock_conn, username="user1", password="pass1", session_config=None, auth_options=None), + AsyncSession(_conn=mock_conn, username="user2", password="pass2", session_config=None, auth_options=None), + AsyncSession(_conn=mock_conn, username="user3", password="pass3", session_config=None, auth_options=None), } config = SessionPoolConfig(size=3) pool = AsyncSessionPool(copy(sessions), config) @@ -389,8 +389,8 @@ async def test_init_with_config(self): """Test initialization with custom config""" mock_conn = AsyncMock() sessions = { - AsyncSession(conn=mock_conn, username="user1", password="pass1", session_config=None, auth_options=None), - AsyncSession(conn=mock_conn, username="user2", password="pass2", session_config=None, auth_options=None), + AsyncSession(_conn=mock_conn, username="user1", password="pass1", session_config=None, auth_options=None), + AsyncSession(_conn=mock_conn, username="user2", password="pass2", session_config=None, auth_options=None), } config = SessionPoolConfig(size=2, wait_timeout=10.0) pool = AsyncSessionPool(copy(sessions), config) @@ -403,9 +403,9 @@ async def test_init_with_all_config_params(self): """Test initialization with all configuration parameters""" mock_conn = AsyncMock() sessions = { - AsyncSession(conn=mock_conn, username="user1", password="pass1", session_config=None, auth_options=None), - AsyncSession(conn=mock_conn, username="user2", password="pass2", session_config=None, auth_options=None), - AsyncSession(conn=mock_conn, username="user3", password="pass3", session_config=None, auth_options=None), + AsyncSession(_conn=mock_conn, username="user1", password="pass1", session_config=None, auth_options=None), + AsyncSession(_conn=mock_conn, username="user2", password="pass2", session_config=None, auth_options=None), + AsyncSession(_conn=mock_conn, username="user3", password="pass3", session_config=None, auth_options=None), } config = SessionPoolConfig( size=3, @@ -422,9 +422,9 @@ async def test_borrow_single_session(self): """Test borrowing a single session""" mock_conn = AsyncMock() sessions = { - AsyncSession(conn=mock_conn, username="user1", password="pass1", session_config=None, auth_options=None), - AsyncSession(conn=mock_conn, username="user2", password="pass2", session_config=None, auth_options=None), - AsyncSession(conn=mock_conn, username="user3", password="pass3", session_config=None, auth_options=None), + AsyncSession(_conn=mock_conn, username="user1", password="pass1", session_config=None, auth_options=None), + AsyncSession(_conn=mock_conn, username="user2", password="pass2", session_config=None, auth_options=None), + AsyncSession(_conn=mock_conn, username="user3", password="pass3", session_config=None, auth_options=None), } config = SessionPoolConfig(size=3) pool = AsyncSessionPool(copy(sessions), config) @@ -446,8 +446,8 @@ async def test_borrow_all_sessions(self): """Test borrowing all available sessions""" mock_conn = AsyncMock() sessions = { - AsyncSession(conn=mock_conn, username="user1", password="pass1", session_config=None, auth_options=None), - AsyncSession(conn=mock_conn, username="user2", password="pass2", session_config=None, auth_options=None), + AsyncSession(_conn=mock_conn, username="user1", password="pass1", session_config=None, auth_options=None), + AsyncSession(_conn=mock_conn, username="user2", password="pass2", session_config=None, auth_options=None), } config = SessionPoolConfig(size=2) pool = AsyncSessionPool(copy(sessions), config) @@ -463,7 +463,7 @@ async def test_borrow_timeout_exceeded(self): """Test borrowing when timeout is exceeded""" mock_conn = AsyncMock() sessions = { - AsyncSession(conn=mock_conn, username="user1", password="pass1", session_config=None, auth_options=None), + AsyncSession(_conn=mock_conn, username="user1", password="pass1", session_config=None, auth_options=None), } config = SessionPoolConfig(size=1, wait_timeout=0.2) pool = AsyncSessionPool(copy(sessions), config) @@ -479,7 +479,7 @@ async def test_borrow_infinite_wait_with_release(self): """Test borrowing with infinite wait that succeeds when session becomes available""" mock_conn = AsyncMock() sessions = { - AsyncSession(conn=mock_conn, username="user1", password="pass1", session_config=None, auth_options=None), + AsyncSession(_conn=mock_conn, username="user1", password="pass1", session_config=None, auth_options=None), } config = SessionPoolConfig(size=1, wait_timeout=None) pool = AsyncSessionPool(copy(sessions), config) @@ -507,7 +507,7 @@ async def test_concurrent_borrowing(self): """Test concurrent borrowing from multiple coroutines""" mock_conn = AsyncMock() sessions = { - AsyncSession(conn=mock_conn, username=f"user{i}", password=f"pass{i}", session_config=None, auth_options=None) + AsyncSession(_conn=mock_conn, username=f"user{i}", password=f"pass{i}", session_config=None, auth_options=None) for i in range(5) } config = SessionPoolConfig(size=5) @@ -539,8 +539,8 @@ async def test_semaphore_consistency(self): """Test that semaphore behavior stays consistent with actual session availability""" mock_conn = AsyncMock() sessions = { - AsyncSession(conn=mock_conn, username="user1", password="pass1", session_config=None, auth_options=None), - AsyncSession(conn=mock_conn, username="user2", password="pass2", session_config=None, auth_options=None), + AsyncSession(_conn=mock_conn, username="user1", password="pass1", session_config=None, auth_options=None), + AsyncSession(_conn=mock_conn, username="user2", password="pass2", session_config=None, auth_options=None), } config = SessionPoolConfig(size=2) pool = AsyncSessionPool(copy(sessions), config) @@ -581,22 +581,22 @@ async def test_close_all_free_sessions(self): """Test closing pool with all sessions free""" mock_conn = AsyncMock() sessions = { - AsyncSession(conn=mock_conn, username="user1", password="pass1", session_config=None, auth_options=None), - AsyncSession(conn=mock_conn, username="user2", password="pass2", session_config=None, auth_options=None), - AsyncSession(conn=mock_conn, username="user3", password="pass3", session_config=None, auth_options=None), + AsyncSession(_conn=mock_conn, username="user1", password="pass1", session_config=None, auth_options=None), + AsyncSession(_conn=mock_conn, username="user2", password="pass2", session_config=None, auth_options=None), + AsyncSession(_conn=mock_conn, username="user3", password="pass3", session_config=None, auth_options=None), } config = SessionPoolConfig(size=3) pool = AsyncSessionPool(copy(sessions), config) # Mock the close_session method for all sessions for session in sessions: - session.close = AsyncMock() + session._close = AsyncMock() - await pool.close() + await pool._close() # Should close all sessions for session in sessions: - session.close.assert_called_once() + session._close.assert_called_once() @pytest.mark.asyncio @patch('nebulagraph_python.client._session_pool.logger') @@ -604,26 +604,26 @@ async def test_close_with_busy_sessions(self, mock_logger): """Test closing pool with some busy sessions""" mock_conn = AsyncMock() sessions = { - AsyncSession(conn=mock_conn, username="user1", password="pass1", session_config=None, auth_options=None), - AsyncSession(conn=mock_conn, username="user2", password="pass2", session_config=None, auth_options=None), - AsyncSession(conn=mock_conn, username="user3", password="pass3", session_config=None, auth_options=None), + AsyncSession(_conn=mock_conn, username="user1", password="pass1", session_config=None, auth_options=None), + AsyncSession(_conn=mock_conn, username="user2", password="pass2", session_config=None, auth_options=None), + AsyncSession(_conn=mock_conn, username="user3", password="pass3", session_config=None, auth_options=None), } pool = AsyncSessionPool(copy(sessions), config=SessionPoolConfig(size=3)) # Mock the close_session method for all sessions for session in sessions: - session.close = AsyncMock() + session._close = AsyncMock() # Manually move a session to busy state busy_session = list(sessions)[1] # Get the second session pool.free_sessions_queue.remove(busy_session) pool.busy_sessions_queue.add(busy_session) - await pool.close() + await pool._close() # Should close all sessions for session in sessions: - session.close.assert_called_once() + session._close.assert_called_once() # Should log error about busy sessions mock_logger.error.assert_called_once() assert "Busy sessions remain" in mock_logger.error.call_args[0][0] @@ -696,8 +696,8 @@ async def test_multiple_borrow_release_cycles(self): """Test multiple borrow-release cycles work correctly""" mock_conn = AsyncMock() sessions = { - AsyncSession(conn=mock_conn, username="user1", password="pass1", session_config=None, auth_options=None), - AsyncSession(conn=mock_conn, username="user2", password="pass2", session_config=None, auth_options=None), + AsyncSession(_conn=mock_conn, username="user1", password="pass1", session_config=None, auth_options=None), + AsyncSession(_conn=mock_conn, username="user2", password="pass2", session_config=None, auth_options=None), } config = SessionPoolConfig(size=2) pool = AsyncSessionPool(copy(sessions), config) @@ -728,7 +728,7 @@ def test_sync_pool_exception_in_context(self): """Test that sessions are properly returned even when exceptions occur in sync pool""" mock_conn = Mock() sessions = { - Session(conn=mock_conn, username="user1", password="pass1", session_config=None, auth_options=None), + Session(_conn=mock_conn, username="user1", password="pass1", session_config=None, auth_options=None), } config = SessionPoolConfig(size=1) pool = SessionPool(copy(sessions), config) @@ -747,7 +747,7 @@ async def test_async_pool_exception_in_context(self): """Test that sessions are properly returned even when exceptions occur in async pool""" mock_conn = AsyncMock() sessions = { - AsyncSession(conn=mock_conn, username="user1", password="pass1", session_config=None, auth_options=None), + AsyncSession(_conn=mock_conn, username="user1", password="pass1", session_config=None, auth_options=None), } config = SessionPoolConfig(size=1) pool = AsyncSessionPool(copy(sessions), config) @@ -765,8 +765,8 @@ def test_sync_multiple_exceptions_in_context(self): """Test multiple exceptions in sync pool context managers""" mock_conn = Mock() sessions = { - Session(conn=mock_conn, username="user1", password="pass1", session_config=None, auth_options=None), - Session(conn=mock_conn, username="user2", password="pass2", session_config=None, auth_options=None), + Session(_conn=mock_conn, username="user1", password="pass1", session_config=None, auth_options=None), + Session(_conn=mock_conn, username="user2", password="pass2", session_config=None, auth_options=None), } config = SessionPoolConfig(size=2) pool = SessionPool(copy(sessions), config) @@ -786,8 +786,8 @@ async def test_async_multiple_exceptions_in_context(self): """Test multiple exceptions in async pool context managers""" mock_conn = AsyncMock() sessions = { - AsyncSession(conn=mock_conn, username="user1", password="pass1", session_config=None, auth_options=None), - AsyncSession(conn=mock_conn, username="user2", password="pass2", session_config=None, auth_options=None), + AsyncSession(_conn=mock_conn, username="user1", password="pass1", session_config=None, auth_options=None), + AsyncSession(_conn=mock_conn, username="user2", password="pass2", session_config=None, auth_options=None), } config = SessionPoolConfig(size=2) pool = AsyncSessionPool(copy(sessions), config) @@ -821,7 +821,7 @@ def test_sync_zero_timeout(self): mock_conn = Mock() sessions = { - Session(conn=mock_conn, username="user1", password="pass1", session_config=None, auth_options=None), + Session(_conn=mock_conn, username="user1", password="pass1", session_config=None, auth_options=None), } pool = SessionPool(copy(sessions), config) @@ -838,7 +838,7 @@ async def test_async_zero_timeout(self): mock_conn = AsyncMock() sessions = { - AsyncSession(conn=mock_conn, username="user1", password="pass1", session_config=None, auth_options=None), + AsyncSession(_conn=mock_conn, username="user1", password="pass1", session_config=None, auth_options=None), } pool = AsyncSessionPool(copy(sessions), config) @@ -888,7 +888,7 @@ def test_sync_pool_with_custom_retry_interval(self): """Test sync pool behavior with custom retry interval""" mock_conn = Mock() sessions = { - Session(conn=mock_conn, username="user1", password="pass1", session_config=None, auth_options=None), + Session(_conn=mock_conn, username="user1", password="pass1", session_config=None, auth_options=None), } config = SessionPoolConfig(size=1, wait_timeout=0.3) pool = SessionPool(copy(sessions), config) @@ -907,7 +907,7 @@ async def test_async_pool_with_custom_retry_interval(self): """Test async pool behavior with custom retry interval""" mock_conn = AsyncMock() sessions = { - AsyncSession(conn=mock_conn, username="user1", password="pass1", session_config=None, auth_options=None), + AsyncSession(_conn=mock_conn, username="user1", password="pass1", session_config=None, auth_options=None), } config = SessionPoolConfig(size=1, wait_timeout=0.3) pool = AsyncSessionPool(copy(sessions), config) @@ -929,7 +929,7 @@ def test_sync_high_concurrency_stress(self): """Test sync pool under high concurrency stress""" mock_conn = Mock() sessions = { - Session(conn=mock_conn, username=f"user{i}", password=f"pass{i}", session_config=None, auth_options=None) + Session(_conn=mock_conn, username=f"user{i}", password=f"pass{i}", session_config=None, auth_options=None) for i in range(10) } config = SessionPoolConfig(size=10) @@ -969,7 +969,7 @@ async def test_async_high_concurrency_stress(self): """Test async pool under high concurrency stress""" mock_conn = AsyncMock() sessions = { - AsyncSession(conn=mock_conn, username=f"user{i}", password=f"pass{i}", session_config=None, auth_options=None) + AsyncSession(_conn=mock_conn, username=f"user{i}", password=f"pass{i}", session_config=None, auth_options=None) for i in range(10) } config = SessionPoolConfig(size=10)