1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15+ from collections import deque
1516import random
1617import string
1718from threading import Event
@@ -37,6 +38,10 @@ def method_1():
3738 with WaitForTopics(topic_list, timeout=5.0):
3839 # 'topic_1' and 'topic_2' received at least one message each
3940 print('Given topics are receiving messages !')
41+ print(wait_for_topics.topics_not_received()) # Should be an empty set
42+ print(wait_for_topics.topics_received()) # Should be {'topic_1', 'topic_2'}
43+ print(wait_for_topics.messages_received('topic_1')) # Should be [message_1, ...]
44+ wait_for_topics.shutdown()
4045
4146 # Method 2, calling wait() and shutdown() manually
4247 def method_2():
@@ -49,9 +54,10 @@ def method_2():
4954 wait_for_topics.shutdown()
5055 """
5156
52- def __init__ (self , topic_tuples , timeout = 5.0 ):
57+ def __init__ (self , topic_tuples , timeout = 5.0 , messages_received_buffer_length = 10 ):
5358 self .topic_tuples = topic_tuples
5459 self .timeout = timeout
60+ self .messages_received_buffer_length = messages_received_buffer_length
5561 self .__ros_context = rclpy .Context ()
5662 rclpy .init (context = self .__ros_context )
5763 self .__ros_executor = SingleThreadedExecutor (context = self .__ros_context )
@@ -64,9 +70,14 @@ def __init__(self, topic_tuples, timeout=5.0):
6470 self .__ros_spin_thread .start ()
6571
6672 def _prepare_ros_node (self ):
67- node_name = '_test_node_' + \
68- '' .join (random .choices (string .ascii_uppercase + string .digits , k = 10 ))
69- self .__ros_node = _WaitForTopicsNode (name = node_name , node_context = self .__ros_context )
73+ node_name = '_test_node_' + '' .join (
74+ random .choices (string .ascii_uppercase + string .digits , k = 10 )
75+ )
76+ self .__ros_node = _WaitForTopicsNode (
77+ name = node_name ,
78+ node_context = self .__ros_context ,
79+ messages_received_buffer_length = self .messages_received_buffer_length ,
80+ )
7081 self .__ros_executor .add_node (self .__ros_node )
7182
7283 def _spin_function (self ):
@@ -91,6 +102,12 @@ def topics_not_received(self):
91102 """Topics that did not receive any messages."""
92103 return self .__ros_node .expected_topics - self .__ros_node .received_topics
93104
105+ def received_messages (self , topic_name ):
106+ """List of received messages of a specific topic."""
107+ if topic_name not in self .__ros_node .received_messages_buffer :
108+ raise KeyError ('No messages received for topic: ' + topic_name )
109+ return list (self .__ros_node .received_messages_buffer [topic_name ])
110+
94111 def __enter__ (self ):
95112 if not self .wait ():
96113 raise RuntimeError ('Did not receive messages on these topics: ' ,
@@ -106,31 +123,49 @@ def __exit__(self, exep_type, exep_value, trace):
106123class _WaitForTopicsNode (Node ):
107124 """Internal node used for subscribing to a set of topics."""
108125
109- def __init__ (self , name = 'test_node' , node_context = None ):
110- super ().__init__ (node_name = name , context = node_context )
126+ def __init__ (
127+ self , name = 'test_node' , node_context = None , messages_received_buffer_length = None
128+ ):
129+ super ().__init__ (node_name = name , context = node_context ) # type: ignore
111130 self .msg_event_object = Event ()
112-
113- def start_subscribers (self , topic_tuples ):
131+ self .messages_received_buffer_length = messages_received_buffer_length
114132 self .subscriber_list = []
115- self .expected_topics = {name for name , _ in topic_tuples }
133+ self .topic_tuples = []
134+ self .expected_topics = set ()
135+ self .received_topics = set ()
136+ self .received_messages_buffer = {}
137+
138+ def _reset (self ):
139+ self .msg_event_object .clear ()
116140 self .received_topics = set ()
141+ for buffer in self .received_messages_buffer .values ():
142+ buffer .clear ()
117143
144+ def start_subscribers (self , topic_tuples ):
145+ self ._reset ()
118146 for topic_name , topic_type in topic_tuples :
119- # Create a subscriber
120- self .subscriber_list .append (
121- self .create_subscription (
122- topic_type ,
123- topic_name ,
124- self .callback_template (topic_name ),
125- 10
147+ if (topic_name , topic_type ) not in self .topic_tuples :
148+ self .topic_tuples .append ((topic_name , topic_type ))
149+ self .expected_topics .add (topic_name )
150+ # Initialize ring buffer of received messages
151+ self .received_messages_buffer [topic_name ] = deque (
152+ maxlen = self .messages_received_buffer_length
153+ )
154+ # Create a subscriber
155+ self .subscriber_list .append (
156+ self .create_subscription (
157+ topic_type ,
158+ topic_name ,
159+ self .callback_template (topic_name ),
160+ 10
161+ )
126162 )
127- )
128163
129164 def callback_template (self , topic_name ):
130-
131165 def topic_callback (data ):
166+ self .get_logger ().debug ('Message received for ' + topic_name )
167+ self .received_messages_buffer [topic_name ].append (data )
132168 if topic_name not in self .received_topics :
133- self .get_logger ().debug ('Message received for ' + topic_name )
134169 self .received_topics .add (topic_name )
135170 if self .received_topics == self .expected_topics :
136171 self .msg_event_object .set ()
0 commit comments