Skip to content

Commit

Permalink
fix(backend): Fix DatabaseManager usage by calling it on-demand (#8404)
Browse files Browse the repository at this point in the history
  • Loading branch information
majdyz authored Oct 23, 2024
1 parent 7f31868 commit 17e79ad
Show file tree
Hide file tree
Showing 11 changed files with 67 additions and 41 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from redis import Redis
from backend.executor.database import DatabaseManager

from autogpt_libs.utils.cache import thread_cached_property
from autogpt_libs.utils.synchronize import RedisKeyedMutex

from .types import (
Expand All @@ -18,9 +19,14 @@


class SupabaseIntegrationCredentialsStore:
def __init__(self, redis: "Redis", db: "DatabaseManager"):
self.db_manager: DatabaseManager = db
def __init__(self, redis: "Redis"):
self.locks = RedisKeyedMutex(redis)

@thread_cached_property
def db_manager(self) -> "DatabaseManager":
from backend.executor.database import DatabaseManager
from backend.util.service import get_service_client
return get_service_client(DatabaseManager)

def add_creds(self, user_id: str, credentials: Credentials) -> None:
with self.locked_user_metadata(user_id):
Expand Down
6 changes: 5 additions & 1 deletion autogpt_platform/backend/backend/executor/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,15 @@
class DatabaseManager(AppService):

def __init__(self):
super().__init__(port=Config().database_api_port)
super().__init__()
self.use_db = True
self.use_redis = True
self.event_queue = RedisEventQueue()

@classmethod
def get_port(cls) -> int:
return Config().database_api_port

@expose
def send_execution_update(self, execution_result_dict: dict[Any, Any]):
self.event_queue.put(ExecutionResult(**execution_result_dict))
Expand Down
15 changes: 10 additions & 5 deletions autogpt_platform/backend/backend/executor/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
if TYPE_CHECKING:
from backend.executor import DatabaseManager

from autogpt_libs.utils.cache import thread_cached

from backend.data import redis
from backend.data.block import Block, BlockData, BlockInput, BlockType, get_block
from backend.data.execution import (
Expand All @@ -31,7 +33,6 @@
from backend.data.model import CREDENTIALS_FIELD_NAME, CredentialsMetaInput
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.util import json
from backend.util.cache import thread_cached
from backend.util.decorator import error_logged, time_measured
from backend.util.logging import configure_logging
from backend.util.process import set_service_name
Expand Down Expand Up @@ -419,7 +420,7 @@ def on_node_executor_start(cls):
redis.connect()
cls.pid = os.getpid()
cls.db_client = get_db_client()
cls.creds_manager = IntegrationCredentialsManager(db_manager=cls.db_client)
cls.creds_manager = IntegrationCredentialsManager()

# Set up shutdown handlers
cls.shutdown_lock = threading.Lock()
Expand Down Expand Up @@ -659,20 +660,24 @@ def callback(_):
class ExecutionManager(AppService):

def __init__(self):
super().__init__(port=settings.config.execution_manager_port)
super().__init__()
self.use_redis = True
self.use_supabase = True
self.pool_size = settings.config.num_graph_workers
self.queue = ExecutionQueue[GraphExecution]()
self.active_graph_runs: dict[str, tuple[Future, threading.Event]] = {}

@classmethod
def get_port(cls) -> int:
return settings.config.execution_manager_port

def run_service(self):
from autogpt_libs.supabase_integration_credentials_store import (
SupabaseIntegrationCredentialsStore,
)

self.credentials_store = SupabaseIntegrationCredentialsStore(
redis=redis.get_redis(), db=self.db_client
redis=redis.get_redis()
)
self.executor = ProcessPoolExecutor(
max_workers=self.pool_size,
Expand Down Expand Up @@ -863,7 +868,7 @@ def _validate_node_input_credentials(self, graph: Graph, user_id: str):
def get_db_client() -> "DatabaseManager":
from backend.executor import DatabaseManager

return get_service_client(DatabaseManager, settings.config.database_api_port)
return get_service_client(DatabaseManager)


@contextmanager
Expand Down
10 changes: 7 additions & 3 deletions autogpt_platform/backend/backend/executor/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from apscheduler.schedulers.background import BackgroundScheduler
from apscheduler.triggers.cron import CronTrigger
from autogpt_libs.utils.cache import thread_cached_property

from backend.data.block import BlockInput
from backend.data.schedule import (
Expand All @@ -14,7 +15,6 @@
update_schedule,
)
from backend.executor.manager import ExecutionManager
from backend.util.cache import thread_cached_property
from backend.util.service import AppService, expose, get_service_client
from backend.util.settings import Config

Expand All @@ -28,14 +28,18 @@ def log(msg, **kwargs):
class ExecutionScheduler(AppService):

def __init__(self, refresh_interval=10):
super().__init__(port=Config().execution_scheduler_port)
super().__init__()
self.use_db = True
self.last_check = datetime.min
self.refresh_interval = refresh_interval

@classmethod
def get_port(cls) -> int:
return Config().execution_scheduler_port

@thread_cached_property
def execution_client(self) -> ExecutionManager:
return get_service_client(ExecutionManager, Config().execution_manager_port)
return get_service_client(ExecutionManager)

def run_service(self):
scheduler = BackgroundScheduler()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from redis.lock import Lock as RedisLock

from backend.data import redis
from backend.executor.database import DatabaseManager
from backend.integrations.oauth import HANDLERS_BY_NAME, BaseOAuthHandler
from backend.util.settings import Settings

Expand Down Expand Up @@ -50,12 +49,10 @@ class IntegrationCredentialsManager:
cause so much latency that it's worth implementing.
"""

def __init__(self, db_manager: DatabaseManager):
def __init__(self):
redis_conn = redis.get_redis()
self._locks = RedisKeyedMutex(redis_conn)
self.store = SupabaseIntegrationCredentialsStore(
redis=redis_conn, db=db_manager
)
self.store = SupabaseIntegrationCredentialsStore(redis=redis_conn)

def create(self, user_id: str, credentials: Credentials) -> None:
return self.store.add_creds(user_id, credentials)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from fastapi import APIRouter, Body, Depends, HTTPException, Path, Query, Request
from pydantic import BaseModel, Field, SecretStr

from backend.executor.manager import get_db_client
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.integrations.oauth import HANDLERS_BY_NAME, BaseOAuthHandler
from backend.util.settings import Settings
Expand All @@ -21,7 +20,7 @@
settings = Settings()
router = APIRouter()

creds_manager = IntegrationCredentialsManager(db_manager=get_db_client())
creds_manager = IntegrationCredentialsManager()


class LoginResponse(BaseModel):
Expand Down
15 changes: 8 additions & 7 deletions autogpt_platform/backend/backend/server/rest_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import uvicorn
from autogpt_libs.auth.middleware import auth_middleware
from autogpt_libs.utils.cache import thread_cached_property
from fastapi import APIRouter, Body, Depends, FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
Expand All @@ -19,10 +20,7 @@
from backend.data.credit import get_block_costs, get_user_credit_model
from backend.data.user import get_or_create_user
from backend.executor import ExecutionManager, ExecutionScheduler
from backend.executor.manager import get_db_client
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.server.model import CreateGraph, SetGraphActiveVersion
from backend.util.cache import thread_cached_property
from backend.util.service import AppService, get_service_client
from backend.util.settings import AppEnvironment, Config, Settings

Expand All @@ -37,9 +35,13 @@ class AgentServer(AppService):
_user_credit_model = get_user_credit_model()

def __init__(self):
super().__init__(port=Config().agent_server_port)
super().__init__()
self.use_redis = True

@classmethod
def get_port(cls) -> int:
return Config().agent_server_port

@asynccontextmanager
async def lifespan(self, _: FastAPI):
await db.connect()
Expand Down Expand Up @@ -98,7 +100,6 @@ def run_service(self):
tags=["integrations"],
dependencies=[Depends(auth_middleware)],
)
self.integration_creds_manager = IntegrationCredentialsManager(get_db_client())

api_router.include_router(
backend.server.routers.analytics.router,
Expand Down Expand Up @@ -308,11 +309,11 @@ async def wrapper(*args, **kwargs):

@thread_cached_property
def execution_manager_client(self) -> ExecutionManager:
return get_service_client(ExecutionManager, Config().execution_manager_port)
return get_service_client(ExecutionManager)

@thread_cached_property
def execution_scheduler_client(self) -> ExecutionScheduler:
return get_service_client(ExecutionScheduler, Config().execution_scheduler_port)
return get_service_client(ExecutionScheduler)

@classmethod
def handle_internal_http_error(cls, request: Request, exc: Exception):
Expand Down
28 changes: 19 additions & 9 deletions autogpt_platform/backend/backend/util/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import threading
import time
import typing
from abc import ABC, abstractmethod
from enum import Enum
from types import NoneType, UnionType
from typing import (
Expand Down Expand Up @@ -99,16 +100,24 @@ def custom_dict_to_class(qualname, data: dict):
return custom_dict_to_class


class AppService(AppProcess):
class AppService(AppProcess, ABC):
shared_event_loop: asyncio.AbstractEventLoop
use_db: bool = False
use_redis: bool = False
use_supabase: bool = False

def __init__(self, port):
self.port = port
def __init__(self):
self.uri = None

@classmethod
@abstractmethod
def get_port(cls) -> int:
pass

@classmethod
def get_host(cls) -> str:
return os.environ.get(f"{cls.service_name.upper()}_HOST", Config().pyro_host)

def run_service(self) -> None:
while True:
time.sleep(10)
Expand Down Expand Up @@ -157,8 +166,7 @@ def cleanup(self):

@conn_retry("Pyro", "Starting Pyro Service")
def __start_pyro(self):
host = Config().pyro_host
daemon = Pyro5.api.Daemon(host=host, port=self.port)
daemon = Pyro5.api.Daemon(host=self.get_host(), port=self.get_port())
self.uri = daemon.register(self, objectId=self.service_name)
logger.info(f"[{self.service_name}] Connected to Pyro; URI = {self.uri}")
daemon.requestLoop()
Expand All @@ -167,16 +175,20 @@ def __start_async_loop(self):
self.shared_event_loop.run_forever()


# --------- UTILITIES --------- #


AS = TypeVar("AS", bound=AppService)


def get_service_client(service_type: Type[AS], port: int) -> AS:
def get_service_client(service_type: Type[AS]) -> AS:
service_name = service_type.service_name

class DynamicClient:
@conn_retry("Pyro", f"Connecting to [{service_name}]")
def __init__(self):
host = os.environ.get(f"{service_name.upper()}_HOST", "localhost")
host = service_type.get_host()
port = service_type.get_port()
uri = f"PYRO:{service_type.service_name}@{host}:{port}"
logger.debug(f"Connecting to service [{service_name}]. URI = {uri}")
self.proxy = Pyro5.api.Proxy(uri)
Expand All @@ -191,8 +203,6 @@ def __getattr__(self, name: str) -> Callable[..., Any]:
return cast(AS, DynamicClient())


# --------- UTILITIES --------- #

builtin_types = [*vars(builtins).values(), NoneType, Enum]


Expand Down
6 changes: 1 addition & 5 deletions autogpt_platform/backend/test/executor/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from backend.server.model import CreateGraph
from backend.usecases.sample import create_test_graph, create_test_user
from backend.util.service import get_service_client
from backend.util.settings import Config
from backend.util.test import SpinTestServer


Expand All @@ -19,10 +18,7 @@ async def test_agent_schedule(server: SpinTestServer):
user_id=test_user.id,
)

scheduler = get_service_client(
ExecutionScheduler, Config().execution_scheduler_port
)

scheduler = get_service_client(ExecutionScheduler)
schedules = scheduler.get_execution_schedules(test_graph.id, test_user.id)
assert len(schedules) == 0

Expand Down
8 changes: 6 additions & 2 deletions autogpt_platform/backend/test/util/test_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@

class ServiceTest(AppService):
def __init__(self):
super().__init__(port=TEST_SERVICE_PORT)
super().__init__()

@classmethod
def get_port(cls) -> int:
return TEST_SERVICE_PORT

@expose
def add(self, a: int, b: int) -> int:
Expand All @@ -28,7 +32,7 @@ async def add_async(a: int, b: int) -> int:
@pytest.mark.asyncio(scope="session")
async def test_service_creation(server):
with ServiceTest():
client = get_service_client(ServiceTest, TEST_SERVICE_PORT)
client = get_service_client(ServiceTest)
assert client.add(5, 3) == 8
assert client.subtract(10, 4) == 6
assert client.fun_with_async(5, 3) == 8

0 comments on commit 17e79ad

Please sign in to comment.