Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(backend): RedisEventQueue into Pub/Sub #8387

Merged
merged 8 commits into from
Oct 27, 2024
140 changes: 117 additions & 23 deletions autogpt_platform/backend/backend/data/queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,18 @@
import logging
from abc import ABC, abstractmethod
from datetime import datetime
from typing import Any, AsyncGenerator, Generator, Generic, TypeVar

from pydantic import BaseModel
from redis.asyncio.client import PubSub as AsyncPubSub
from redis.client import PubSub

from backend.data import redis
from backend.data.execution import ExecutionResult
from backend.util.settings import Config

logger = logging.getLogger(__name__)
config = Config()


class DateTimeEncoder(json.JSONEncoder):
Expand All @@ -16,35 +23,122 @@ def default(self, o):
return super().default(o)


class AbstractEventQueue(ABC):
@abstractmethod
def put(self, execution_result: ExecutionResult):
pass
M = TypeVar("M", bound=BaseModel)


class BaseRedisEventBus(Generic[M], ABC):
Model: type[M]

@property
@abstractmethod
def get(self) -> ExecutionResult | None:
def event_bus_name(self) -> str:
pass

def _serialize_message(self, item: M, channel_key: str) -> tuple[str, str]:
message = json.dumps(item.model_dump(), cls=DateTimeEncoder)
channel_name = f"{self.event_bus_name}-{channel_key}"
logger.info(f"[{channel_name}] Publishing an event to Redis {message}")
return message, channel_name

def _deserialize_message(self, msg: Any, channel_key: str) -> M | None:
message_type = "pmessage" if "*" in channel_key else "message"
if msg["type"] != message_type:
return None
try:
data = json.loads(msg["data"])
logger.info(f"Consuming an event from Redis {data}")
return self.Model(**data)
except Exception as e:
logger.error(f"Failed to parse event result from Redis {msg} {e}")

def _subscribe(
self, connection: redis.Redis | redis.AsyncRedis, channel_key: str
) -> tuple[PubSub | AsyncPubSub, str]:
channel_name = f"{self.event_bus_name}-{channel_key}"
pubsub = connection.pubsub()
return pubsub, channel_name

class RedisEventQueue(AbstractEventQueue):
def __init__(self):
self.queue_name = redis.QUEUE_NAME

class RedisEventBus(BaseRedisEventBus[M], ABC):
Model: type[M]

@property
def connection(self):
def connection(self) -> redis.Redis:
return redis.get_redis()

def put(self, execution_result: ExecutionResult):
message = json.dumps(execution_result.model_dump(), cls=DateTimeEncoder)
logger.info(f"Putting execution result to Redis {message}")
self.connection.lpush(self.queue_name, message)

def get(self) -> ExecutionResult | None:
message = self.connection.rpop(self.queue_name)
if message is not None and isinstance(message, (str, bytes, bytearray)):
data = json.loads(message)
logger.info(f"Getting execution result from Redis {data}")
return ExecutionResult(**data)
elif message is not None:
logger.error(f"Failed to get execution result from Redis {message}")
return None
def publish_event(self, event: M, channel_key: str):
message, channel_name = self._serialize_message(event, channel_key)
self.connection.publish(channel_name, message)

def listen_events(self, channel_key: str) -> Generator[M, None, None]:
pubsub, channel_name = self._subscribe(self.connection, channel_key)
assert isinstance(pubsub, PubSub)

if "*" in channel_key:
pubsub.psubscribe(channel_name)
else:
pubsub.subscribe(channel_name)

for message in pubsub.listen():
if event := self._deserialize_message(message, channel_key):
yield event


class AsyncRedisEventBus(BaseRedisEventBus[M], ABC):
Model: type[M]

@property
async def connection(self) -> redis.AsyncRedis:
return await redis.get_redis_async()

async def publish_event(self, event: M, channel_key: str):
message, channel_name = self._serialize_message(event, channel_key)
connection = await self.connection
await connection.publish(channel_name, message)

async def listen_events(self, channel_key: str) -> AsyncGenerator[M, None]:
pubsub, channel_name = self._subscribe(await self.connection, channel_key)
assert isinstance(pubsub, AsyncPubSub)

if "*" in channel_key:
await pubsub.psubscribe(channel_name)
else:
await pubsub.subscribe(channel_name)

async for message in pubsub.listen():
if event := self._deserialize_message(message, channel_key):
yield event


class RedisExecutionEventBus(RedisEventBus[ExecutionResult]):
Model = ExecutionResult

@property
def event_bus_name(self) -> str:
return config.execution_event_bus_name

def publish(self, res: ExecutionResult):
self.publish_event(res, f"{res.graph_id}-{res.graph_exec_id}")

def listen(
self, graph_id: str = "*", graph_exec_id: str = "*"
) -> Generator[ExecutionResult, None, None]:
for execution_result in self.listen_events(f"{graph_id}-{graph_exec_id}"):
yield execution_result


