Skip to content

Commit

Permalink
feat: Memoization (#38)
Browse files Browse the repository at this point in the history
- added a memoization map that is shared between processes
- added function to either call a function or return the memoized result
- added function to convert a list to a tuple recursively and return the
last modification time for a file
- added tests

---------

Co-authored-by: megalinter-bot <[email protected]>
Co-authored-by: Lars Reimann <[email protected]>
  • Loading branch information
3 people authored Jan 24, 2024
1 parent dae57dc commit 2a26b48
Show file tree
Hide file tree
Showing 2 changed files with 182 additions and 0 deletions.
99 changes: 99 additions & 0 deletions src/safeds_runner/server/pipeline_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
import queue
import runpy
import threading
import typing
from functools import cached_property
from multiprocessing.managers import SyncManager
from pathlib import Path
from typing import Any

import simple_websocket
Expand All @@ -25,6 +27,8 @@
)
from safeds_runner.server.module_manager import InMemoryFinder

MemoizationMap: typing.TypeAlias = dict[tuple[str, tuple[Any], tuple[Any]], Any]


class PipelineManager:
"""
Expand Down Expand Up @@ -56,6 +60,10 @@ def _messages_queue_thread(self) -> threading.Thread:
daemon=True,
)

@cached_property
def _memoization_map(self) -> MemoizationMap:
return self._multiprocessing_manager.dict() # type: ignore[return-value]

def startup(self) -> None:
"""
Prepare the runner for running Safe-DS pipelines.
Expand Down Expand Up @@ -132,6 +140,7 @@ def execute_pipeline(
execution_id,
self._messages_queue,
self._placeholder_map[execution_id],
self._memoization_map,
)
process.execute()

Expand Down Expand Up @@ -176,6 +185,7 @@ def __init__(
execution_id: str,
messages_queue: queue.Queue[Message],
placeholder_map: dict[str, Any],
memoization_map: MemoizationMap,
):
"""
Create a new process which will execute the given pipeline, when started.
Expand All @@ -190,11 +200,14 @@ def __init__(
A queue to write outgoing messages to.
placeholder_map : dict[str, Any]
A map to save calculated placeholders in.
memoization_map : MemoizationMap
A map to save memoizable functions in.
"""
self._pipeline = pipeline
self._id = execution_id
self._messages_queue = messages_queue
self._placeholder_map = placeholder_map
self._memoization_map = memoization_map
self._process = multiprocessing.Process(target=self._execute, daemon=True)

def _send_message(self, message_type: str, value: dict[Any, Any] | str) -> None:
Expand Down Expand Up @@ -222,6 +235,17 @@ def save_placeholder(self, placeholder_name: str, value: Any) -> None:
create_placeholder_description(placeholder_name, placeholder_type),
)

def get_memoization_map(self) -> MemoizationMap:
"""
Get the shared memoization map.
Returns
-------
MemoizationMap
Memoization Map
"""
return self._memoization_map

