Skip to content

Commit

Permalink
Don't jump if any of jump's items isn't permitted to execute
Browse files Browse the repository at this point in the history
  • Loading branch information
soininen committed Sep 3, 2024
1 parent 6876546 commit 6ba35fc
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 35 deletions.
27 changes: 20 additions & 7 deletions spine_engine/spine_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,15 +101,14 @@ def __init__(
Raises:
EngineInitFailed: Raised if initialization fails
"""
super().__init__()
self._queue = mp.Queue()
if items is None:
items = {}
self._items = items
if execution_permits is None:
execution_permits = {}
self._execution_permits = execution_permits
connections = list(map(Connection.from_dict, connections)) # Deserialize connections
connections = list(map(Connection.from_dict, connections))
project_item_loader = ProjectItemLoader()
self._executable_item_classes = project_item_loader.load_executable_item_classes(items_module_name)
required_items = required_items_for_execution(
Expand Down Expand Up @@ -137,8 +136,11 @@ def __init__(
self._item_names = list(self._dag) # Names of permitted items and their neighbors
if jumps is None:
jumps = []
self._jumps = list(map(Jump.from_dict, jumps))
validate_jumps(self._jumps, self._dag)
else:
jumps = list(map(Jump.from_dict, jumps))
items_by_jump = _get_items_by_jump(jumps, self._dag)
self._jumps = filter_unneeded_jumps(jumps, items_by_jump, execution_permits)
validate_jumps(self._jumps, items_by_jump, self._dag)
for x in self._connections + self._jumps:
x.make_logger(self._queue)
for x in self._jumps:
Expand Down Expand Up @@ -761,14 +763,25 @@ def _validate_dag(dag):
raise EngineInitFailed("DAG contains unconnected items.")


def validate_jumps(jumps, dag):
def filter_unneeded_jumps(jumps, items_by_jump, execution_permits):
"""Drops jumps whose items are not going to be executed.
Args:
jumps (Iterable of Jump): jumps to filter
items_by_jump (dict): mapping from jump to list of item names
execution_permits (dict): mapping from item name to boolean telling if its is permitted to execute
"""
return [jump for jump in jumps if all(execution_permits[item] for item in items_by_jump[jump])]


def validate_jumps(jumps, items_by_jump, dag):
"""Raises an exception in case jumps are not valid.
Args:
jumps (list of Jump): jumps
items_by_jump (dict): mapping from jump to list of item names
dag (DiGraph): jumps' DAG
"""
items_by_jump = _get_items_by_jump(jumps, dag)
for jump in jumps:
validate_single_jump(jump, jumps, dag, items_by_jump)

Expand All @@ -778,7 +791,7 @@ def validate_single_jump(jump, jumps, dag, items_by_jump=None):
Args:
jump (Jump): the jump to check
jumps (list of Jump): all jumps in dag
jumps (list of Jump): all jumps in DAG
dag (DiGraph): jumps' DAG
items_by_jump (dict, optional): mapping jumps to a set of items in between destination and source
"""
Expand Down
6 changes: 3 additions & 3 deletions spine_engine/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,7 @@ def make_connections(connections, permitted_items):
list of Connection: List of permitted Connections or an empty list if the DAG contains no connections
"""
if not connections:
return list()
return []
connections = connections_to_selected_items(connections, permitted_items)
return connections

Expand Down Expand Up @@ -433,10 +433,10 @@ def dag_edges(connections):
Returns:
dict: DAG edges. Mapping of source item (node) to a list of destination items (nodes)
"""
edges = dict()
edges = {}
for connection in connections:
source, destination = connection.source, connection.destination
edges.setdefault(source, list()).append(destination)
edges.setdefault(source, []).append(destination)
return edges


Expand Down
63 changes: 38 additions & 25 deletions tests/test_spine_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -869,31 +869,44 @@ def test_nested_jump_with_inner_self_jump(self):
self._assert_resource_args(item_c.execute.call_args_list, expected)

def test_stopping_execution_in_the_middle_of_a_loop_does_not_leave_multithread_executor_running(self):
with TemporaryDirectory() as temp_dir:
item_a = self._mock_item("a")
item_b = self._mock_item("b")
item_instances = {"a": [item_a, item_a, item_a, item_a], "b": [item_b]}
items = {
"a": {"type": "TestItem"},
"b": {"type": "TestItem"},
}
connections = [c.to_dict() for c in (Connection("a", "right", "b", "left"),)]
jumps = [Jump("a", "right", "a", "right", self._LOOP_FOREVER).to_dict()]
engine = self._create_engine(items, connections, item_instances, jumps=jumps)

def execute_item_a(loop_counter, *args, **kwargs):
if loop_counter[0] == 2:
engine.stop()
return ItemExecutionFinishState.STOPPED
loop_counter[0] += 1
return ItemExecutionFinishState.SUCCESS

loop_counter = [0]
item_a.execute.side_effect = partial(execute_item_a, loop_counter)
engine.run()
self.assertEqual(engine.state(), SpineEngineState.USER_STOPPED)
self.assertEqual(item_a.execute.call_count, 3)
item_b.execute.assert_not_called()
item_a = self._mock_item("a")
item_b = self._mock_item("b")
item_instances = {"a": [item_a, item_a, item_a, item_a], "b": [item_b]}
items = {
"a": {"type": "TestItem"},
"b": {"type": "TestItem"},
}
connections = [c.to_dict() for c in (Connection("a", "right", "b", "left"),)]
jumps = [Jump("a", "right", "a", "right", self._LOOP_FOREVER).to_dict()]
engine = self._create_engine(items, connections, item_instances, jumps=jumps)

def execute_item_a(loop_counter, *args, **kwargs):
if loop_counter[0] == 2:
engine.stop()
return ItemExecutionFinishState.STOPPED
loop_counter[0] += 1
return ItemExecutionFinishState.SUCCESS

loop_counter = [0]
item_a.execute.side_effect = partial(execute_item_a, loop_counter)
engine.run()
self.assertEqual(engine.state(), SpineEngineState.USER_STOPPED)
self.assertEqual(item_a.execute.call_count, 3)
item_b.execute.assert_not_called()

def test_executing_loop_source_item_only_does_not_execute_the_loop(self):
item_a = self._mock_item("a")
item_b = self._mock_item("b")
item_instances = {"a": [item_a, item_a, item_a, item_a], "b": [item_b, item_b, item_b, item_b]}
items = {
"a": {"type": "TestItem"},
"b": {"type": "TestItem"},
}
connections = [c.to_dict() for c in (Connection("a", "right", "b", "left"),)]
jumps = [Jump("b", "right", "a", "right", self._LOOP_FOREVER).to_dict()]
self._run_engine(items, connections, item_instances, execution_permits={"a": False, "b": True}, jumps=jumps)
self.assertEqual(item_a.execute.call_count, 0)
self.assertEqual(item_b.execute.call_count, 1)

def _assert_resource_args(self, arg_packs, expected_packs):
self.assertEqual(len(arg_packs), len(expected_packs))
Expand Down

0 comments on commit 6ba35fc

Please sign in to comment.