Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add text and line buffering mode to sandbox exec #2475

Merged
merged 22 commits into from
Nov 19, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 24 additions & 5 deletions modal/container_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ class _ContainerProcess:
_stdout: _StreamReader
_stderr: _StreamReader
_stdin: _StreamWriter
_text: bool
_by_line: bool
_returncode: Optional[int] = None

def __init__(
Expand All @@ -27,14 +29,31 @@ def __init__(
client: _Client,
stdout: StreamType = StreamType.PIPE,
stderr: StreamType = StreamType.PIPE,
text: bool = True,
by_line: bool = False,
) -> None:
self._process_id = process_id
self._client = client
self._stdout = _StreamReader(
api_pb2.FILE_DESCRIPTOR_STDOUT, process_id, "container_process", self._client, stream_type=stdout
self._text = text
self._by_line = by_line
T = str if text else bytes
azliu0 marked this conversation as resolved.
Show resolved Hide resolved
self._stdout: _StreamReader[T] = _StreamReader(
api_pb2.FILE_DESCRIPTOR_STDOUT,
process_id,
"container_process",
self._client,
stream_type=stdout,
text=text,
by_line=by_line,
)
self._stderr = _StreamReader(
api_pb2.FILE_DESCRIPTOR_STDERR, process_id, "container_process", self._client, stream_type=stderr
self._stderr: _StreamReader[T] = _StreamReader(
api_pb2.FILE_DESCRIPTOR_STDERR,
process_id,
"container_process",
self._client,
stream_type=stderr,
text=text,
by_line=by_line,
)
self._stdin = _StreamWriter(process_id, "container_process", self._client)

Expand All @@ -57,7 +76,7 @@ def stdin(self) -> _StreamWriter:
return self._stdin

@property
def returncode(self) -> _StreamWriter:
def returncode(self) -> int:
if self._returncode is None:
raise InvalidError(
"You must call wait() before accessing the returncode. "
Expand Down
91 changes: 60 additions & 31 deletions modal/io_streams.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright Modal Labs 2022
import asyncio
from typing import TYPE_CHECKING, AsyncGenerator, Literal, Optional, Tuple, Union
from typing import TYPE_CHECKING, AsyncGenerator, Generic, List, Literal, Optional, Tuple, TypeVar, Union

from grpclib import Status
from grpclib.exceptions import GRPCError, StreamTerminatedError
Expand All @@ -18,7 +18,7 @@


async def _sandbox_logs_iterator(
sandbox_id: str, file_descriptor: int, last_entry_id: Optional[str], client: _Client
sandbox_id: str, file_descriptor: api_pb2.FileDescriptor.ValueType, last_entry_id: Optional[str], client: _Client
) -> AsyncGenerator[Tuple[Optional[str], str], None]:
req = api_pb2.SandboxGetLogsRequest(
sandbox_id=sandbox_id,
Expand All @@ -37,22 +37,23 @@ async def _sandbox_logs_iterator(


async def _container_process_logs_iterator(
process_id: str, file_descriptor: int, client: _Client
) -> AsyncGenerator[Optional[str], None]:
process_id: str, file_descriptor: api_pb2.FileDescriptor.ValueType, client: _Client
) -> AsyncGenerator[Optional[bytes], None]:
req = api_pb2.ContainerExecGetOutputRequest(
exec_id=process_id,
timeout=55,
file_descriptor=file_descriptor,
exec_id=process_id, timeout=55, file_descriptor=file_descriptor, get_raw_bytes=True
)
async for batch in client.stub.ContainerExecGetOutput.unary_stream(req):
if batch.HasField("exit_code"):
yield None
break
for item in batch.items:
yield item.message
yield item.message_bytes


T = TypeVar("T", str, bytes)


class _StreamReader:
class _StreamReader(Generic[T]):
"""Provides an interface to buffer and fetch logs from a stream (`stdout` or `stderr`).

As an asynchronous iterable, the object supports the async for statement.
Expand All @@ -75,22 +76,24 @@ class _StreamReader:

def __init__(
self,
file_descriptor: int,
file_descriptor: api_pb2.FileDescriptor.ValueType,
object_id: str,
object_type: Literal["sandbox", "container_process"],
client: _Client,
stream_type: StreamType = StreamType.PIPE,
by_line: bool = False, # if True, streamed logs are further processed into complete lines.
text: bool = True,
by_line: bool = False,
) -> None:
"""mdmd:hidden"""

self._file_descriptor = file_descriptor
self._object_type = object_type
self._object_id = object_id
self._client = client
self._stream = None
self._last_entry_id = None
self._line_buffer = ""
self._last_entry_id: Optional[str] = None
self._string_line_buffer: str = ""
self._bytes_line_buffer: bytes = b""
self._text = text
self._by_line = by_line
# Whether the reader received an EOF. Once EOF is True, it returns
# an empty string for any subsequent reads (including async for)
Expand All @@ -109,14 +112,14 @@ def __init__(
# Container process streams need to be consumed as they are produced,
# otherwise the process will block. Use a buffer to store the stream
# until the client consumes it.
self._container_process_buffer = []
self._container_process_buffer: List[Optional[bytes]] = []
self._consume_container_process_task = asyncio.create_task(self._consume_container_process_stream())

@property
def file_descriptor(self):
return self._file_descriptor

async def read(self) -> str:
async def read(self) -> T:
"""Fetch and return contents of the entire stream. If EOF was received,
return an empty string.

Expand All @@ -132,8 +135,11 @@ async def read(self) -> str:
```

"""
data = ""
# TODO: maybe combine this with get_app_logs_loop
if self._text:
data = ""
else:
data = b""

async for message in self._get_logs_by_line():
if message is None:
break
Expand Down Expand Up @@ -177,7 +183,7 @@ async def _consume_container_process_stream(self):
break
raise exc

async def _stream_container_process(self) -> AsyncGenerator[Tuple[Optional[str], str], None]:
async def _stream_container_process(self) -> AsyncGenerator[Tuple[Optional[bytes], str], None]:
"""mdmd:hidden
Streams the container process buffer to the reader.
"""
Expand All @@ -198,7 +204,7 @@ async def _stream_container_process(self) -> AsyncGenerator[Tuple[Optional[str],

entry_id += 1

async def _get_logs(self) -> AsyncGenerator[Optional[str], None]:
async def _get_logs(self) -> AsyncGenerator[Optional[Union[bytes, str]], None]:
"""mdmd:hidden
Streams sandbox or process logs from the server to the reader.

Expand All @@ -220,7 +226,7 @@ async def _get_logs(self) -> AsyncGenerator[Optional[str], None]:
while not completed:
try:
if self._object_type == "sandbox":
iterator = _sandbox_logs_iterator(
iterator: AsyncGenerator[Tuple[Optional[Union[T, str]], str], None] = _sandbox_logs_iterator(
azliu0 marked this conversation as resolved.
Show resolved Hide resolved
self._object_id, self._file_descriptor, self._last_entry_id, self._client
)
else:
Expand All @@ -244,21 +250,35 @@ async def _get_logs(self) -> AsyncGenerator[Optional[str], None]:
continue
raise

async def _get_logs_by_line(self) -> AsyncGenerator[Optional[str], None]:
async def _get_logs_by_line(self) -> AsyncGenerator[Optional[Union[bytes, str]], None]:
"""mdmd:hidden
Processes logs from the server and yields complete lines only.
"""
async for message in self._get_logs():
if message is None:
if self._line_buffer:
yield self._line_buffer
self._line_buffer = ""
yield None
if self._object_type == "sandbox":
azliu0 marked this conversation as resolved.
Show resolved Hide resolved
if self._string_line_buffer:
yield self._string_line_buffer
self._string_line_buffer = ""
yield None
else:
if self._bytes_line_buffer:
yield self._bytes_line_buffer
self._bytes_line_buffer = b""
yield None
else:
self._line_buffer += message
while "\n" in self._line_buffer:
line, self._line_buffer = self._line_buffer.split("\n", 1)
yield line + "\n"
if self._object_type == "sandbox":
assert isinstance(message, str)
self._string_line_buffer += message
while "\n" in self._string_line_buffer:
line, self._string_line_buffer = self._string_line_buffer.split("\n", 1)
yield line + "\n"
else:
assert isinstance(message, bytes)
self._bytes_line_buffer += message
while b"\n" in self._bytes_line_buffer:
line, self._bytes_line_buffer = self._bytes_line_buffer.split(b"\n", 1)
yield line + b"\n"

def __aiter__(self):
"""mdmd:hidden"""
Expand All @@ -271,13 +291,22 @@ def __aiter__(self):

async def __anext__(self):
"""mdmd:hidden"""
assert self._stream is not None
value = await self._stream.__anext__()

# The stream yields None if it receives an EOF batch.
if value is None:
raise StopAsyncIteration

return value
# Sandbox logs are strings by default
if isinstance(value, str):
return value

# Container process logs are bytes by default
if self._text:
return value.decode("utf-8")
else:
return value


MAX_BUFFER_SIZE = 2 * 1024 * 1024
Expand Down
12 changes: 8 additions & 4 deletions modal/sandbox.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright Modal Labs 2022
import asyncio
import os
from typing import TYPE_CHECKING, AsyncGenerator, Dict, List, Optional, Sequence, Tuple, Union
from typing import TYPE_CHECKING, AsyncGenerator, Dict, List, Literal, Optional, Sequence, Tuple, Union

from google.protobuf.message import Message
from grpclib import GRPCError, Status
Expand Down Expand Up @@ -403,13 +403,16 @@ async def _get_task_id(self):
async def exec(
self,
*cmds: str,
# Deprecated: internal use only
pty_info: Optional[api_pb2.PTYInfo] = None,
pty_info: Optional[api_pb2.PTYInfo] = None, # Deprecated: internal use only
stdout: StreamType = StreamType.PIPE,
stderr: StreamType = StreamType.PIPE,
timeout: Optional[int] = None,
workdir: Optional[str] = None,
secrets: Sequence[_Secret] = (),
# Encode output as text.
azliu0 marked this conversation as resolved.
Show resolved Hide resolved
text: bool = True,
# Control line-buffering output. The default differs from subprocess.Popen for backwards compatibility.
bufsize: Literal[-1, 1] = 1,
# Internal option to set terminal size and metadata
_pty_info: Optional[api_pb2.PTYInfo] = None,
):
Expand Down Expand Up @@ -449,7 +452,8 @@ async def exec(
secret_ids=[secret.object_id for secret in secrets],
)
)
return _ContainerProcess(resp.exec_id, self._client, stdout=stdout, stderr=stderr)
by_line = bufsize == 1
return _ContainerProcess(resp.exec_id, self._client, stdout=stdout, stderr=stderr, text=text, by_line=by_line)

@property
def stdout(self) -> _StreamReader:
Expand Down
4 changes: 3 additions & 1 deletion test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -723,7 +723,9 @@ async def ContainerExecGetOutput(self, stream):
api_pb2.RuntimeOutputBatch(
items=[
api_pb2.RuntimeOutputMessage(
message=message.decode("utf-8"), file_descriptor=request.file_descriptor
message=message.decode("utf-8"),
file_descriptor=request.file_descriptor,
message_bytes=message,
)
]
)
Expand Down
38 changes: 34 additions & 4 deletions test/io_streams_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ async def sandbox_get_logs(servicer, stream):
ctx.set_responder("SandboxGetLogs", sandbox_get_logs)

with enable_output():
stdout = StreamReader(
stdout: StreamReader[str] = StreamReader(
file_descriptor=api_pb2.FILE_DESCRIPTOR_STDOUT,
object_id="sb-123",
object_type="sandbox",
Expand Down Expand Up @@ -54,7 +54,7 @@ async def sandbox_get_logs(servicer, stream):
ctx.set_responder("SandboxGetLogs", sandbox_get_logs)

with enable_output():
stdout = StreamReader(
stdout: StreamReader[str] = StreamReader(
file_descriptor=api_pb2.FILE_DESCRIPTOR_STDOUT,
object_id="sb-123",
object_type="sandbox",
Expand Down Expand Up @@ -88,7 +88,7 @@ async def sandbox_get_logs(servicer, stream):
ctx.set_responder("SandboxGetLogs", sandbox_get_logs)

with enable_output():
stdout = StreamReader(
stdout: StreamReader[str] = StreamReader(
file_descriptor=api_pb2.FILE_DESCRIPTOR_STDOUT,
object_id="sb-123",
object_type="sandbox",
Expand Down Expand Up @@ -134,7 +134,7 @@ async def sandbox_get_logs(servicer, stream):
ctx.set_responder("SandboxGetLogs", sandbox_get_logs)

with enable_output():
stdout = StreamReader(
stdout: StreamReader[str] = StreamReader(
file_descriptor=api_pb2.FILE_DESCRIPTOR_STDOUT,
object_id="sb-123",
object_type="sandbox",
Expand All @@ -147,3 +147,33 @@ async def sandbox_get_logs(servicer, stream):
out.append(line)

assert out == ["foobar\n", "baz"]


def test_stream_reader_bytes_mode(servicer, client):
"""Test that the stream reader works in bytes mode."""

async def sandbox_get_logs(servicer, stream):
await stream.recv_message()

log = api_pb2.TaskLogs(
data="foo\n",
file_descriptor=api_pb2.FILE_DESCRIPTOR_STDOUT,
)
await stream.send_message(api_pb2.TaskLogsBatch(entry_id="0", items=[log]))

# send EOF
await stream.send_message(api_pb2.TaskLogsBatch(eof=True))

with servicer.intercept() as ctx:
ctx.set_responder("SandboxGetLogs", sandbox_get_logs)

with enable_output():
stdout: StreamReader[bytes] = StreamReader(
file_descriptor=api_pb2.FILE_DESCRIPTOR_STDOUT,
object_id="sb-123",
object_type="sandbox",
client=client,
text=False,
)

assert stdout.read() == b"foo\n"
14 changes: 14 additions & 0 deletions test/sandbox_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,20 @@ async def test_sandbox_async_for(app, servicer):
assert await sb.stderr.read.aio() == ""


@skip_non_linux
def test_sandbox_exec_stdout_bytes_mode(app, servicer):
"""Test that the stream reader works in bytes mode."""

sb = Sandbox.create(app=app)

p = sb.exec("echo", "foo", text=False)
assert p.stdout.read() == b"foo\n"

p = sb.exec("echo", "foo", text=False)
for line in p.stdout:
assert line == b"foo\n"


@skip_non_linux
def test_app_sandbox(client, servicer):
image = Image.debian_slim().pip_install("xyz")
Expand Down
Loading