Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 29 additions & 4 deletions src/modeci_mdf/execution_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

"""
import ast
import builtins
import copy
import functools
import inspect
Expand Down Expand Up @@ -53,7 +54,7 @@

FORMAT_DEFAULT = FORMAT_NUMPY

KNOWN_PARAMETERS = ["constant"]
KNOWN_PARAMETERS = ["constant", "math", "numpy"] + dir(builtins)


time_scale_str_regex = r"(TimeScale)?\.(.*)"
Expand Down Expand Up @@ -778,8 +779,20 @@ def __init__(self, node: Node, verbose: Optional[bool] = False):

# If we are dealing with a list of symbols, each must treated separately
all_req_vars.extend(
get_required_variables_from_expression(arg_expr)
[
v
for v in get_required_variables_from_expression(arg_expr)
if v not in f.args
]
)
if f.value is not None:
all_req_vars.extend(
[
v
for v in get_required_variables_from_expression(f.value)
if f.args is None or v not in f.args
]
)

all_present = [v in all_known_vars for v in all_req_vars]
func_missing_vars[f.id] = {
Expand Down Expand Up @@ -839,14 +852,26 @@ def __init__(self, node: Node, verbose: Optional[bool] = False):
all_req_vars = []

if p.value is not None and type(p.value) == str:
all_req_vars.extend(get_required_variables_from_expression(p.value))
all_req_vars.extend(
[
v
for v in get_required_variables_from_expression(p.value)
if p.args is None or v not in p.args
]
)

if p.args is not None:
for arg in p.args:
arg_expr = p.args[arg]
if isinstance(arg_expr, str):
all_req_vars.extend(
get_required_variables_from_expression(arg_expr)
[
v
for v in get_required_variables_from_expression(
arg_expr
)
if v not in p.args
]
)

all_known_vars_plus_this = all_known_vars + [p.id]
Expand Down
47 changes: 42 additions & 5 deletions tests/test_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,9 @@
{"slope": 2, "intercept": 4, "variable0": "input"},
{"slope": 2, "intercept": 4, "variable0": "input"},
{"slope": "2 * input", "intercept": 4, "variable0": "input"},
# expressions as arg values referencing other args is not currently supported
pytest.param(
{"slope": 2, "intercept": "2 * slope", "variable0": "input"},
marks=pytest.mark.xfail,
),
{"slope": 2, "intercept": "2 * slope", "variable0": "input"},
{"slope": "math.sqrt(4)", "intercept": 4, "variable0": 1},
{"slope": "numpy.sqrt(4)", "intercept": 4, "variable0": 1},
],
)
def test_single_function_variations(create_model, args, function, value, result):
Expand Down Expand Up @@ -85,3 +83,42 @@ def test_condition_variations(create_model, node_specific, termination, result):
eg.evaluate(initializer={"A_input": 0})

assert eg.enodes["A"].evaluable_outputs["A_output"].curr_value == result


def test_dependency_in_function_value(create_model):
m = create_model(
[
mdf.Node(
id="N",
input_ports=[mdf.InputPort(id="input")],
functions=[
mdf.Function(id="f", value="g"),
mdf.Function(id="g", value="1"),
],
output_ports=[mdf.OutputPort(id="output", value="f")],
)
]
)

eg = EvaluableGraph(m.graphs[0])
eg.evaluate(initializer={"input": 1})


# NOTE: this is enabled by "don't include known args in required variable"
# but could not be tested until dependency checking for value
def test_available_arg_in_function_value(create_model):
m = create_model(
[
mdf.Node(
id="N",
input_ports=[mdf.InputPort(id="input")],
functions=[
mdf.Function(id="f", args={"g": 1}, value="g"),
],
output_ports=[mdf.OutputPort(id="output", value="f")],
)
]
)

eg = EvaluableGraph(m.graphs[0])
eg.evaluate(initializer={"input": 1})