Skip to content

Commit

Permalink
Merge branch 'main' into azliu/file-watching
Browse files Browse the repository at this point in the history
  • Loading branch information
azliu0 committed Dec 13, 2024
2 parents 0be3a8c + eb093de commit bb07b51
Show file tree
Hide file tree
Showing 17 changed files with 248 additions and 153 deletions.
16 changes: 3 additions & 13 deletions modal/_ipython.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,11 @@
# Copyright Modal Labs 2022
import sys
import warnings

ipy_outstream = None
try:
with warnings.catch_warnings():
warnings.simplefilter("ignore")
import ipykernel.iostream

ipy_outstream = ipykernel.iostream.OutStream
except ImportError:
pass


def is_notebook(stdout=None):
if ipy_outstream is None:
ipykernel_iostream = sys.modules.get("ipykernel.iostream")
if ipykernel_iostream is None:
return False
if stdout is None:
stdout = sys.stdout
return isinstance(stdout, ipy_outstream)
return isinstance(stdout, ipykernel_iostream.OutStream)
4 changes: 4 additions & 0 deletions modal/_runtime/asgi.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
# Copyright Modal Labs 2022

# Note: this module isn't imported unless it's needed.
# This is because aiohttp is a pretty big dependency that adds significant latency when imported

import asyncio
from collections.abc import AsyncGenerator
from typing import Any, Callable, NoReturn, Optional, cast
Expand Down
31 changes: 13 additions & 18 deletions modal/_runtime/user_code_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,6 @@
import modal.cls
import modal.object
from modal import Function
from modal._runtime.asgi import (
LifespanManager,
asgi_app_wrapper,
get_ip_address,
wait_for_web_server,
web_server_proxy,
webhook_asgi_app,
wsgi_app_wrapper,
)
from modal._utils.async_utils import synchronizer
from modal._utils.function_utils import LocalFunctionError, is_async as get_is_async, is_global_object
from modal.exception import ExecutionError, InvalidError
Expand All @@ -28,6 +19,7 @@
if typing.TYPE_CHECKING:
import modal.app
import modal.partial_function
from modal._runtime.asgi import LifespanManager


@dataclass
Expand All @@ -36,7 +28,7 @@ class FinalizedFunction:
is_async: bool
is_generator: bool
data_format: int # api_pb2.DataFormat
lifespan_manager: Optional[LifespanManager] = None
lifespan_manager: Optional["LifespanManager"] = None


class Service(metaclass=ABCMeta):
Expand All @@ -63,19 +55,22 @@ def construct_webhook_callable(
webhook_config: api_pb2.WebhookConfig,
container_io_manager: "modal._runtime.container_io_manager.ContainerIOManager",
):
# Note: aiohttp is a significant dependency of the `asgi` module, so we import it locally
from modal._runtime import asgi

# For webhooks, the user function is used to construct an asgi app:
if webhook_config.type == api_pb2.WEBHOOK_TYPE_ASGI_APP:
# Function returns an asgi_app, which we can use as a callable.
return asgi_app_wrapper(user_defined_callable(), container_io_manager)
return asgi.asgi_app_wrapper(user_defined_callable(), container_io_manager)

elif webhook_config.type == api_pb2.WEBHOOK_TYPE_WSGI_APP:
# Function returns an wsgi_app, which we can use as a callable.
return wsgi_app_wrapper(user_defined_callable(), container_io_manager)
# Function returns an wsgi_app, which we can use as a callable
return asgi.wsgi_app_wrapper(user_defined_callable(), container_io_manager)

