Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(backend): Ensure validity of OAuth credentials during graph execution #8191

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
ffcf200
feat(platform/executor): Refresh OAuth credentials before starting a …
Pwuts Sep 26, 2024
5850b13
feat(libs): Make `SupabaseIntegrationCredentialsStore` thread safe
Pwuts Sep 27, 2024
1f4d0ad
feat(backend): Add `IntegrationCredentialsManager`
Pwuts Sep 27, 2024
ffeb605
feat(executor): Change credential injection mechanism to acquire cred…
Pwuts Sep 27, 2024
478b59c
Merge branch 'master' into reinier/open-1891-ensure-oauth-credentials…
Pwuts Sep 27, 2024
8267679
Remove unused `_get_provider_oauth_handler` from `executor.manager`
Pwuts Sep 27, 2024
5f3ef42
fix process forking issue in `AgentServer`
Pwuts Sep 27, 2024
da29cce
remove broken `agent_server_client` caching
Pwuts Sep 27, 2024
b464e76
lint
Pwuts Sep 28, 2024
dadd262
Merge branch 'master' into reinier/open-1891-ensure-oauth-credentials…
Pwuts Sep 30, 2024
2095326
fix refresh check concurrency issue in `IntegrationCredentialsManager…
Pwuts Oct 1, 2024
94fba3c
Extract annotation type peeling logic from `@expose`
Pwuts Oct 1, 2024
fab700b
Merge branch 'master' into reinier/open-1891-ensure-oauth-credentials…
Pwuts Oct 2, 2024
c4b7533
Merge branch 'master' into reinier/open-1891-ensure-oauth-credentials…
Pwuts Oct 9, 2024
3b43949
Minimize cross-thread lock access
Pwuts Oct 9, 2024
1f5c261
Merge branch 'master' into reinier/open-1891-ensure-oauth-credentials…
Pwuts Oct 9, 2024
31ed844
fix lock scopes
Pwuts Oct 9, 2024
d509060
fix circular import issues
Pwuts Oct 9, 2024
08e9e1e
move `IntegrationCredentialsManager` to `backend.integrations.creds_s…
Pwuts Oct 9, 2024
b43f0a9
lint
Pwuts Oct 9, 2024
cdee453
add notes
Pwuts Oct 10, 2024
519e396
Merge branch 'master' into reinier/open-1891-ensure-oauth-credentials…
ntindle Oct 10, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from .store import SupabaseIntegrationCredentialsStore
from .types import APIKeyCredentials, OAuth2Credentials
from .types import Credentials, APIKeyCredentials, OAuth2Credentials

__all__ = [
"SupabaseIntegrationCredentialsStore",
"Credentials",
"APIKeyCredentials",
"OAuth2Credentials",
]
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
import secrets
from datetime import datetime, timedelta, timezone
from typing import cast
from typing import TYPE_CHECKING, cast

from supabase import Client
if TYPE_CHECKING:
from redis import Redis
from supabase import Client

from autogpt_libs.utils.synchronize import RedisKeyedMutex

