Skip to content

Commit d93de6d

Browse files
Hoang Phangwaramadze
authored andcommitted
Add fix for serializer
1 parent 782a4c7 commit d93de6d

File tree

3 files changed

+181
-59
lines changed

3 files changed

+181
-59
lines changed

quixstreams/dataframe/windows/time_based.py

Lines changed: 51 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -274,11 +274,11 @@ def _on_expired_window(
274274
class SessionWindow(Window):
275275
"""
276276
Session window groups events that occur within a specified timeout period.
277-
277+
278278
A session starts with the first event and extends each time a new event arrives
279279
within the timeout period. The session closes after the timeout period with no
280280
new events.
281-
281+
282282
Each session window can have different start and end times based on the actual
283283
events, making sessions dynamic rather than fixed-time intervals.
284284
"""
@@ -400,23 +400,25 @@ def process_window(
400400
late_by_ms=late_by_ms,
401401
)
402402
return [], []
403-
403+
404404
# Look for an existing session that can be extended
405405
session_start = None
406406
session_end = None
407407
can_extend_session = False
408408
existing_aggregated = None
409409
old_window_to_delete = None
410-
410+
411411
# Search for active sessions that can accommodate the new event
412412
search_start = max(0, timestamp_ms - timeout_ms * 2)
413-
windows = state.get_windows(search_start, timestamp_ms + timeout_ms + 1, backwards=True)
414-
413+
windows = state.get_windows(
414+
search_start, timestamp_ms + timeout_ms + 1, backwards=True
415+
)
416+
415417
for (window_start, window_end), aggregated_value, _ in windows:
416418
# Calculate the time gap between the new event and the session's last activity
417419
session_last_activity = window_end - timeout_ms
418420
time_gap = timestamp_ms - session_last_activity
419-
421+
420422
# Check if this session can be extended
421423
if time_gap <= timeout_ms + grace_ms and timestamp_ms >= window_start:
422424
session_start = window_start
@@ -433,12 +435,12 @@ def process_window(
433435

434436
# Process the event for this session
435437
updated_windows: list[WindowKeyResult] = []
436-
438+
437439
# Delete the old window if extending an existing session
438440
if can_extend_session and old_window_to_delete:
439441
old_start, old_end = old_window_to_delete
440-
transaction.delete_window(old_start, old_end, prefix=key)
441-
442+
transaction.delete_window(old_start, old_end, prefix=state._prefix) # type: ignore # noqa: SLF001
443+
442444
# Add to collection if needed
443445
if collect:
444446
state.add_to_collection(
@@ -454,7 +456,11 @@ def process_window(
454456
current_value = self._initialize_value()
455457

456458
aggregated = self._aggregate_value(current_value, value, timestamp_ms)
457-
459+
460+
# By this point, session_start and session_end are guaranteed to be set
461+
assert session_start is not None # noqa: S101
462+
assert session_end is not None # noqa: S101
463+
458464
# Output intermediate results for aggregations
459465
if aggregate:
460466
updated_windows.append(
@@ -463,8 +469,10 @@ def process_window(
463469
self._results(aggregated, [], session_start, session_end),
464470
)
465471
)
466-
467-
state.update_window(session_start, session_end, value=aggregated, timestamp_ms=timestamp_ms)
472+
473+
state.update_window(
474+
session_start, session_end, value=aggregated, timestamp_ms=timestamp_ms
475+
)
468476

469477
# Expire old sessions
470478
if self._closing_strategy == ClosingStrategy.PARTITION:
@@ -489,13 +497,13 @@ def expire_sessions_by_partition(
489497

490498
# Import the parsing function to extract message keys from window keys
491499
from quixstreams.state.rocksdb.windowed.serialization import parse_window_key
492-
500+
493501
expired_results = []
494-
502+
495503
# Collect all keys and extract unique prefixes to avoid iteration conflicts
496504
all_keys = list(transaction.keys())
497505
seen_prefixes = set()
498-
506+
499507
for key_bytes in all_keys:
500508
try:
501509
prefix, start_ms, end_ms = parse_window_key(key_bytes)
@@ -504,21 +512,23 @@ def expire_sessions_by_partition(
504512
except (ValueError, IndexError):
505513
# Skip invalid window key formats
506514
continue
507-
515+
508516
# Expire sessions for each unique prefix
509517
for prefix in seen_prefixes:
510518
state = transaction.as_state(prefix=prefix)
511-
prefix_expired = list(self.expire_sessions_by_key(
512-
prefix, state, expiry_threshold, collect
513-
))
519+
prefix_expired = list(
520+
self.expire_sessions_by_key(prefix, state, expiry_threshold, collect)
521+
)
514522
expired_results.extend(prefix_expired)
515523
count += len(prefix_expired)
516524

517525
if count:
518526
logger.debug(
519-
"Expired %s session windows in %ss", count, round(time.monotonic() - start, 2)
527+
"Expired %s session windows in %ss",
528+
count,
529+
round(time.monotonic() - start, 2),
520530
)
521-
531+
522532
return expired_results
523533

524534
def expire_sessions_by_key(
@@ -532,29 +542,40 @@ def expire_sessions_by_key(
532542
count = 0
533543

534544
# Get all windows and check which ones have expired
535-
all_windows = list(state.get_windows(0, expiry_threshold + self._timeout_ms, backwards=False))
536-
545+
all_windows = list(
546+
state.get_windows(0, expiry_threshold + self._timeout_ms, backwards=False)
547+
)
548+
537549
windows_to_delete = []
538550
for (window_start, window_end), aggregated, _ in all_windows:
539551
# Session expires when the session end time has passed the expiry threshold
540552
if window_end <= expiry_threshold:
541553
collected = []
542554
if collect:
543555
collected = state.get_from_collection(window_start, window_end)
544-
556+
545557
windows_to_delete.append((window_start, window_end))
546558
count += 1
547-
yield (key, self._results(aggregated, collected, window_start, window_end))
559+
yield (
560+
key,
561+
self._results(aggregated, collected, window_start, window_end),
562+
)
548563

549564
# Clean up expired windows
550565
for window_start, window_end in windows_to_delete:
551-
state._transaction.delete_window(window_start, window_end, prefix=state._prefix)
566+
state._transaction.delete_window( # type: ignore # noqa: SLF001
567+
window_start,
568+
window_end,
569+
prefix=state._prefix, # type: ignore # noqa: SLF001
570+
)
552571
if collect:
553572
state.delete_from_collection(window_end, start=window_start)
554573

555574
if count:
556575
logger.debug(
557-
"Expired %s session windows in %ss", count, round(time.monotonic() - start, 2)
576+
"Expired %s session windows in %ss",
577+
count,
578+
round(time.monotonic() - start, 2),
558579
)
559580

560581
def _on_expired_session(
@@ -576,9 +597,9 @@ def _on_expired_session(
576597
topic = "unknown"
577598
partition = -1
578599
offset = -1
579-
600+
580601
to_log = True
581-
602+
582603
# Trigger the "on_late" callback if provided
583604
if self._on_late:
584605
to_log = self._on_late(

quixstreams/state/types.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -391,6 +391,18 @@ def expire_all_windows(
391391
"""
392392
...
393393

394+
def delete_window(self, start_ms: int, end_ms: int, prefix: bytes) -> None:
395+
"""
396+
Delete a specific window from RocksDB.
397+
398+
This method removes a single window entry with the specified start and end timestamps.
399+
400+
:param start_ms: The start timestamp of the window to delete
401+
:param end_ms: The end timestamp of the window to delete
402+
:param prefix: The key prefix for the window
403+
"""
404+
...
405+
394406
def delete_windows(
395407
self, max_start_time: int, delete_values: bool, prefix: bytes
396408
) -> None:

0 commit comments

Comments
 (0)