Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
azliu0 committed Nov 16, 2024
1 parent 15edece commit 5fea951
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 25 deletions.
28 changes: 24 additions & 4 deletions modal/io_streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,14 @@ def __init__(
self._last_entry_id: Optional[str] = None
self._string_line_buffer: str = ""
self._bytes_line_buffer: bytes = b""

# Sandbox logs are streamed to the client as strings, so StreamReaders reading
# them must have text mode enabled.
if object_type == "sandbox" and not text:
raise ValueError("Sandbox streams must have text mode enabled.")
self._text = text
self._by_line = by_line

# Whether the reader received an EOF. Once EOF is True, it returns
# an empty string for any subsequent reads (including async for)
self.eof = False
Expand Down Expand Up @@ -135,16 +141,30 @@ async def read(self) -> T:
```
"""
if self._text:
data = ""
if self._object_type == "sandbox":
return await self._read_sandbox()
else:
data = b""
return await self._read_container_process()

async for message in self._get_logs_by_line():
async def _read_sandbox(self) -> str:
assert self._object_type == "sandbox"
data = ""
async for message in self._get_logs():
if message is None:
break
data += message
return data

async def _read_container_process(self) -> T:
assert self._object_type == "container_process"
data = "" if self._text else b""
async for message in self._get_logs():
if message is None:
break
if self._text:
data += message.decode("utf-8")
else:
data += message
return data

async def _consume_container_process_stream(self):
Expand Down
8 changes: 4 additions & 4 deletions modal/sandbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ class _Sandbox(_Object, type_prefix="sb"):
"""

_result: Optional[api_pb2.GenericResult]
_stdout: _StreamReader
_stderr: _StreamReader
_stdout: _StreamReader[str]
_stderr: _StreamReader[str]
_stdin: _StreamWriter
_task_id: Optional[str] = None
_tunnels: Optional[Dict[int, Tunnel]] = None
Expand Down Expand Up @@ -280,10 +280,10 @@ async def create(
return obj

def _hydrate_metadata(self, handle_metadata: Optional[Message]):
self._stdout = StreamReader(
self._stdout: _StreamReader[str] = StreamReader(
api_pb2.FILE_DESCRIPTOR_STDOUT, self.object_id, "sandbox", self._client, by_line=True
)
self._stderr = StreamReader(
self._stderr: _StreamReader[str] = StreamReader(
api_pb2.FILE_DESCRIPTOR_STDERR, self.object_id, "sandbox", self._client, by_line=True
)
self._stdin = StreamWriter(self.object_id, "sandbox", self._client)
Expand Down
8 changes: 4 additions & 4 deletions test/cli_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ def test_serve(servicer, set_env_client, server_url_env, test_dir):

@pytest.fixture
def mock_shell_pty(servicer):
servicer.shell_prompt = "TEST_PROMPT# "
servicer.shell_prompt = b"TEST_PROMPT# "

def mock_get_pty_info(shell: bool) -> api_pb2.PTYInfo:
rows, cols = (64, 128)
Expand Down Expand Up @@ -417,7 +417,7 @@ def test_shell(servicer, set_env_client, test_dir, mock_shell_pty):
fake_stdin.clear()
fake_stdin.extend([b'echo "Hello World"\n', b"exit\n"])

shell_prompt = servicer.shell_prompt.encode("utf-8")
shell_prompt = servicer.shell_prompt

# Function is explicitly specified
_run(["shell", app_file.as_posix() + "::foo"])
Expand Down Expand Up @@ -445,7 +445,7 @@ def test_shell(servicer, set_env_client, test_dir, mock_shell_pty):
def test_shell_cmd(servicer, set_env_client, test_dir, mock_shell_pty):
app_file = test_dir / "supports" / "app_run_tests" / "default_app.py"
_, captured_out = mock_shell_pty
shell_prompt = servicer.shell_prompt.encode("utf-8")
shell_prompt = servicer.shell_prompt
_run(["shell", "--cmd", "pwd", app_file.as_posix() + "::foo"])
expected_output = subprocess.run(["pwd"], capture_output=True, check=True).stdout
assert captured_out == [(1, shell_prompt), (1, expected_output)]
Expand All @@ -456,7 +456,7 @@ def test_shell_preserve_token(servicer, set_env_client, mock_shell_pty, monkeypa
monkeypatch.setenv("MODAL_TOKEN_ID", "my-token-id")

fake_stdin, captured_out = mock_shell_pty
shell_prompt = servicer.shell_prompt.encode("utf-8")
shell_prompt = servicer.shell_prompt

fake_stdin.clear()
fake_stdin.extend([b'echo "$MODAL_TOKEN_ID"\n', b"exit\n"])
Expand Down
3 changes: 2 additions & 1 deletion test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -709,7 +709,8 @@ async def ContainerExecGetOutput(self, stream):
api_pb2.RuntimeOutputBatch(
items=[
api_pb2.RuntimeOutputMessage(
message=self.shell_prompt, file_descriptor=request.file_descriptor
file_descriptor=request.file_descriptor,
message_bytes=self.shell_prompt,
)
]
)
Expand Down
24 changes: 12 additions & 12 deletions test/io_streams_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# Copyright Modal Labs 2024
import pytest

from modal import enable_output
from modal.io_streams import StreamReader
from modal_proto import api_pb2
Expand Down Expand Up @@ -149,31 +151,29 @@ async def sandbox_get_logs(servicer, stream):
assert out == ["foobar\n", "baz"]


def test_stream_reader_bytes_mode(servicer, client):
@pytest.mark.asyncio
async def test_stream_reader_bytes_mode(servicer, client):
"""Test that the stream reader works in bytes mode."""

async def sandbox_get_logs(servicer, stream):
async def container_exec_get_output(servicer, stream):
await stream.recv_message()

log = api_pb2.TaskLogs(
data="foo\n",
file_descriptor=api_pb2.FILE_DESCRIPTOR_STDOUT,
await stream.send_message(
api_pb2.RuntimeOutputBatch(batch_index=0, items=[api_pb2.RuntimeOutputMessage(message_bytes=b"foo\n")])
)
await stream.send_message(api_pb2.TaskLogsBatch(entry_id="0", items=[log]))

# send EOF
await stream.send_message(api_pb2.TaskLogsBatch(eof=True))
await stream.send_message(api_pb2.RuntimeOutputBatch(exit_code=0))

with servicer.intercept() as ctx:
ctx.set_responder("SandboxGetLogs", sandbox_get_logs)
ctx.set_responder("ContainerExecGetOutput", container_exec_get_output)

with enable_output():
stdout: StreamReader[bytes] = StreamReader(
file_descriptor=api_pb2.FILE_DESCRIPTOR_STDOUT,
object_id="sb-123",
object_type="sandbox",
object_id="tp-123",
object_type="container_process",
client=client,
text=False,
)

assert stdout.read() == b"foo\n"
assert await stdout.read.aio() == b"foo\n"

0 comments on commit 5fea951

Please sign in to comment.