diff --git a/rclpy/rclpy/executors.py b/rclpy/rclpy/executors.py index a588e3540..38733825f 100644 --- a/rclpy/rclpy/executors.py +++ b/rclpy/rclpy/executors.py @@ -174,15 +174,237 @@ def timeout(self, timeout: float) -> None: self._timeout = timeout -class Executor(ContextManager['Executor']): +class AbstractExecutor: + """Abstract Executor API.""" + + @property + def context(self) -> Context: + """Get the context associated with the executor.""" + raise NotImplementedError() + + def wake(self) -> None: + """ + Wake the executor because something changed. + + This is used to tell the executor when entities are created or destroyed. + """ + raise NotImplementedError() + + def shutdown(self, timeout_sec: Optional[float] = None) -> bool: + """ + Stop executing callbacks and wait for their completion. + + :param timeout_sec: Seconds to wait. Block forever if ``None`` or negative. + Don't wait if 0. + :return: ``True`` if all outstanding callbacks finished executing, or ``False`` if the + timeout expires before all outstanding work is done. + """ + raise NotImplementedError() + + def spin(self) -> None: + """Execute callbacks until shutdown.""" + raise NotImplementedError() + + def spin_until_future_complete( + self, + future: Future[Any], + timeout_sec: Optional[float] = None + ) -> None: + """Execute callbacks until a given future is done or a timeout occurs.""" + raise NotImplementedError() + + def spin_once(self, timeout_sec: Optional[float] = None) -> None: + """ + Wait for and execute a single callback. + + This method should not be called from multiple threads. + + :param timeout_sec: Seconds to wait. Block forever if ``None`` or negative. + Don't wait if 0. + """ + raise NotImplementedError() + + def spin_once_until_future_complete( + self, + future: Future[Any], + timeout_sec: Optional[Union[float, TimeoutObject]] = None + ) -> None: + """ + Wait for and execute a single callback. + + This should behave in the same way as :meth:`spin_once`. + If needed by the implementation, it should awake other threads waiting. + + :param future: The executor will wait until this future is done. + :param timeout_sec: Maximum seconds to wait. Block forever if ``None`` or negative. + Don't wait if 0. + """ + raise NotImplementedError() + + @overload + def create_task( + self, + callback: Callable[..., Coroutine[Any, Any, T]], + *args: Any, + **kwargs: Any + ) -> Task[T]: ... + + @overload + def create_task( + self, + callback: Callable[..., T], + *args: Any, + **kwargs: Any + ) -> Task[T]: ... + + def create_task( + self, + callback: Callable[..., Any], + *args: Any, + **kwargs: Any + ) -> Task[Any]: + """ + Add a callback or coroutine to be executed during :meth:`spin` and return a Future. + + Arguments to this function are passed to the callback. + + .. warning:: Created task is queued in the executor in FIFO order, + but users should not rely on the task execution order. + + :param callback: A callback to be run in the executor. + """ + raise NotImplementedError() + + def create_future(self) -> Future[Any]: + """Create a Future object attached to the Executor.""" + raise NotImplementedError() + + def add_node(self, node: 'Node') -> bool: + """ + Add a node whose callbacks should be managed by this executor. + + :param node: The node to add to the executor. + :return: ``True`` if the node was added, ``False`` otherwise. + """ + raise NotImplementedError() + + def remove_node(self, node: 'Node') -> None: + """ + Stop managing this node's callbacks. + + :param node: The node to remove from the executor. + """ + raise NotImplementedError() + + def get_nodes(self) -> List['Node']: + """Return nodes that have been added to this executor.""" + raise NotImplementedError() + + +class BaseExecutor(AbstractExecutor): + """The base class for an executor.""" + + def create_future(self) -> Future: + return Future(executor=self) + + def _take_subscription( + self, + sub: Subscription[Any] + ) -> Optional[Callable[[], Coroutine[None, None, None]]]: + try: + with sub.handle: + msg_info = sub.handle.take_message(sub.msg_type, sub.raw) + if msg_info is None: + return None + + if sub._callback_type is Subscription.CallbackType.MessageOnly: + msg_tuple: Union[Tuple[Msg], Tuple[Msg, MessageInfo]] = (msg_info[0], ) + else: + msg_tuple = msg_info + + async def _execute() -> None: + await await_or_execute(sub.callback, *msg_tuple) + + return _execute + except InvalidHandle: + # Subscription is a Destroyable, which means that on __enter__ it can throw an + # InvalidHandle exception if the entity has already been destroyed. Handle that here + # by just returning an empty argument, which means we will skip doing any real work + # in _execute_subscription below + pass + + return None + + def _take_client( + self, + client: Client[Any, Any] + ) -> Optional[Callable[[], Coroutine[None, None, None]]]: + try: + with client.handle: + header_and_response = client.handle.take_response(client.srv_type.Response) + + async def _execute() -> None: + header, response = header_and_response + if header is None: + return + try: + sequence = header.request_id.sequence_number + future = client.get_pending_request(sequence) + except KeyError: + # The request was cancelled + pass + else: + if isinstance(future, Future) and not future._executor(): + future._set_executor(self) + + future.set_result(response) + return _execute + + except InvalidHandle: + # Client is a Destroyable, which means that on __enter__ it can throw an + # InvalidHandle exception if the entity has already been destroyed. Handle that here + # by just returning an empty argument, which means we will skip doing any real work + # in _execute_client below + pass + + return None + + def _take_service( + self, + srv: Service[Any, Any] + ) -> Optional[Callable[[], Coroutine[None, None, None]]]: + try: + with srv.handle: + request_and_header = srv.handle.service_take_request(srv.srv_type.Request) + + async def _execute() -> None: + (request, header) = request_and_header + if header is None: + return + + response = await await_or_execute(srv.callback, request, srv.srv_type.Response()) + srv.send_response(response, header) + return _execute + except InvalidHandle: + # Service is a Destroyable, which means that on __enter__ it can throw an + # InvalidHandle exception if the entity has already been destroyed. Handle that here + # by just returning an empty argument, which means we will skip doing any real work + # in _execute_service below + pass + + return None + + +class Executor(ContextManager['Executor'], BaseExecutor): """ - The base class for an executor. + The base class for a wait-set based executor. An executor controls the threading model used to process callbacks. Callbacks are units of work like subscription callbacks, timer callbacks, service calls, and received client responses. An executor controls which threads callbacks get executed in. A custom executor must define :meth:`spin_once`. + A custom executor should use :meth:`wait_for_ready_callbacks` to get work. If the executor has any cleanup then it should also define :meth:`shutdown`. :param context: The context to be associated with, or ``None`` for the default global context. @@ -226,30 +448,14 @@ def __init__(self, *, context: Optional[Context] = None) -> None: @property def context(self) -> Context: - """Get the context associated with the executor.""" return self._context - @overload - def create_task(self, callback: Callable[..., Coroutine[Any, Any, T]], - *args: Any, **kwargs: Any - ) -> Task[T]: ... - - @overload - def create_task(self, callback: Callable[..., T], *args: Any, **kwargs: Any - ) -> Task[T]: ... - - def create_task(self, callback: Callable[..., Any], *args: Any, **kwargs: Any - ) -> Task[Any]: - """ - Add a callback or coroutine to be executed during :meth:`spin` and return a Future. - - Arguments to this function are passed to the callback. - - .. warning:: Created task is queued in the executor in FIFO order, - but users should not rely on the task execution order. - - :param callback: A callback to be run in the executor. - """ + def create_task( + self, + callback: Callable[..., Any], + *args: Any, + **kwargs: Any + ) -> Task[Any]: task = Task(callback, args, kwargs, executor=self) with self._tasks_lock: self._tasks.append((task, None, None)) @@ -259,14 +465,6 @@ def create_task(self, callback: Callable[..., Any], *args: Any, **kwargs: Any return task def shutdown(self, timeout_sec: Optional[float] = None) -> bool: - """ - Stop executing callbacks and wait for their completion. - - :param timeout_sec: Seconds to wait. Block forever if ``None`` or negative. - Don't wait if 0. - :return: ``True`` if all outstanding callbacks finished executing, or ``False`` if the - timeout expires before all outstanding work is done. - """ with self._shutdown_lock: if not self._is_shutdown: self._is_shutdown = True @@ -298,12 +496,6 @@ def __del__(self) -> None: self._sigint_gc.destroy() def add_node(self, node: 'Node') -> bool: - """ - Add a node whose callbacks should be managed by this executor. - - :param node: The node to add to the executor. - :return: ``True`` if the node was added, ``False`` otherwise. - """ with self._nodes_lock: if node not in self._nodes: self._nodes.add(node) @@ -315,11 +507,6 @@ def add_node(self, node: 'Node') -> bool: return False def remove_node(self, node: 'Node') -> None: - """ - Stop managing this node's callbacks. - - :param node: The node to remove from the executor. - """ with self._nodes_lock: try: self._nodes.remove(node) @@ -331,21 +518,14 @@ def remove_node(self, node: 'Node') -> None: self._guard.trigger() def wake(self) -> None: - """ - Wake the executor because something changed. - - This is used to tell the executor when entities are created or destroyed. - """ if self._guard: self._guard.trigger() def get_nodes(self) -> List['Node']: - """Return nodes that have been added to this executor.""" with self._nodes_lock: return list(self._nodes) def spin(self) -> None: - """Execute callbacks until shutdown.""" while self._context.ok() and not self._is_shutdown: self.spin_once() @@ -354,7 +534,6 @@ def spin_until_future_complete( future: Future[Any], timeout_sec: Optional[float] = None ) -> None: - """Execute callbacks until a given future is done or a timeout occurs.""" # Make sure the future wakes this executor when it is done future.add_done_callback(lambda x: self.wake()) @@ -385,36 +564,6 @@ def spin_until_future_complete( timeout_left.timeout = end - now - def spin_once(self, timeout_sec: Optional[float] = None) -> None: - """ - Wait for and execute a single callback. - - A custom executor should use :meth:`wait_for_ready_callbacks` to get work. - - This method should not be called from multiple threads. - - :param timeout_sec: Seconds to wait. Block forever if ``None`` or negative. - Don't wait if 0. - """ - raise NotImplementedError() - - def spin_once_until_future_complete( - self, - future: Future[Any], - timeout_sec: Optional[Union[float, TimeoutObject]] = None - ) -> None: - """ - Wait for and execute a single callback. - - This should behave in the same way as :meth:`spin_once`. - If needed by the implementation, it should awake other threads waiting. - - :param future: The executor will wait until this future is done. - :param timeout_sec: Maximum seconds to wait. Block forever if ``None`` or negative. - Don't wait if 0. - """ - raise NotImplementedError() - def _spin_once_until_future_complete( self, future: Future[Any], @@ -467,85 +616,6 @@ async def _execute() -> None: return None - def _take_subscription(self, sub: Subscription[Any] - ) -> Optional[Callable[[], Coroutine[None, None, None]]]: - try: - with sub.handle: - msg_info = sub.handle.take_message(sub.msg_type, sub.raw) - if msg_info is None: - return None - - if sub._callback_type is Subscription.CallbackType.MessageOnly: - msg_tuple: Union[Tuple[Msg], Tuple[Msg, MessageInfo]] = (msg_info[0], ) - else: - msg_tuple = msg_info - - async def _execute() -> None: - await await_or_execute(sub.callback, *msg_tuple) - - return _execute - except InvalidHandle: - # Subscription is a Destroyable, which means that on __enter__ it can throw an - # InvalidHandle exception if the entity has already been destroyed. Handle that here - # by just returning an empty argument, which means we will skip doing any real work - # in _execute_subscription below - pass - - return None - - def _take_client(self, client: Client[Any, Any] - ) -> Optional[Callable[[], Coroutine[None, None, None]]]: - try: - with client.handle: - header_and_response = client.handle.take_response(client.srv_type.Response) - - async def _execute() -> None: - header, response = header_and_response - if header is None: - return - try: - sequence = header.request_id.sequence_number - future = client.get_pending_request(sequence) - except KeyError: - # The request was cancelled - pass - else: - future._set_executor(self) - future.set_result(response) - return _execute - - except InvalidHandle: - # Client is a Destroyable, which means that on __enter__ it can throw an - # InvalidHandle exception if the entity has already been destroyed. Handle that here - # by just returning an empty argument, which means we will skip doing any real work - # in _execute_client below - pass - - return None - - def _take_service(self, srv: Service[Any, Any] - ) -> Optional[Callable[[], Coroutine[None, None, None]]]: - try: - with srv.handle: - request_and_header = srv.handle.service_take_request(srv.srv_type.Request) - - async def _execute() -> None: - (request, header) = request_and_header - if header is None: - return - - response = await await_or_execute(srv.callback, request, srv.srv_type.Response()) - srv.send_response(response, header) - return _execute - except InvalidHandle: - # Service is a Destroyable, which means that on __enter__ it can throw an - # InvalidHandle exception if the entity has already been destroyed. Handle that here - # by just returning an empty argument, which means we will skip doing any real work - # in _execute_service below - pass - - return None - def _take_guard_condition(self, gc: GuardCondition ) -> Callable[[], Coroutine[None, None, None]]: gc._executor_triggered = False diff --git a/rclpy/test/test_executor.py b/rclpy/test/test_executor.py index c340fa70e..f4244a526 100644 --- a/rclpy/test/test_executor.py +++ b/rclpy/test/test_executor.py @@ -17,6 +17,7 @@ import threading import time import unittest +from unittest.mock import Mock import warnings import rclpy @@ -729,6 +730,25 @@ def timer2_callback() -> None: self.node.destroy_timer(timer1) self.node.destroy_client(cli) + def test_create_future_returns_future_with_executor_attached(self) -> None: + self.assertIsNotNone(self.node.handle) + mock = Mock() + + executor = SingleThreadedExecutor(context=self.context) + executor.create_task = mock + + try: + fut = executor.create_future() + + def cb(fut): + ... + + fut.add_done_callback(cb) + fut.set_result('Result') + mock.assert_called_once_with(cb, fut) + finally: + executor.shutdown() + if __name__ == '__main__': unittest.main()