Skip to content

Commit

Permalink
feat(libs): Add API key rate limit middleware (#8850)
Browse files Browse the repository at this point in the history
Once we release api key feature, we will want to be able to rate limit
as well. This is the foundation for that.
For now it is a blanket rate limit, later we will be able to add tiered
rate limits

### Changes 🏗️

Added new middleware libary in autogpt_libs which contains the logic for
getting the api key, storing it's details in redis and checking how many
requests it's done, how many are left and what the reset time is.

---------

Co-authored-by: Zamil Majdy <[email protected]>
Co-authored-by: Reinier van der Leer <[email protected]>
  • Loading branch information
3 people authored Dec 3, 2024
1 parent 7c2e371 commit 3bca279
Show file tree
Hide file tree
Showing 7 changed files with 118 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def feature_flag(
"""

def decorator(
func: Callable[P, Union[T, Awaitable[T]]]
func: Callable[P, Union[T, Awaitable[T]]],
) -> Callable[P, Union[T, Awaitable[T]]]:
@wraps(func)
async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@


class LoggingConfig(BaseSettings):

level: str = Field(
default="INFO",
description="Logging level",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@
),
("", ""),
("hello", "hello"),
("hello\x1B[31m world", "hello world"),
("\x1B[36mHello,\x1B[32m World!", "Hello, World!"),
("hello\x1b[31m world", "hello world"),
("\x1b[36mHello,\x1b[32m World!", "Hello, World!"),
(
"\x1B[1m\x1B[31mError:\x1B[0m\x1B[31m file not found",
"\x1b[1m\x1b[31mError:\x1b[0m\x1b[31m file not found",
"Error: file not found",
),
],
Expand Down
Empty file.
31 changes: 31 additions & 0 deletions autogpt_platform/autogpt_libs/autogpt_libs/rate_limit/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from pydantic import Field
from pydantic_settings import BaseSettings, SettingsConfigDict


class RateLimitSettings(BaseSettings):
redis_host: str = Field(
default="redis://localhost:6379",
description="Redis host",
validation_alias="REDIS_HOST",
)

redis_port: str = Field(
default="6379", description="Redis port", validation_alias="REDIS_PORT"
)

redis_password: str = Field(
default="password",
description="Redis password",
validation_alias="REDIS_PASSWORD",
)

requests_per_minute: int = Field(
default=60,
description="Maximum number of requests allowed per minute per API key",
validation_alias="RATE_LIMIT_REQUESTS_PER_MINUTE",
)

model_config = SettingsConfigDict(case_sensitive=True, extra="ignore")


RATE_LIMIT_SETTINGS = RateLimitSettings()
51 changes: 51 additions & 0 deletions autogpt_platform/autogpt_libs/autogpt_libs/rate_limit/limiter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import time
from typing import Tuple

from redis import Redis

from .config import RATE_LIMIT_SETTINGS


class RateLimiter:
def __init__(
self,
redis_host: str = RATE_LIMIT_SETTINGS.redis_host,
redis_port: str = RATE_LIMIT_SETTINGS.redis_port,
redis_password: str = RATE_LIMIT_SETTINGS.redis_password,
requests_per_minute: int = RATE_LIMIT_SETTINGS.requests_per_minute,
):
self.redis = Redis(
host=redis_host,
port=redis_port,
password=redis_password,
decode_responses=True,
)
self.window = 60
self.max_requests = requests_per_minute

async def check_rate_limit(self, api_key_id: str) -> Tuple[bool, int, int]:
"""
Check if request is within rate limits.
Args:
api_key_id: The API key identifier to check
Returns:
Tuple of (is_allowed, remaining_requests, reset_time)
"""
now = time.time()
window_start = now - self.window
key = f"ratelimit:{api_key_id}:1min"

pipe = self.redis.pipeline()
pipe.zremrangebyscore(key, 0, window_start)
pipe.zadd(key, {str(now): now})
pipe.zcount(key, window_start, now)
pipe.expire(key, self.window)

_, _, request_count, _ = pipe.execute()

remaining = max(0, self.max_requests - request_count)
reset_time = int(now + self.window)

return request_count <= self.max_requests, remaining, reset_time
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from fastapi import HTTPException, Request
from starlette.middleware.base import RequestResponseEndpoint

from .limiter import RateLimiter


async def rate_limit_middleware(request: Request, call_next: RequestResponseEndpoint):
"""FastAPI middleware for rate limiting API requests."""
limiter = RateLimiter()

if not request.url.path.startswith("/api"):
return await call_next(request)

api_key = request.headers.get("Authorization")
if not api_key:
return await call_next(request)

api_key = api_key.replace("Bearer ", "")

is_allowed, remaining, reset_time = await limiter.check_rate_limit(api_key)

if not is_allowed:
raise HTTPException(
status_code=429, detail="Rate limit exceeded. Please try again later."
)

response = await call_next(request)
response.headers["X-RateLimit-Limit"] = str(limiter.max_requests)
response.headers["X-RateLimit-Remaining"] = str(remaining)
response.headers["X-RateLimit-Reset"] = str(reset_time)

return response

0 comments on commit 3bca279

Please sign in to comment.