elif webhook_config.type == api_pb2.WEBHOOK_TYPE_FUNCTION:
# Function is a webhook without an ASGI app. Create one for it.
return asgi_app_wrapper(
webhook_asgi_app(user_defined_callable, webhook_config.method, webhook_config.web_endpoint_docs),
return asgi.asgi_app_wrapper(
asgi.webhook_asgi_app(user_defined_callable, webhook_config.method, webhook_config.web_endpoint_docs),
container_io_manager,
)

Expand All @@ -86,11 +81,11 @@ def construct_webhook_callable(
# We intentionally try to connect to the external interface instead of the loopback
# interface here so users are forced to expose the server. This allows us to potentially
# change the implementation to use an external bridge in the future.
host = get_ip_address(b"eth0")
host = asgi.get_ip_address(b"eth0")
port = webhook_config.web_server_port
startup_timeout = webhook_config.web_server_startup_timeout
wait_for_web_server(host, port, timeout=startup_timeout)
return asgi_app_wrapper(web_server_proxy(host, port), container_io_manager)
asgi.wait_for_web_server(host, port, timeout=startup_timeout)
return asgi.asgi_app_wrapper(asgi.web_server_proxy(host, port), container_io_manager)
else:
raise InvalidError(f"Unrecognized web endpoint type {webhook_config.type}")

Expand Down
114 changes: 23 additions & 91 deletions modal/_utils/blob_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,22 @@
from collections.abc import AsyncIterator
from contextlib import AbstractContextManager, contextmanager
from pathlib import Path, PurePosixPath
from typing import Any, BinaryIO, Callable, Optional, Union
from typing import TYPE_CHECKING, Any, BinaryIO, Callable, Optional, Union
from urllib.parse import urlparse

from aiohttp import BytesIOPayload
from aiohttp.abc import AbstractStreamWriter

from modal_proto import api_pb2
from modal_proto.modal_api_grpc import ModalClientModal

from ..exception import ExecutionError
from .async_utils import TaskContext, retry
from .grpc_utils import retry_transient_errors
from .hash_utils import UploadHashes, get_sha256_hex, get_upload_hashes
from .hash_utils import UploadHashes, get_upload_hashes
from .http_utils import ClientSessionRegistry
from .logger import logger

if TYPE_CHECKING:
from .bytes_io_segment_payload import BytesIOSegmentPayload

# Max size for function inputs and outputs.
MAX_OBJECT_SIZE_BYTES = 2 * 1024 * 1024 # 2 MiB

Expand All @@ -38,93 +38,16 @@
# read ~16MiB chunks by default
DEFAULT_SEGMENT_CHUNK_SIZE = 2**24


class BytesIOSegmentPayload(BytesIOPayload):
"""Modified bytes payload for concurrent sends of chunks from the same file.
Adds:
* read limit using remaining_bytes, in order to split files across streams
* larger read chunk (to prevent excessive read contention between parts)
* calculates an md5 for the segment
Feels like this should be in some standard lib...
"""

def __init__(
self,
bytes_io: BinaryIO, # should *not* be shared as IO position modification is not locked
segment_start: int,
segment_length: int,
chunk_size: int = DEFAULT_SEGMENT_CHUNK_SIZE,
progress_report_cb: Optional[Callable] = None,
):
# not thread safe constructor!
super().__init__(bytes_io)
self.initial_seek_pos = bytes_io.tell()
self.segment_start = segment_start
self.segment_length = segment_length
# seek to start of file segment we are interested in, in order to make .size() evaluate correctly
self._value.seek(self.initial_seek_pos + segment_start)
assert self.segment_length <= super().size
self.chunk_size = chunk_size
self.progress_report_cb = progress_report_cb or (lambda *_, **__: None)
self.reset_state()

def reset_state(self):
self._md5_checksum = hashlib.md5()
self.num_bytes_read = 0
self._value.seek(self.initial_seek_pos)

@contextmanager
def reset_on_error(self):
try:
yield
except Exception as exc:
try:
self.progress_report_cb(reset=True)
except Exception as cb_exc:
raise cb_exc from exc
raise exc
finally:
self.reset_state()

@property
def size(self) -> int:
return self.segment_length

def md5_checksum(self):
return self._md5_checksum

async def write(self, writer: AbstractStreamWriter):
loop = asyncio.get_event_loop()

async def safe_read():
read_start = self.initial_seek_pos + self.segment_start + self.num_bytes_read
self._value.seek(read_start)
num_bytes = min(self.chunk_size, self.remaining_bytes())
chunk = await loop.run_in_executor(None, self._value.read, num_bytes)

await loop.run_in_executor(None, self._md5_checksum.update, chunk)
self.num_bytes_read += len(chunk)
return chunk

chunk = await safe_read()
while chunk and self.remaining_bytes() > 0:
await writer.write(chunk)
self.progress_report_cb(advance=len(chunk))
chunk = await safe_read()
if chunk:
await writer.write(chunk)
self.progress_report_cb(advance=len(chunk))

def remaining_bytes(self):
return self.segment_length - self.num_bytes_read
# Files larger than this will be multipart uploaded. The server might request multipart upload for smaller files as
# well, but the limit will never be raised.
# TODO(dano): remove this once we stop requiring md5 for blobs
MULTIPART_UPLOAD_THRESHOLD = 1024**3


@retry(n_attempts=5, base_delay=0.5, timeout=None)
async def _upload_to_s3_url(
upload_url,
payload: BytesIOSegmentPayload,
payload: "BytesIOSegmentPayload",
content_md5_b64: Optional[str] = None,
content_type: Optional[str] = "application/octet-stream", # set to None to force omission of ContentType header
) -> str:
Expand Down Expand Up @@ -180,6 +103,8 @@ async def perform_multipart_upload(
upload_chunk_size: int = DEFAULT_SEGMENT_CHUNK_SIZE,
progress_report_cb: Optional[Callable] = None,
) -> None:
from .bytes_io_segment_payload import BytesIOSegmentPayload

upload_coros = []
file_offset = 0
num_bytes_left = content_length
Expand Down Expand Up @@ -273,6 +198,8 @@ async def _blob_upload(
progress_report_cb=progress_report_cb,
)
else:
from .bytes_io_segment_payload import BytesIOSegmentPayload

payload = BytesIOSegmentPayload(
data, segment_start=0, segment_length=content_length, progress_report_cb=progress_report_cb
)
Expand Down Expand Up @@ -309,8 +236,9 @@ async def blob_upload_file(
stub: ModalClientModal,
progress_report_cb: Optional[Callable] = None,
sha256_hex: Optional[str] = None,
md5_hex: Optional[str] = None,
) -> str:
upload_hashes = get_upload_hashes(file_obj, sha256_hex=sha256_hex)
upload_hashes = get_upload_hashes(file_obj, sha256_hex=sha256_hex, md5_hex=md5_hex)
return await _blob_upload(upload_hashes, file_obj, stub, progress_report_cb)


Expand Down Expand Up @@ -369,6 +297,7 @@ class FileUploadSpec:
use_blob: bool
content: Optional[bytes] # typically None if using blob, required otherwise
sha256_hex: str
md5_hex: str
mode: int # file permission bits (last 12 bits of st_mode)
size: int

Expand All @@ -386,21 +315,24 @@ def _get_file_upload_spec(
fp.seek(0)

if size >= LARGE_FILE_LIMIT:
# TODO(dano): remove the placeholder md5 once we stop requiring md5 for blobs
md5_hex = "baadbaadbaadbaadbaadbaadbaadbaad" if size > MULTIPART_UPLOAD_THRESHOLD else None
use_blob = True
content = None
sha256_hex = get_sha256_hex(fp)
hashes = get_upload_hashes(fp, md5_hex=md5_hex)
else:
use_blob = False
content = fp.read()
sha256_hex = get_sha256_hex(content)
hashes = get_upload_hashes(content)

return FileUploadSpec(
source=source,
source_description=source_description,
mount_filename=mount_filename.as_posix(),
use_blob=use_blob,
content=content,
sha256_hex=sha256_hex,
sha256_hex=hashes.sha256_hex(),
md5_hex=hashes.md5_hex(),
mode=mode & 0o7777,
size=size,
)
Expand Down
97 changes: 97 additions & 0 deletions modal/_utils/bytes_io_segment_payload.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# Copyright Modal Labs 2024

import asyncio
import hashlib
from contextlib import contextmanager
from typing import BinaryIO, Callable, Optional

# Note: this module needs to import aiohttp in global scope
# This takes about 50ms and isn't needed in many cases for Modal execution
# To avoid this, we import it in local scope when needed (blob_utils.py)
from aiohttp import BytesIOPayload
from aiohttp.abc import AbstractStreamWriter

# read ~16MiB chunks by default
DEFAULT_SEGMENT_CHUNK_SIZE = 2**24


class BytesIOSegmentPayload(BytesIOPayload):
"""Modified bytes payload for concurrent sends of chunks from the same file.
Adds:
* read limit using remaining_bytes, in order to split files across streams
* larger read chunk (to prevent excessive read contention between parts)
* calculates an md5 for the segment
Feels like this should be in some standard lib...
"""

def __init__(
self,
bytes_io: BinaryIO, # should *not* be shared as IO position modification is not locked
segment_start: int,
segment_length: int,
chunk_size: int = DEFAULT_SEGMENT_CHUNK_SIZE,
progress_report_cb: Optional[Callable] = None,
):
# not thread safe constructor!
super().__init__(bytes_io)
self.initial_seek_pos = bytes_io.tell()
self.segment_start = segment_start
self.segment_length = segment_length
# seek to start of file segment we are interested in, in order to make .size() evaluate correctly
self._value.seek(self.initial_seek_pos + segment_start)
assert self.segment_length <= super().size
self.chunk_size = chunk_size
self.progress_report_cb = progress_report_cb or (lambda *_, **__: None)
self.reset_state()

def reset_state(self):
self._md5_checksum = hashlib.md5()
self.num_bytes_read = 0
self._value.seek(self.initial_seek_pos)

@contextmanager
def reset_on_error(self):
try:
yield
except Exception as exc:
try:
self.progress_report_cb(reset=True)
except Exception as cb_exc:
raise cb_exc from exc
raise exc
finally:
self.reset_state()

@property
def size(self) -> int:
return self.segment_length

def md5_checksum(self):
return self._md5_checksum

async def write(self, writer: "AbstractStreamWriter"):
loop = asyncio.get_event_loop()

async def safe_read():
read_start = self.initial_seek_pos + self.segment_start + self.num_bytes_read
self._value.seek(read_start)
num_bytes = min(self.chunk_size, self.remaining_bytes())
chunk = await loop.run_in_executor(None, self._value.read, num_bytes)

await loop.run_in_executor(None, self._md5_checksum.update, chunk)
self.num_bytes_read += len(chunk)
return chunk

chunk = await safe_read()
while chunk and self.remaining_bytes() > 0:
await writer.write(chunk)
self.progress_report_cb(advance=len(chunk))
chunk = await safe_read()
if chunk:
await writer.write(chunk)
self.progress_report_cb(advance=len(chunk))

def remaining_bytes(self):
return self.segment_length - self.num_bytes_read
Loading

0 comments on commit bb07b51

Please sign in to comment.