Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions docs/2_concurrency.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,26 @@ async def concurrent_example():
asyncio.run(concurrent_example())
```

## Contextual Execution

By default, statements run on a random session from the pool. When you need to run several queries on the same session, call `borrow` to obtain and reuse a specific session.

```python
async def contextual_example():
async with await NebulaAsyncClient.connect(
hosts=["127.0.0.1:9669"],
username="root",
password="NebulaGraph01",
session_pool_config=SessionPoolConfig(),
) as client:
print("Connected to the server...")
async with client.borrow() as session:
await session.execute("SESSION SET GRAPH movie")
res = await session.execute("MATCH (v:Movie) RETURN count(v)")
res.print()
```


## Understanding Timeout Values

The client uses three different timeouts that apply at different stages:
Expand Down
8 changes: 8 additions & 0 deletions src/nebulagraph_python/client/_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@ class Connection:

# Config
config: ConnectionConfig
# Track which host was successfully connected for session routing
connected: HostAddress | None = field(default=None, init=False)

# Owned Resources
_stub: Optional[graph_pb2_grpc.GraphServiceStub] = field(default=None, init=False)
Expand Down Expand Up @@ -152,6 +154,8 @@ def connect(self):
logger.info(
f"Successfully connected to {host_addr.host}:{host_addr.port}."
)
# Remember which host we actually connected to
self.connected = host_addr
return
except Exception as e:
logger.warning(
Expand All @@ -174,6 +178,7 @@ def close(self):
self._channel.close()
self._channel = None
self._stub = None
self.connected = None
except Exception:
logger.exception("Failed to close connection")

Expand Down Expand Up @@ -303,6 +308,7 @@ class AsyncConnection:
"""

config: ConnectionConfig
connected: HostAddress | None = None
_stub: Optional[graph_pb2_grpc.GraphServiceStub] = field(default=None, init=False)
_channel: Optional[grpc.aio.Channel] = field(
default=None, init=False
Expand Down Expand Up @@ -358,6 +364,7 @@ async def connect(self):
logger.info(
f"Successfully connected to {host_addr.host}:{host_addr.port} asynchronously."
)
self.connected = host_addr
return
except Exception as e:
logger.warning(
Expand All @@ -380,6 +387,7 @@ async def close(self):
await self._channel.close()
self._channel = None
self._stub = None
self.connected = None
except BaseException:
logger.exception("Failed to close async connection")

Expand Down
28 changes: 15 additions & 13 deletions src/nebulagraph_python/client/_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from nebulagraph_python.error import ExecutingError


@dataclass
@dataclass(kw_only=True, frozen=True)
class SessionConfig:
schema: Optional[str] = None
graph: Optional[str] = None
Expand All @@ -47,31 +47,32 @@ class SessionBase:

@dataclass
class Session(SessionBase):
conn: "Connection"
_conn: "Connection"

def execute(
self, statement: str, *, timeout: Optional[float] = None, do_ping: bool = False
):
res = self.conn.execute(
res = self._conn.execute(
self._session, statement, timeout=timeout, do_ping=do_ping
)
# Retry for only one time
if res.status_code == ErrorCode.SESSION_NOT_FOUND.value:
self._session = self.conn.authenticate(
self._session = self._conn.authenticate(
self.username,
self.password,
session_config=self.session_config,
auth_options=self.auth_options,
)
res = self.conn.execute(
res = self._conn.execute(
self._session, statement, timeout=timeout, do_ping=do_ping
)
res.raise_on_error()
return res

def close(self):
def _close(self):
"""Close session"""
try:
self.conn.execute(self._session, "SESSION CLOSE")
self._conn.execute(self._session, "SESSION CLOSE")
except Exception:
logger.exception("Failed to close session")

Expand All @@ -84,30 +85,31 @@ def __eq__(self, other):

@dataclass
class AsyncSession(SessionBase):
conn: "AsyncConnection"
_conn: "AsyncConnection"

async def execute(
self, statement: str, *, timeout: Optional[float] = None, do_ping: bool = False
):
res = await self.conn.execute(
res = await self._conn.execute(
self._session, statement, timeout=timeout, do_ping=do_ping
)
# Retry for only one time
if res.status_code == ErrorCode.SESSION_NOT_FOUND.value:
self._session = await self.conn.authenticate(
self._session = await self._conn.authenticate(
self.username,
self.password,
session_config=self.session_config,
auth_options=self.auth_options,
)
res = await self.conn.execute(
res = await self._conn.execute(
self._session, statement, timeout=timeout, do_ping=do_ping
)
res.raise_on_error()
return res

async def close(self):
async def _close(self):
try:
await self.conn.execute(self._session, "SESSION CLOSE")
await self._conn.execute(self._session, "SESSION CLOSE")
except Exception:
logger.exception("Failed to close async session")

Expand Down
16 changes: 8 additions & 8 deletions src/nebulagraph_python/client/_session_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ async def connect(
except Exception:
# Clean up any sessions that were successfully created
for session in sessions:
await session.close()
await session._close()
raise

def __init__(
Expand Down Expand Up @@ -157,20 +157,20 @@ async def borrow(self):
self.busy_sessions_queue.remove(got_session)
self.queue_count.release()

async def close(self):
async def _close(self):
# Acquire all semaphore permits to prevent new borrows
for _ in range(self.config.size):
await self.queue_count.acquire()
async with self.queue_lock:
# Close all free sessions
for session in self.free_sessions_queue:
await session.close()
await session._close()
# Close all busy sessions (if any remain)
for session in self.busy_sessions_queue:
logger.error(
"Busy sessions remain after acquire all semaphore permits, which indicates a bug in the AsyncSessionPool"
)
await session.close()
await session._close()


class SessionPool:
Expand Down Expand Up @@ -209,7 +209,7 @@ def connect(
except Exception:
# Clean up any sessions that were successfully created
for session in sessions:
session.close()
session._close()
raise

def __init__(
Expand Down Expand Up @@ -273,17 +273,17 @@ def borrow(self):
self.busy_sessions_queue.remove(got_session)
self.queue_count.release()

def close(self):
def _close(self):
# Acquire all semaphore permits to prevent new borrows
for _ in range(self.config.size):
self.queue_count.acquire()
with self.queue_lock:
# Close all free sessions
for session in self.free_sessions_queue:
session.close()
session._close()
# Close all busy sessions (if any remain)
for session in self.busy_sessions_queue:
logger.error(
"Busy sessions remain after acquire all semaphore permits, which indicates a bug in the SessionPool"
)
session.close()
session._close()
69 changes: 35 additions & 34 deletions src/nebulagraph_python/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
# limitations under the License.

import logging
from collections.abc import AsyncGenerator, Generator
from contextlib import asynccontextmanager, contextmanager
from typing import Any, Dict, List, Literal, Optional, Union

from nebulagraph_python.client._connection import (
Expand Down Expand Up @@ -120,7 +122,7 @@ async def connect(
)
else:
self._sessions[host_addr] = AsyncSession(
conn=conn,
_conn=conn,
username=username,
password=password,
session_config=session_config or SessionConfig(),
Expand All @@ -134,40 +136,31 @@ async def connect(
async def execute(
self, statement: str, *, timeout: Optional[float] = None, do_ping: bool = False
) -> ResultSet:
async with self.borrow() as session:
return await session.execute(statement, timeout=timeout, do_ping=do_ping)

@asynccontextmanager
async def borrow(self) -> AsyncGenerator[AsyncSession, None]:
if isinstance(self._conn, AsyncConnectionPool):
addr, _conn = await self._conn.next_connection()
addr, conn = await self._conn.next_connection()
else:
addr = self._conn.config.hosts[0]
conn = self._conn
addr = conn.connected
if addr is None:
raise ValueError("Connection not connected")

_session = self._sessions[addr]

if isinstance(_session, AsyncSessionPool):
async with _session.borrow() as session:
return (
await session.execute(statement, timeout=timeout, do_ping=do_ping)
).raise_on_error()
yield session
else:
return (
await _session.execute(statement, timeout=timeout, do_ping=do_ping)
).raise_on_error()

async def ping(self, timeout: Optional[float] = None) -> bool:
try:
res = (
(await self.execute(statement="RETURN 1", timeout=timeout))
.one()
.as_primitive()
)
if not res == {"1": 1}:
raise ValueError(f"Unexpected result from ping: {res}")
return True
except Exception:
logger.exception("Failed to ping NebulaGraph")
return False
yield _session

async def close(self):
"""Close the client connection and session. No Exception will be raised but an error will be logged."""
for session in self._sessions.values():
await session.close()
await session._close()
await self._conn.close()

async def __aenter__(self):
Expand Down Expand Up @@ -245,7 +238,7 @@ def __init__(
)
else:
self._sessions[host_addr] = Session(
conn=conn,
_conn=conn,
username=username,
password=password,
session_config=session_config or SessionConfig(),
Expand All @@ -258,21 +251,29 @@ def __init__(
def execute(
self, statement: str, *, timeout: Optional[float] = None, do_ping: bool = False
) -> ResultSet:
"""Execute a statement using a borrowed session, raising on errors."""
with self.borrow() as session:
return session.execute(statement, timeout=timeout, do_ping=do_ping)

@contextmanager
def borrow(self) -> Generator[Session, None, None]:
"""Yield a session bound to the selected connection."""
if isinstance(self._conn, ConnectionPool):
addr, _conn = self._conn.next_connection()
addr, conn = self._conn.next_connection()
else:
addr = self._conn.config.hosts[0]
conn = self._conn
addr = conn.connected
if addr is None:
raise ValueError("Connection not connected")

# Route to the correct session (pool or single session)
_session = self._sessions[addr]

if isinstance(_session, SessionPool):
with _session.borrow() as session:
return session.execute(
statement, timeout=timeout, do_ping=do_ping
).raise_on_error()
yield session
else:
return _session.execute(
statement, timeout=timeout, do_ping=do_ping
).raise_on_error()
yield _session

def ping(self, timeout: Optional[float] = None) -> bool:
try:
Expand All @@ -291,7 +292,7 @@ def ping(self, timeout: Optional[float] = None) -> bool:
def close(self):
"""Close the client connection and session. No Exception will be raised but an error will be logged."""
for session in self._sessions.values():
session.close()
session._close()
self._conn.close()

def __enter__(self):
Expand Down
Loading