Skip to content

Commit e5f8887

Browse files
committed
Wall clock window expiration PoC
1 parent dda7ccf commit e5f8887

File tree

5 files changed

+103
-7
lines changed

5 files changed

+103
-7
lines changed

quixstreams/dataframe/windows/base.py

Lines changed: 70 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,12 @@
1717
from typing_extensions import TypeAlias
1818

1919
from quixstreams.context import message_context
20-
from quixstreams.core.stream import TransformExpandedCallback
20+
from quixstreams.core.stream import (
21+
Stream,
22+
TransformExpandedCallback,
23+
TransformFunction,
24+
TransformWallClockExpandedCallback,
25+
)
2126
from quixstreams.core.stream.exceptions import InvalidOperation
2227
from quixstreams.models.topics.manager import TopicManager
2328
from quixstreams.state import WindowedPartitionTransaction
@@ -42,6 +47,8 @@
4247
Iterable[Message],
4348
]
4449

50+
WallClockCallback = Callable[[int, WindowedPartitionTransaction], Iterable[Message]]
51+
4552

4653
class Window(abc.ABC):
4754
def __init__(
@@ -69,6 +76,14 @@ def process_window(
6976
) -> tuple[Iterable[WindowKeyResult], Iterable[WindowKeyResult]]:
7077
pass
7178

79+
@abstractmethod
80+
def process_wall_clock(
81+
self,
82+
timestamp_ms: int,
83+
transaction: WindowedPartitionTransaction,
84+
) -> Iterable[WindowKeyResult]:
85+
pass
86+
7287
def register_store(self) -> None:
7388
TopicManager.ensure_topics_copartitioned(*self._dataframe.topics)
7489
# Create a config for the changelog topic based on the underlying SDF topics
@@ -83,6 +98,7 @@ def _apply_window(
8398
self,
8499
func: TransformRecordCallbackExpandedWindowed,
85100
name: str,
101+
wall_clock_func: WallClockCallback,
86102
) -> "StreamingDataFrame":
87103
self.register_store()
88104

@@ -92,12 +108,24 @@ def _apply_window(
92108
processing_context=self._dataframe.processing_context,
93109
store_name=name,
94110
)
111+
wall_clock_transform_func = _as_wall_clock(
112+
func=wall_clock_func,
113+
stream_id=self._dataframe.stream_id,
114+
processing_context=self._dataframe.processing_context,
115+
store_name=name,
116+
)
95117
# Manually modify the Stream and clone the source StreamingDataFrame
96118
# to avoid adding "transform" API to it.
97119
# Transform callbacks can modify record key and timestamp,
98120
# and it's prone to misuse.
99-
stream = self._dataframe.stream.add_transform(func=windowed_func, expand=True)
100-
return self._dataframe.__dataframe_clone__(stream=stream)
121+
windowed_stream = self._dataframe.stream.add_transform(
122+
func=windowed_func, expand=True
123+
)
124+
wall_clock_stream = Stream(
125+
TransformFunction(wall_clock_transform_func, expand=True, wall_clock=True)
126+
)
127+
sdf = self._dataframe.__dataframe_clone__(stream=windowed_stream)
128+
return sdf.concat_wall_clock(wall_clock_stream)
101129

102130
def final(self) -> "StreamingDataFrame":
103131
"""
@@ -140,9 +168,17 @@ def window_callback(
140168
for key, window in expired_windows:
141169
yield (window, key, window["start"], None)
142170

171+
def wall_clock_callback(
172+
timestamp: int, transaction: WindowedPartitionTransaction
173+
) -> Iterable[Message]:
174+
# TODO: Check if this will work for sliding windows
175+
for key, window in self.process_wall_clock(timestamp, transaction):
176+
yield (window, key, window["start"], None)
177+
143178
return self._apply_window(
144179
func=window_callback,
145180
name=self._name,
181+
wall_clock_func=wall_clock_callback,
146182
)
147183

148184
def current(self) -> "StreamingDataFrame":
@@ -188,7 +224,17 @@ def window_callback(
188224
for key, window in updated_windows:
189225
yield (window, key, window["start"], None)
190226

191-
return self._apply_window(func=window_callback, name=self._name)
227+
def wall_clock_callback(
228+
timestamp: int, transaction: WindowedPartitionTransaction
229+
) -> Iterable[Message]:
230+
# TODO: Implement wall_clock callback
231+
return []
232+
233+
return self._apply_window(
234+
func=window_callback,
235+
name=self._name,
236+
wall_clock_func=wall_clock_callback,
237+
)
192238

193239
# Implemented by SingleAggregationWindowMixin and MultiAggregationWindowMixin
194240
# Single aggregation and multi aggregation windows store aggregations and collections
@@ -424,6 +470,26 @@ def wrapper(
424470
return wrapper
425471

426472

473+
def _as_wall_clock(
474+
func: WallClockCallback,
475+
processing_context: "ProcessingContext",
476+
store_name: str,
477+
stream_id: str,
478+
) -> TransformWallClockExpandedCallback:
479+
@functools.wraps(func)
480+
def wrapper(timestamp: int) -> Iterable[Message]:
481+
ctx = message_context()
482+
transaction = cast(
483+
WindowedPartitionTransaction,
484+
processing_context.checkpoint.get_store_transaction(
485+
stream_id=stream_id, partition=ctx.partition, store_name=store_name
486+
),
487+
)
488+
return func(timestamp, transaction)
489+
490+
return wrapper
491+
492+
427493
class WindowOnLateCallback(Protocol):
428494
def __call__(
429495
self,

quixstreams/dataframe/windows/count_based.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,13 @@ def process_window(
189189
state.set(key=self.STATE_KEY, value=data)
190190
return updated_windows, expired_windows
191191

192+
def process_wall_clock(
193+
self,
194+
timestamp_ms: int,
195+
transaction: WindowedPartitionTransaction,
196+
) -> Iterable[WindowKeyResult]:
197+
return []
198+
192199
def _get_collection_start_id(self, window: CountWindowData) -> int:
193200
start_id = window.get("collection_start_id", _MISSING)
194201
if start_id is _MISSING:

quixstreams/dataframe/windows/time_based.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,11 +200,27 @@ def process_window(
200200

201201
return updated_windows, expired_windows
202202

203+
def process_wall_clock(
204+
self,
205+
timestamp_ms: int,
206+
transaction: WindowedPartitionTransaction,
207+
) -> Iterable[WindowKeyResult]:
208+
latest_expired_window_end = transaction.get_latest_expired(prefix=b"")
209+
latest_timestamp = max(timestamp_ms, latest_expired_window_end)
210+
max_expired_window_end = latest_timestamp - self._grace_ms
211+
return self.expire_by_partition(
212+
transaction,
213+
max_expired_window_end,
214+
self.collect,
215+
advance_last_expired_timestamp=False,
216+
)
217+
203218
def expire_by_partition(
204219
self,
205220
transaction: WindowedPartitionTransaction,
206221
max_expired_end: int,
207222
collect: bool,
223+
advance_last_expired_timestamp: bool = True,
208224
) -> Iterable[WindowKeyResult]:
209225
for (
210226
window_start,
@@ -214,6 +230,7 @@ def expire_by_partition(
214230
step_ms=self._step_ms if self._step_ms else self._duration_ms,
215231
collect=collect,
216232
delete=True,
233+
advance_last_expired_timestamp=advance_last_expired_timestamp,
217234
):
218235
yield key, self._results(aggregated, collected, window_start, window_end)
219236

quixstreams/state/rocksdb/windowed/transaction.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,7 @@ def expire_all_windows(
298298
step_ms: int,
299299
delete: bool = True,
300300
collect: bool = False,
301+
advance_last_expired_timestamp: bool = True,
301302
) -> Iterable[ExpiredWindowDetail]:
302303
"""
303304
Get all expired windows for all prefix from RocksDB up to the specified `max_end_time` timestamp.
@@ -360,9 +361,12 @@ def expire_all_windows(
360361
if collect:
361362
self.delete_from_collection(end=start, prefix=prefix)
362363

363-
self._set_timestamp(
364-
prefix=b"", cache=self._last_expired_timestamps, timestamp_ms=last_expired
365-
)
364+
if advance_last_expired_timestamp:
365+
self._set_timestamp(
366+
prefix=b"",
367+
cache=self._last_expired_timestamps,
368+
timestamp_ms=last_expired,
369+
)
366370

367371
def delete_windows(
368372
self, max_start_time: int, delete_values: bool, prefix: bytes

quixstreams/state/types.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,7 @@ def expire_all_windows(
378378
step_ms: int,
379379
delete: bool = True,
380380
collect: bool = False,
381+
advance_last_expired_timestamp: bool = True,
381382
) -> Iterable[ExpiredWindowDetail[V]]:
382383
"""
383384
Get all expired windows for all prefix from RocksDB up to the specified `max_start_time` timestamp.
@@ -388,6 +389,7 @@ def expire_all_windows(
388389
:param max_end_time: The timestamp up to which windows are considered expired, inclusive.
389390
:param delete: If True, expired windows will be deleted.
390391
:param collect: If True, values will be collected into windows.
392+
:param advance_last_expired_timestamp: If True, the last expired timestamp will be persisted.
391393
"""
392394
...
393395

0 commit comments

Comments
 (0)