Skip to content

Commit

Permalink
Add text and line buffering mode to sandbox exec (#2475)
Browse files Browse the repository at this point in the history
* add line buffering and text mode to exec

* save progress

* address comments

* fix tests

* adjust text

* fix type

* pass type check

* small wording clarification

* no need for backwards compat here

* fix tests

* pass 3.9

* simplify branching logic

* make containerprocess generic

* fix sync

* fix type signature

* fix the overload spec

* remove unnecessary diff

* by_lines return bytes

* add tests

* iterator
  • Loading branch information
azliu0 authored Nov 19, 2024
1 parent 5c7239c commit ff1530a
Show file tree
Hide file tree
Showing 8 changed files with 220 additions and 67 deletions.
52 changes: 34 additions & 18 deletions modal/container_process.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright Modal Labs 2024
import asyncio
import platform
from typing import Optional
from typing import Generic, Optional, TypeVar

from modal_proto import api_pb2

Expand All @@ -13,12 +13,16 @@
from .io_streams import _StreamReader, _StreamWriter
from .stream_type import StreamType

T = TypeVar("T", str, bytes)

class _ContainerProcess:

class _ContainerProcess(Generic[T]):
_process_id: Optional[str] = None
_stdout: _StreamReader
_stderr: _StreamReader
_stdout: _StreamReader[T]
_stderr: _StreamReader[T]
_stdin: _StreamWriter
_text: bool
_by_line: bool
_returncode: Optional[int] = None

def __init__(
Expand All @@ -27,43 +31,55 @@ def __init__(
client: _Client,
stdout: StreamType = StreamType.PIPE,
stderr: StreamType = StreamType.PIPE,
text: bool = True,
by_line: bool = False,
) -> None:
self._process_id = process_id
self._client = client
self._stdout = _StreamReader(
api_pb2.FILE_DESCRIPTOR_STDOUT, process_id, "container_process", self._client, stream_type=stdout
self._text = text
self._by_line = by_line
self._stdout = _StreamReader[T](
api_pb2.FILE_DESCRIPTOR_STDOUT,
process_id,
"container_process",
self._client,
stream_type=stdout,
text=text,
by_line=by_line,
)
self._stderr = _StreamReader(
api_pb2.FILE_DESCRIPTOR_STDERR, process_id, "container_process", self._client, stream_type=stderr
self._stderr = _StreamReader[T](
api_pb2.FILE_DESCRIPTOR_STDERR,
process_id,
"container_process",
self._client,
stream_type=stderr,
text=text,
by_line=by_line,
)
self._stdin = _StreamWriter(process_id, "container_process", self._client)

@property
def stdout(self) -> _StreamReader:
"""`StreamReader` for the container process's stdout stream."""

def stdout(self) -> _StreamReader[T]:
"""StreamReader for the container process's stdout stream."""
return self._stdout

@property
def stderr(self) -> _StreamReader:
"""`StreamReader` for the container process's stderr stream."""

def stderr(self) -> _StreamReader[T]:
"""StreamReader for the container process's stderr stream."""
return self._stderr

@property
def stdin(self) -> _StreamWriter:
"""`StreamWriter` for the container process's stdin stream."""

"""StreamWriter for the container process's stdin stream."""
return self._stdin

@property
def returncode(self) -> _StreamWriter:
def returncode(self) -> int:
if self._returncode is None:
raise InvalidError(
"You must call wait() before accessing the returncode. "
"To poll for the status of a running process, use poll() instead."
)

return self._returncode

async def poll(self) -> Optional[int]:
Expand Down
79 changes: 49 additions & 30 deletions modal/io_streams.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright Modal Labs 2022
import asyncio
from typing import TYPE_CHECKING, AsyncGenerator, Literal, Optional, Tuple, Union
from typing import TYPE_CHECKING, AsyncGenerator, Generic, List, Literal, Optional, Tuple, TypeVar, Union

from grpclib import Status
from grpclib.exceptions import GRPCError, StreamTerminatedError
Expand All @@ -19,7 +19,7 @@

async def _sandbox_logs_iterator(
sandbox_id: str, file_descriptor: int, last_entry_id: Optional[str], client: _Client
) -> AsyncGenerator[Tuple[Optional[str], str], None]:
) -> AsyncGenerator[Tuple[Optional[bytes], str], None]:
req = api_pb2.SandboxGetLogsRequest(
sandbox_id=sandbox_id,
file_descriptor=file_descriptor,
Expand All @@ -30,29 +30,30 @@ async def _sandbox_logs_iterator(
last_entry_id = log_batch.entry_id

for message in log_batch.items:
yield (message.data, last_entry_id)
yield (message.data.encode("utf-8"), last_entry_id)
if log_batch.eof:
yield (None, last_entry_id)
break


async def _container_process_logs_iterator(
process_id: str, file_descriptor: int, client: _Client
) -> AsyncGenerator[Optional[str], None]:
) -> AsyncGenerator[Optional[bytes], None]:
req = api_pb2.ContainerExecGetOutputRequest(
exec_id=process_id,
timeout=55,
file_descriptor=file_descriptor,
exec_id=process_id, timeout=55, file_descriptor=file_descriptor, get_raw_bytes=True
)
async for batch in client.stub.ContainerExecGetOutput.unary_stream(req):
if batch.HasField("exit_code"):
yield None
break
for item in batch.items:
yield item.message
yield item.message_bytes


T = TypeVar("T", str, bytes)


class _StreamReader:
class _StreamReader(Generic[T]):
"""Provides an interface to buffer and fetch logs from a stream (`stdout` or `stderr`).
As an asynchronous iterable, the object supports the async for statement.
Expand Down Expand Up @@ -80,18 +81,30 @@ def __init__(
object_type: Literal["sandbox", "container_process"],
client: _Client,
stream_type: StreamType = StreamType.PIPE,
by_line: bool = False, # if True, streamed logs are further processed into complete lines.
text: bool = True,
by_line: bool = False,
) -> None:
"""mdmd:hidden"""

self._file_descriptor = file_descriptor
self._object_type = object_type
self._object_id = object_id
self._client = client
self._stream = None
self._last_entry_id = None
self._line_buffer = ""
self._last_entry_id: Optional[str] = None
self._line_buffer = b""

# Sandbox logs are streamed to the client as strings, so StreamReaders reading
# them must have text mode enabled.
if object_type == "sandbox" and not text:
raise ValueError("Sandbox streams must have text mode enabled.")

# line-buffering is only supported when text=True
if by_line and not text:
raise ValueError("line-buffering is only supported when text=True")

self._text = text
self._by_line = by_line

# Whether the reader received an EOF. Once EOF is True, it returns
# an empty string for any subsequent reads (including async for)
self.eof = False
Expand All @@ -109,14 +122,14 @@ def __init__(
# Container process streams need to be consumed as they are produced,
# otherwise the process will block. Use a buffer to store the stream
# until the client consumes it.
self._container_process_buffer = []
self._container_process_buffer: List[Optional[bytes]] = []
self._consume_container_process_task = asyncio.create_task(self._consume_container_process_stream())

@property
def file_descriptor(self):
return self._file_descriptor

async def read(self) -> str:
async def read(self) -> T:
"""Fetch and return contents of the entire stream. If EOF was received,
return an empty string.
Expand All @@ -132,13 +145,14 @@ async def read(self) -> str:
```
"""
data = ""
# TODO: maybe combine this with get_app_logs_loop
async for message in self._get_logs_by_line():
data = "" if self._text else b""
async for message in self._get_logs():
if message is None:
break
data += message

if self._text:
data += message.decode("utf-8")
else:
data += message
return data

async def _consume_container_process_stream(self):
Expand Down Expand Up @@ -177,7 +191,7 @@ async def _consume_container_process_stream(self):
break
raise exc

async def _stream_container_process(self) -> AsyncGenerator[Tuple[Optional[str], str], None]:
async def _stream_container_process(self) -> AsyncGenerator[Tuple[Optional[bytes], str], None]:
"""mdmd:hidden
Streams the container process buffer to the reader.
"""
Expand All @@ -198,7 +212,7 @@ async def _stream_container_process(self) -> AsyncGenerator[Tuple[Optional[str],

entry_id += 1

async def _get_logs(self) -> AsyncGenerator[Optional[str], None]:
async def _get_logs(self) -> AsyncGenerator[Optional[bytes], None]:
"""mdmd:hidden
Streams sandbox or process logs from the server to the reader.
Expand Down Expand Up @@ -244,23 +258,24 @@ async def _get_logs(self) -> AsyncGenerator[Optional[str], None]:
continue
raise

async def _get_logs_by_line(self) -> AsyncGenerator[Optional[str], None]:
async def _get_logs_by_line(self) -> AsyncGenerator[Optional[bytes], None]:
"""mdmd:hidden
Processes logs from the server and yields complete lines only.
"""
async for message in self._get_logs():
if message is None:
if self._line_buffer:
yield self._line_buffer
self._line_buffer = ""
self._line_buffer = b""
yield None
else:
assert isinstance(message, bytes)
self._line_buffer += message
while "\n" in self._line_buffer:
line, self._line_buffer = self._line_buffer.split("\n", 1)
yield line + "\n"
while b"\n" in self._line_buffer:
line, self._line_buffer = self._line_buffer.split(b"\n", 1)
yield line + b"\n"

def __aiter__(self):
def __aiter__(self) -> AsyncGenerator[T, None]:
"""mdmd:hidden"""
if not self._stream:
if self._by_line:
Expand All @@ -269,15 +284,19 @@ def __aiter__(self):
self._stream = self._get_logs()
return self

async def __anext__(self):
async def __anext__(self) -> T:
"""mdmd:hidden"""
assert self._stream is not None
value = await self._stream.__anext__()

# The stream yields None if it receives an EOF batch.
if value is None:
raise StopAsyncIteration

return value
if self._text:
return value.decode("utf-8")
else:
return value


MAX_BUFFER_SIZE = 2 * 1024 * 1024
Expand Down
Loading

0 comments on commit ff1530a

Please sign in to comment.