Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(backend): RedisEventQueue into Pub/Sub #8387

Merged
merged 8 commits into from
Oct 27, 2024
91 changes: 66 additions & 25 deletions autogpt_platform/backend/backend/data/queue.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import json
import logging
from abc import ABC, abstractmethod
from datetime import datetime
from typing import Any, AsyncGenerator, Generator

from backend.data import redis
from backend.data.execution import ExecutionResult
from backend.util.settings import Config

logger = logging.getLogger(__name__)
config = Config()


class DateTimeEncoder(json.JSONEncoder):
Expand All @@ -16,35 +18,74 @@ def default(self, o):
return super().default(o)


class AbstractEventQueue(ABC):
@abstractmethod
def put(self, execution_result: ExecutionResult):
pass
class BaseRedisEventBus:
majdyz marked this conversation as resolved.
Show resolved Hide resolved

@abstractmethod
def get(self) -> ExecutionResult | None:
pass


class RedisEventQueue(AbstractEventQueue):
def __init__(self):
self.queue_name = redis.QUEUE_NAME

@property
def connection(self):
return redis.get_redis()
self.event_bus_name = config.execution_event_bus_name

def put(self, execution_result: ExecutionResult):
def _serialize_message(self, execution_result: ExecutionResult) -> tuple[str, str]:
message = json.dumps(execution_result.model_dump(), cls=DateTimeEncoder)
logger.info(f"Putting execution result to Redis {message}")
self.connection.lpush(self.queue_name, message)
channel_name = f"{self.event_bus_name}-{execution_result.graph_id}-{execution_result.graph_exec_id}"
return message, channel_name

def get(self) -> ExecutionResult | None:
message = self.connection.rpop(self.queue_name)
if message is not None and isinstance(message, (str, bytes, bytearray)):
data = json.loads(message)
@staticmethod
def _deserialize_message(msg: Any) -> ExecutionResult | None:
if msg["type"] not in ("message", "pmessage"):
return None
try:
data = json.loads(msg["data"])
logger.info(f"Getting execution result from Redis {data}")
return ExecutionResult(**data)
elif message is not None:
logger.error(f"Failed to get execution result from Redis {message}")
return None
except Exception as e:
logger.error(f"Failed to get execution result from Redis {msg} {e}")


class RedisEventBus(BaseRedisEventBus):
@property
def connection(self) -> redis.Redis:
return redis.get_redis()

def publish(self, execution_result: ExecutionResult):
message, channel_name = self._serialize_message(execution_result)
self.connection.publish(channel_name, message)

def listen(
self, graph_id: str = "*", execution_id: str = "*"
) -> Generator[ExecutionResult, None, None]:
pubsub = self.connection.pubsub()
channel_name = f"{self.event_bus_name}-{graph_id}-{execution_id}"
if "*" in channel_name:
pubsub.psubscribe(channel_name)
else:
pubsub.subscribe(channel_name)

for message in pubsub.listen():
if execution_result := self._deserialize_message(message):
yield execution_result


class AsyncRedisEventBus(BaseRedisEventBus):
@property
async def connection(self) -> redis.AsyncRedis:
return await redis.get_redis_async()

async def publish(self, execution_result: ExecutionResult):
message, channel_name = self._serialize_message(execution_result)
connection = await self.connection
await connection.publish(channel_name, message)

async def listen(
self, graph_id: str = "*", execution_id: str = "*"
) -> AsyncGenerator[ExecutionResult, None]:
connection = await self.connection
pubsub = connection.pubsub()
channel_name = f"{self.event_bus_name}-{graph_id}-{execution_id}"
if "*" in channel_name:
await pubsub.psubscribe(channel_name)
else:
await pubsub.subscribe(channel_name)

async for message in pubsub.listen():
if execution_result := self._deserialize_message(message):
yield execution_result
46 changes: 41 additions & 5 deletions autogpt_platform/backend/backend/data/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from dotenv import load_dotenv
from redis import Redis
from redis.asyncio import Redis as AsyncRedis

