diff --git a/README.rst b/README.rst index 1becc86..464f133 100644 --- a/README.rst +++ b/README.rst @@ -482,6 +482,86 @@ You can add your own classes & limit access to attrs: will now allow access to `foo.bar` but not allow anything else. +Assignment Support +------------------ + +If you want to allow modification of the `names` dictionary using assignment or augmented assignment +(`=`, `+=`, etc.), set `assign_modify_names=True`. + +.. code-block:: pycon + + >>> names = dict() + >>> simple_eval("a = 10 * 2", names=names, assign_modify_names=True) + >>> print(names['a']) + 20 + + >>> names = dict(a=10) + >>> simple_eval("a += 5", names=names, assign_modify_names=True) + >>> print(names['a']) + 15 + +When using the `SimpleEval` class, updated values are available in the `.results` attribute: + +.. code-block:: pycon + + >>> s = SimpleEval(names=dict(a=10, b=5), assign_modify_names=True) + >>> s.eval("b += a") + >>> print(s.results) + {'b': 15} + +Note: Assignment to attributes (e.g., `a.b = 1`) or tuples (e.g., `a, b = (1, 2)`) is not supported. + +Multiple Expressions +-------------------- + +By default, only the first expression is evaluated. To evaluate multiple expressions separated +by `;` or newlines and return the last expression's result, set `multiple_expression_support=True`. + +.. code-block:: pycon + + >>> simple_eval("5 * 2; 6 * 2", multiple_expression_support=False) + 10 + + >>> simple_eval("5 * 2\n 6 * 2", multiple_expression_support=True) + 12 + +Combined with assignment: + +.. code-block:: pycon + + >>> simple_eval("a=5;b=2;a + b", multiple_expression_support=True, assign_modify_names=True) + 7 + +Attribute Chain Flattening +-------------------------- + +If `attr_chain_flattening=True`, then attributes can be treated as flat keys in `names`. + +.. code-block:: pycon + + >>> simple_eval("a + a.b", names={"a": 1, "a.b": 2}, attr_chain_flattening=True) + 3 + +If both a flat key and an actual attribute exist, the flat key takes precedence: + +.. code-block:: pycon + + >>> from types import SimpleNamespace + >>> simple_eval("a.attr", names={"a": SimpleNamespace(attr=True), "a.attr": False}, attr_chain_flattening=True) + False + +With assignment enabled, only flat keys are written, as attribute assignment is unsupported: + +.. code-block:: pycon + + >>> from types import SimpleNamespace + >>> names = {"a": SimpleNamespace(b=1)} + >>> simple_eval("a.b = 10", names=names, attr_chain_flattening=True, assign_modify_names=True) + >>> print(names['a.b']) + 10 + >>> print(names['a'].b) + 1 + Other... -------- diff --git a/simpleeval.py b/simpleeval.py index 9976064..06bf248 100644 --- a/simpleeval.py +++ b/simpleeval.py @@ -43,7 +43,8 @@ - impala2 (Kirill Stepanov) (massive _eval refactor) - gk (ugik) (Other iterables than str can DOS too, and can be made) - daveisfera (Dave Johansen) 'not' Boolean op, Pycharm, pep8, various other fixes -- xaled (Khalid Grandi) method chaining correctly, double-eval bugfix. +- xaled (Khalid Grandi) method chaining correctly, double-eval bugfix, + adding support for name assignments, multiple expressions and attribute chain flattening. - EdwardBetts (Edward Betts) spelling correction. - charlax (Charles-Axel Dein charlax) Makefile and cleanups - mommothazaz123 (Andrew Zhu) f"string" support, Python 3.8 support @@ -466,6 +467,9 @@ def safe_lshift(a, b): # pylint: disable=invalid-name DEFAULT_NAMES = {"True": True, "False": False, "None": None} ATTR_INDEX_FALLBACK = True +ATTR_CHAIN_FLATTENING = False +ASSIGN_MODIFY_NAMES = False +MULTIPLE_EXPRESSION_SUPPORT = False ######################################## @@ -481,7 +485,7 @@ class SimpleEval(object): # pylint: disable=too-few-public-methods expr = "" - def __init__(self, operators=None, functions=None, names=None, allowed_attrs=None): + def __init__(self, operators=None, functions=None, names=None, allowed_attrs=None, **options): """ Create the evaluator instance. Set up valid operators (+,-, etc) functions (add, random, get_val, whatever) and names.""" @@ -492,6 +496,7 @@ def __init__(self, operators=None, functions=None, names=None, allowed_attrs=Non functions = DEFAULT_FUNCTIONS.copy() if names is None: names = DEFAULT_NAMES.copy() + self.results = dict() # updated or set names self.operators = operators self.functions = functions @@ -535,7 +540,12 @@ def __init__(self, operators=None, functions=None, names=None, allowed_attrs=Non # Defaults: - self.ATTR_INDEX_FALLBACK = ATTR_INDEX_FALLBACK + self.ATTR_INDEX_FALLBACK = options.get("attr_index_fallback", ATTR_INDEX_FALLBACK) + self.attr_chain_flattening = options.get("attr_chain_flattening", ATTR_CHAIN_FLATTENING) + self.assign_modify_names = options.get("assign_modify_names", ASSIGN_MODIFY_NAMES) + self.multiple_expression_support = options.get( + "multiple_expression_support", MULTIPLE_EXPRESSION_SUPPORT + ) # Check for forbidden functions: @@ -546,33 +556,52 @@ def __init__(self, operators=None, functions=None, names=None, allowed_attrs=Non def __del__(self): self.nodes = None - @staticmethod - def parse(expr): + def parse(self, expr): """parse an expression into a node tree""" parsed = ast.parse(expr.strip()) if not parsed.body: raise InvalidExpression("Sorry, cannot evaluate empty string") - if len(parsed.body) > 1: - warnings.warn( - "'{}' contains multiple expressions. Only the first will be used.".format(expr), - MultipleExpressions, - ) - return parsed.body[0] + + if self.multiple_expression_support: + return parsed.body + else: + if len(parsed.body) > 1: + warnings.warn( + "'{}' contains multiple expressions. Only the first will be used.".format( + expr + ), + MultipleExpressions, + ) + return parsed.body[0] def eval(self, expr, previously_parsed=None): """evaluate an expression, using the operators, functions and names previously set up.""" + # clear results + self.results.clear() # set a copy of the expression aside, so we can give nice errors... self.expr = expr - return self._eval(previously_parsed or self.parse(expr)) + # parse + parsed_expressions = previously_parsed or self.parse(expr) + if not isinstance(parsed_expressions, list): + parsed_expressions = [parsed_expressions] + + ret = None + for parsed_expression in parsed_expressions: + ret = self._eval(parsed_expression) + + return ret def _eval(self, node): """The internal evaluator used on each node in the parsed tree.""" + if self.attr_chain_flattening and isinstance(node, ast.Attribute): + node = self._flatten_expr(node) + try: handler = self.nodes[type(node)] except KeyError: @@ -586,16 +615,33 @@ def _eval_expr(self, node): return self._eval(node.value) def _eval_assign(self, node): - warnings.warn( - "Assignment ({}) attempted, but this is ignored".format(self.expr), AssignmentAttempted - ) - return self._eval(node.value) + # Raise assignment attempt warnings before node evaluation to align with test case expectations + if not self.assign_modify_names: + warnings.warn( + "Assignment ({}) attempted, but this is ignored".format(self.expr), + AssignmentAttempted, + ) + + evaluated_value = self._eval(node.value) + if self.assign_modify_names: + for target in node.targets: + self._assign_value(target, evaluated_value) + + return evaluated_value def _eval_aug_assign(self, node): - warnings.warn( - "Assignment ({}) attempted, but this is ignored".format(self.expr), AssignmentAttempted - ) - return self._eval(node.value) + # Raise assignment attempt warnings before node evaluation to align with test case expectations + if not self.assign_modify_names: + warnings.warn( + "Assignment ({}) attempted, but this is ignored".format(self.expr), + AssignmentAttempted, + ) + + evaluated_value = self._eval(node.value) + if self.assign_modify_names: + evaluated_value = self._aug_assign_value(node.target, node.op, evaluated_value) + + return evaluated_value @staticmethod def _eval_import(node): @@ -804,6 +850,84 @@ def _eval_formattedvalue(self, node): return fmt.format(self._eval(node.value)) return self._eval(node.value) + def _flatten_expr(self, expr_node): + chain = self._get_attr_chain(expr_node) + + if chain: + flattened = self._flatten_chain(chain, ctx=expr_node.ctx) + if flattened: + return flattened + return expr_node + + @staticmethod + def _get_attr_chain(node): + """Recursively collect attribute chain from the AST node.""" + chain = [] + while isinstance(node, ast.Attribute): + chain.append(node.attr) + node = node.value + if isinstance(node, ast.Name): + chain.append(node.id) + chain.reverse() + return chain + return None + + def _flatten_chain(self, chain, ctx=None): + """Try to find the longest prefix of the chain that exists in names""" + for i in range(len(chain), 0, -1): + prefix = ".".join(chain[:i]) + if prefix in self.names: + if i == len(chain): + # Fully matched + return ast.Name(id=prefix, ctx=ctx) + else: + # Partially matched + base = ast.Name(id=prefix, ctx=ctx) + for attr in chain[i:]: + base = ast.Attribute(value=base, attr=attr, ctx=ctx) + return base + return None # No flattening + + def _assign_value(self, target, value): + if isinstance(target, ast.Name): + self._assign_update(target.id, value) + return + + if isinstance(target, ast.Attribute) and self.attr_chain_flattening: + chain = self._get_attr_chain(target) + if chain: + self._assign_update(".".join(chain), value) + return + + raise FeatureNotAvailable(f"Sorry, {type(target)} Assign is not available.") + + def _aug_assign_value(self, target, operation, value): + def calculate_new_value(_target_value): + try: + operator = self.operators[type(operation)] + except KeyError: + raise OperatorNotDefined(operation, self.expr) + return operator(_target_value, value) + + if isinstance(target, ast.Name): + value = calculate_new_value(self.names[target.id]) + self._assign_update(target.id, value) + return value + + if isinstance(target, ast.Attribute) and self.attr_chain_flattening: + chain = self._get_attr_chain(target) + if chain: + key = ".".join(chain) + value = calculate_new_value(self.names[key]) + self._assign_update(key, value) + return value + + raise FeatureNotAvailable(f"Sorry, {type(target)} Aug Assign is not available.") + + def _assign_update(self, name, value): + self.names[name] = value + self.results[name] = value + class EvalWithCompoundTypes(SimpleEval): """ @@ -920,12 +1044,13 @@ def do_generator(gi=0): return to_return -def simple_eval(expr, operators=None, functions=None, names=None, allowed_attrs=None): +def simple_eval(expr, operators=None, functions=None, names=None, allowed_attrs=None, **options): """Simply evaluate an expression""" s = SimpleEval( operators=operators, functions=functions, names=names, allowed_attrs=allowed_attrs, + **options, ) return s.eval(expr) diff --git a/test_simpleeval.py b/test_simpleeval.py index 42c384e..1da7629 100644 --- a/test_simpleeval.py +++ b/test_simpleeval.py @@ -930,6 +930,7 @@ def test_dict_attr_access_disabled(self): def test_object(self): """using an object for name lookup""" + # pylint: disable=attribute-defined-outside-init class TestObject(object): @@ -1424,5 +1425,296 @@ def bar(self): simple_eval(evil, names={"foo": Foo()}, allowed_attrs=extended_attrs) +class TestAttrChainFlattening(DRYTest): + class Namespace: + def __init__(self, **kwargs): + for k, v in kwargs.items(): + setattr(self, k, v) + + class Point: + def __init__(self, x, y): + self.x = x + self.y = y + + def sum(self): + return self.x + self.y + + def product(self): + return self.x * self.y + + def sub(self): + return self.y - self.x + + def _parse_flatten_expr(self, code): + tree = ast.parse(code) + return self.s._flatten_expr(tree.body[0].value) + + def assertIsAttributeNode(self, r, attr=None): + self.assertIsInstance(r, ast.Attribute) + if attr is not None: + self.assertEqual(r.attr, attr) + + def assertIsNameNode(self, r, name=None): + self.assertIsInstance(r, ast.Name) + if name is not None: + self.assertEqual(r.id, name) + + def test_flatten_expr_func(self): + self.s.names.update( + { + "a": 40, + "a.b": 43, + "a.b.c": 44, + } + ) + # 1 parts chain in self.names + result = self._parse_flatten_expr("a") + self.assertIsNameNode(result, "a") + + # 2 parts chain in self.names + result = self._parse_flatten_expr("a.b") + self.assertIsNameNode(result, "a.b") + + # 3 parts chain in self.names + result = self._parse_flatten_expr("a.b.c") + self.assertIsNameNode(result, "a.b.c") + + # 3 parts chain, only 2 in self.names + result = self._parse_flatten_expr("a.b.d") + self.assertIsAttributeNode(result, "d") + self.assertIsNameNode(result.value, "a.b") + + # Name that does not exist in self.names + result = self._parse_flatten_expr("x") + self.assertIsNameNode(result, "x") + + # Ends with a chain that exist n self.names, should not be processed + result = self._parse_flatten_expr("x.a.b.c") + self.assertIsAttributeNode(result, "c") + self.assertIsAttributeNode(result.value, "b") + self.assertIsAttributeNode(result.value.value, "a") + self.assertIsNameNode(result.value.value.value, "x") + + def test_attribute_flattening_simple(self): + self.s.attr_chain_flattening = True + ns = self.Namespace + + self.s.names.update( + {"a": 40, "a.b": 43, "a.b.c": 44, "a.c": ns(d=46), "x": 45, "y": ns(a=ns(b=ns(c=47)))} + ) + + self.t("a", 40) + self.t("a.b", 43) + self.t("a.b.c", 44) + self.t("a.c.d", 46) + self.t("x", 45) + self.t("y.a.b.c", 47) + + def test_attribute_flattening_complex(self): + self.s.attr_chain_flattening = True + ns = self.Namespace + pt = self.Point + + self.s.names.update( + { + "a": 40, + "a.b": 43, + "a.b.c": 44, + "a.c": ns(d=46, pt1=pt(45, 46), pt2=pt(11, 13)), + "p": pt(14, 45), + "q": pt(78, 91), + "x": 45, + "y": ns(a=ns(b=ns(c=47, d=pt(47, 12))), b=pt(11, 12), c=pt(15, 44)), + } + ) + + self.t("a + a.b", 83) + self.t("a * a.b.c", 1760) + self.t("q.sum()", 169) + self.t("p.product()", 630) + self.t("a.c.pt1.product()", 2070) + self.t("y.a.b.d.sub()", -35) + self.t("a + a.b.c - a.c.d + a.c.pt1.x * a.c.pt2.sum() - x - y.c.y * y.a.b.d.sub()", 2613) + + +class TestAssignModifyNames(DRYTest): + def test_assign_simple(self): + self.s.assign_modify_names = True + + self.s.names.update( + { + "a": 40, + "b": 30, + } + ) + + self.t("c = a + b", 70) # simple assign + self.assertIn("c", self.s.names) + self.assertIn("c", self.s.results) + self.assertEqual(self.s.names["c"], 70) + self.assertEqual(self.s.results["c"], 70) + + self.t("x = y = a + b", 70) # multiple targets + self.assertIn("x", self.s.names) + self.assertIn("x", self.s.results) + self.assertIn("y", self.s.names) + self.assertIn("y", self.s.results) + self.assertEqual(self.s.results["x"], 70) + self.assertEqual(self.s.results["y"], 70) + + with self.assertRaises(FeatureNotAvailable): # attribute assign + self.s.eval("obj.attr = a + b") + + with self.assertRaises(FeatureNotAvailable): # Tuple assign + self.s.eval("z, w = a + b") + + self.t("a = a + b", 70) # update value + self.assertIn("a", self.s.names) + self.assertIn("a", self.s.results) + self.assertEqual(self.s.results["a"], 70) + + def test_assign_with_flatten_names(self): + self.s.assign_modify_names = True + self.s.attr_chain_flattening = True + + self.t("a.b = 100", 100) # simple assign + self.assertIn("a.b", self.s.names) + self.assertIn("a.b", self.s.results) + self.assertEqual(self.s.names["a.b"], 100) + self.assertEqual(self.s.results["a.b"], 100) + + self.t("a.c = a.b * 2", 200) # simple assign with flatten name in the expr + self.assertIn("a.c", self.s.results) + self.assertEqual(self.s.results["a.c"], 200) + + self.t("b.a = c.a = y = 70", 70) # multiple targets + self.assertIn("b.a", self.s.results) + self.assertIn("c.a", self.s.results) + self.assertIn("y", self.s.results) + self.assertEqual(self.s.results["b.a"], 70) + self.assertEqual(self.s.results["c.a"], 70) + self.assertEqual(self.s.results["y"], 70) + + with self.assertRaises(FeatureNotAvailable): # attribute assign + self.s.eval("a.b.func().c = 70") + + with self.assertRaises(FeatureNotAvailable): # Tuple assign + self.s.eval("z.y, w.x = 70") + + self.t("a.b = a.b - 30", 70) # update value + self.assertEqual(self.s.names["a.b"], 70) + + def test_multiple_assigns(self): + self.s.assign_modify_names = True + self.s.attr_chain_flattening = True + self.s.multiple_expression_support = True + + self.t("a = 10; a.b = 20;", 20) + self.assertEqual(self.s.names["a"], 10) + self.assertEqual(self.s.names["a.b"], 20) + + def test_aug_assign_simple(self): + self.s.assign_modify_names = True + + self.s.names.update( + { + "a": 40, + "b": 30, + } + ) + + self.t("a += b", 70) # simple aug assign + self.assertIn("a", self.s.results) + self.assertEqual(self.s.names["a"], 70) + self.assertEqual(self.s.results["a"], 70) + + with self.assertRaises(FeatureNotAvailable): # attribute assign + self.s.eval("obj.attr += a + b") + + self.t("a += 30", 100) # update value + self.assertIn("a", self.s.results) + self.assertEqual(self.s.results["a"], 100) + + def test_aug_assign_with_flatten_names(self): + self.s.assign_modify_names = True + self.s.attr_chain_flattening = True + + self.s.names.update( + { + "a.b": 40, + "a.c": 30, + } + ) + + self.t("a.b += 100", 140) # simple aug assign + self.assertIn("a.b", self.s.results) + self.assertEqual(self.s.results["a.b"], 140) + + self.t("a.c += a.b * 2", 310) # simple assign with flatten name in the expr + self.assertIn("a.c", self.s.results) + self.assertEqual(self.s.results["a.c"], 310) + + with self.assertRaises(FeatureNotAvailable): # attribute assign + self.s.eval("a.b.func().c += 70") + + def test_multiple_aug_assigns(self): + self.s.assign_modify_names = True + self.s.attr_chain_flattening = True + self.s.multiple_expression_support = True + + self.s.names.update( + { + "a": 40, + "a.c": 30, + } + ) + + self.t("a += a.c + 10; a.c += 20;", 50) + self.assertEqual(self.s.names["a"], 80) + self.assertEqual(self.s.names["a.c"], 50) + + def test_multiple_expression(self): + self.s.multiple_expression_support = True + self.s.assign_modify_names = True + + self.t("a = 5\nb = 10\na + b", 15) # with \n + self.t("a = 5;b = 10;a + b", 15) # with ; + + def test_options(self): + ns = TestAttrChainFlattening.Namespace + + # multiple_expression_support + # self.assertEqual(simple_eval("5 * 2; 6 * 2"), 10) # without + self.assertEqual(simple_eval("5 * 2; 6 * 2", multiple_expression_support=True), 12) # with + + # assign_modify_names + # names = dict() + # simple_eval("a = 10", names=names) + # self.assertIsNone(names.get('a')) # without + + names = dict() + simple_eval("a = 10", names=names, assign_modify_names=True) + self.assertEqual(names.get("a"), 10) # with + + # attr_chain_flattening + names = {"a": ns(b=5), "a.b": 10} + self.assertEqual(simple_eval("a.b", names=names), 5) # without + + names = {"a": ns(b=5), "a.b": 10} + self.assertEqual(simple_eval("a.b", names=names, attr_chain_flattening=True), 10) # with + + # evaluator + all options + names = {"a": ns(b=5, d=1), "a.b": 10, "a.c": 2} + evaluator = SimpleEval( + names=names, + attr_chain_flattening=True, + assign_modify_names=True, + multiple_expression_support=True, + ) + ret = evaluator.eval("c = a.b * a.c; d=a.d + c; d*=2") + self.assertEqual(ret, 42) + self.assertEqual(evaluator.results.get("d"), 42) + + if __name__ == "__main__": # pragma: no cover unittest.main()