Skip to content

Commit

Permalink
Merge pull request #4 from UKGovernmentBEIS/craig/migrate-project
Browse files Browse the repository at this point in the history
Initial `k8s_sandbox` project
  • Loading branch information
craigwalton-dsit authored Dec 18, 2024
2 parents bf61269 + a6b63bd commit 01318da
Show file tree
Hide file tree
Showing 88 changed files with 7,554 additions and 1 deletion.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,4 @@ repos:
- id: helm-docs-built
args:
# Make the tool search for charts only under the `charts` directory
- --chart-search-root=src/aisitools/k8s_sandbox/resources/helm/agent-env
- --chart-search-root=src/k8s_sandbox/resources/helm/agent-env
2,680 changes: 2,680 additions & 0 deletions poetry.lock

Large diffs are not rendered by default.

51 changes: 51 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
[tool.poetry]
name = "inspect-k8s-sandbox"
version = "0.1.0"
description = "A Kubernetes Sandbox Environment for Inspect"
authors = ["UK AI Safety Institute"]
readme = "README.md"
packages = [
{include = "k8s_sandbox", from = "src"},
]

[tool.poetry.dependencies]
python = "^3.10"
inspect-ai = ">=0.3.50"
kubernetes = "^31.0.0"

[tool.poetry.group.dev.dependencies]
mypy = "^1.9.0"
pre-commit = "^3.6.2"
pytest = "^8.1.1"
pytest-asyncio = "^0.23.7"
pytest-repeat = "^0.9.3"
ruff = "^0.6.0"
types-pyyaml = "^6.0.12"

[tool.poetry.plugins.inspect_ai]
k8s-sandbox = "k8s_sandbox._sandbox_environment"