def _execute(self) -> None:
logging.info(
"Executing %s.%s.%s...",
Expand Down Expand Up @@ -278,6 +302,81 @@ def runner_save_placeholder(placeholder_name: str, value: Any) -> None:
current_pipeline.save_placeholder(placeholder_name, value)


def runner_memoized_function_call(
function_name: str,
function_callable: typing.Callable,
parameters: list[Any],
hidden_parameters: list[Any],
) -> Any:
"""
Call a function that can be memoized and save the result.
If a function has been previously memoized, the previous result may be reused.
Parameters
----------
function_name : str
Fully qualified function name
function_callable : typing.Callable
Function that is called and memoized if the result was not found in the memoization map
parameters : list[Any]
List of parameters for the function
hidden_parameters : list[Any]
List of hidden parameters for the function. This is used for memoizing some impure functions.
Returns
-------
Any
The result of the specified function, if any exists
"""
if current_pipeline is None:
return None # pragma: no cover
memoization_map = current_pipeline.get_memoization_map()
key = (function_name, _convert_list_to_tuple(parameters), _convert_list_to_tuple(hidden_parameters))
if key in memoization_map:
return memoization_map[key]
result = function_callable(*parameters)
memoization_map[key] = result
return result


def _convert_list_to_tuple(values: list) -> tuple:
"""
Recursively convert a mutable list of values to an immutable tuple containing the same values, to make the values hashable.
Parameters
----------
values : list
Values that should be converted to a tuple
Returns
-------
tuple
Converted list containing all the elements of the provided list
"""
return tuple(_convert_list_to_tuple(value) if isinstance(value, list) else value for value in values)


def runner_filemtime(filename: str) -> int | None:
"""
Get the last modification timestamp of the provided file.
Parameters
----------
filename: str
Name of the file
Returns
-------
int | None
Last modification timestamp if the provided file exists, otherwise None
"""
try:
return Path(filename).stat().st_mtime_ns
except FileNotFoundError:
return None


def get_backtrace_info(error: BaseException) -> list[dict[str, Any]]:
"""
Create a simplified backtrace from an exception.
Expand Down
83 changes: 83 additions & 0 deletions tests/safeds_runner/server/test_memoization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import tempfile
import typing
from datetime import UTC, datetime
from queue import Queue
from typing import Any

import pytest
from safeds_runner.server import pipeline_manager
from safeds_runner.server.messages import MessageDataProgram, ProgramMainInformation
from safeds_runner.server.pipeline_manager import PipelineProcess


@pytest.mark.parametrize(
argnames="function_name,params,hidden_params,expected_result",
argvalues=[
("function_pure", [1, 2, 3], [], "abc"),
("function_impure_readfile", ["filea.txt"], [1234567891], "abc"),
],
ids=["function_pure", "function_impure_readfile"],
)
def test_memoization_already_present_values(
function_name: str,
params: list,
hidden_params: list,
expected_result: Any,
) -> None:
pipeline_manager.current_pipeline = PipelineProcess(
MessageDataProgram({}, ProgramMainInformation("", "", "")),
"",
Queue(),
{},
{},
)
pipeline_manager.current_pipeline.get_memoization_map()[
(
function_name,
pipeline_manager._convert_list_to_tuple(params),
pipeline_manager._convert_list_to_tuple(hidden_params),
)
] = expected_result
result = pipeline_manager.runner_memoized_function_call(function_name, lambda *_: None, params, hidden_params)
assert result == expected_result


@pytest.mark.parametrize(
argnames="function_name,function,params,hidden_params,expected_result",
argvalues=[
("function_pure", lambda a, b, c: a + b + c, [1, 2, 3], [], 6),
("function_impure_readfile", lambda filename: filename.split(".")[0], ["abc.txt"], [1234567891], "abc"),
],
ids=["function_pure", "function_impure_readfile"],
)
def test_memoization_not_present_values(
function_name: str,
function: typing.Callable,
params: list,
hidden_params: list,
expected_result: Any,
) -> None:
pipeline_manager.current_pipeline = PipelineProcess(
MessageDataProgram({}, ProgramMainInformation("", "", "")),
"",
Queue(),
{},
{},
)
# Save value in map
result = pipeline_manager.runner_memoized_function_call(function_name, function, params, hidden_params)
assert result == expected_result
# Test if value is actually saved by calling another function that does not return the expected result
result2 = pipeline_manager.runner_memoized_function_call(function_name, lambda *_: None, params, hidden_params)
assert result2 == expected_result


def test_file_mtime_exists() -> None:
with tempfile.NamedTemporaryFile() as file:
file_mtime = pipeline_manager.runner_filemtime(file.name)
assert file_mtime is not None


def test_file_mtime_not_exists() -> None:
file_mtime = pipeline_manager.runner_filemtime(f"file_not_exists.{datetime.now(tz=UTC).timestamp()}")
assert file_mtime is None

0 comments on commit 2a26b48

Please sign in to comment.