class AsyncRedisExecutionEventBus(AsyncRedisEventBus[ExecutionResult]):
Model = ExecutionResult

@property
def event_bus_name(self) -> str:
return config.execution_event_bus_name

async def publish(self, res: ExecutionResult):
await self.publish_event(res, f"{res.graph_id}-{res.graph_exec_id}")

async def listen(
self, graph_id: str = "*", graph_exec_id: str = "*"
) -> AsyncGenerator[ExecutionResult, None]:
async for execution_result in self.listen_events(f"{graph_id}-{graph_exec_id}"):
yield execution_result
46 changes: 41 additions & 5 deletions autogpt_platform/backend/backend/data/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from dotenv import load_dotenv
from redis import Redis
from redis.asyncio import Redis as AsyncRedis

from backend.util.retry import conn_retry

Expand All @@ -11,10 +12,10 @@
HOST = os.getenv("REDIS_HOST", "localhost")
PORT = int(os.getenv("REDIS_PORT", "6379"))
PASSWORD = os.getenv("REDIS_PASSWORD", "password")
QUEUE_NAME = os.getenv("REDIS_QUEUE", "execution_events")

logger = logging.getLogger(__name__)
connection: Redis | None = None
connection_async: AsyncRedis | None = None


@conn_retry("Redis", "Acquiring connection")
Expand Down Expand Up @@ -42,7 +43,42 @@ def disconnect():
connection = None


def get_redis() -> Redis:
if not connection:
raise RuntimeError("Redis connection is not established")
return connection
def get_redis(auto_connect: bool = True) -> Redis:
if connection:
return connection
if auto_connect:
return connect()
raise RuntimeError("Redis connection is not established")


@conn_retry("AsyncRedis", "Acquiring connection")
async def connect_async() -> AsyncRedis:
global connection_async
if connection_async:
return connection_async

c = AsyncRedis(
host=HOST,
port=PORT,
password=PASSWORD,
decode_responses=True,
)
await c.ping()
connection_async = c
return connection_async


@conn_retry("AsyncRedis", "Releasing connection")
async def disconnect_async():
global connection_async
if connection_async:
await connection_async.close()
connection_async = None


async def get_redis_async(auto_connect: bool = True) -> AsyncRedis:
if connection_async:
return connection_async
if auto_connect:
return await connect_async()
raise RuntimeError("AsyncRedis connection is not established")
6 changes: 3 additions & 3 deletions autogpt_platform/backend/backend/executor/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
upsert_execution_output,
)
from backend.data.graph import get_graph, get_node
from backend.data.queue import RedisEventQueue
from backend.data.queue import RedisExecutionEventBus
from backend.data.user import get_user_metadata, update_user_metadata
from backend.util.service import AppService, expose
from backend.util.settings import Config
Expand All @@ -30,15 +30,15 @@ def __init__(self):
super().__init__()
self.use_db = True
self.use_redis = True
self.event_queue = RedisEventQueue()
self.event_queue = RedisExecutionEventBus()

@classmethod
def get_port(cls) -> int:
return Config().database_api_port

@expose
def send_execution_update(self, execution_result_dict: dict[Any, Any]):
self.event_queue.put(ExecutionResult(**execution_result_dict))
self.event_queue.publish(ExecutionResult(**execution_result_dict))

@staticmethod
def exposed_run_and_wait(
Expand Down
12 changes: 4 additions & 8 deletions autogpt_platform/backend/backend/server/ws_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from fastapi.middleware.cors import CORSMiddleware

from backend.data import redis
from backend.data.queue import RedisEventQueue
from backend.data.queue import AsyncRedisExecutionEventBus
from backend.data.user import DEFAULT_USER_ID
from backend.server.conn_manager import ConnectionManager
from backend.server.model import ExecutionSubscription, Methods, WsMessage
Expand Down Expand Up @@ -51,13 +51,9 @@ def get_connection_manager():
async def event_broadcaster(manager: ConnectionManager):
try:
redis.connect()
event_queue = RedisEventQueue()
while True:
event = event_queue.get()
if event:
await manager.send_execution_result(event)
else:
await asyncio.sleep(0.1)
event_queue = AsyncRedisExecutionEventBus()
async for event in event_queue.listen():
await manager.send_execution_result(event)
except Exception as e:
logger.exception(f"Event broadcaster error: {e}")
raise
Expand Down
5 changes: 5 additions & 0 deletions autogpt_platform/backend/backend/util/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,11 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
description="What environment to behave as: local or cloud",
)

execution_event_bus_name: str = Field(
default="execution_event",
description="Name of the event bus",
)

backend_cors_allow_origins: List[str] = Field(default_factory=list)

@field_validator("backend_cors_allow_origins")
Expand Down
Loading