[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"


[tool.ruff.lint]
select = ["E", # pycodestyle errors
"W", # pycodestyle warnings
"F", # flake8
"D", # pydocstyle
"I", # isort
]
ignore = ["E203", "D10", "D203", "D212"]

[tool.ruff.lint.pydocstyle]
convention = "google"


[tool.pytest.ini_options]
asyncio_mode = "auto"
markers = [
"req_k8s: marks tests as requiring a test Kubernetes cluster (deselect with '-m \"not req_k8s\"')"
]
14 changes: 14 additions & 0 deletions src/k8s_sandbox/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from k8s_sandbox._pod import GetReturncodeError, PodError
from k8s_sandbox._sandbox_environment import (
K8sError,
K8sSandboxEnvironment,
K8sSandboxEnvironmentConfig,
)

__all__ = [
"GetReturncodeError",
"PodError",
"K8sError",
"K8sSandboxEnvironment",
"K8sSandboxEnvironmentConfig",
]
238 changes: 238 additions & 0 deletions src/k8s_sandbox/_helm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,238 @@
import asyncio
import logging
import os
import re
from pathlib import Path
from typing import Any, NoReturn

from inspect_ai.util import ExecResult, concurrency
from kubernetes.client.rest import ApiException # type: ignore
from shortuuid import uuid

from k8s_sandbox._kubernetes_api import (
get_current_context_namespace,
k8s_client,
)
from k8s_sandbox._logger import format_log_message, sandbox_log
from k8s_sandbox._pod import Pod

DEFAULT_CHART = Path(__file__).parent / "resources" / "helm" / "agent-env"
DEFAULT_TIMEOUT = 300
MAX_INSTALL_ATTEMPTS = 3
INSTALL_RETRY_DELAY_SECONDS = 5


logger = logging.getLogger(__name__)


class _ResourceQuotaModifiedError(Exception):
pass


class Release:
"""A release of a Helm chart."""

def __init__(
self,
task_name: str,
chart_path: Path | None = None,
values_path: Path | None = None,
) -> None:
self.task_name = task_name
self._chart_path = chart_path or DEFAULT_CHART
self._values_path = values_path
self._namespace = get_current_context_namespace()
# The release name is used in pod names too, so limit it to 8 chars.
self.release_name = self._generate_release_name()

def _generate_release_name(self) -> str:
return uuid().lower()[:8]

async def install(self) -> None:
async with _install_semaphore():
sandbox_log(
"Installing helm chart.",
chart=self._chart_path,
release=self.release_name,
values=self._values_path,
namespace=self._namespace,
task=self.task_name,
)
attempt = 1
while True:
try:
await self._install(upgrade=attempt > 1)
break
except _ResourceQuotaModifiedError:
if attempt >= MAX_INSTALL_ATTEMPTS:
raise
attempt += 1
await asyncio.sleep(INSTALL_RETRY_DELAY_SECONDS)

async def uninstall(self, quiet: bool) -> None:
await uninstall(self.release_name, quiet)

async def get_sandbox_pods(self) -> dict[str, Pod]:
client = k8s_client()
loop = asyncio.get_running_loop()
try:
pods = await loop.run_in_executor(
None,
lambda: client.list_namespaced_pod(
self._namespace,
label_selector=f"app.kubernetes.io/instance={self.release_name}",
),
)
except ApiException as e:
_raise_runtime_error(
"Failed to list pods.", release=self.release_name, from_exception=e
)
if not pods.items:
_raise_runtime_error("No pods found.", release=self.release_name)
sandboxes = dict()
for pod in pods.items:
service_name = pod.metadata.labels.get("inspect/service")
# Depending on the Helm chart, some Pods may not have a service label.
# These should not be considered to be a sandbox pod (as per our docs).
if service_name is not None:
default_container_name = pod.spec.containers[0].name
sandboxes[service_name] = Pod(
pod.metadata.name, self._namespace, default_container_name
)
return sandboxes

async def _install(self, upgrade: bool) -> None:
# Whilst `upgrade --install` could always be used, prefer explicitly using
# `install` for the first attempt.
subcommand = ["upgrade", "--install"] if upgrade else ["install"]
values = ["--values", str(self._values_path)] if self._values_path else []
result = await _run_subprocess(
"helm",
subcommand
+ [
self.release_name,
str(self._chart_path),
"--namespace",
self._namespace,
"--wait",
"--timeout",
f"{_get_timeout()}s",
"--set",
# Annotation do not have strict length reqs. Quoting/escaping
# handled by asyncio.create_subprocess_exec.
f"annotations.inspectTaskName={self.task_name}",
]
+ values,
capture_output=True,
)
if not result.success:
self._raise_install_error(result)

def _raise_install_error(self, result: ExecResult[str]) -> NoReturn:
# When concurrent helm operations are modifying the same resource quota, the
# following error occasionally occurs. Retry.
if re.search(
r"Operation cannot be fulfilled on resourcequotas \".*\": the object has "
r"been modified; please apply your changes to the latest version and try "
r"again",
result.stderr,
):
sandbox_log(
"resourcequota modified error whilst installing helm chart.",
release=self.release_name,
error=result.stderr,
)
raise _ResourceQuotaModifiedError(result.stderr)
_raise_runtime_error(
"Helm install failed.", release=self.release_name, result=result
)


async def uninstall(release_name: str, quiet: bool) -> None:
namespace = get_current_context_namespace()
async with _uninstall_semaphore():
sandbox_log(
"Uninstalling helm release.", release=release_name, namespace=namespace
)
result = await _run_subprocess(
"helm",
[
"uninstall",
release_name,
"--namespace",
namespace,
"--wait",
"--timeout",
f"{_get_timeout()}s",
],
capture_output=quiet,
)
if not result.success:
captured_output = result.stdout if not quiet else "not captured"
_raise_runtime_error(
"Helm uninstall failed.", release=release_name, result=captured_output
)


def _raise_runtime_error(
message: str, from_exception: Exception | None = None, **kwargs: Any
) -> NoReturn:
formatted = format_log_message(message, **kwargs)
logger.error(formatted)
if from_exception:
raise RuntimeError(formatted) from from_exception
else:
raise RuntimeError(formatted)


async def _run_subprocess(
cmd: str, args: list[str], capture_output: bool
) -> ExecResult[str]:
proc = await asyncio.create_subprocess_exec(
cmd,
*args,
stdout=asyncio.subprocess.PIPE if capture_output else None,
stderr=asyncio.subprocess.PIPE if capture_output else None,
)
stdout, stderr = await proc.communicate()
return ExecResult(
success=proc.returncode == 0,
returncode=proc.returncode or 1,
stdout=stdout.decode() if stdout else "",
stderr=stderr.decode() if stderr else "",
)


def _get_timeout() -> int:
if user_configured_timeout := os.environ.get("INSPECT_HELM_TIMEOUT"):
timeout_int = int(user_configured_timeout)
if timeout_int <= 0:
raise ValueError(
"INSPECT_HELM_TIMEOUT must be a positive int: "
f"{user_configured_timeout}"
)
return timeout_int
return DEFAULT_TIMEOUT


def _install_semaphore() -> asyncio.Semaphore:
# Limit concurrent subprocess calls to `helm install` and `helm uninstall`.
# Use distinct semaphores for each operation to avoid deadlocks where all permits
# are acquired by the "install" operations which are waiting for cluster resources
# to be released by the "uninstall" operations.
# Use Inspect's concurrency function as this ensures each asyncio.Semaphore is
# unique per event loop.
return concurrency("helm-install", _get_environ_int("INSPECT_MAX_HELM_INSTALL", 8))


def _uninstall_semaphore() -> asyncio.Semaphore:
return concurrency(
"helm-uninstall", _get_environ_int("INSPECT_MAX_HELM_UNINSTALL", 8)
)


def _get_environ_int(name: str, default: int) -> int:
try:
return int(os.environ[name])
except (KeyError, ValueError):
return default
44 changes: 44 additions & 0 deletions src/k8s_sandbox/_kubernetes_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from __future__ import annotations

import logging
import threading

from kubernetes import client, config # type: ignore

logger = logging.getLogger(__name__)

_thread_local = threading.local()
_load_config_lock = threading.Lock()
_config_loaded = False


def k8s_client() -> client.CoreV1Api:
"""
Gets a thread-local Kubernetes client.
This function is thread-safe and ensures that the Kubernetes configuration is
loaded.
A Kubernetes client cannot be used simultaneously from multiple threads (which are
used because the kubernetes client is not async).
"""
_ensure_config_loaded()
if not hasattr(_thread_local, "client"):
_thread_local.client = client.CoreV1Api()
return _thread_local.client


def get_current_context_namespace() -> str:
"""Get the current context's namespace from the Kubernetes configuration."""
_ensure_config_loaded()
_, current_ctx = config.list_kube_config_contexts()
namespace = current_ctx["context"]["namespace"]
assert isinstance(namespace, str)
return namespace


def _ensure_config_loaded() -> None:
with _load_config_lock:
global _config_loaded
if not _config_loaded:
config.load_kube_config()
_config_loaded = True
Loading

0 comments on commit 01318da

Please sign in to comment.