Skip to content

Commit

Permalink
make containerprocess generic
Browse files Browse the repository at this point in the history
  • Loading branch information
azliu0 committed Nov 18, 2024
1 parent 1b1582d commit 4933489
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 28 deletions.
36 changes: 14 additions & 22 deletions modal/container_process.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright Modal Labs 2024
import asyncio
import platform
from typing import Optional
from typing import Generic, Optional, TypeVar

from modal_proto import api_pb2

Expand All @@ -13,11 +13,13 @@
from .io_streams import _StreamReader, _StreamWriter
from .stream_type import StreamType

T = TypeVar("T", str, bytes)

class _ContainerProcess:

class _ContainerProcess(Generic[T]):
_process_id: Optional[str] = None
_stdout: _StreamReader
_stderr: _StreamReader
_stdout: _StreamReader[T]
_stderr: _StreamReader[T]
_stdin: _StreamWriter
_text: bool
_by_line: bool
Expand All @@ -36,8 +38,7 @@ def __init__(
self._client = client
self._text = text
self._by_line = by_line
T = str if text else bytes
self._stdout: _StreamReader[T] = _StreamReader(
self._stdout = _StreamReader[T](
api_pb2.FILE_DESCRIPTOR_STDOUT,
process_id,
"container_process",
Expand All @@ -46,7 +47,7 @@ def __init__(
text=text,
by_line=by_line,
)
self._stderr: _StreamReader[T] = _StreamReader(
self._stderr = _StreamReader[T](
api_pb2.FILE_DESCRIPTOR_STDERR,
process_id,
"container_process",
Expand All @@ -58,21 +59,18 @@ def __init__(
self._stdin = _StreamWriter(process_id, "container_process", self._client)

@property
def stdout(self) -> _StreamReader:
"""`StreamReader` for the container process's stdout stream."""

def stdout(self) -> _StreamReader[T]:
"""StreamReader for the container process's stdout stream."""
return self._stdout

@property
def stderr(self) -> _StreamReader:
"""`StreamReader` for the container process's stderr stream."""

def stderr(self) -> _StreamReader[T]:
"""StreamReader for the container process's stderr stream."""
return self._stderr

@property
def stdin(self) -> _StreamWriter:
"""`StreamWriter` for the container process's stdin stream."""

"""StreamWriter for the container process's stdin stream."""
return self._stdin

@property
Expand All @@ -82,7 +80,6 @@ def returncode(self) -> int:
"You must call wait() before accessing the returncode. "
"To poll for the status of a running process, use poll() instead."
)

return self._returncode

async def poll(self) -> Optional[int]:
Expand All @@ -104,7 +101,6 @@ async def poll(self) -> Optional[int]:

async def wait(self) -> int:
"""Wait for the container process to finish running. Returns the exit code."""

if self._returncode is not None:
return self._returncode

Expand Down Expand Up @@ -154,15 +150,11 @@ async def _handle_input(data: bytes, message_index: int):
await stdout_task
await stderr_task

# TODO: this doesn't work right now.
# if exit_status != 0:
# raise ExecutionError(f"Process exited with status code {exit_status}")

except (asyncio.TimeoutError, TimeoutError):
connecting_status.stop()
stdout_task.cancel()
stderr_task.cancel()
raise InteractiveTimeoutError("Failed to establish connection to container. Please try again.")


ContainerProcess = synchronize_api(_ContainerProcess)
ContainerProcess = synchronize_api(_ContainerProcess[T])
34 changes: 33 additions & 1 deletion modal/sandbox.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright Modal Labs 2022
import asyncio
import os
from typing import TYPE_CHECKING, AsyncGenerator, Dict, List, Literal, Optional, Sequence, Tuple, Union
from typing import TYPE_CHECKING, AsyncGenerator, Dict, List, Literal, Optional, Sequence, Tuple, Union, overload

from google.protobuf.message import Message
from grpclib import GRPCError, Status
Expand Down Expand Up @@ -400,6 +400,38 @@ async def _get_task_id(self):
await asyncio.sleep(0.5)
return self._task_id

@overload
async def exec(
self,
*cmds: str,
pty_info: Optional[api_pb2.PTYInfo] = None,
stdout: StreamType = StreamType.PIPE,
stderr: StreamType = StreamType.PIPE,
timeout: Optional[int] = None,
workdir: Optional[str] = None,
secrets: Sequence[_Secret] = (),
text: Literal[True],
bufsize: Literal[-1, 1] = -1,
_pty_info: Optional[api_pb2.PTYInfo] = None,
) -> _ContainerProcess[str]:
...

@overload
async def exec(
self,
*cmds: str,
pty_info: Optional[api_pb2.PTYInfo] = None,
stdout: StreamType = StreamType.PIPE,
stderr: StreamType = StreamType.PIPE,
timeout: Optional[int] = None,
workdir: Optional[str] = None,
secrets: Sequence[_Secret] = (),
text: Literal[False],
bufsize: Literal[-1, 1] = -1,
_pty_info: Optional[api_pb2.PTYInfo] = None,
) -> _ContainerProcess[bytes]:
...

async def exec(
self,
*cmds: str,
Expand Down
5 changes: 0 additions & 5 deletions test/supports/type_assertions.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,3 @@ async def async_block() -> None:
assert_type(should_also_be_str, str)
should_be_int = await instance.bar.local("bar")
assert_type(should_be_int, int)


sb_app = modal.App.lookup("sandbox", create_if_missing=True)
sb = modal.Sandbox(app=sb_app)
cp = sb.exec(bufsize=-1)

0 comments on commit 4933489

Please sign in to comment.