Skip to content

Commit 0ec7950

Browse files
committed
Refactor state_keys method return types in context and server_context modules
1 parent 8d2a144 commit 0ec7950

File tree

2 files changed

+9
-10
lines changed

2 files changed

+9
-10
lines changed

python/restate/context.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def get(self,
9898
"""
9999

100100
@abc.abstractmethod
101-
def state_keys(self) -> RestateDurableFuture[List[str]]:
101+
def state_keys(self) -> Awaitable[List[str]]:
102102
"""Returns the list of keys in the store."""
103103

104104
@abc.abstractmethod
@@ -323,7 +323,7 @@ def get(self,
323323
"""
324324

325325
@abc.abstractmethod
326-
def state_keys(self) -> RestateDurableFuture[List[str]]:
326+
def state_keys(self) -> Awaitable[List[str]]:
327327
"""
328328
Returns the list of keys in the store.
329329
"""

python/restate/server_context.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ def __init__(self,
210210
self.attempt_headers = attempt_headers
211211
self.send = send
212212
self.receive = receive
213-
self.run_coros_to_execute: dict[int, Callable[[], Awaitable[bytes | Failure]]] = {}
213+
self.run_coros_to_execute: dict[int, Callable[[], Awaitable[None]]] = {}
214214
self.sync_point = SyncPoint()
215215

216216
async def enter(self):
@@ -330,9 +330,11 @@ async def fetch_result():
330330
if not self.vm.is_completed(handle):
331331
await self.create_poll_or_cancel_coroutine([handle])
332332
res = self.must_take_notification(handle)
333-
if res is None or serde is None or not isinstance(res, bytes):
333+
if res is None or serde is None:
334334
return res
335-
return serde.deserialize(res)
335+
if isinstance(res, bytes):
336+
return serde.deserialize(res)
337+
return res
336338

337339
return fetch_result
338340

@@ -361,8 +363,8 @@ def get(self, name: str, serde: Serde[T] = JsonSerde()) -> Awaitable[Optional[T]
361363
handle = self.vm.sys_get_state(name)
362364
return self.create_future(handle, serde) # type: ignore
363365

364-
def state_keys(self) -> RestateDurableFuture[List[str]]:
365-
return self.create_future(self.vm.sys_get_state_keys()) # type: ignore
366+
def state_keys(self) -> Awaitable[List[str]]:
367+
return self.create_future(self.vm.sys_get_state_keys())
366368

367369
def set(self, name: str, value: T, serde: Serde[T] = JsonSerde()) -> None:
368370
"""Set the value associated with the given name."""
@@ -398,11 +400,9 @@ async def create_run_coroutine(self,
398400

399401
buffer = serde.serialize(action_result)
400402
self.vm.propose_run_completion_success(handle, buffer)
401-
return buffer
402403
except TerminalError as t:
403404
failure = Failure(code=t.status_code, message=t.message)
404405
self.vm.propose_run_completion_failure(handle, failure)
405-
return failure
406406
# pylint: disable=W0718
407407
except Exception as e:
408408
if max_attempts is None and max_retry_duration is None:
@@ -412,7 +412,6 @@ async def create_run_coroutine(self,
412412
max_duration_ms = None if max_retry_duration is None else int(max_retry_duration.total_seconds() * 1000)
413413
config = RunRetryConfig(max_attempts=max_attempts, max_duration=max_duration_ms)
414414
self.vm.propose_run_completion_transient(handle, failure=failure, attempt_duration_ms=1, config=config)
415-
return failure
416415
# pylint: disable=W0236
417416
# pylint: disable=R0914
418417
def run(self,

0 commit comments

Comments
 (0)