Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(feat): Enabled mTLS in ExtractorAgent, Fixes #989 #1010

Merged
merged 14 commits into from
Nov 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 88 additions & 0 deletions python-sdk/indexify/common_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
from typing import Optional, Union

import httpx
import yaml
from httpx import AsyncClient, Client


def get_httpx_client(
Default2882 marked this conversation as resolved.
Show resolved Hide resolved
config_path: Optional[str] = None,
make_async: Optional[bool] = False
) -> AsyncClient | Client:
"""
Creates and returns an httpx.Client instance, optionally configured with TLS settings from a YAML config file.

The function creates a basic httpx.Client by default. If a config path is provided and the config specifies
'use_tls' as True, it creates a TLS-enabled client with HTTP/2 support using the provided certificate settings.
To get https.AsyncClient, provide make_async as True.

Args:
config_path (Optional[str]): Path to a YAML configuration file. If provided, the file should contain TLS
configuration with the following structure:
{
"use_tls": bool,
"tls_config": {
"cert_path": str, # Path to client certificate
"key_path": str, # Path to client private key
"ca_bundle_path": str # Optional: Path to CA bundle for verification
}
}
make_async (Optional[bool]): Whether to make an asynchronous httpx client instance.

Returns:
httpx.Client: An initialized httpx client instance, either with basic settings or TLS configuration
if specified in the config file.

Example:
# Basic client without TLS
client = get_httpx_client()

# Client with TLS configuration from config file
client = get_httpx_client("/path/to/config.yaml")
"""
if config_path:
with open(config_path, "r") as file:
config = yaml.safe_load(file)
if config.get("use_tls", False):
print(f'Configuring client with TLS config: {config}')
tls_config = config["tls_config"]
return get_sync_or_async_client(make_async, **tls_config)
return get_sync_or_async_client(make_async)

def get_sync_or_async_client(
make_async: Optional[bool] = False,
cert_path: Optional[str] = None,
key_path: Optional[str] = None,
ca_bundle_path: Optional[str] = None,
) -> AsyncClient | Client:
"""
Creates and returns either a synchronous or asynchronous httpx client with optional TLS configuration.

Args:
make_async (Optional[bool]): If True, returns an AsyncClient; if False, returns a synchronous Client.
Defaults to False.
cert_path (Optional[str]): Path to the client certificate file. Required for TLS configuration
when key_path is also provided.
key_path (Optional[str]): Path to the client private key file. Required for TLS configuration
when cert_path is also provided.
ca_bundle_path (Optional[str]): Path to the CA bundle file for certificate verification.
If not provided, defaults to system CA certificates.
"""
if make_async:
if cert_path and key_path:
return httpx.AsyncClient(
http2=True,
cert=(cert_path, key_path),
verify=ca_bundle_path if ca_bundle_path else True,
)
else:
return httpx.AsyncClient()
else:
if cert_path and key_path:
return httpx.Client(
http2=True,
cert=(cert_path, key_path),
verify=ca_bundle_path if ca_bundle_path else True
)
else:
return httpx.Client()
55 changes: 19 additions & 36 deletions python-sdk/indexify/executor/agent.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
import asyncio
import json
import ssl
import traceback
from concurrent.futures.process import BrokenProcessPool
from importlib.metadata import version
from typing import Dict, List, Optional
from pathlib import Path

import httpx
import yaml
from httpx_sse import aconnect_sse
from pydantic import BaseModel
from rich.console import Console
Expand All @@ -21,6 +19,7 @@
)
from indexify.functions_sdk.graph_definition import ComputeGraphMetadata
from indexify.http_client import IndexifyClient
from indexify.common_util import get_httpx_client

from ..functions_sdk.image import ImageInformation
from . import image_dependency_installer
Expand Down Expand Up @@ -58,15 +57,15 @@ def __init__(
self,
executor_id: str,
num_workers,
code_path: str,
code_path: Path,
server_addr: str = "localhost:8900",
config_path: Optional[str] = None,
name_alias: Optional[str] = None,
image_version: Optional[int] = None,
):
self.name_alias = name_alias
self.image_version = image_version

self._config_path = config_path
self._probe = RuntimeProbes()

runtime_probe: ProbeInfo = self._probe.probe()
Expand All @@ -82,39 +81,21 @@ def __init__(
)

