Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(backend): Add migrations to fix credentials inputs #8674

Merged
merged 4 commits into from
Nov 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 82 additions & 0 deletions autogpt_platform/backend/backend/data/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from datetime import datetime, timezone
from typing import Any, Literal, Type

import prisma
from prisma.models import AgentGraph, AgentGraphExecution, AgentNode, AgentNodeLink
from prisma.types import AgentGraphWhereInput
from pydantic.fields import computed_field
Expand Down Expand Up @@ -528,3 +529,84 @@ async def __create_graph(tx, graph: Graph, user_id: str):
for link in graph.links
]
)


# ------------------------ UTILITIES ------------------------ #


async def fix_llm_provider_credentials():
"""Fix node credentials with provider `llm`"""
from autogpt_libs.supabase_integration_credentials_store import (
SupabaseIntegrationCredentialsStore,
)

from .redis import get_redis
from .user import get_user_integrations

store = SupabaseIntegrationCredentialsStore(get_redis())

broken_nodes = await prisma.get_client().query_raw(
"""
SELECT user.id user_id,
node.id node_id,
node."constantInput" node_preset_input
FROM platform."AgentGraph" graph
LEFT JOIN platform."AgentNode" node
ON node."agentGraphId" = graph.id
LEFT JOIN platform."User" user
ON graph."userId" = user.id
WHERE node."constantInput"::jsonb->'credentials'->>'provider' = 'llm'
ORDER BY user_id;
"""
)
logger.info(f"Fixing LLM credential inputs on {len(broken_nodes)} nodes")

user_id: str = ""
user_integrations = None
for node in broken_nodes:
if node["user_id"] != user_id:
# Save queries by only fetching once per user
user_id = node["user_id"]
user_integrations = await get_user_integrations(user_id)
elif not user_integrations:
raise RuntimeError(f"Impossible state while processing node {node}")

node_id: str = node["node_id"]
node_preset_input: dict = json.loads(node["node_preset_input"])
credentials_meta: dict = node_preset_input["credentials"]

credentials = next(
(
c
for c in user_integrations.credentials
if c.id == credentials_meta["id"]
),
None,
)
if not credentials:
continue
if credentials.type != "api_key":
logger.warning(
f"User {user_id} credentials {credentials.id} with provider 'llm' "
f"has invalid type '{credentials.type}'"
)
continue

api_key = credentials.api_key.get_secret_value()
if api_key.startswith("sk-ant-api03-"):
credentials.provider = credentials_meta["provider"] = "anthropic"
elif api_key.startswith("sk-"):
credentials.provider = credentials_meta["provider"] = "openai"
elif api_key.startswith("gsk_"):
credentials.provider = credentials_meta["provider"] = "groq"
else:
logger.warning(
f"Could not identify provider from key prefix {api_key[:13]}*****"
)
continue

store.update_creds(user_id, credentials)
await AgentNode.prisma().update(
where={"id": node_id},
data={"constantInput": json.dumps(node_preset_input)},
)
2 changes: 2 additions & 0 deletions autogpt_platform/backend/backend/server/rest_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import backend.data.block
import backend.data.db
import backend.data.graph
import backend.data.user
import backend.server.routers.v1
import backend.util.service
Expand All @@ -23,6 +24,7 @@ async def lifespan_context(app: fastapi.FastAPI):
await backend.data.db.connect()
await backend.data.block.initialize_blocks()
await backend.data.user.migrate_and_encrypt_user_integrations()
await backend.data.graph.fix_llm_provider_credentials()
yield
await backend.data.db.disconnect()

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
-- Correct credentials.provider field on all nodes with 'llm' provider credentials
UPDATE "AgentNode"
SET "constantInput" = JSONB_SET(
"constantInput"::jsonb,
'{credentials,provider}',
CASE
WHEN "constantInput"::jsonb->'credentials'->>'id' = '53c25cb8-e3ee-465c-a4d1-e75a4c899c2a' THEN '"openai"'::jsonb
WHEN "constantInput"::jsonb->'credentials'->>'id' = '24e5d942-d9e3-4798-8151-90143ee55629' THEN '"anthropic"'::jsonb
WHEN "constantInput"::jsonb->'credentials'->>'id' = '4ec22295-8f97-4dd1-b42b-2c6957a02545' THEN '"groq"'::jsonb
ELSE ("constantInput"::jsonb->'credentials'->>'provider')::jsonb
END
)::text
WHERE "constantInput"::jsonb->'credentials'->>'provider' = 'llm';
Loading