Skip to content

Commit

Permalink
Configure read timeout based on wait parameter (#373)
Browse files Browse the repository at this point in the history
This fixes a bug in the new `wait` implementation where the default read
timeout for the HTTP client is shorter than the timeout on the server.
This results in the client erroring before the server has had the
opportunity to respond with a partial prediction.

This commit now provides a custom timeout for the `predictions.create`
request based on the `wait` parameter provided. We add a 500ms buffer to
the timeout to account for some discrepancy between server and client
timings.

I attempted to try and refactor the shared code between models,
deployments & predictions but gave up. We now have a single function
that creates the `headers` and `timeout` params and passes them in at
the various call sites.
  • Loading branch information
aron authored Oct 16, 2024
1 parent c59bb32 commit 7cfd984
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 65 deletions.
52 changes: 25 additions & 27 deletions replicate/deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from replicate.prediction import (
Prediction,
_create_prediction_body,
_create_prediction_headers,
_create_prediction_request_params,
_json_to_prediction,
)
from replicate.resource import Namespace, Resource
Expand Down Expand Up @@ -421,21 +421,25 @@ def create(
Create a new prediction with the deployment.
"""

wait = params.pop("wait", None)
file_encoding_strategy = params.pop("file_encoding_strategy", None)

if input is not None:
input = encode_json(
input,
client=self._client,
file_encoding_strategy=file_encoding_strategy,
)
headers = _create_prediction_headers(wait=params.pop("wait", None))
body = _create_prediction_body(version=None, input=input, **params)

body = _create_prediction_body(version=None, input=input, **params)
extras = _create_prediction_request_params(
wait=wait,
)
resp = self._client._request(
"POST",
f"/v1/deployments/{self._deployment.owner}/{self._deployment.name}/predictions",
json=body,
headers=headers,
**extras,
)

return _json_to_prediction(self._client, resp.json())
Expand All @@ -449,21 +453,24 @@ async def async_create(
Create a new prediction with the deployment.
"""

wait = params.pop("wait", None)
file_encoding_strategy = params.pop("file_encoding_strategy", None)
if input is not None:
input = await async_encode_json(
input,
client=self._client,
file_encoding_strategy=file_encoding_strategy,
)
headers = _create_prediction_headers(wait=params.pop("wait", None))
body = _create_prediction_body(version=None, input=input, **params)

body = _create_prediction_body(version=None, input=input, **params)
extras = _create_prediction_request_params(
wait=wait,
)
resp = await self._client._async_request(
"POST",
f"/v1/deployments/{self._deployment.owner}/{self._deployment.name}/predictions",
json=body,
headers=headers,
**extras,
)

return _json_to_prediction(self._client, resp.json())
Expand All @@ -484,24 +491,20 @@ def create(
Create a new prediction with the deployment.
"""

url = _create_prediction_url_from_deployment(deployment)

wait = params.pop("wait", None)
file_encoding_strategy = params.pop("file_encoding_strategy", None)

url = _create_prediction_url_from_deployment(deployment)
if input is not None:
input = encode_json(
input,
client=self._client,
file_encoding_strategy=file_encoding_strategy,
)
headers = _create_prediction_headers(wait=params.pop("wait", None))
body = _create_prediction_body(version=None, input=input, **params)

resp = self._client._request(
"POST",
url,
json=body,
headers=headers,
)
body = _create_prediction_body(version=None, input=input, **params)
extras = _create_prediction_request_params(wait=wait)
resp = self._client._request("POST", url, json=body, **extras)

return _json_to_prediction(self._client, resp.json())

Expand All @@ -515,25 +518,20 @@ async def async_create(
Create a new prediction with the deployment.
"""

url = _create_prediction_url_from_deployment(deployment)

wait = params.pop("wait", None)
file_encoding_strategy = params.pop("file_encoding_strategy", None)

url = _create_prediction_url_from_deployment(deployment)
if input is not None:
input = await async_encode_json(
input,
client=self._client,
file_encoding_strategy=file_encoding_strategy,
)

headers = _create_prediction_headers(wait=params.pop("wait", None))
body = _create_prediction_body(version=None, input=input, **params)

resp = await self._client._async_request(
"POST",
url,
json=body,
headers=headers,
)
extras = _create_prediction_request_params(wait=wait)
resp = await self._client._async_request("POST", url, json=body, **extras)

return _json_to_prediction(self._client, resp.json())

Expand Down
37 changes: 15 additions & 22 deletions replicate/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from replicate.prediction import (
Prediction,
_create_prediction_body,
_create_prediction_headers,
_create_prediction_request_params,
_json_to_prediction,
)
from replicate.resource import Namespace, Resource
Expand Down Expand Up @@ -389,24 +389,20 @@ def create(
Create a new prediction with the deployment.
"""

url = _create_prediction_url_from_model(model)

wait = params.pop("wait", None)
file_encoding_strategy = params.pop("file_encoding_strategy", None)

path = _create_prediction_path_from_model(model)
if input is not None:
input = encode_json(
input,
client=self._client,
file_encoding_strategy=file_encoding_strategy,
)
headers = _create_prediction_headers(wait=params.pop("wait", None))
body = _create_prediction_body(version=None, input=input, **params)

resp = self._client._request(
"POST",
url,
json=body,
headers=headers,
)
body = _create_prediction_body(version=None, input=input, **params)
extras = _create_prediction_request_params(wait=wait)
resp = self._client._request("POST", path, json=body, **extras)

return _json_to_prediction(self._client, resp.json())

Expand All @@ -420,24 +416,21 @@ async def async_create(
Create a new prediction with the deployment.
"""

url = _create_prediction_url_from_model(model)

wait = params.pop("wait", None)
file_encoding_strategy = params.pop("file_encoding_strategy", None)

path = _create_prediction_path_from_model(model)

if input is not None:
input = await async_encode_json(
input,
client=self._client,
file_encoding_strategy=file_encoding_strategy,
)
headers = _create_prediction_headers(wait=params.pop("wait", None))
body = _create_prediction_body(version=None, input=input, **params)

resp = await self._client._async_request(
"POST",
url,
json=body,
headers=headers,
)
body = _create_prediction_body(version=None, input=input, **params)
extras = _create_prediction_request_params(wait=wait)
resp = await self._client._async_request("POST", path, json=body, **extras)

return _json_to_prediction(self._client, resp.json())

Expand Down Expand Up @@ -522,7 +515,7 @@ def _json_to_model(client: "Client", json: Dict[str, Any]) -> Model:
return model


def _create_prediction_url_from_model(
def _create_prediction_path_from_model(
model: Union[str, Tuple[str, str], "Model"],
) -> str:
owner, name = None, None
Expand Down
62 changes: 46 additions & 16 deletions replicate/prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
overload,
)

import httpx
from typing_extensions import NotRequired, TypedDict, Unpack

from replicate.exceptions import ModelError, ReplicateError
Expand Down Expand Up @@ -446,6 +447,9 @@ def create( # type: ignore
Create a new prediction for the specified model, version, or deployment.
"""

wait = params.pop("wait", None)
file_encoding_strategy = params.pop("file_encoding_strategy", None)

if args:
version = args[0] if len(args) > 0 else None
input = args[1] if len(args) > 1 else input
Expand Down Expand Up @@ -477,26 +481,20 @@ def create( # type: ignore
**params,
)

file_encoding_strategy = params.pop("file_encoding_strategy", None)
if input is not None:
input = encode_json(
input,
client=self._client,
file_encoding_strategy=file_encoding_strategy,
)
headers = _create_prediction_headers(wait=params.pop("wait", None))

body = _create_prediction_body(
version,
input,
**params,
)

resp = self._client._request(
"POST",
"/v1/predictions",
headers=headers,
json=body,
)
extras = _create_prediction_request_params(wait=wait)
resp = self._client._request("POST", "/v1/predictions", json=body, **extras)

return _json_to_prediction(self._client, resp.json())

Expand Down Expand Up @@ -538,6 +536,8 @@ async def async_create( # type: ignore
"""
Create a new prediction for the specified model, version, or deployment.
"""
wait = params.pop("wait", None)
file_encoding_strategy = params.pop("file_encoding_strategy", None)

if args:
version = args[0] if len(args) > 0 else None
Expand Down Expand Up @@ -570,25 +570,21 @@ async def async_create( # type: ignore
**params,
)

file_encoding_strategy = params.pop("file_encoding_strategy", None)
if input is not None:
input = await async_encode_json(
input,
client=self._client,
file_encoding_strategy=file_encoding_strategy,
)
headers = _create_prediction_headers(wait=params.pop("wait", None))

body = _create_prediction_body(
version,
input,
**params,
)

extras = _create_prediction_request_params(wait=wait)
resp = await self._client._async_request(
"POST",
"/v1/predictions",
headers=headers,
json=body,
"POST", "/v1/predictions", json=body, **extras
)

return _json_to_prediction(self._client, resp.json())
Expand Down Expand Up @@ -628,6 +624,40 @@ async def async_cancel(self, id: str) -> Prediction:
return _json_to_prediction(self._client, resp.json())


class CreatePredictionRequestParams(TypedDict):
headers: NotRequired[Optional[dict]]
timeout: NotRequired[Optional[httpx.Timeout]]


def _create_prediction_request_params(
wait: Optional[Union[int, bool]],
) -> CreatePredictionRequestParams:
timeout = _create_prediction_timeout(wait=wait)
headers = _create_prediction_headers(wait=wait)

return {
"headers": headers,
"timeout": timeout,
}


def _create_prediction_timeout(
*, wait: Optional[Union[int, bool]] = None
) -> Union[httpx.Timeout, None]:
"""
Returns an `httpx.Timeout` instances appropriate for the optional
`Prefer: wait=x` header that can be provided with the request. This
will ensure that we give the server enough time to respond with
a partial prediction in the event that the request times out.
"""

if not wait:
return None

read_timeout = 60.0 if isinstance(wait, bool) else wait
return httpx.Timeout(5.0, read=read_timeout + 0.5)


def _create_prediction_headers(
*,
wait: Optional[Union[int, bool]] = None,
Expand Down

0 comments on commit 7cfd984

Please sign in to comment.