self.num_workers = num_workers
self._use_tls = False
if config_path:
with open(config_path, "r") as f:
config = yaml.safe_load(f)
self._config = config
if config.get("use_tls", False):
console.print(
"Running the extractor with TLS enabled", style="cyan bold"
)
self._use_tls = True
tls_config = config["tls_config"]
self._ssl_context = ssl.create_default_context(
ssl.Purpose.SERVER_AUTH, cafile=tls_config["ca_bundle_path"]
)
self._ssl_context.load_cert_chain(
certfile=tls_config["cert_path"], keyfile=tls_config["key_path"]
)
self._protocol = "wss"
self._tls_config = tls_config
else:
self._ssl_context = None
self._protocol = "ws"
console.print(
"Running the extractor with TLS enabled", style="cyan bold"
)
self._protocol = "https"
else:
self._ssl_context = None
self._protocol = "http"
self._config = {}

self._task_store: TaskStore = TaskStore()
self._executor_id = executor_id
self._function_worker = FunctionWorker(
workers=num_workers,
indexify_client=IndexifyClient(
service_url=f"{self._protocol}://{server_addr}"
service_url=f"{self._protocol}://{server_addr}",
config_path=config_path,
),
)
self._has_registered = False
Expand All @@ -124,7 +105,9 @@ def __init__(
self._downloader = Downloader(code_path=code_path, base_url=self._base_url)
self._max_queued_tasks = 10
self._task_reporter = TaskReporter(
base_url=self._base_url, executor_id=self._executor_id
base_url=self._base_url,
executor_id=self._executor_id,
config_path=self._config_path
)

async def task_completion_reporter(self):
Expand Down Expand Up @@ -197,7 +180,7 @@ async def task_launcher(self):
if self._require_image_bootstrap:
try:
image_info = await _get_image_info_for_compute_graph(
task, self._protocol, self._server_addr
task, self._protocol, self._server_addr, self._config_path
)
image_dependency_installer.executor_image_builder(
image_info, self.name_alias, self.image_version
Expand Down Expand Up @@ -365,8 +348,8 @@ async def run(self):
asyncio.create_task(self.task_completion_reporter())
self._should_run = True
while self._should_run:
self._protocol = "http"
url = f"{self._protocol}://{self._server_addr}/internal/executors/{self._executor_id}/tasks"
print(f'calling url: {url}')

def to_sentence_case(snake_str):
words = snake_str.split("_")
Expand Down Expand Up @@ -407,9 +390,8 @@ def to_sentence_case(snake_str):
border_style="cyan",
)
)

try:
async with httpx.AsyncClient() as client:
async with get_httpx_client(self._config_path, True) as client:
async with aconnect_sse(
client,
"POST",
Expand Down Expand Up @@ -453,14 +435,15 @@ def shutdown(self, loop):


async def _get_image_info_for_compute_graph(
task: Task, protocol, server_addr
task: Task, protocol, server_addr, config_path: str
) -> ImageInformation:
namespace = task.namespace
graph_name: str = task.compute_graph
compute_fn_name: str = task.compute_fn

http_client = IndexifyClient(
service_url=f"{protocol}://{server_addr}", namespace=namespace
service_url=f"{protocol}://{server_addr}", namespace=namespace,
config_path=config_path
)
compute_graph: ComputeGraphMetadata = http_client.graph(graph_name)

Expand Down
8 changes: 5 additions & 3 deletions python-sdk/indexify/executor/downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from indexify.functions_sdk.object_serializer import MsgPackSerializer

from .api_objects import Task
from ..common_util import get_httpx_client

custom_theme = Theme(
{
Expand All @@ -29,9 +30,10 @@ class DownloadedInputs(BaseModel):


class Downloader:
def __init__(self, code_path: str, base_url: str):
def __init__(self, code_path: str, base_url: str, config_path: Optional[str] = None):
self.code_path = code_path
self.base_url = base_url
self._client = get_httpx_client(config_path)

async def download_graph(self, namespace: str, name: str, version: int) -> str:
path = os.path.join(self.code_path, namespace, f"{name}.{version}")
Expand All @@ -46,7 +48,7 @@ async def download_graph(self, namespace: str, name: str, version: int) -> str:
)
)

response = httpx.get(
response = self._client.get(
f"{self.base_url}/internal/namespaces/{namespace}/compute_graphs/{name}/code"
)
try:
Expand Down Expand Up @@ -85,7 +87,7 @@ async def download_input(self, task: Task) -> IndexifyData:
)
)