from .types import (
Credentials,
Expand All @@ -14,26 +18,28 @@


class SupabaseIntegrationCredentialsStore:
def __init__(self, supabase: Client):
def __init__(self, supabase: "Client", redis: "Redis"):
self.supabase = supabase
self.locks = RedisKeyedMutex(redis)

def add_creds(self, user_id: str, credentials: Credentials) -> None:
if self.get_creds_by_id(user_id, credentials.id):
raise ValueError(
f"Can not re-create existing credentials with ID {credentials.id} "
f"for user with ID {user_id}"
with self.locked_user_metadata(user_id):
if self.get_creds_by_id(user_id, credentials.id):
raise ValueError(
f"Can not re-create existing credentials #{credentials.id} "
f"for user #{user_id}"
)
self._set_user_integration_creds(
user_id, [*self.get_all_creds(user_id), credentials]
)
self._set_user_integration_creds(
user_id, [*self.get_all_creds(user_id), credentials]
)

def get_all_creds(self, user_id: str) -> list[Credentials]:
user_metadata = self._get_user_metadata(user_id)
return UserMetadata.model_validate(user_metadata).integration_credentials

def get_creds_by_id(self, user_id: str, credentials_id: str) -> Credentials | None:
credentials = self.get_all_creds(user_id)
return next((c for c in credentials if c.id == credentials_id), None)
all_credentials = self.get_all_creds(user_id)
return next((c for c in all_credentials if c.id == credentials_id), None)

def get_creds_by_provider(self, user_id: str, provider: str) -> list[Credentials]:
credentials = self.get_all_creds(user_id)
Expand All @@ -44,42 +50,45 @@ def get_authorized_providers(self, user_id: str) -> list[str]:
return list(set(c.provider for c in credentials))

def update_creds(self, user_id: str, updated: Credentials) -> None:
current = self.get_creds_by_id(user_id, updated.id)
if not current:
raise ValueError(
f"Credentials with ID {updated.id} "
f"for user with ID {user_id} not found"
)
if type(current) is not type(updated):
raise TypeError(
f"Can not update credentials with ID {updated.id} "
f"from type {type(current)} "
f"to type {type(updated)}"
)

# Ensure no scopes are removed when updating credentials
if (
isinstance(updated, OAuth2Credentials)
and isinstance(current, OAuth2Credentials)
and not set(updated.scopes).issuperset(current.scopes)
):
raise ValueError(
f"Can not update credentials with ID {updated.id} "
f"and scopes {current.scopes} "
f"to more restrictive set of scopes {updated.scopes}"
)

# Update the credentials
updated_credentials_list = [
updated if c.id == updated.id else c for c in self.get_all_creds(user_id)
]
self._set_user_integration_creds(user_id, updated_credentials_list)
with self.locked_user_metadata(user_id):
current = self.get_creds_by_id(user_id, updated.id)
if not current:
raise ValueError(
f"Credentials with ID {updated.id} "
f"for user with ID {user_id} not found"
)
if type(current) is not type(updated):
raise TypeError(
f"Can not update credentials with ID {updated.id} "
f"from type {type(current)} "
f"to type {type(updated)}"
)

# Ensure no scopes are removed when updating credentials
if (
isinstance(updated, OAuth2Credentials)
and isinstance(current, OAuth2Credentials)
and not set(updated.scopes).issuperset(current.scopes)
):
raise ValueError(
f"Can not update credentials with ID {updated.id} "
f"and scopes {current.scopes} "
f"to more restrictive set of scopes {updated.scopes}"
)

# Update the credentials
updated_credentials_list = [
updated if c.id == updated.id else c
for c in self.get_all_creds(user_id)
]
self._set_user_integration_creds(user_id, updated_credentials_list)

def delete_creds_by_id(self, user_id: str, credentials_id: str) -> None:
filtered_credentials = [
c for c in self.get_all_creds(user_id) if c.id != credentials_id
]
self._set_user_integration_creds(user_id, filtered_credentials)
with self.locked_user_metadata(user_id):
filtered_credentials = [
c for c in self.get_all_creds(user_id) if c.id != credentials_id
]
self._set_user_integration_creds(user_id, filtered_credentials)

async def store_state_token(
self, user_id: str, provider: str, scopes: list[str]
Expand All @@ -94,14 +103,15 @@ async def store_state_token(
scopes=scopes,
)

user_metadata = self._get_user_metadata(user_id)
oauth_states = user_metadata.get("integration_oauth_states", [])
oauth_states.append(state.model_dump())
user_metadata["integration_oauth_states"] = oauth_states
with self.locked_user_metadata(user_id):
user_metadata = self._get_user_metadata(user_id)
oauth_states = user_metadata.get("integration_oauth_states", [])
oauth_states.append(state.model_dump())
user_metadata["integration_oauth_states"] = oauth_states

self.supabase.auth.admin.update_user_by_id(
user_id, {"user_metadata": user_metadata}
)
self.supabase.auth.admin.update_user_by_id(
user_id, {"user_metadata": user_metadata}
)

return token

Expand Down Expand Up @@ -136,29 +146,30 @@ async def get_any_valid_scopes_from_state_token(
return []

async def verify_state_token(self, user_id: str, token: str, provider: str) -> bool:
user_metadata = self._get_user_metadata(user_id)
oauth_states = user_metadata.get("integration_oauth_states", [])

now = datetime.now(timezone.utc)
valid_state = next(
(
state
for state in oauth_states
if state["token"] == token
and state["provider"] == provider
and state["expires_at"] > now.timestamp()
),
None,
)

if valid_state:
# Remove the used state
oauth_states.remove(valid_state)
user_metadata["integration_oauth_states"] = oauth_states
self.supabase.auth.admin.update_user_by_id(
user_id, {"user_metadata": user_metadata}
with self.locked_user_metadata(user_id):
user_metadata = self._get_user_metadata(user_id)
oauth_states = user_metadata.get("integration_oauth_states", [])

now = datetime.now(timezone.utc)
valid_state = next(
(
state
for state in oauth_states
if state["token"] == token
and state["provider"] == provider
and state["expires_at"] > now.timestamp()
),
None,
)
return True

if valid_state:
# Remove the used state
oauth_states.remove(valid_state)
user_metadata["integration_oauth_states"] = oauth_states
self.supabase.auth.admin.update_user_by_id(
user_id, {"user_metadata": user_metadata}
)
return True

return False

Expand All @@ -178,3 +189,7 @@ def _get_user_metadata(self, user_id: str) -> UserMetadataRaw:
if not response.user:
raise ValueError(f"User with ID {user_id} not found")
return cast(UserMetadataRaw, response.user.user_metadata)

def locked_user_metadata(self, user_id: str):
key = (self.supabase.supabase_url, f"user:{user_id}", "metadata")
return self.locks.locked(key)
56 changes: 56 additions & 0 deletions autogpt_platform/autogpt_libs/autogpt_libs/utils/synchronize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from contextlib import contextmanager
from threading import Lock
from typing import TYPE_CHECKING, Any

from expiringdict import ExpiringDict

if TYPE_CHECKING:
from redis import Redis
from redis.lock import Lock as RedisLock


class RedisKeyedMutex:
"""
This class provides a mutex that can be locked and unlocked by a specific key,
using Redis as a distributed locking provider.
It uses an ExpiringDict to automatically clear the mutex after a specified timeout,
in case the key is not unlocked for a specified duration, to prevent memory leaks.
"""

def __init__(self, redis: "Redis", timeout: int | None = 60):
self.redis = redis
self.timeout = timeout
self.locks: dict[Any, "RedisLock"] = ExpiringDict(
max_len=6000, max_age_seconds=self.timeout
)
self.locks_lock = Lock()

@contextmanager
def locked(self, key: Any):
lock = self.acquire(key)
try:
yield
finally:
lock.release()

def acquire(self, key: Any) -> "RedisLock":
"""Acquires and returns a lock with the given key"""
with self.locks_lock:
if key not in self.locks:
self.locks[key] = self.redis.lock(
str(key), self.timeout, thread_local=False
)
lock = self.locks[key]
lock.acquire()
return lock

def release(self, key: Any):
if lock := self.locks.get(key):
ntindle marked this conversation as resolved.
Show resolved Hide resolved
lock.release()

def release_all_locks(self):
"""Call this on process termination to ensure all locks are released"""
self.locks_lock.acquire(blocking=False)
ntindle marked this conversation as resolved.
Show resolved Hide resolved
for lock in self.locks.values():
if lock.locked() and lock.owned():
lock.release()
36 changes: 35 additions & 1 deletion autogpt_platform/autogpt_libs/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 4 additions & 0 deletions autogpt_platform/autogpt_libs/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ packages = [{ include = "autogpt_libs" }]

[tool.poetry.dependencies]
colorama = "^0.4.6"
expiringdict = "^1.2.2"
google-cloud-logging = "^3.8.0"
pydantic = "^2.8.2"
pydantic-settings = "^2.5.2"
Expand All @@ -16,6 +17,9 @@ python = ">=3.10,<4.0"
python-dotenv = "^1.0.1"
supabase = "^2.7.2"

[tool.poetry.group.dev.dependencies]
redis = "^5.0.8"

[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"
3 changes: 2 additions & 1 deletion autogpt_platform/backend/backend/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ def main(**kwargs):
"""

from backend.executor import ExecutionManager, ExecutionScheduler
from backend.server import AgentServer, WebsocketServer
from backend.server.rest_api import AgentServer
from backend.server.ws_api import WebsocketServer

run_processes(
ExecutionManager(),
Expand Down
2 changes: 0 additions & 2 deletions autogpt_platform/backend/backend/data/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from multiprocessing import Manager
from typing import Any, Generic, TypeVar

from autogpt_libs.supabase_integration_credentials_store.types import Credentials
from prisma.enums import AgentExecutionStatus
from prisma.models import (
AgentGraphExecution,
Expand All @@ -26,7 +25,6 @@ class GraphExecution(BaseModel):
graph_exec_id: str
graph_id: str
start_node_execs: list["NodeExecution"]
node_input_credentials: dict[str, Credentials] # dict[node_id, Credentials]


class NodeExecution(BaseModel):
Expand Down
Loading
Loading