Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add text and line buffering mode to sandbox exec #2475

Merged
merged 22 commits into from
Nov 19, 2024
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 24 additions & 5 deletions modal/container_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ class _ContainerProcess:
_stdout: _StreamReader
_stderr: _StreamReader
_stdin: _StreamWriter
_text: bool
_by_line: bool
_returncode: Optional[int] = None

def __init__(
Expand All @@ -27,14 +29,31 @@ def __init__(
client: _Client,
stdout: StreamType = StreamType.PIPE,
stderr: StreamType = StreamType.PIPE,
text: bool = True,
by_line: bool = False,
) -> None:
self._process_id = process_id
self._client = client
self._stdout = _StreamReader(
api_pb2.FILE_DESCRIPTOR_STDOUT, process_id, "container_process", self._client, stream_type=stdout
self._text = text
self._by_line = by_line
T = str if text else bytes
azliu0 marked this conversation as resolved.
Show resolved Hide resolved
self._stdout: _StreamReader[T] = _StreamReader(
api_pb2.FILE_DESCRIPTOR_STDOUT,
process_id,
"container_process",
self._client,
stream_type=stdout,
text=text,
by_line=by_line,
)
self._stderr = _StreamReader(
api_pb2.FILE_DESCRIPTOR_STDERR, process_id, "container_process", self._client, stream_type=stderr
self._stderr: _StreamReader[T] = _StreamReader(
api_pb2.FILE_DESCRIPTOR_STDERR,
process_id,
"container_process",
self._client,
stream_type=stderr,
text=text,
by_line=by_line,
)
self._stdin = _StreamWriter(process_id, "container_process", self._client)

Expand All @@ -57,7 +76,7 @@ def stdin(self) -> _StreamWriter:
return self._stdin

@property
def returncode(self) -> _StreamWriter:
def returncode(self) -> int:
if self._returncode is None:
raise InvalidError(
"You must call wait() before accessing the returncode. "
Expand Down
105 changes: 77 additions & 28 deletions modal/io_streams.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright Modal Labs 2022
import asyncio
from typing import TYPE_CHECKING, AsyncGenerator, Literal, Optional, Tuple, Union
from typing import TYPE_CHECKING, AsyncGenerator, Generic, List, Literal, Optional, Tuple, TypeVar, Union

from grpclib import Status
from grpclib.exceptions import GRPCError, StreamTerminatedError
Expand Down Expand Up @@ -38,21 +38,22 @@ async def _sandbox_logs_iterator(

async def _container_process_logs_iterator(
process_id: str, file_descriptor: int, client: _Client
) -> AsyncGenerator[Optional[str], None]:
) -> AsyncGenerator[Optional[bytes], None]:
req = api_pb2.ContainerExecGetOutputRequest(
exec_id=process_id,
timeout=55,
file_descriptor=file_descriptor,
exec_id=process_id, timeout=55, file_descriptor=file_descriptor, get_raw_bytes=True
)
async for batch in client.stub.ContainerExecGetOutput.unary_stream(req):
if batch.HasField("exit_code"):
yield None
break
for item in batch.items:
yield item.message
yield item.message_bytes


T = TypeVar("T", str, bytes)

class _StreamReader:

class _StreamReader(Generic[T]):
"""Provides an interface to buffer and fetch logs from a stream (`stdout` or `stderr`).

As an asynchronous iterable, the object supports the async for statement.
Expand Down Expand Up @@ -80,18 +81,26 @@ def __init__(
object_type: Literal["sandbox", "container_process"],
client: _Client,
stream_type: StreamType = StreamType.PIPE,
by_line: bool = False, # if True, streamed logs are further processed into complete lines.
text: bool = True,
by_line: bool = False,
) -> None:
"""mdmd:hidden"""

self._file_descriptor = file_descriptor
self._object_type = object_type
self._object_id = object_id
self._client = client
self._stream = None
self._last_entry_id = None
self._line_buffer = ""
self._last_entry_id: Optional[str] = None
self._string_line_buffer: str = ""
self._bytes_line_buffer: bytes = b""

# Sandbox logs are streamed to the client as strings, so StreamReaders reading
# them must have text mode enabled.
if object_type == "sandbox" and not text:
raise ValueError("Sandbox streams must have text mode enabled.")
self._text = text
self._by_line = by_line

# Whether the reader received an EOF. Once EOF is True, it returns
# an empty string for any subsequent reads (including async for)
self.eof = False
Expand All @@ -109,14 +118,14 @@ def __init__(
# Container process streams need to be consumed as they are produced,
# otherwise the process will block. Use a buffer to store the stream
# until the client consumes it.
self._container_process_buffer = []
self._container_process_buffer: List[Optional[bytes]] = []
self._consume_container_process_task = asyncio.create_task(self._consume_container_process_stream())

@property
def file_descriptor(self):
return self._file_descriptor

async def read(self) -> str:
async def read(self) -> T:
"""Fetch and return contents of the entire stream. If EOF was received,
return an empty string.

Expand All @@ -132,13 +141,30 @@ async def read(self) -> str:
```

"""
if self._object_type == "sandbox":
return await self._read_sandbox()
else:
return await self._read_container_process()

async def _read_sandbox(self) -> str:
azliu0 marked this conversation as resolved.
Show resolved Hide resolved
assert self._object_type == "sandbox"
data = ""
# TODO: maybe combine this with get_app_logs_loop
async for message in self._get_logs_by_line():
async for message in self._get_logs():
if message is None:
break
data += message
return data

async def _read_container_process(self) -> T:
assert self._object_type == "container_process"
data = "" if self._text else b""
async for message in self._get_logs():
if message is None:
break
if self._text:
data += message.decode("utf-8")
else:
data += message
return data

async def _consume_container_process_stream(self):
Expand Down Expand Up @@ -177,7 +203,7 @@ async def _consume_container_process_stream(self):
break
raise exc

async def _stream_container_process(self) -> AsyncGenerator[Tuple[Optional[str], str], None]:
async def _stream_container_process(self) -> AsyncGenerator[Tuple[Optional[bytes], str], None]:
"""mdmd:hidden
Streams the container process buffer to the reader.
"""
Expand All @@ -198,7 +224,7 @@ async def _stream_container_process(self) -> AsyncGenerator[Tuple[Optional[str],

entry_id += 1

async def _get_logs(self) -> AsyncGenerator[Optional[str], None]:
async def _get_logs(self) -> AsyncGenerator[Optional[Union[bytes, str]], None]:
"""mdmd:hidden
Streams sandbox or process logs from the server to the reader.

Expand All @@ -220,7 +246,7 @@ async def _get_logs(self) -> AsyncGenerator[Optional[str], None]:
while not completed:
try:
if self._object_type == "sandbox":
iterator = _sandbox_logs_iterator(
iterator: AsyncGenerator[Tuple[Optional[Union[T, str]], str], None] = _sandbox_logs_iterator(
azliu0 marked this conversation as resolved.
Show resolved Hide resolved
self._object_id, self._file_descriptor, self._last_entry_id, self._client
)
else:
Expand All @@ -244,21 +270,35 @@ async def _get_logs(self) -> AsyncGenerator[Optional[str], None]:
continue
raise

async def _get_logs_by_line(self) -> AsyncGenerator[Optional[str], None]:
async def _get_logs_by_line(self) -> AsyncGenerator[Optional[Union[bytes, str]], None]:
"""mdmd:hidden
Processes logs from the server and yields complete lines only.
"""
async for message in self._get_logs():
if message is None:
if self._line_buffer:
yield self._line_buffer
self._line_buffer = ""
yield None
if self._object_type == "sandbox":
azliu0 marked this conversation as resolved.
Show resolved Hide resolved
if self._string_line_buffer:
yield self._string_line_buffer
self._string_line_buffer = ""
yield None
else:
if self._bytes_line_buffer:
yield self._bytes_line_buffer
self._bytes_line_buffer = b""
yield None
else:
self._line_buffer += message
while "\n" in self._line_buffer:
line, self._line_buffer = self._line_buffer.split("\n", 1)
yield line + "\n"
if self._object_type == "sandbox":
assert isinstance(message, str)
self._string_line_buffer += message
while "\n" in self._string_line_buffer:
line, self._string_line_buffer = self._string_line_buffer.split("\n", 1)
yield line + "\n"
else:
assert isinstance(message, bytes)
self._bytes_line_buffer += message
while b"\n" in self._bytes_line_buffer:
line, self._bytes_line_buffer = self._bytes_line_buffer.split(b"\n", 1)
yield line + b"\n"

def __aiter__(self):
"""mdmd:hidden"""
Expand All @@ -271,13 +311,22 @@ def __aiter__(self):

async def __anext__(self):
"""mdmd:hidden"""
assert self._stream is not None
value = await self._stream.__anext__()

# The stream yields None if it receives an EOF batch.
if value is None:
raise StopAsyncIteration

return value
# Sandbox logs are strings
azliu0 marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(value, str):
return value

# Container process logs are bytes
if self._text:
return value.decode("utf-8")
else:
return value


MAX_BUFFER_SIZE = 2 * 1024 * 1024
Expand Down
20 changes: 12 additions & 8 deletions modal/sandbox.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright Modal Labs 2022
import asyncio
import os
from typing import TYPE_CHECKING, AsyncGenerator, Dict, List, Optional, Sequence, Tuple, Union
from typing import TYPE_CHECKING, AsyncGenerator, Dict, List, Literal, Optional, Sequence, Tuple, Union

from google.protobuf.message import Message
from grpclib import GRPCError, Status
Expand Down Expand Up @@ -46,8 +46,8 @@ class _Sandbox(_Object, type_prefix="sb"):
"""

_result: Optional[api_pb2.GenericResult]
_stdout: _StreamReader
_stderr: _StreamReader
_stdout: _StreamReader[str]
_stderr: _StreamReader[str]
_stdin: _StreamWriter
_task_id: Optional[str] = None
_tunnels: Optional[Dict[int, Tunnel]] = None
Expand Down Expand Up @@ -280,10 +280,10 @@ async def create(
return obj

def _hydrate_metadata(self, handle_metadata: Optional[Message]):
self._stdout = StreamReader(
self._stdout: _StreamReader[str] = StreamReader(
api_pb2.FILE_DESCRIPTOR_STDOUT, self.object_id, "sandbox", self._client, by_line=True
)
self._stderr = StreamReader(
self._stderr: _StreamReader[str] = StreamReader(
api_pb2.FILE_DESCRIPTOR_STDERR, self.object_id, "sandbox", self._client, by_line=True
)
self._stdin = StreamWriter(self.object_id, "sandbox", self._client)
Expand Down Expand Up @@ -403,13 +403,16 @@ async def _get_task_id(self):
async def exec(
self,
*cmds: str,
# Deprecated: internal use only
pty_info: Optional[api_pb2.PTYInfo] = None,
pty_info: Optional[api_pb2.PTYInfo] = None, # Deprecated: internal use only
stdout: StreamType = StreamType.PIPE,
stderr: StreamType = StreamType.PIPE,
timeout: Optional[int] = None,
workdir: Optional[str] = None,
secrets: Sequence[_Secret] = (),
# Encode output as text.
azliu0 marked this conversation as resolved.
Show resolved Hide resolved
text: bool = True,
# Control line-buffered output.
bufsize: Literal[-1, 1] = -1,
azliu0 marked this conversation as resolved.
Show resolved Hide resolved
# Internal option to set terminal size and metadata
_pty_info: Optional[api_pb2.PTYInfo] = None,
):
Expand Down Expand Up @@ -449,7 +452,8 @@ async def exec(
secret_ids=[secret.object_id for secret in secrets],
)
)
return _ContainerProcess(resp.exec_id, self._client, stdout=stdout, stderr=stderr)
by_line = bufsize == 1
return _ContainerProcess(resp.exec_id, self._client, stdout=stdout, stderr=stderr, text=text, by_line=by_line)

@property
def stdout(self) -> _StreamReader:
Expand Down
8 changes: 4 additions & 4 deletions test/cli_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ def test_serve(servicer, set_env_client, server_url_env, test_dir):

@pytest.fixture
def mock_shell_pty(servicer):
servicer.shell_prompt = "TEST_PROMPT# "
servicer.shell_prompt = b"TEST_PROMPT# "

def mock_get_pty_info(shell: bool) -> api_pb2.PTYInfo:
rows, cols = (64, 128)
Expand Down Expand Up @@ -417,7 +417,7 @@ def test_shell(servicer, set_env_client, test_dir, mock_shell_pty):
fake_stdin.clear()
fake_stdin.extend([b'echo "Hello World"\n', b"exit\n"])

shell_prompt = servicer.shell_prompt.encode("utf-8")
shell_prompt = servicer.shell_prompt

# Function is explicitly specified
_run(["shell", app_file.as_posix() + "::foo"])
Expand Down Expand Up @@ -445,7 +445,7 @@ def test_shell(servicer, set_env_client, test_dir, mock_shell_pty):
def test_shell_cmd(servicer, set_env_client, test_dir, mock_shell_pty):
app_file = test_dir / "supports" / "app_run_tests" / "default_app.py"
_, captured_out = mock_shell_pty
shell_prompt = servicer.shell_prompt.encode("utf-8")
shell_prompt = servicer.shell_prompt
_run(["shell", "--cmd", "pwd", app_file.as_posix() + "::foo"])
expected_output = subprocess.run(["pwd"], capture_output=True, check=True).stdout
assert captured_out == [(1, shell_prompt), (1, expected_output)]
Expand All @@ -456,7 +456,7 @@ def test_shell_preserve_token(servicer, set_env_client, mock_shell_pty, monkeypa
monkeypatch.setenv("MODAL_TOKEN_ID", "my-token-id")

fake_stdin, captured_out = mock_shell_pty
shell_prompt = servicer.shell_prompt.encode("utf-8")
shell_prompt = servicer.shell_prompt

fake_stdin.clear()
fake_stdin.extend([b'echo "$MODAL_TOKEN_ID"\n', b"exit\n"])
Expand Down
7 changes: 5 additions & 2 deletions test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -709,7 +709,8 @@ async def ContainerExecGetOutput(self, stream):
api_pb2.RuntimeOutputBatch(
items=[
api_pb2.RuntimeOutputMessage(
message=self.shell_prompt, file_descriptor=request.file_descriptor
file_descriptor=request.file_descriptor,
message_bytes=self.shell_prompt,
)
]
)
Expand All @@ -723,7 +724,9 @@ async def ContainerExecGetOutput(self, stream):
api_pb2.RuntimeOutputBatch(
items=[
api_pb2.RuntimeOutputMessage(
message=message.decode("utf-8"), file_descriptor=request.file_descriptor
message=message.decode("utf-8"),
file_descriptor=request.file_descriptor,
message_bytes=message,
)
]
)
Expand Down
Loading
Loading