Skip to content

Commit ab220fc

Browse files
author
Hoang Phan
committed
Add fix for serializer
1 parent 07ac02a commit ab220fc

File tree

3 files changed

+181
-59
lines changed

3 files changed

+181
-59
lines changed

quixstreams/dataframe/windows/time_based.py

Lines changed: 51 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -293,11 +293,11 @@ def _on_expired_window(
293293
class SessionWindow(Window):
294294
"""
295295
Session window groups events that occur within a specified timeout period.
296-
296+
297297
A session starts with the first event and extends each time a new event arrives
298298
within the timeout period. The session closes after the timeout period with no
299299
new events.
300-
300+
301301
Each session window can have different start and end times based on the actual
302302
events, making sessions dynamic rather than fixed-time intervals.
303303
"""
@@ -419,23 +419,25 @@ def process_window(
419419
late_by_ms=late_by_ms,
420420
)
421421
return [], []
422-
422+
423423
# Look for an existing session that can be extended
424424
session_start = None
425425
session_end = None
426426
can_extend_session = False
427427
existing_aggregated = None
428428
old_window_to_delete = None
429-
429+
430430
# Search for active sessions that can accommodate the new event
431431
search_start = max(0, timestamp_ms - timeout_ms * 2)
432-
windows = state.get_windows(search_start, timestamp_ms + timeout_ms + 1, backwards=True)
433-
432+
windows = state.get_windows(
433+
search_start, timestamp_ms + timeout_ms + 1, backwards=True
434+
)
435+
434436
for (window_start, window_end), aggregated_value, _ in windows:
435437
# Calculate the time gap between the new event and the session's last activity
436438
session_last_activity = window_end - timeout_ms
437439
time_gap = timestamp_ms - session_last_activity
438-
440+
439441
# Check if this session can be extended
440442
if time_gap <= timeout_ms + grace_ms and timestamp_ms >= window_start:
441443
session_start = window_start
@@ -452,12 +454,12 @@ def process_window(
452454

453455
# Process the event for this session
454456
updated_windows: list[WindowKeyResult] = []
455-
457+
456458
# Delete the old window if extending an existing session
457459
if can_extend_session and old_window_to_delete:
458460
old_start, old_end = old_window_to_delete
459-
transaction.delete_window(old_start, old_end, prefix=key)
460-
461+
transaction.delete_window(old_start, old_end, prefix=state._prefix) # type: ignore # noqa: SLF001
462+
461463
# Add to collection if needed
462464
if collect:
463465
state.add_to_collection(
@@ -473,7 +475,11 @@ def process_window(
473475
current_value = self._initialize_value()
474476

475477
aggregated = self._aggregate_value(current_value, value, timestamp_ms)
476-
478+
479+
# By this point, session_start and session_end are guaranteed to be set
480+
assert session_start is not None # noqa: S101
481+
assert session_end is not None # noqa: S101
482+
477483
# Output intermediate results for aggregations
478484
if aggregate:
479485
updated_windows.append(
@@ -482,8 +488,10 @@ def process_window(
482488
self._results(aggregated, [], session_start, session_end),
483489
)
484490
)
485-
486-
state.update_window(session_start, session_end, value=aggregated, timestamp_ms=timestamp_ms)
491+
492+
state.update_window(
493+
session_start, session_end, value=aggregated, timestamp_ms=timestamp_ms
494+
)
487495

488496
# Expire old sessions
489497
if self._closing_strategy == ClosingStrategy.PARTITION:
@@ -508,13 +516,13 @@ def expire_sessions_by_partition(
508516

509517
# Import the parsing function to extract message keys from window keys
510518
from quixstreams.state.rocksdb.windowed.serialization import parse_window_key
511-
519+
512520
expired_results = []
513-
521+
514522
# Collect all keys and extract unique prefixes to avoid iteration conflicts
515523
all_keys = list(transaction.keys())
516524
seen_prefixes = set()
517-
525+
518526
for key_bytes in all_keys:
519527
try:
520528
prefix, start_ms, end_ms = parse_window_key(key_bytes)
@@ -523,21 +531,23 @@ def expire_sessions_by_partition(
523531
except (ValueError, IndexError):
524532
# Skip invalid window key formats
525533
continue
526-
534+
527535
# Expire sessions for each unique prefix
528536
for prefix in seen_prefixes:
529537
state = transaction.as_state(prefix=prefix)
530-
prefix_expired = list(self.expire_sessions_by_key(
531-
prefix, state, expiry_threshold, collect
532-
))
538+
prefix_expired = list(
539+
self.expire_sessions_by_key(prefix, state, expiry_threshold, collect)
540+
)
533541
expired_results.extend(prefix_expired)
534542
count += len(prefix_expired)
535543

536544
if count:
537545
logger.debug(
538-
"Expired %s session windows in %ss", count, round(time.monotonic() - start, 2)
546+
"Expired %s session windows in %ss",
547+
count,
548+
round(time.monotonic() - start, 2),
539549
)
540-
550+
541551
return expired_results
542552

543553
def expire_sessions_by_key(
@@ -551,29 +561,40 @@ def expire_sessions_by_key(
551561
count = 0
552562

553563
# Get all windows and check which ones have expired
554-
all_windows = list(state.get_windows(0, expiry_threshold + self._timeout_ms, backwards=False))
555-
564+
all_windows = list(
565+
state.get_windows(0, expiry_threshold + self._timeout_ms, backwards=False)
566+
)
567+
556568
windows_to_delete = []
557569
for (window_start, window_end), aggregated, _ in all_windows:
558570
# Session expires when the session end time has passed the expiry threshold
559571
if window_end <= expiry_threshold:
560572
collected = []
561573
if collect:
562574
collected = state.get_from_collection(window_start, window_end)
563-
575+
564576
windows_to_delete.append((window_start, window_end))
565577
count += 1
566-
yield (key, self._results(aggregated, collected, window_start, window_end))
578+
yield (
579+
key,
580+
self._results(aggregated, collected, window_start, window_end),
581+
)
567582

568583
# Clean up expired windows
569584
for window_start, window_end in windows_to_delete:
570-
state._transaction.delete_window(window_start, window_end, prefix=state._prefix)
585+
state._transaction.delete_window( # type: ignore # noqa: SLF001
586+
window_start,
587+
window_end,
588+
prefix=state._prefix, # type: ignore # noqa: SLF001
589+
)
571590
if collect:
572591
state.delete_from_collection(window_end, start=window_start)
573592

574593
if count:
575594
logger.debug(
576-
"Expired %s session windows in %ss", count, round(time.monotonic() - start, 2)
595+
"Expired %s session windows in %ss",
596+
count,
597+
round(time.monotonic() - start, 2),
577598
)
578599

579600
def _on_expired_session(
@@ -595,9 +616,9 @@ def _on_expired_session(
595616
topic = "unknown"
596617
partition = -1
597618
offset = -1
598-
619+
599620
to_log = True
600-
621+
601622
# Trigger the "on_late" callback if provided
602623
if self._on_late:
603624
to_log = self._on_late(

quixstreams/state/types.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -391,6 +391,18 @@ def expire_all_windows(
391391
"""
392392
...
393393

394+
def delete_window(self, start_ms: int, end_ms: int, prefix: bytes) -> None:
395+
"""
396+
Delete a specific window from RocksDB.
397+
398+
This method removes a single window entry with the specified start and end timestamps.
399+
400+
:param start_ms: The start timestamp of the window to delete
401+
:param end_ms: The end timestamp of the window to delete
402+
:param prefix: The key prefix for the window
403+
"""
404+
...
405+
394406
def delete_windows(
395407
self, max_start_time: int, delete_values: bool, prefix: bytes
396408
) -> None:

0 commit comments

Comments
 (0)