response = httpx.get(url)
response = self._client.get(url)
try:
response.raise_for_status()
except httpx.HTTPStatusError as e:
Expand Down
6 changes: 4 additions & 2 deletions python-sdk/indexify/executor/task_reporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import nanoid
from rich import print

from indexify.common_util import get_httpx_client
from indexify.executor.api_objects import RouterOutput as ApiRouterOutput
from indexify.executor.api_objects import Task, TaskResult
from indexify.executor.task_store import CompletedTask
Expand All @@ -22,9 +23,10 @@ def __bool__(self):


class TaskReporter:
def __init__(self, base_url: str, executor_id: str):
def __init__(self, base_url: str, executor_id: str, config_path: Optional[str] = None):
self._base_url = base_url
self._executor_id = executor_id
self._client = get_httpx_client(config_path)

def report_task_outcome(self, completed_task: CompletedTask):
fn_outputs = []
Expand Down Expand Up @@ -84,7 +86,7 @@ def report_task_outcome(self, completed_task: CompletedTask):
else:
kwargs["files"] = FORCE_MULTIPART
try:
response = httpx.post(
response = self._client.post(
url=f"{self._base_url}/internal/ingest_files",
**kwargs,
)
Expand Down
36 changes: 12 additions & 24 deletions python-sdk/indexify/http_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
import cloudpickle
import httpx
import msgpack
import yaml
from httpx_sse import connect_sse
from pydantic import BaseModel, Json
from rich import print

from indexify.common_util import get_httpx_client, get_sync_or_async_client
from indexify.error import ApiException, GraphStillProcessing
from indexify.functions_sdk.data_objects import IndexifyData
from indexify.functions_sdk.graph import ComputeGraphMetadata, Graph
Expand Down Expand Up @@ -55,18 +55,8 @@ def __init__(
service_url = os.environ["INDEXIFY_URL"]

self.service_url = service_url
self._client = httpx.Client()
if config_path:
with open(config_path, "r") as file:
config = yaml.safe_load(file)

if config.get("use_tls", False):
tls_config = config["tls_config"]
self._client = httpx.Client(
http2=True,
cert=(tls_config["cert_path"], tls_config["key_path"]),
verify=tls_config.get("ca_bundle_path", True),
)
self._config_path = config_path
self._client = get_httpx_client(config_path)

self.namespace: str = namespace
self.compute_graphs: List[Graph] = []
Expand Down Expand Up @@ -135,17 +125,15 @@ def with_mtls(
if not (cert_path and key_path):
raise ValueError("Both cert and key must be provided for mTLS")

client_certs = (cert_path, key_path)
verify_option = ca_bundle_path if ca_bundle_path else True
client = IndexifyClient(
*args,
**kwargs,
service_url=service_url,
http2=True,
cert=client_certs,
verify=verify_option,
client = get_sync_or_async_client(
cert_path=cert_path,
key_path=key_path,
ca_bundle_path=ca_bundle_path
)
return client

indexify_client = IndexifyClient(service_url, *args, **kwargs)
indexify_client._client = client
return indexify_client

def _add_api_key(self, kwargs):
if self._api_key:
Expand Down Expand Up @@ -279,7 +267,7 @@ def invoke_graph_with_object(
"params": params,
}
self._add_api_key(kwargs)
with httpx.Client() as client:
with get_httpx_client(self._config_path) as client:
with connect_sse(
client,
"POST",
Expand Down
27 changes: 27 additions & 0 deletions python-sdk/tests/test_constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from indexify.executor.api_objects import Task

tls_config = {
"use_tls": True,
"tls_config": {
"ca_bundle_path": "/path/to/ca_bundle.pem",
"cert_path": "/path/to/cert.pem",
"key_path": "/path/to/key.pem"
}
}

cert_path = tls_config["tls_config"]["cert_path"]
key_path = tls_config["tls_config"]["key_path"]
ca_bundle_path = tls_config["tls_config"]["ca_bundle_path"]
service_url = "localhost:8900"
config_path = "test/config/path"
code_path = "test/code_path"
task = Task(
id = "test_id",
namespace = "default",
compute_graph = "test_compute_graph",
compute_fn = "test_compute_fn",
invocation_id = "test_invocation_id",
input_key = "test|input|key",
requester_output_id = "test_output_id",
graph_version = 1,
)
Loading
Loading