diff --git a/ros2topic/package.xml b/ros2topic/package.xml index 367de1f7a..00217de65 100644 --- a/ros2topic/package.xml +++ b/ros2topic/package.xml @@ -12,6 +12,7 @@ Apache License 2.0 BSD-3-Clause + MIT License Aditya Pande Dirk Thomas diff --git a/ros2topic/ros2topic/eval/__init__.py b/ros2topic/ros2topic/eval/__init__.py new file mode 100644 index 000000000..7f949d743 --- /dev/null +++ b/ros2topic/ros2topic/eval/__init__.py @@ -0,0 +1,161 @@ +# Copyright 2022 Yaroslav Polyakov +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + + +"""Safe user-supplied python expression evaluation.""" + +import ast +import dataclasses + +__version__ = '2.0.3' + + +class EvalException(Exception): + pass + + +class ValidationException(EvalException): + pass + + +class CompilationException(EvalException): + exc = None + + def __init__(self, exc): + super().__init__(exc) + self.exc = exc + + +class ExecutionException(EvalException): + exc = None + + def __init__(self, exc): + super().__init__(exc) + self.exc = exc + + +@dataclasses.dataclass +class EvalModel: + """eval security model.""" + + nodes: list = dataclasses.field(default_factory=list) + allowed_functions: list = dataclasses.field(default_factory=list) + imported_functions: dict = dataclasses.field(default_factory=dict) + attributes: list = dataclasses.field(default_factory=list) + + def clone(self): + return EvalModel(**dataclasses.asdict(self)) + + +class SafeAST(ast.NodeVisitor): + """AST-tree walker class.""" + + def __init__(self, model: EvalModel): + self.model = model + + def generic_visit(self, node): + """Check node, raise exception if node is not in whitelist.""" + if type(node).__name__ in self.model.nodes: + + if isinstance(node, ast.Attribute): + print(self.model.attributes) + print(node) + if node.attr not in self.model.attributes: + raise ValidationException( + 'Attribute {aname} is not allowed'.format( + aname=node.attr)) + + if isinstance(node, ast.Call): + if isinstance(node.func, ast.Name): + if node.func.id not in self.model.allowed_functions and \ + node.func.id not in self.model.imported_functions: + raise ValidationException( + 'Call to function {fname}() is not allowed'.format( + fname=node.func.id)) + else: + # Call to allowed function. good. No exception + pass + elif isinstance(node.func, ast.Attribute): + pass + # print("attr:", node.func.attr) + else: + raise ValidationException('Indirect function call') + + ast.NodeVisitor.generic_visit(self, node) + else: + raise ValidationException( + 'Node type {optype!r} is not allowed. (whitelist it manually)'.format( + optype=type(node).__name__)) + + +base_eval_model = EvalModel( + nodes=[ + # 123, 'asdf' + 'Num', 'Str', + # any expression or constant + 'Expression', 'Constant', + # == ... + 'Compare', 'Eq', 'NotEq', 'Gt', 'GtE', 'Lt', 'LtE', + # variable name + 'Name', 'Load', + 'BinOp', + 'Add', 'Sub', 'USub', + 'Subscript', 'Index', # person['name'] + 'BoolOp', 'And', 'Or', 'UnaryOp', 'Not', # True and True + 'In', 'NotIn', # "aaa" in i['list'] + 'IfExp', # for if expressions, like: expr1 if expr2 else expr3 + 'NameConstant', # for True and False constants + 'Div', 'Mod' + ], +) + + +mult_eval_model = base_eval_model.clone() +mult_eval_model.nodes.append('Mul') + + +class Expr(): + + def __init__(self, expr, model=None, filename=None): + + self.expr = expr + self.model = model or base_eval_model + + try: + self.node = ast.parse(self.expr, '', 'eval') + except SyntaxError as e: + raise CompilationException(e) + + v = SafeAST(model=self.model) + v.visit(self.node) + + self.code = compile(self.node, filename or '', 'eval') + + def safe_eval(self, ctx=None): + + try: + result = eval(self.code, self.model.imported_functions, ctx) + except Exception as e: + raise ExecutionException(e) + + return result + + def __str__(self): + return ('Expr(expr={expr!r})'.format(expr=self.expr)) diff --git a/ros2topic/ros2topic/verb/hz.py b/ros2topic/ros2topic/verb/hz.py index be2d71047..4c11f3564 100644 --- a/ros2topic/ros2topic/verb/hz.py +++ b/ros2topic/ros2topic/verb/hz.py @@ -29,8 +29,8 @@ # This file is originally from: # https://github.com/ros/ros_comm/blob/6e5016f4b2266d8a60c9a1e163c4928b8fc7115e/tools/rostopic/src/rostopic/__init__.py -from collections import defaultdict +from collections import defaultdict, OrderedDict import functools import math import threading @@ -48,7 +48,9 @@ from ros2topic.api import get_msg_class from ros2topic.api import positive_int from ros2topic.api import TopicNameCompleter +from ros2topic.eval import base_eval_model, Expr from ros2topic.verb import VerbExtension +from rosidl_runtime_py.convert import message_to_ordereddict DEFAULT_WINDOW_SIZE = 10000 @@ -91,18 +93,75 @@ def main(self, *, args): return main(args) +def _setup_base_safe_eval(): + safe_eval_model = base_eval_model.clone() + + # extend base_eval_model + safe_eval_model.nodes.extend(['Call', 'Attribute', 'List', 'Tuple', 'Dict', 'Set', + 'ListComp', 'DictComp', 'SetComp', 'comprehension', + 'Mult', 'Pow', 'boolop', 'mod', 'Invert', + 'Is', 'IsNot', 'FloorDiv', 'If', 'For']) + + # allow-list safe Python built-in functions + safe_builtins = [ + 'abs', 'all', 'any', 'bin', 'bool', 'chr', 'cmp', 'divmod', 'enumerate', + 'float', 'format', 'hex', 'id', 'int', 'isinstance', 'issubclass', + 'len', 'list', 'long', 'max', 'min', 'ord', 'pow', 'range', 'reversed', + 'round', 'slice', 'sorted', 'str', 'sum', 'tuple', 'type', 'unichr', + 'unicode', 'xrange', 'zip', 'filter', 'dict', 'set', 'next' + ] + + safe_eval_model.allowed_functions.extend(safe_builtins) + return safe_eval_model + + +def _get_nested_messages(msg_ordereddict): + """List all message field names recursively.""" + all_attributes = [] + for (k, v) in msg_ordereddict.items(): + all_attributes.append(k) + if type(v) is OrderedDict: + nested_attrs = _get_nested_messages(v) + all_attributes.extend(nested_attrs) + return all_attributes + + +def _setup_safe_eval(safe_eval_model, msg_class, topic): + # allow-list topic builtins, msg attributes + topic_builtins = [i for i in dir(topic) if not i.startswith('_')] + safe_eval_model.attributes.extend(topic_builtins) + + # recursively get all nested message attributes + # msg_class in this case is a prototype that needs to be instantiated to get + # an ordered dictionary of message fields + msg_ordereddict = message_to_ordereddict(msg_class()) + msg_attributes = _get_nested_messages(msg_ordereddict) + safe_eval_model.attributes.extend(msg_attributes) + return safe_eval_model + + def main(args): - topics = args.topic_name - if args.filter_expr: - def expr_eval(expr): - def eval_fn(m): - return eval(expr) - return eval_fn - filter_expr = expr_eval(args.filter_expr) - else: + with DirectNode(args) as node: + topics = args.topic_name filter_expr = None + # set up custom safe eval model for filter expression + if args.filter_expr: + safe_eval_model = _setup_base_safe_eval() + for topic in topics: + msg_class = get_msg_class( + node, topic, blocking=True, include_hidden_topics=True) + if msg_class is None: + continue + + safe_eval_model = _setup_safe_eval(safe_eval_model, msg_class, topic) + + def expr_eval(expr): + def eval_fn(m): + safe_expression = Expr(expr, model=safe_eval_model) + return eval(safe_expression.code) + return eval_fn + filter_expr = expr_eval(args.filter_expr) - with DirectNode(args) as node: _rostopic_hz(node.node, topics, qos_args=args, window_size=args.window_size, filter_expr=filter_expr, use_wtime=args.use_wtime) diff --git a/ros2topic/test/test_cli.py b/ros2topic/test/test_cli.py index b67786c84..b8184c6ab 100644 --- a/ros2topic/test/test_cli.py +++ b/ros2topic/test/test_cli.py @@ -938,6 +938,19 @@ def test_filtered_topic_hz(self): average_rate = float(average_rate_line_pattern.match(head_line).group(1)) assert math.isclose(average_rate, 0.5, rel_tol=1e-2) + # check that use of eval() on hz verb cannot be exploited + try: + self.launch_topic_command( + arguments=[ + 'hz', + '--filter', + '__import__("os").system("cat /etc/passwd")', + '/chatter' + ] + ) + except ValueError as e: + self.assertIn('Attribute system is not allowed', str(e)) + @launch_testing.markers.retry_on_failure(times=5, delay=1) def test_topic_bw(self): with self.launch_topic_command(arguments=['bw', '/defaults']) as topic_command: