Skip to content

Commit 5e6d82d

Browse files
committed
Wall clock window expiration PoC
1 parent da7d135 commit 5e6d82d

File tree

6 files changed

+105
-8
lines changed

6 files changed

+105
-8
lines changed

quixstreams/dataframe/utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from datetime import timedelta
1+
from datetime import datetime, timedelta
22
from typing import Union
33

44

@@ -22,3 +22,8 @@ def ensure_milliseconds(delta: Union[int, timedelta]) -> int:
2222
f'Timedelta must be either "int" representing milliseconds '
2323
f'or "datetime.timedelta", got "{type(delta)}"'
2424
)
25+
26+
27+
def now() -> int:
28+
# TODO: Should be UTC time
29+
return int(datetime.now().timestamp() * 1000)

quixstreams/dataframe/windows/base.py

Lines changed: 70 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,11 @@
1717
from typing_extensions import TypeAlias
1818

1919
from quixstreams.context import message_context
20-
from quixstreams.core.stream import TransformExpandedCallback
20+
from quixstreams.core.stream import (
21+
Stream,
22+
TransformExpandedCallback,
23+
TransformFunction,
24+
)
2125
from quixstreams.core.stream.exceptions import InvalidOperation
2226
from quixstreams.models.topics.manager import TopicManager
2327
from quixstreams.state import WindowedPartitionTransaction
@@ -42,6 +46,8 @@
4246
Iterable[Message],
4347
]
4448

49+
WallClockCallback = Callable[[WindowedPartitionTransaction], Iterable[Message]]
50+
4551

