Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add text and line buffering mode to sandbox exec #2475

Merged
merged 22 commits into from
Nov 19, 2024
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 24 additions & 5 deletions modal/container_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ class _ContainerProcess:
_stdout: _StreamReader
_stderr: _StreamReader
_stdin: _StreamWriter
_text: bool
_by_line: bool
_returncode: Optional[int] = None

def __init__(
Expand All @@ -27,14 +29,31 @@ def __init__(
client: _Client,
stdout: StreamType = StreamType.PIPE,
stderr: StreamType = StreamType.PIPE,
text: bool = True,
by_line: bool = False,
) -> None:
self._process_id = process_id
self._client = client
self._stdout = _StreamReader(
api_pb2.FILE_DESCRIPTOR_STDOUT, process_id, "container_process", self._client, stream_type=stdout
self._text = text
self._by_line = by_line
T = str if text else bytes
azliu0 marked this conversation as resolved.
Show resolved Hide resolved
self._stdout: _StreamReader[T] = _StreamReader(
api_pb2.FILE_DESCRIPTOR_STDOUT,
process_id,
"container_process",
self._client,
stream_type=stdout,
text=text,
by_line=by_line,
)
self._stderr = _StreamReader(
api_pb2.FILE_DESCRIPTOR_STDERR, process_id, "container_process", self._client, stream_type=stderr
self._stderr: _StreamReader[T] = _StreamReader(
api_pb2.FILE_DESCRIPTOR_STDERR,
process_id,
"container_process",
self._client,
stream_type=stderr,
text=text,
by_line=by_line,
)
self._stdin = _StreamWriter(process_id, "container_process", self._client)

Expand All @@ -57,7 +76,7 @@ def stdin(self) -> _StreamWriter:
return self._stdin

@property
def returncode(self) -> _StreamWriter:
def returncode(self) -> int:
if self._returncode is None:
raise InvalidError(
"You must call wait() before accessing the returncode. "
Expand Down
67 changes: 43 additions & 24 deletions modal/io_streams.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright Modal Labs 2022
import asyncio
from typing import TYPE_CHECKING, AsyncGenerator, Literal, Optional, Tuple, Union
from typing import TYPE_CHECKING, AsyncGenerator, Generic, List, Literal, Optional, Tuple, TypeVar, Union

from grpclib import Status
from grpclib.exceptions import GRPCError, StreamTerminatedError
Expand All @@ -19,7 +19,7 @@

async def _sandbox_logs_iterator(
sandbox_id: str, file_descriptor: int, last_entry_id: Optional[str], client: _Client
) -> AsyncGenerator[Tuple[Optional[str], str], None]:
) -> AsyncGenerator[Tuple[Optional[bytes], str], None]:
req = api_pb2.SandboxGetLogsRequest(
sandbox_id=sandbox_id,
file_descriptor=file_descriptor,
Expand All @@ -30,29 +30,30 @@ async def _sandbox_logs_iterator(
last_entry_id = log_batch.entry_id

for message in log_batch.items:
yield (message.data, last_entry_id)
yield (message.data.encode("utf-8"), last_entry_id)
if log_batch.eof:
yield (None, last_entry_id)
break


async def _container_process_logs_iterator(
process_id: str, file_descriptor: int, client: _Client
) -> AsyncGenerator[Optional[str], None]:
) -> AsyncGenerator[Optional[bytes], None]:
req = api_pb2.ContainerExecGetOutputRequest(
exec_id=process_id,
timeout=55,
file_descriptor=file_descriptor,
exec_id=process_id, timeout=55, file_descriptor=file_descriptor, get_raw_bytes=True
)
async for batch in client.stub.ContainerExecGetOutput.unary_stream(req):
if batch.HasField("exit_code"):
yield None
break
for item in batch.items:
yield item.message
yield item.message_bytes


T = TypeVar("T", str, bytes)


class _StreamReader:
class _StreamReader(Generic[T]):
"""Provides an interface to buffer and fetch logs from a stream (`stdout` or `stderr`).

As an asynchronous iterable, the object supports the async for statement.
Expand Down Expand Up @@ -80,18 +81,30 @@ def __init__(
object_type: Literal["sandbox", "container_process"],
client: _Client,
stream_type: StreamType = StreamType.PIPE,
by_line: bool = False, # if True, streamed logs are further processed into complete lines.
text: bool = True,
by_line: bool = False,
) -> None:
"""mdmd:hidden"""

self._file_descriptor = file_descriptor
self._object_type = object_type
self._object_id = object_id
self._client = client
self._stream = None
self._last_entry_id = None
self._last_entry_id: Optional[str] = None
self._line_buffer = ""

# Sandbox logs are streamed to the client as strings, so StreamReaders reading
# them must have text mode enabled.
if object_type == "sandbox" and not text:
raise ValueError("Sandbox streams must have text mode enabled.")

# line-buffering is only supported when text=True
if by_line and not text:
raise ValueError("line-buffering is only supported when text=True")

self._text = text
self._by_line = by_line

# Whether the reader received an EOF. Once EOF is True, it returns
# an empty string for any subsequent reads (including async for)
self.eof = False
Expand All @@ -109,14 +122,14 @@ def __init__(
# Container process streams need to be consumed as they are produced,
# otherwise the process will block. Use a buffer to store the stream
# until the client consumes it.
self._container_process_buffer = []
self._container_process_buffer: List[Optional[bytes]] = []
self._consume_container_process_task = asyncio.create_task(self._consume_container_process_stream())

@property
def file_descriptor(self):
return self._file_descriptor

async def read(self) -> str:
async def read(self) -> T:
"""Fetch and return contents of the entire stream. If EOF was received,
return an empty string.

Expand All @@ -132,13 +145,14 @@ async def read(self) -> str:
```

"""
data = ""
# TODO: maybe combine this with get_app_logs_loop
async for message in self._get_logs_by_line():
data = "" if self._text else b""
async for message in self._get_logs():
if message is None:
break
data += message

if self._text:
data += message.decode("utf-8")
else:
data += message
return data

async def _consume_container_process_stream(self):
Expand Down Expand Up @@ -177,7 +191,7 @@ async def _consume_container_process_stream(self):
break
raise exc

async def _stream_container_process(self) -> AsyncGenerator[Tuple[Optional[str], str], None]:
async def _stream_container_process(self) -> AsyncGenerator[Tuple[Optional[bytes], str], None]:
"""mdmd:hidden
Streams the container process buffer to the reader.
"""
Expand All @@ -198,7 +212,7 @@ async def _stream_container_process(self) -> AsyncGenerator[Tuple[Optional[str],

entry_id += 1

async def _get_logs(self) -> AsyncGenerator[Optional[str], None]:
async def _get_logs(self) -> AsyncGenerator[Optional[Union[bytes, str]], None]:
"""mdmd:hidden
Streams sandbox or process logs from the server to the reader.

Expand Down Expand Up @@ -255,7 +269,8 @@ async def _get_logs_by_line(self) -> AsyncGenerator[Optional[str], None]:
self._line_buffer = ""
yield None
else:
self._line_buffer += message
assert isinstance(message, bytes)
self._line_buffer += message.decode("utf-8")
while "\n" in self._line_buffer:
line, self._line_buffer = self._line_buffer.split("\n", 1)
yield line + "\n"
Expand All @@ -269,15 +284,19 @@ def __aiter__(self):
self._stream = self._get_logs()
return self

async def __anext__(self):
async def __anext__(self) -> T:
"""mdmd:hidden"""
assert self._stream is not None
value = await self._stream.__anext__()

# The stream yields None if it receives an EOF batch.
if value is None:
raise StopAsyncIteration

return value
if self._text:
return value.decode("utf-8")
else:
return value


MAX_BUFFER_SIZE = 2 * 1024 * 1024
Expand Down
21 changes: 13 additions & 8 deletions modal/sandbox.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright Modal Labs 2022
import asyncio
import os
from typing import TYPE_CHECKING, AsyncGenerator, Dict, List, Optional, Sequence, Tuple, Union
from typing import TYPE_CHECKING, AsyncGenerator, Dict, List, Literal, Optional, Sequence, Tuple, Union

from google.protobuf.message import Message
from grpclib import GRPCError, Status
Expand Down Expand Up @@ -46,8 +46,8 @@ class _Sandbox(_Object, type_prefix="sb"):
"""

_result: Optional[api_pb2.GenericResult]
_stdout: _StreamReader
_stderr: _StreamReader
_stdout: _StreamReader[str]
_stderr: _StreamReader[str]
_stdin: _StreamWriter
_task_id: Optional[str] = None
_tunnels: Optional[Dict[int, Tunnel]] = None
Expand Down Expand Up @@ -280,10 +280,10 @@ async def create(
return obj

def _hydrate_metadata(self, handle_metadata: Optional[Message]):
self._stdout = StreamReader(
self._stdout: _StreamReader[str] = StreamReader(
api_pb2.FILE_DESCRIPTOR_STDOUT, self.object_id, "sandbox", self._client, by_line=True
)
self._stderr = StreamReader(
self._stderr: _StreamReader[str] = StreamReader(
api_pb2.FILE_DESCRIPTOR_STDERR, self.object_id, "sandbox", self._client, by_line=True
)
self._stdin = StreamWriter(self.object_id, "sandbox", self._client)
Expand Down Expand Up @@ -403,13 +403,17 @@ async def _get_task_id(self):
async def exec(
self,
*cmds: str,
# Deprecated: internal use only
pty_info: Optional[api_pb2.PTYInfo] = None,
pty_info: Optional[api_pb2.PTYInfo] = None, # Deprecated: internal use only
stdout: StreamType = StreamType.PIPE,
stderr: StreamType = StreamType.PIPE,
timeout: Optional[int] = None,
workdir: Optional[str] = None,
secrets: Sequence[_Secret] = (),
# Encode output as text.
azliu0 marked this conversation as resolved.
Show resolved Hide resolved
text: bool = True,
# Control line-buffered output.
# -1 means unbuffered, 1 means line-buffered (only available if `text=True`).
bufsize: Literal[-1, 1] = -1,
azliu0 marked this conversation as resolved.
Show resolved Hide resolved
# Internal option to set terminal size and metadata
_pty_info: Optional[api_pb2.PTYInfo] = None,
):
Expand Down Expand Up @@ -449,7 +453,8 @@ async def exec(
secret_ids=[secret.object_id for secret in secrets],
)
)
return _ContainerProcess(resp.exec_id, self._client, stdout=stdout, stderr=stderr)
by_line = bufsize == 1
return _ContainerProcess(resp.exec_id, self._client, stdout=stdout, stderr=stderr, text=text, by_line=by_line)

@property
def stdout(self) -> _StreamReader:
Expand Down
8 changes: 4 additions & 4 deletions test/cli_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ def test_serve(servicer, set_env_client, server_url_env, test_dir):

@pytest.fixture
def mock_shell_pty(servicer):
servicer.shell_prompt = "TEST_PROMPT# "
servicer.shell_prompt = b"TEST_PROMPT# "

def mock_get_pty_info(shell: bool) -> api_pb2.PTYInfo:
rows, cols = (64, 128)
Expand Down Expand Up @@ -417,7 +417,7 @@ def test_shell(servicer, set_env_client, test_dir, mock_shell_pty):
fake_stdin.clear()
fake_stdin.extend([b'echo "Hello World"\n', b"exit\n"])

shell_prompt = servicer.shell_prompt.encode("utf-8")
shell_prompt = servicer.shell_prompt

# Function is explicitly specified
_run(["shell", app_file.as_posix() + "::foo"])
Expand Down Expand Up @@ -445,7 +445,7 @@ def test_shell(servicer, set_env_client, test_dir, mock_shell_pty):
def test_shell_cmd(servicer, set_env_client, test_dir, mock_shell_pty):
app_file = test_dir / "supports" / "app_run_tests" / "default_app.py"
_, captured_out = mock_shell_pty
shell_prompt = servicer.shell_prompt.encode("utf-8")
shell_prompt = servicer.shell_prompt
_run(["shell", "--cmd", "pwd", app_file.as_posix() + "::foo"])
expected_output = subprocess.run(["pwd"], capture_output=True, check=True).stdout
assert captured_out == [(1, shell_prompt), (1, expected_output)]
Expand All @@ -456,7 +456,7 @@ def test_shell_preserve_token(servicer, set_env_client, mock_shell_pty, monkeypa
monkeypatch.setenv("MODAL_TOKEN_ID", "my-token-id")

fake_stdin, captured_out = mock_shell_pty
shell_prompt = servicer.shell_prompt.encode("utf-8")
shell_prompt = servicer.shell_prompt

fake_stdin.clear()
fake_stdin.extend([b'echo "$MODAL_TOKEN_ID"\n', b"exit\n"])
Expand Down
7 changes: 5 additions & 2 deletions test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -709,7 +709,8 @@ async def ContainerExecGetOutput(self, stream):
api_pb2.RuntimeOutputBatch(
items=[
api_pb2.RuntimeOutputMessage(
message=self.shell_prompt, file_descriptor=request.file_descriptor
file_descriptor=request.file_descriptor,
message_bytes=self.shell_prompt,
)
]
)
Expand All @@ -723,7 +724,9 @@ async def ContainerExecGetOutput(self, stream):
api_pb2.RuntimeOutputBatch(
items=[
api_pb2.RuntimeOutputMessage(
message=message.decode("utf-8"), file_descriptor=request.file_descriptor
message=message.decode("utf-8"),
file_descriptor=request.file_descriptor,
message_bytes=message,
)
]
)
Expand Down
Loading
Loading