Skip to content

Commit 6e97964

Browse files
[behaviours] add ProbabilisticBehaviour(Behaviour)
1 parent ff5d051 commit 6e97964

File tree

2 files changed

+102
-0
lines changed

2 files changed

+102
-0
lines changed

py_trees/behaviours.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import copy
1717
import functools
1818
import operator
19+
import random
1920
import typing
2021

2122
from . import behaviour, blackboard, common, meta
@@ -736,3 +737,47 @@ def update(self) -> common.Status:
736737
"|".join(["T" if result else "F" for result in results])
737738
)
738739
return common.Status.FAILURE
740+
741+
742+
class ProbabilisticBehaviour(behaviour.Behaviour):
743+
"""
744+
Return a status based on a probability distribution. If unspecified - a uniform distribution will be used.
745+
746+
Args:
747+
name: name of the behaviour
748+
weights: 3 probabilities that correspond to returning :data:`~py_trees.common.Status.SUCCESS`,
749+
:data:`~py_trees.common.Status.FAILURE` and :data:`~py_trees.common.Status.RUNNING` respectively.
750+
751+
.. note:: Probability distribution does not need to be normalised, it will be normalised internally.
752+
753+
Raises:
754+
ValueError if only some probabilities are specified
755+
756+
"""
757+
758+
def __init__(self, name: str, weights: typing.Optional[typing.List[float]] = None):
759+
if weights is not None and (type(weights) is not list or len(weights) != 3):
760+
raise ValueError(
761+
"Either all or none of the probabilities must be specified"
762+
)
763+
764+
super(ProbabilisticBehaviour, self).__init__(name=name)
765+
766+
self._population = [
767+
common.Status.SUCCESS,
768+
common.Status.FAILURE,
769+
common.Status.RUNNING,
770+
]
771+
self._weights = weights if weights is not None else [1.0, 1.0, 1.0]
772+
773+
def update(self) -> common.Status:
774+
"""
775+
Return a status based on a probability distribution.
776+
777+
Returns:
778+
:data:`~py_trees.common.Status.SUCCESS` with probability weights[0],
779+
:data:`~py_trees.common.Status.FAILURE` with probability weights[1] and
780+
:data:`~py_trees.common.Status.RUNNING` with probability weights[2].
781+
"""
782+
self.logger.debug("%s.update()" % self.__class__.__name__)
783+
return random.choices(self._population, self._weights, k=1)[0]
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
#!/usr/bin/env python
2+
#
3+
# License: BSD
4+
# https://raw.githubusercontent.com/splintered-reality/py_trees/devel/LICENSE
5+
#
6+
7+
##############################################################################
8+
# Imports
9+
##############################################################################
10+
11+
import py_trees
12+
import py_trees.console as console
13+
import py_trees.tests
14+
import pytest
15+
16+
##############################################################################
17+
# Logging Level
18+
##############################################################################
19+
20+
py_trees.logging.level = py_trees.logging.Level.DEBUG
21+
logger = py_trees.logging.Logger("Tests")
22+
23+
##############################################################################
24+
# Tests
25+
##############################################################################
26+
27+
28+
def test_probabilistic_behaviour_workflow() -> None:
29+
console.banner("Probabilistic Behaviour")
30+
31+
with pytest.raises(ValueError) as context: # if raised, context survives
32+
# intentional error -> silence mypy
33+
unused_root = py_trees.behaviours.ProbabilisticBehaviour( # noqa: F841 [unused]
34+
name="ProbabilisticBehaviour", weights="invalid_type" # type: ignore[arg-type]
35+
)
36+
py_trees.tests.print_assert_details("ValueError raised", "raised", "not raised")
37+
py_trees.tests.print_assert_details("ValueError raised", "yes", "yes")
38+
assert "ValueError" == context.typename
39+
40+
root = py_trees.behaviours.ProbabilisticBehaviour(
41+
name="ProbabilisticBehaviour", weights=[0.0, 0.0, 1.0]
42+
)
43+
44+
py_trees.tests.print_assert_details(
45+
text="task not yet ticked",
46+
expected=py_trees.common.Status.INVALID,
47+
result=root.status,
48+
)
49+
assert root.status == py_trees.common.Status.INVALID
50+
51+
root.tick_once()
52+
py_trees.tests.print_assert_details(
53+
text="task ticked once",
54+
expected=py_trees.common.Status.RUNNING,
55+
result=root.status,
56+
)
57+
assert root.status == py_trees.common.Status.RUNNING

0 commit comments

Comments
 (0)