diff --git a/rqt_bag/src/rqt_bag/bag_timeline.py b/rqt_bag/src/rqt_bag/bag_timeline.py index 0f1b261..5e2f385 100644 --- a/rqt_bag/src/rqt_bag/bag_timeline.py +++ b/rqt_bag/src/rqt_bag/bag_timeline.py @@ -28,6 +28,7 @@ import threading import time +from typing import Callable, Iterable, Iterator, Optional, Tuple, Union from python_qt_binding.QtCore import qDebug, Qt, QTimer, qWarning, Signal from python_qt_binding.QtWidgets import QGraphicsScene, QMessageBox @@ -43,6 +44,7 @@ from .message_loader_thread import MessageLoaderThread from .player import Player from .recorder import Recorder +from .rosbag2 import Entry, Rosbag2 from .timeline_frame import TimelineFrame @@ -67,6 +69,7 @@ def __init__(self, context, publish_clock): super(BagTimeline, self).__init__() self._bags = [] self._bag_lock = threading.RLock() + self._bags_size = 0 self.background_task = None # Display string self.background_task_cancel = False @@ -158,6 +161,7 @@ def add_bag(self, bag): :param bag: ros bag file, ''rosbag2.bag'' """ self._bags.append(bag) + self._bags_size += bag.size() bag_topics = bag.get_topics() qDebug('Topics from this bag: {}'.format(bag_topics)) @@ -187,8 +191,7 @@ def add_bag(self, bag): self._timeline_frame.index_cache_cv.notify() def file_size(self): - with self._bag_lock: - return sum(b.size() for b in self._bags) + return self._bags_size # TODO Rethink API and if these need to be visible def _get_start_stamp(self): @@ -264,47 +267,37 @@ def get_datatype(self, topic): datatype = bag_datatype return datatype - def get_entries(self, topics, start_stamp, end_stamp): + def get_entries(self, topics: Optional[Union[str, Iterable[str]]], + start_stamp: Time, end_stamp: Time, + progress_cb: Optional[Callable[[int], None]] = None) -> Iterator[Entry]: """ Get a generator for bag entries. :param topics: list of topics to query, ''list(str)'' :param start_stamp: stamp to start at, ''rclpy.time.Time'' - :param end_stamp: stamp to end at, ''rclpy.time,Time'' + :param end_stamp: stamp to end at, ''rclpy.time.Time'' + :param progress_cb: callback function to report progress, called once per each percent. :returns: entries the bag file, ''msg'' """ - with self._bag_lock: - bag_entries = [] - for b in self._bags: - bag_start_time = b.get_earliest_timestamp() - if bag_start_time is not None and bag_start_time > end_stamp: - continue + for b, entry in self.get_entries_with_bags(topics, start_stamp, end_stamp, progress_cb): + yield entry + return None - bag_end_time = b.get_latest_timestamp() - if bag_end_time is not None and bag_end_time < start_stamp: - continue - - # Get all of the entries for each topic. When opening multiple - # bags, the requested topic may not be in a given bag database - for topic in topics: - entries = b.get_entries_in_range(start_stamp, end_stamp, topic) - if entries is not None: - bag_entries.extend(entries) - - for entry in sorted(bag_entries, key=lambda entry: entry.timestamp): - yield entry - - def get_entries_with_bags(self, topic, start_stamp, end_stamp): + def get_entries_with_bags(self, topics: Optional[Union[str, Iterable[str]]], + start_stamp: Time, end_stamp: Time, + progress_cb: Optional[Callable[[int], None]] = None) \ + -> Iterator[Tuple[Rosbag2, Entry]]: """ Get a generator of bag entries. - :param topics: list of topics to query, ''list(str)'' + :param topics: list of topics to query (if None, all topics are used), ''list(str)'' :param start_stamp: stamp to start at, ''rclpy.time.Time'' - :param end_stamp: stamp to end at, ''rclpy.time,Time'' - :returns: tuple of (bag, entry) for the entries in the bag file, ''(rosbag2.bag, msg)'' + :param end_stamp: stamp to end at, ''rclpy.time.Time'' + :param progress_cb: callback function to report progress, called once per each percent. + :returns: tuple of (bag, entry) for the entries in the bag file, ''(rosbag2.Rosbag2, msg)'' """ with self._bag_lock: - bag_entries = [] + relevant_bags = [] for b in self._bags: bag_start_time = b.get_earliest_timestamp() if bag_start_time is not None and bag_start_time > end_stamp: @@ -314,11 +307,61 @@ def get_entries_with_bags(self, topic, start_stamp, end_stamp): if bag_end_time is not None and bag_end_time < start_stamp: continue - for entry in b.get_entries_in_range(start_stamp, end_stamp): - bag_entries.append((b, entry)) + relevant_bags.append(b) + + generators = {} + last_entries = {} + for b in relevant_bags: + generators[b] = b.entries_in_range_generator(start_stamp, end_stamp, topics) + try: + last_entries[b] = next(generators[b]) + except StopIteration: + last_entries[b] = None + + to_delete = [] + for b in last_entries: + if last_entries[b] is None: + to_delete.append(b) + + for b in to_delete: + del last_entries[b] + del generators[b] + relevant_bags.remove(b) + + if progress_cb is not None: + progress = 0 + num_entries = 0 + estimated_num_entries = 0 + for b in relevant_bags: + estimated_num_entries += b.estimate_num_entries_in_range( + start_stamp, end_stamp, topics) + + while any(last_entries.values()): + min_bag = None + min_entry = None + for b, entry in last_entries.items(): + if entry is not None: + if min_entry is None or entry.timestamp < min_entry.timestamp: + min_bag = b + min_entry = entry + if min_bag is None: + return + + if progress_cb is not None: + num_entries += 1 + new_progress = int(100.0 * (float(num_entries) / estimated_num_entries)) + if new_progress != progress: + progress_cb(new_progress) + progress = new_progress + + yield min_bag, min_entry + + try: + last_entries[min_bag] = next(generators[min_bag]) + except StopIteration: + last_entries[min_bag] = None - for bag, entry in sorted(bag_entries, key=lambda item: item[1].timestamp): - yield bag, entry + return def get_entry(self, t, topic): """ diff --git a/rqt_bag/src/rqt_bag/index_cache_thread.py b/rqt_bag/src/rqt_bag/index_cache_thread.py index f349137..3d2ddd4 100644 --- a/rqt_bag/src/rqt_bag/index_cache_thread.py +++ b/rqt_bag/src/rqt_bag/index_cache_thread.py @@ -45,6 +45,9 @@ def __init__(self, timeline): self.start() def run(self): + # Delay start of the indexing so that the basic UI has time to be loaded + time.sleep(2.0) + while not self._stop_flag: with self.timeline.index_cache_cv: # Wait until the cache is dirty @@ -52,23 +55,15 @@ def run(self): self.timeline.index_cache_cv.wait() if self._stop_flag: return - # Update the index for one topic - total_topics = len(self.timeline.topics) - update_step = max(1, total_topics / 100) - topic_num = 1 - progress = 0 - updated = False - for topic in self.timeline.topics: - if topic in self.timeline.invalidated_caches: - updated |= (self.timeline._update_index_cache(topic) > 0) - if topic_num % update_step == 0 or topic_num == total_topics: - new_progress = int(100.0 * (float(topic_num) / total_topics)) - if new_progress != progress: - progress = new_progress - if not self._stop_flag: - self.timeline.scene().background_progress = progress - self.timeline.scene().status_bar_changed_signal.emit() - topic_num += 1 + + # Update the index for all invalidated topics + def progress_cb(progress: int) -> None: + if not self._stop_flag: + self.timeline.scene().background_progress = progress + self.timeline.scene().status_bar_changed_signal.emit() + + topics = self.timeline.invalidated_caches.intersection(set(self.timeline.topics)) + updated = (self.timeline._update_index_cache(topics, progress_cb) > 0) if updated: self.timeline.scene().background_progress = 0 diff --git a/rqt_bag/src/rqt_bag/rosbag2.py b/rqt_bag/src/rqt_bag/rosbag2.py index a808fb8..cf5b3f6 100644 --- a/rqt_bag/src/rqt_bag/rosbag2.py +++ b/rqt_bag/src/rqt_bag/rosbag2.py @@ -31,17 +31,21 @@ from collections import namedtuple import os +from typing import Callable, Iterable, Iterator, List, Optional, Union from rclpy import logging from rclpy.clock import Clock, ClockType from rclpy.duration import Duration from rclpy.serialization import deserialize_message +from rclpy.time import Time import rosbag2_py from rosbag2_py import get_default_storage_id, StorageFilter from rosidl_runtime_py.utilities import get_message +from rqt_bag import bag_helper + WRITE_ONLY_MSG = 'open for writing only, returning None' Entry = namedtuple('Entry', ['topic', 'data', 'timestamp']) @@ -154,25 +158,111 @@ def get_entry_after(self, timestamp, topic=None): self.reader.reset_filter() return result - def get_entries_in_range(self, t_start, t_end, topic=None): + def get_entries_in_range(self, t_start: Time, t_end: Time, + topic: Optional[Union[str, Iterable[str]]] = None, + progress_cb: Optional[Callable[[int], None]] = None) \ + -> Optional[List[Entry]]: + """ + Get a list of all entries in a given time range, sorted by receive stamp. + + Do not use this function for large bags. It will load all entries into memory. Use + entries_in_range_generator() instead and process the data as they are returned. + + :param t_start: stamp to start at, ''rclpy.time.Time'' + :param t_end: stamp to end at, ''rclpy.time.Time'' + :param topic: topic or list of topics to query (if None, all topics are), ''list(str)'' + :param progress_cb: callback function to report progress, called once per each percent. + :returns: entries in the bag file, ''list(Entry)'' + """ if not self.reader: self._logger.warn('get_entries_in_range - ' + WRITE_ONLY_MSG) return None + return list(self.entries_in_range_generator(t_start, t_end, topic, progress_cb)) + + def entries_in_range_generator(self, t_start: Time, t_end: Time, + topic: Optional[Union[str, Iterable[str]]] = None, + progress_cb: Optional[Callable[[int], None]] = None) \ + -> Iterator[Entry]: + """ + Get a generator of all entries in a given time range, sorted by receive stamp. + + :param t_start: stamp to start at, ''rclpy.time.Time'' + :param t_end: stamp to end at, ''rclpy.time.Time'' + :param topic: topic or list of topics to query (if None, all topics are), ''list(str)'' + :param progress_cb: callback function to report progress, called once per each percent. + :returns: generator of entries in the bag file, ''Generator(Entry)'' + """ + if not self.reader: + self._logger.warn('entries_in_range_generator - ' + WRITE_ONLY_MSG) + return + + if isinstance(topic, Iterable) and not isinstance(topic, str): + topics = topic + else: + topics = [topic] if topic is not None else [] + self.reader.set_read_order(rosbag2_py.ReadOrder(reverse=False)) - self.reader.set_filter(rosbag2_py.StorageFilter(topics=[topic] if topic else [])) + self.reader.set_filter(rosbag2_py.StorageFilter(topics=topics)) self.reader.seek(t_start.nanoseconds) - entries = [] + if progress_cb is not None: + num_entries = 0 + progress = 0 + estimated_num_entries = self.estimate_num_entries_in_range(t_start, t_end, topic) + while self.reader.has_next(): next_entry = self.read_next() if next_entry.timestamp <= t_end.nanoseconds: - entries.append(next_entry) + if progress_cb is not None: + num_entries += 1 + new_progress = int(100.0 * (float(num_entries) / estimated_num_entries)) + if new_progress != progress: + progress_cb(new_progress) + progress = new_progress + yield next_entry else: break # No filter self.reader.reset_filter() - return entries + + if progress_cb is not None and progress != 100: + progress_cb(100) + + return + + def estimate_num_entries_in_range(self, t_start: Time, t_end: Time, + topic: Optional[Union[str, Iterable[str]]] = None) -> int: + """ + Estimate the number of entries in the given time range. + + The computation is only approximate, based on the assumption that messages are distributed + evenly across the whole bag on every topic. + + :param t_start: stamp to start at, ''rclpy.time.Time'' + :param t_end: stamp to end at, ''rclpy.time.Time'' + :param topic: topic or list of topics to query (if None, all topics are), ''list(str)'' + :returns: the approximate number of entries, ''int'' + """ + if not self.reader: + self._logger.warn('estimate_num_entries_in_range - ' + WRITE_ONLY_MSG) + return 0 + + if isinstance(topic, Iterable) and not isinstance(topic, str): + topics = topic + else: + topics = [topic] if topic is not None else [] + + range_duration = t_end - t_start + bag_duration = self.get_latest_timestamp() - self.get_earliest_timestamp() + fraction = bag_helper.to_sec(range_duration) / bag_helper.to_sec(bag_duration) + + num_messages = 0 + for t_info in self.metadata.topics_with_message_count: + if t_info.topic_metadata.name in topics: + num_messages += t_info.message_count + + return int(fraction * num_messages) def read_next(self): return Entry(*self.reader.read_next()) diff --git a/rqt_bag/src/rqt_bag/timeline_frame.py b/rqt_bag/src/rqt_bag/timeline_frame.py index cdfaa83..0972ce6 100644 --- a/rqt_bag/src/rqt_bag/timeline_frame.py +++ b/rqt_bag/src/rqt_bag/timeline_frame.py @@ -28,6 +28,7 @@ import bisect import threading +from typing import Callable, Iterable, Optional, Union from python_qt_binding.QtCore import qDebug, QPointF, QRectF, Qt, qWarning, Slot from python_qt_binding.QtGui import QBrush, QColor, QCursor, QFont, \ @@ -870,43 +871,51 @@ def set_renderer_active(self, topic, active): # Index Caching functions - def _update_index_cache(self, topic): + def _update_index_cache(self, topic: Optional[Union[str, Iterable[str]]], + progress_cb: Optional[Callable[[int], None]] = None) -> int: """ - Update the cache of message timestamps for the given topic. + Update the cache of message timestamps for the given topic(s). + :param topic: topic or list of topics to update the cache for, ''list(str)'' + :param progress_cb: callback function to report progress, called once per each percent. :return: number of messages added to the index cache """ if self._start_stamp is None or self._end_stamp is None: return 0 - if topic not in self.index_cache: - # Don't have any cache of messages in this topic - start_time = self._start_stamp - topic_cache = [] - self.index_cache[topic] = topic_cache + if isinstance(topic, Iterable) and not isinstance(topic, str): + topics = [t for t in topic if t in self.invalidated_caches] + if len(topics) == 0: + return 0 else: - topic_cache = self.index_cache[topic] - - # Check if the cache has been invalidated if topic not in self.invalidated_caches: return 0 - - if len(topic_cache) == 0: - start_time = self._start_stamp + topics = [topic] + + start_time = self._start_stamp + for t in topics: + if t not in self.index_cache: + # Don't have any cache of messages in this topic + topic_cache = [] + self.index_cache[t] = topic_cache else: - start_time = Time(seconds=max(0.0, topic_cache[-1])) + topic_cache = self.index_cache[t] + if len(topic_cache) > 0: + start_time = min(start_time, Time(seconds=max(0.0, topic_cache[-1]))) end_time = self._end_stamp - topic_cache_len = len(topic_cache) - - for entry in self.scene().get_entries([topic], start_time, end_time): + newly_added = 0 + for entry in self.scene().get_entries(topics, start_time, end_time, progress_cb): + topic_cache = self.index_cache[entry.topic] topic_cache.append(bag_helper.to_sec(Time(nanoseconds=entry.timestamp))) + newly_added += 1 - if topic in self.invalidated_caches: - self.invalidated_caches.remove(topic) + for t in topics: + if t in self.invalidated_caches: + self.invalidated_caches.remove(t) - return len(topic_cache) - topic_cache_len + return newly_added def cache_message(self, topic, t): """