4652
class Window(abc.ABC):
4753
def __init__(
@@ -69,6 +75,13 @@ def process_window(
6975
) -> tuple[Iterable[WindowKeyResult], Iterable[WindowKeyResult]]:
7076
pass
7177

78+
@abstractmethod
79+
def process_wall_clock(
80+
self,
81+
transaction: WindowedPartitionTransaction,
82+
) -> Iterable[WindowKeyResult]:
83+
pass
84+
7285
def register_store(self) -> None:
7386
TopicManager.ensure_topics_copartitioned(*self._dataframe.topics)
7487
# Create a config for the changelog topic based on the underlying SDF topics
@@ -83,6 +96,7 @@ def _apply_window(
8396
self,
8497
func: TransformRecordCallbackExpandedWindowed,
8598
name: str,
99+
wall_clock_func: WallClockCallback,
86100
) -> "StreamingDataFrame":
87101
self.register_store()
88102

@@ -92,12 +106,24 @@ def _apply_window(
92106
processing_context=self._dataframe.processing_context,
93107
store_name=name,
94108
)
109+
wall_clock_transform_func = _as_wall_clock(
110+
func=wall_clock_func,
111+
stream_id=self._dataframe.stream_id,
112+
processing_context=self._dataframe.processing_context,
113+
store_name=name,
114+
)
95115
# Manually modify the Stream and clone the source StreamingDataFrame
96116
# to avoid adding "transform" API to it.
97117
# Transform callbacks can modify record key and timestamp,
98118
# and it's prone to misuse.
99-
stream = self._dataframe.stream.add_transform(func=windowed_func, expand=True)
100-
return self._dataframe.__dataframe_clone__(stream=stream)
119+
windowed_stream = self._dataframe.stream.add_transform(
120+
func=windowed_func, expand=True
121+
)
122+
wall_clock_stream = Stream(
123+
func=TransformFunction(wall_clock_transform_func, expand=True)
124+
)
125+
sdf = self._dataframe.__dataframe_clone__(stream=windowed_stream)
126+
return sdf.concat_wall_clock(wall_clock_stream)
101127

102128
def final(self) -> "StreamingDataFrame":
103129
"""
@@ -140,9 +166,17 @@ def window_callback(
140166
for key, window in expired_windows:
141167
yield (window, key, window["start"], None)
142168

169+
def wall_clock_callback(
170+
transaction: WindowedPartitionTransaction,
171+
) -> Iterable[Message]:
172+
# TODO: Check if this will work for sliding windows
173+
for key, window in self.process_wall_clock(transaction):
174+
yield (window, key, window["start"], None)
175+
143176
return self._apply_window(
144177
func=window_callback,
145178
name=self._name,
179+
wall_clock_func=wall_clock_callback,
146180
)
147181

148182
def current(self) -> "StreamingDataFrame":
@@ -188,7 +222,17 @@ def window_callback(
188222
for key, window in updated_windows:
189223
yield (window, key, window["start"], None)
190224

191-
return self._apply_window(func=window_callback, name=self._name)
225+
def wall_clock_callback(
226+
transaction: WindowedPartitionTransaction,
227+
) -> Iterable[Message]:
228+
# TODO: Implement wall_clock callback
229+
return []
230+
231+
return self._apply_window(
232+
func=window_callback,
233+
name=self._name,
234+
wall_clock_func=wall_clock_callback,
235+
)
192236

193237
# Implemented by SingleAggregationWindowMixin and MultiAggregationWindowMixin
194238
# Single aggregation and multi aggregation windows store aggregations and collections
@@ -424,6 +468,28 @@ def wrapper(
424468
return wrapper
425469

426470

471+
def _as_wall_clock(
472+
func: WallClockCallback,
473+
processing_context: "ProcessingContext",
474+
store_name: str,
475+
stream_id: str,
476+
) -> TransformExpandedCallback:
477+
@functools.wraps(func)
478+
def wrapper(
479+
value: Any, key: Any, timestamp: int, headers: Any
480+
) -> Iterable[Message]:
481+
ctx = message_context()
482+
transaction = cast(
483+
WindowedPartitionTransaction,
484+
processing_context.checkpoint.get_store_transaction(
485+
stream_id=stream_id, partition=ctx.partition, store_name=store_name
486+
),
487+
)
488+
return func(transaction)
489+
490+
return wrapper
491+
492+
427493
class WindowOnLateCallback(Protocol):
428494
def __call__(
429495
self,

quixstreams/dataframe/windows/count_based.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,12 @@ def process_window(
189189
state.set(key=self.STATE_KEY, value=data)
190190
return updated_windows, expired_windows
191191

192+
def process_wall_clock(
193+
self,
194+
transaction: WindowedPartitionTransaction,
195+
) -> Iterable[WindowKeyResult]:
196+
return []
197+
192198
def _get_collection_start_id(self, window: CountWindowData) -> int:
193199
start_id = window.get("collection_start_id", _MISSING)
194200
if start_id is _MISSING:

quixstreams/dataframe/windows/time_based.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from typing import TYPE_CHECKING, Any, Iterable, Literal, Optional
44

55
from quixstreams.context import message_context
6+
from quixstreams.dataframe.utils import now
67
from quixstreams.state import WindowedPartitionTransaction, WindowedState
78

89
from .base import (
@@ -200,11 +201,23 @@ def process_window(
200201

201202
return updated_windows, expired_windows
202203

204+
def process_wall_clock(
205+
self,
206+
transaction: WindowedPartitionTransaction,
207+
) -> Iterable[WindowKeyResult]:
208+
return self.expire_by_partition(
209+
transaction=transaction,
210+
max_expired_end=now() - self._grace_ms,
211+
collect=self.collect,
212+
advance_last_expired_timestamp=False,
213+
)
214+
203215
def expire_by_partition(
204216
self,
205217
transaction: WindowedPartitionTransaction,
206218
max_expired_end: int,
207219
collect: bool,
220+
advance_last_expired_timestamp: bool = True,
208221
) -> Iterable[WindowKeyResult]:
209222
for (
210223
window_start,
@@ -214,6 +227,7 @@ def expire_by_partition(
214227
step_ms=self._step_ms if self._step_ms else self._duration_ms,
215228
collect=collect,
216229
delete=True,
230+
advance_last_expired_timestamp=advance_last_expired_timestamp,
217231
):
218232
yield key, self._results(aggregated, collected, window_start, window_end)
219233

quixstreams/state/rocksdb/windowed/transaction.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,7 @@ def expire_all_windows(
298298
step_ms: int,
299299
delete: bool = True,
300300
collect: bool = False,
301+
advance_last_expired_timestamp: bool = True,
301302
) -> Iterable[ExpiredWindowDetail]:
302303
"""
303304
Get all expired windows for all prefix from RocksDB up to the specified `max_end_time` timestamp.
@@ -360,9 +361,12 @@ def expire_all_windows(
360361
if collect:
361362
self.delete_from_collection(end=start, prefix=prefix)
362363

363-
self._set_timestamp(
364-
prefix=b"", cache=self._last_expired_timestamps, timestamp_ms=last_expired
365-
)
364+
if advance_last_expired_timestamp:
365+
self._set_timestamp(
366+
prefix=b"",
367+
cache=self._last_expired_timestamps,
368+
timestamp_ms=last_expired,
369+
)
366370

367371
def delete_windows(
368372
self, max_start_time: int, delete_values: bool, prefix: bytes

quixstreams/state/types.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,7 @@ def expire_all_windows(
378378
step_ms: int,
379379
delete: bool = True,
380380
collect: bool = False,
381+
advance_last_expired_timestamp: bool = True,
381382
) -> Iterable[ExpiredWindowDetail[V]]:
382383
"""
383384
Get all expired windows for all prefix from RocksDB up to the specified `max_start_time` timestamp.
@@ -388,6 +389,7 @@ def expire_all_windows(
388389
:param max_end_time: The timestamp up to which windows are considered expired, inclusive.
389390
:param delete: If True, expired windows will be deleted.
390391
:param collect: If True, values will be collected into windows.
392+
:param advance_last_expired_timestamp: If True, the last expired timestamp will be persisted.
391393
"""
392394
...
393395

0 commit comments

Comments
 (0)