Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add text and line buffering mode to sandbox exec #2475

Merged
merged 22 commits into from
Nov 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 34 additions & 18 deletions modal/container_process.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright Modal Labs 2024
import asyncio
import platform
from typing import Optional
from typing import Generic, Optional, TypeVar

from modal_proto import api_pb2

Expand All @@ -13,12 +13,16 @@
from .io_streams import _StreamReader, _StreamWriter
from .stream_type import StreamType

T = TypeVar("T", str, bytes)

class _ContainerProcess:

class _ContainerProcess(Generic[T]):
_process_id: Optional[str] = None
_stdout: _StreamReader
_stderr: _StreamReader
_stdout: _StreamReader[T]
_stderr: _StreamReader[T]
_stdin: _StreamWriter
_text: bool
_by_line: bool
_returncode: Optional[int] = None

def __init__(
Expand All @@ -27,43 +31,55 @@ def __init__(
client: _Client,
stdout: StreamType = StreamType.PIPE,
stderr: StreamType = StreamType.PIPE,
text: bool = True,
by_line: bool = False,
) -> None:
self._process_id = process_id
self._client = client
self._stdout = _StreamReader(
api_pb2.FILE_DESCRIPTOR_STDOUT, process_id, "container_process", self._client, stream_type=stdout
self._text = text
self._by_line = by_line
self._stdout = _StreamReader[T](
api_pb2.FILE_DESCRIPTOR_STDOUT,
process_id,
"container_process",
self._client,
stream_type=stdout,
text=text,
by_line=by_line,
)
self._stderr = _StreamReader(
api_pb2.FILE_DESCRIPTOR_STDERR, process_id, "container_process", self._client, stream_type=stderr
self._stderr = _StreamReader[T](
api_pb2.FILE_DESCRIPTOR_STDERR,
process_id,
"container_process",
self._client,
stream_type=stderr,
text=text,
by_line=by_line,
)
self._stdin = _StreamWriter(process_id, "container_process", self._client)

@property
def stdout(self) -> _StreamReader:
"""`StreamReader` for the container process's stdout stream."""

def stdout(self) -> _StreamReader[T]:
"""StreamReader for the container process's stdout stream."""
return self._stdout

@property
def stderr(self) -> _StreamReader:
"""`StreamReader` for the container process's stderr stream."""

def stderr(self) -> _StreamReader[T]:
"""StreamReader for the container process's stderr stream."""
return self._stderr

@property
def stdin(self) -> _StreamWriter:
"""`StreamWriter` for the container process's stdin stream."""

"""StreamWriter for the container process's stdin stream."""
return self._stdin

@property
def returncode(self) -> _StreamWriter:
def returncode(self) -> int:
if self._returncode is None:
raise InvalidError(
"You must call wait() before accessing the returncode. "
"To poll for the status of a running process, use poll() instead."
)

return self._returncode

async def poll(self) -> Optional[int]:
Expand Down
79 changes: 49 additions & 30 deletions modal/io_streams.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright Modal Labs 2022
import asyncio
from typing import TYPE_CHECKING, AsyncGenerator, Literal, Optional, Tuple, Union
from typing import TYPE_CHECKING, AsyncGenerator, Generic, List, Literal, Optional, Tuple, TypeVar, Union

from grpclib import Status
from grpclib.exceptions import GRPCError, StreamTerminatedError
Expand All @@ -19,7 +19,7 @@

async def _sandbox_logs_iterator(
sandbox_id: str, file_descriptor: int, last_entry_id: Optional[str], client: _Client
) -> AsyncGenerator[Tuple[Optional[str], str], None]:
) -> AsyncGenerator[Tuple[Optional[bytes], str], None]:
req = api_pb2.SandboxGetLogsRequest(
sandbox_id=sandbox_id,
file_descriptor=file_descriptor,
Expand All @@ -30,29 +30,30 @@ async def _sandbox_logs_iterator(
last_entry_id = log_batch.entry_id

for message in log_batch.items:
yield (message.data, last_entry_id)
yield (message.data.encode("utf-8"), last_entry_id)
if log_batch.eof:
yield (None, last_entry_id)
break


async def _container_process_logs_iterator(
process_id: str, file_descriptor: int, client: _Client
) -> AsyncGenerator[Optional[str], None]:
) -> AsyncGenerator[Optional[bytes], None]:
req = api_pb2.ContainerExecGetOutputRequest(
exec_id=process_id,
timeout=55,
file_descriptor=file_descriptor,
exec_id=process_id, timeout=55, file_descriptor=file_descriptor, get_raw_bytes=True
)
async for batch in client.stub.ContainerExecGetOutput.unary_stream(req):
if batch.HasField("exit_code"):
yield None
break
for item in batch.items:
yield item.message
yield item.message_bytes


T = TypeVar("T", str, bytes)


class _StreamReader:
class _StreamReader(Generic[T]):
"""Provides an interface to buffer and fetch logs from a stream (`stdout` or `stderr`).

As an asynchronous iterable, the object supports the async for statement.
Expand Down Expand Up @@ -80,18 +81,30 @@ def __init__(
object_type: Literal["sandbox", "container_process"],
client: _Client,
stream_type: StreamType = StreamType.PIPE,
by_line: bool = False, # if True, streamed logs are further processed into complete lines.
text: bool = True,
by_line: bool = False,
) -> None:
"""mdmd:hidden"""

self._file_descriptor = file_descriptor
self._object_type = object_type
self._object_id = object_id
self._client = client
self._stream = None
self._last_entry_id = None
self._line_buffer = ""
self._last_entry_id: Optional[str] = None
self._line_buffer = b""

# Sandbox logs are streamed to the client as strings, so StreamReaders reading
# them must have text mode enabled.
if object_type == "sandbox" and not text:
raise ValueError("Sandbox streams must have text mode enabled.")

# line-buffering is only supported when text=True
if by_line and not text:
raise ValueError("line-buffering is only supported when text=True")

self._text = text
self._by_line = by_line

# Whether the reader received an EOF. Once EOF is True, it returns
# an empty string for any subsequent reads (including async for)
self.eof = False
Expand All @@ -109,14 +122,14 @@ def __init__(
# Container process streams need to be consumed as they are produced,
# otherwise the process will block. Use a buffer to store the stream
# until the client consumes it.
self._container_process_buffer = []
self._container_process_buffer: List[Optional[bytes]] = []
self._consume_container_process_task = asyncio.create_task(self._consume_container_process_stream())

@property
def file_descriptor(self):
return self._file_descriptor

async def read(self) -> str:
async def read(self) -> T:
"""Fetch and return contents of the entire stream. If EOF was received,
return an empty string.

Expand All @@ -132,13 +145,14 @@ async def read(self) -> str:
```

"""
data = ""
# TODO: maybe combine this with get_app_logs_loop
async for message in self._get_logs_by_line():
data = "" if self._text else b""
async for message in self._get_logs():
if message is None:
break
data += message

if self._text:
data += message.decode("utf-8")
else:
data += message
return data

async def _consume_container_process_stream(self):
Expand Down Expand Up @@ -177,7 +191,7 @@ async def _consume_container_process_stream(self):
break
raise exc

async def _stream_container_process(self) -> AsyncGenerator[Tuple[Optional[str], str], None]:
async def _stream_container_process(self) -> AsyncGenerator[Tuple[Optional[bytes], str], None]:
"""mdmd:hidden
Streams the container process buffer to the reader.
"""
Expand All @@ -198,7 +212,7 @@ async def _stream_container_process(self) -> AsyncGenerator[Tuple[Optional[str],

entry_id += 1

async def _get_logs(self) -> AsyncGenerator[Optional[str], None]:
async def _get_logs(self) -> AsyncGenerator[Optional[bytes], None]:
"""mdmd:hidden
Streams sandbox or process logs from the server to the reader.

Expand Down Expand Up @@ -244,23 +258,24 @@ async def _get_logs(self) -> AsyncGenerator[Optional[str], None]:
continue
raise

async def _get_logs_by_line(self) -> AsyncGenerator[Optional[str], None]:
async def _get_logs_by_line(self) -> AsyncGenerator[Optional[bytes], None]:
"""mdmd:hidden
Processes logs from the server and yields complete lines only.
"""
async for message in self._get_logs():
if message is None:
if self._line_buffer:
yield self._line_buffer
self._line_buffer = ""
self._line_buffer = b""
yield None
else:
assert isinstance(message, bytes)
self._line_buffer += message
while "\n" in self._line_buffer:
line, self._line_buffer = self._line_buffer.split("\n", 1)
yield line + "\n"
while b"\n" in self._line_buffer:
line, self._line_buffer = self._line_buffer.split(b"\n", 1)
yield line + b"\n"

def __aiter__(self):
def __aiter__(self) -> AsyncGenerator[T, None]:
"""mdmd:hidden"""
if not self._stream:
if self._by_line:
Expand All @@ -269,15 +284,19 @@ def __aiter__(self):
self._stream = self._get_logs()
return self

async def __anext__(self):
async def __anext__(self) -> T:
"""mdmd:hidden"""
assert self._stream is not None
value = await self._stream.__anext__()

# The stream yields None if it receives an EOF batch.
if value is None:
raise StopAsyncIteration

return value
if self._text:
return value.decode("utf-8")
else:
return value


MAX_BUFFER_SIZE = 2 * 1024 * 1024
Expand Down
Loading
Loading