Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Have client retry lost inputs #2600

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions modal/exception.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,14 @@ class ServerWarning(UserWarning):
"""Warning originating from the Modal server and re-issued in client code."""


class LostInputsError(Error):
"""Raised when the server reports that it is no longer processing specified inputs."""

def __init__(self, lost_inputs: list[str]):
self.lost_inputs = lost_inputs
super().__init__()


class _CliUserExecutionError(Exception):
"""mdmd:hidden
Private wrapper for exceptions during when importing or running stubs from the CLI.
Expand Down
19 changes: 14 additions & 5 deletions modal/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
ExecutionError,
FunctionTimeoutError,
InvalidError,
LostInputsError,
NotFoundError,
OutputExpiredError,
deprecation_warning,
Expand Down Expand Up @@ -180,7 +181,7 @@ async def create(
return _Invocation(client.stub, function_call_id, client, retry_context)

async def pop_function_call_outputs(
self, timeout: Optional[float], clear_on_success: bool
self, timeout: Optional[float], clear_on_success: bool, expected_input_ids: Optional[list[str]] = None
) -> api_pb2.FunctionGetOutputsResponse:
t0 = time.time()
if timeout is None:
Expand All @@ -196,13 +197,17 @@ async def pop_function_call_outputs(
last_entry_id="0-0",
clear_on_success=clear_on_success,
requested_at=time.time(),
expected_input_ids=expected_input_ids,
)
response: api_pb2.FunctionGetOutputsResponse = await retry_transient_errors(
self.stub.FunctionGetOutputs,
request,
attempt_timeout=backend_timeout + ATTEMPT_TIMEOUT_GRACE_PERIOD,
)

if response.lost_input_ids:
raise LostInputsError(list(response.lost_input_ids))

if len(response.outputs) > 0:
return response

Expand All @@ -225,10 +230,14 @@ async def _retry_input(self) -> None:
request,
)

async def _get_single_output(self) -> Any:
async def _get_single_output(self, expected_input_id: Optional[str] = None) -> Any:
# waits indefinitely for a single result for the function, and clear the outputs buffer after
item: api_pb2.FunctionGetOutputsItem = (
await self.pop_function_call_outputs(timeout=None, clear_on_success=True)
await self.pop_function_call_outputs(
timeout=None,
clear_on_success=True,
expected_input_ids=[expected_input_id] if expected_input_id else None,
)
).outputs[0]
return await _process_result(item.result, item.data_format, self.stub, self.client)

Expand All @@ -248,8 +257,8 @@ async def run_function(self) -> Any:

while True:
try:
return await self._get_single_output()
except (UserCodeException, FunctionTimeoutError) as exc:
return await self._get_single_output(ctx.input_id)
except (UserCodeException, FunctionTimeoutError, LostInputsError) as exc:
await user_retry_manager.raise_or_sleep(exc)
await self._retry_input()

Expand Down
5 changes: 5 additions & 0 deletions modal_proto/api.proto
Original file line number Diff line number Diff line change
Expand Up @@ -1457,13 +1457,18 @@ message FunctionGetOutputsRequest {
string last_entry_id = 6;
bool clear_on_success = 7; // expires *any* remaining outputs soon after this call, not just the returned ones
double requested_at = 8; // Used for waypoints.
// The inputs ids the client expects the server to be processing. This is optional and used for sync inputs only.
repeated string expected_input_ids = 9;
}

message FunctionGetOutputsResponse {
repeated int32 idxs = 3;
repeated FunctionGetOutputsItem outputs = 4;
string last_entry_id = 5;
int32 num_unfinished_inputs = 6;
// A subset of expected_input_ids in the request which the server has no record of.
// Client should retry these inputs. Used for sync inputs only.
repeated string lost_input_ids = 7;
}

message FunctionGetRequest {
Expand Down
9 changes: 9 additions & 0 deletions test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ def __init__(self, blob_host, blobs, credentials):
self.done = False
self.rate_limit_sleep_duration = None
self.fail_get_inputs = False
self.fail_get_outputs_with_lost_inputs = False
self.slow_put_inputs = False
self.container_inputs = []
self.container_outputs = []
Expand Down Expand Up @@ -1106,6 +1107,14 @@ async def FunctionGetOutputs(self, stream):
input_id=input_id, idx=idx, result=result, data_format=api_pb2.DATA_FORMAT_PICKLE
)

if self.fail_get_outputs_with_lost_inputs:
# We fail the output after invoking the input's function because so our tests use the number of function
# invocations to assert the function was retried the correct number of times.
await stream.send_message(
api_pb2.FunctionGetOutputsResponse(num_unfinished_inputs=1, lost_input_ids=[input_id])
)
return

if output_exc:
output = output_exc
else:
Expand Down
17 changes: 17 additions & 0 deletions test/function_retry_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import modal
from modal import App
from modal.exception import LostInputsError
from modal.retries import RetryManager
from modal_proto import api_pb2

Expand Down Expand Up @@ -56,7 +57,9 @@ def test_all_retries_fail_raises_error(client, setup_app_and_function, monkeypat
app, f = setup_app_and_function
with app.run(client=client):
with pytest.raises(FunctionCallCountException) as exc_info:
# The client should give up after the 4th call.
f.remote(5)
# Assert the function was called 4 times - the original call plus 3 retries
assert exc_info.value.function_call_count == 4


Expand Down Expand Up @@ -85,3 +88,17 @@ def test_retry_dealy_ms():

retry_policy = api_pb2.FunctionRetryPolicy(retries=2, backoff_coefficient=3, initial_delay_ms=2000)
assert RetryManager._retry_delay_ms(2, retry_policy) == 6000


def test_lost_inputs_retried(client, setup_app_and_function, monkeypatch, servicer):
monkeypatch.setenv("MODAL_CLIENT_RETRIES", "true")
app, f = setup_app_and_function
# This flag forces the fake server always report 1 lost input, and no successful outputs.
servicer.fail_get_outputs_with_lost_inputs = True
with app.run(client=client):
with pytest.raises(LostInputsError):
# The value we pass to the function doesn't matter. The call to GetOutputs will always fail due to lost
# inputs. We use function as a way to track how many times the function was retried.
f.remote(5)
# Assert the function was called 4 times - the original call plus 3 retries
assert function_call_count == 4
Loading