diff --git a/src/modeci_mdf/execution_engine.py b/src/modeci_mdf/execution_engine.py index e2abde9c..2f55cdd8 100644 --- a/src/modeci_mdf/execution_engine.py +++ b/src/modeci_mdf/execution_engine.py @@ -12,6 +12,7 @@ """ import ast +import builtins import copy import functools import inspect @@ -53,7 +54,7 @@ FORMAT_DEFAULT = FORMAT_NUMPY -KNOWN_PARAMETERS = ["constant"] +KNOWN_PARAMETERS = ["constant", "math", "numpy"] + dir(builtins) time_scale_str_regex = r"(TimeScale)?\.(.*)" @@ -778,8 +779,20 @@ def __init__(self, node: Node, verbose: Optional[bool] = False): # If we are dealing with a list of symbols, each must treated separately all_req_vars.extend( - get_required_variables_from_expression(arg_expr) + [ + v + for v in get_required_variables_from_expression(arg_expr) + if v not in f.args + ] ) + if f.value is not None: + all_req_vars.extend( + [ + v + for v in get_required_variables_from_expression(f.value) + if f.args is None or v not in f.args + ] + ) all_present = [v in all_known_vars for v in all_req_vars] func_missing_vars[f.id] = { @@ -839,14 +852,26 @@ def __init__(self, node: Node, verbose: Optional[bool] = False): all_req_vars = [] if p.value is not None and type(p.value) == str: - all_req_vars.extend(get_required_variables_from_expression(p.value)) + all_req_vars.extend( + [ + v + for v in get_required_variables_from_expression(p.value) + if p.args is None or v not in p.args + ] + ) if p.args is not None: for arg in p.args: arg_expr = p.args[arg] if isinstance(arg_expr, str): all_req_vars.extend( - get_required_variables_from_expression(arg_expr) + [ + v + for v in get_required_variables_from_expression( + arg_expr + ) + if v not in p.args + ] ) all_known_vars_plus_this = all_known_vars + [p.id] diff --git a/tests/test_execution.py b/tests/test_execution.py index 22b7e187..6b30328c 100644 --- a/tests/test_execution.py +++ b/tests/test_execution.py @@ -22,11 +22,9 @@ {"slope": 2, "intercept": 4, "variable0": "input"}, {"slope": 2, "intercept": 4, "variable0": "input"}, {"slope": "2 * input", "intercept": 4, "variable0": "input"}, - # expressions as arg values referencing other args is not currently supported - pytest.param( - {"slope": 2, "intercept": "2 * slope", "variable0": "input"}, - marks=pytest.mark.xfail, - ), + {"slope": 2, "intercept": "2 * slope", "variable0": "input"}, + {"slope": "math.sqrt(4)", "intercept": 4, "variable0": 1}, + {"slope": "numpy.sqrt(4)", "intercept": 4, "variable0": 1}, ], ) def test_single_function_variations(create_model, args, function, value, result): @@ -85,3 +83,42 @@ def test_condition_variations(create_model, node_specific, termination, result): eg.evaluate(initializer={"A_input": 0}) assert eg.enodes["A"].evaluable_outputs["A_output"].curr_value == result + + +def test_dependency_in_function_value(create_model): + m = create_model( + [ + mdf.Node( + id="N", + input_ports=[mdf.InputPort(id="input")], + functions=[ + mdf.Function(id="f", value="g"), + mdf.Function(id="g", value="1"), + ], + output_ports=[mdf.OutputPort(id="output", value="f")], + ) + ] + ) + + eg = EvaluableGraph(m.graphs[0]) + eg.evaluate(initializer={"input": 1}) + + +# NOTE: this is enabled by "don't include known args in required variable" +# but could not be tested until dependency checking for value +def test_available_arg_in_function_value(create_model): + m = create_model( + [ + mdf.Node( + id="N", + input_ports=[mdf.InputPort(id="input")], + functions=[ + mdf.Function(id="f", args={"g": 1}, value="g"), + ], + output_ports=[mdf.OutputPort(id="output", value="f")], + ) + ] + ) + + eg = EvaluableGraph(m.graphs[0]) + eg.evaluate(initializer={"input": 1})