from backend.util.retry import conn_retry

Expand All @@ -11,10 +12,10 @@
HOST = os.getenv("REDIS_HOST", "localhost")
PORT = int(os.getenv("REDIS_PORT", "6379"))
PASSWORD = os.getenv("REDIS_PASSWORD", "password")
QUEUE_NAME = os.getenv("REDIS_QUEUE", "execution_events")

logger = logging.getLogger(__name__)
connection: Redis | None = None
connection_async: AsyncRedis | None = None


@conn_retry("Redis", "Acquiring connection")
Expand Down Expand Up @@ -42,7 +43,42 @@ def disconnect():
connection = None


def get_redis() -> Redis:
if not connection:
raise RuntimeError("Redis connection is not established")
return connection
def get_redis(auto_connect: bool = True) -> Redis:
if connection:
return connection
if auto_connect:
return connect()
raise RuntimeError("Redis connection is not established")


@conn_retry("AsyncRedis", "Acquiring connection")
async def connect_async() -> AsyncRedis:
global connection_async
if connection_async:
return connection_async

c = AsyncRedis(
host=HOST,
port=PORT,
password=PASSWORD,
decode_responses=True,
)
await c.ping()
connection_async = c
return connection_async


@conn_retry("AsyncRedis", "Releasing connection")
async def disconnect_async():
global connection_async
if connection_async:
await connection_async.close()
connection_async = None


async def get_redis_async(auto_connect: bool = True) -> AsyncRedis:
if connection_async:
return connection_async
if auto_connect:
return await connect_async()
raise RuntimeError("AsyncRedis connection is not established")
6 changes: 3 additions & 3 deletions autogpt_platform/backend/backend/executor/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
upsert_execution_output,
)
from backend.data.graph import get_graph, get_node
from backend.data.queue import RedisEventQueue
from backend.data.queue import RedisEventBus
from backend.util.service import AppService, expose
from backend.util.settings import Config

Expand All @@ -29,11 +29,11 @@ def __init__(self):
super().__init__(port=Config().database_api_port)
self.use_db = True
self.use_redis = True
self.event_queue = RedisEventQueue()
self.event_queue = RedisEventBus()

@expose
def send_execution_update(self, execution_result_dict: dict[Any, Any]):
self.event_queue.put(ExecutionResult(**execution_result_dict))
self.event_queue.publish(ExecutionResult(**execution_result_dict))

@staticmethod
def exposed_run_and_wait(
Expand Down
12 changes: 4 additions & 8 deletions autogpt_platform/backend/backend/server/ws_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from fastapi.middleware.cors import CORSMiddleware

from backend.data import redis
from backend.data.queue import RedisEventQueue
from backend.data.queue import AsyncRedisEventBus
from backend.data.user import DEFAULT_USER_ID
from backend.server.conn_manager import ConnectionManager
from backend.server.model import ExecutionSubscription, Methods, WsMessage
Expand Down Expand Up @@ -51,13 +51,9 @@ def get_connection_manager():
async def event_broadcaster(manager: ConnectionManager):
try:
redis.connect()
event_queue = RedisEventQueue()
while True:
event = event_queue.get()
if event:
await manager.send_execution_result(event)
else:
await asyncio.sleep(0.1)
event_queue = AsyncRedisEventBus()
async for event in event_queue.listen():
await manager.send_execution_result(event)
except Exception as e:
logger.exception(f"Event broadcaster error: {e}")
raise
Expand Down
5 changes: 5 additions & 0 deletions autogpt_platform/backend/backend/util/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,11 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
description="What environment to behave as: local or cloud",
)

execution_event_bus_name: str = Field(
default="execution_event",
description="Name of the event bus",
)

backend_cors_allow_origins: List[str] = Field(default_factory=list)

@field_validator("backend_cors_allow_origins")
Expand Down
Loading