Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(platform): Simplify Credentials UX #8524

Merged
merged 16 commits into from
Nov 12, 2024
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -29,48 +29,55 @@
api_key=SecretStr(settings.secrets.revid_api_key),
title="Use Credits for Revid",
expires_at=None,
hidden=True,
)
ideogram_credentials = APIKeyCredentials(
id="760f84fc-b270-42de-91f6-08efe1b512d0",
provider="ideogram",
api_key=SecretStr(settings.secrets.ideogram_api_key),
title="Use Credits for Ideogram",
expires_at=None,
hidden=True,
)
replicate_credentials = APIKeyCredentials(
id="6b9fc200-4726-4973-86c9-cd526f5ce5db",
provider="replicate",
api_key=SecretStr(settings.secrets.replicate_api_key),
title="Use Credits for Replicate",
expires_at=None,
hidden=True,
)
openai_credentials = APIKeyCredentials(
id="53c25cb8-e3ee-465c-a4d1-e75a4c899c2a",
provider="llm",
provider="openai",
api_key=SecretStr(settings.secrets.openai_api_key),
title="Use Credits for OpenAI",
expires_at=None,
hidden=True,
)
anthropic_credentials = APIKeyCredentials(
id="24e5d942-d9e3-4798-8151-90143ee55629",
provider="llm",
provider="anthropic",
api_key=SecretStr(settings.secrets.anthropic_api_key),
title="Use Credits for Anthropic",
expires_at=None,
hidden=True,
)
groq_credentials = APIKeyCredentials(
id="4ec22295-8f97-4dd1-b42b-2c6957a02545",
provider="llm",
provider="groq",
api_key=SecretStr(settings.secrets.groq_api_key),
title="Use Credits for Groq",
expires_at=None,
hidden=True,
)
did_credentials = APIKeyCredentials(
id="7f7b0654-c36b-4565-8fa7-9a52575dfae2",
provider="d_id",
api_key=SecretStr(settings.secrets.did_api_key),
title="Use Credits for D-ID",
expires_at=None,
hidden=True,
)

DEFAULT_CREDENTIALS = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ class _BaseCredentials(BaseModel):
id: str = Field(default_factory=lambda: str(uuid4()))
provider: str
title: Optional[str]
hidden: Optional[bool] = False
Pwuts marked this conversation as resolved.
Show resolved Hide resolved
Pwuts marked this conversation as resolved.
Show resolved Hide resolved

@field_serializer("*")
def dump_secret_strings(value: Any, _info):
Expand Down
10 changes: 7 additions & 3 deletions autogpt_platform/backend/backend/blocks/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,11 @@
# "ollama": BlockSecret(value=""),
# }

AICredentials = CredentialsMetaInput[Literal["llm"], Literal["api_key"]]
AICredentials = CredentialsMetaInput[list[str], Literal["api_key"]]
Pwuts marked this conversation as resolved.
Show resolved Hide resolved

TEST_CREDENTIALS = APIKeyCredentials(
id="ed55ac19-356e-4243-a6cb-bc599e9b716f",
provider="llm",
provider="openai",
api_key=SecretStr("mock-openai-api-key"),
title="Mock OpenAI API key",
expires_at=None,
Expand All @@ -50,8 +50,12 @@
def AICredentialsField() -> AICredentials:
return CredentialsField(
description="API key for the LLM provider.",
provider="llm",
provider=["anthropic", "groq", "openai", "ollama"],
supported_credential_types={"api_key"},
discriminator="model",
discriminator_mapping={
model.value: model.metadata.provider for model in LlmModel
},
)


Expand Down
20 changes: 18 additions & 2 deletions autogpt_platform/backend/backend/data/model.py
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a note: this really has to be considered a tempfix. Scopes are provider-specific. So is supported_credential_types. The CredentialsField was made to handle credentials from a specific provider. If a block supports multiple providers, the discrimination logic should be handled outside the CredentialsField.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See also #8524 (comment)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(I'd like to leave this discussion here for reference in the follow-up PR)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, I'm happy for us to merge this as a much needed temp hotfix and we can adjust that in a follow up.

Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def SchemaField(
)


CP = TypeVar("CP", bound=str)
CP = TypeVar("CP", bound=str | list[str])
CT = TypeVar("CT", bound=CredentialsType)


Expand All @@ -156,6 +156,8 @@ def CredentialsField(
supported_credential_types: set[CT],
required_scopes: set[str] = set(),
*,
discriminator: Optional[str] = None,
discriminator_mapping: Optional[dict[str, Any]] = None,
title: Optional[str] = None,
description: Optional[str] = None,
**kwargs,
Expand All @@ -164,16 +166,30 @@ def CredentialsField(
`CredentialsField` must and can only be used on fields named `credentials`.
This is enforced by the `BlockSchema` base class.
"""
# Convert provider list to enum for schema
provider_list = [provider] if isinstance(provider, str) else provider

json_extra = {
k: v
for k, v in {
"credentials_provider": provider,
"credentials_provider": (
provider if isinstance(provider, str) else list(provider)
),
Pwuts marked this conversation as resolved.
Show resolved Hide resolved
"credentials_scopes": list(required_scopes) or None, # omit if empty
"credentials_types": list(supported_credential_types),
"discriminator": discriminator,
"discriminator_mapping": discriminator_mapping,
}.items()
if v is not None
}

# Force provider to always be a string
# Otherwise input provider will fail jsonschema validation
# if multiple providers are provided
json_extra.update(
{"properties": {"provider": {"type": "string", "enum": provider_list}}}
)

Pwuts marked this conversation as resolved.
Show resolved Hide resolved
return Field(
title=title,
description=description,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ class CredentialsMetaResponse(BaseModel):
title: str | None
scopes: list[str] | None
username: str | None
hidden: bool | None


@router.post("/{provider}/callback")
Expand Down Expand Up @@ -111,6 +112,7 @@ def callback(
title=credentials.title,
scopes=credentials.scopes,
username=credentials.username,
hidden=credentials.hidden,
)


Expand All @@ -127,6 +129,7 @@ def list_credentials(
title=cred.title,
scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None,
username=cred.username if isinstance(cred, OAuth2Credentials) else None,
hidden=cred.hidden,
)
for cred in credentials
]
Expand Down
16 changes: 7 additions & 9 deletions autogpt_platform/frontend/src/app/profile/page.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,7 @@ import { useToast } from "@/components/ui/use-toast";
import { IconKey, IconUser } from "@/components/ui/icons";
import { LogOutIcon, Trash2Icon } from "lucide-react";
import { providerIcons } from "@/components/integrations/credentials-input";
import {
CredentialsProviderName,
CredentialsProvidersContext,
} from "@/components/integrations/credentials-provider";
import { CredentialsProvidersContext } from "@/components/integrations/credentials-provider";
import {
Table,
TableBody,
Expand All @@ -23,6 +20,7 @@ import {
TableHeader,
TableRow,
} from "@/components/ui/table";
import { CredentialsProviderName } from "@/lib/autogpt-server-api";

export default function PrivatePage() {
const { user, isLoading, error } = useUser();
Expand Down Expand Up @@ -62,7 +60,7 @@ export default function PrivatePage() {
[providers, toast],
);

if (isLoading || !providers || !providers) {
if (isLoading || !providers) {
return (
<div className="flex h-[80vh] items-center justify-center">
<FaSpinner className="mr-2 h-16 w-16 animate-spin" />
Expand All @@ -76,15 +74,15 @@ export default function PrivatePage() {
}

const allCredentials = Object.values(providers).flatMap((provider) =>
[...provider.savedOAuthCredentials, ...provider.savedApiKeys].map(
(credentials) => ({
[...provider.savedOAuthCredentials, ...provider.savedApiKeys]
.filter((cred) => !cred.hidden)
.map((credentials) => ({
...credentials,
provider: provider.provider,
providerName: provider.providerName,
ProviderIcon: providerIcons[provider.provider],
TypeIcon: { oauth2: IconUser, api_key: IconKey }[credentials.type],
}),
),
})),
Pwuts marked this conversation as resolved.
Show resolved Hide resolved
);

return (
Expand Down
68 changes: 13 additions & 55 deletions autogpt_platform/frontend/src/components/CustomNode.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,13 @@ import {
BlockUIType,
BlockCost,
} from "@/lib/autogpt-server-api/types";
import { beautifyString, cn, setNestedProperty } from "@/lib/utils";
import {
beautifyString,
cn,
getValue,
parseKeys,
setNestedProperty,
} from "@/lib/utils";
import { Button } from "@/components/ui/button";
import { Switch } from "@/components/ui/switch";
import { history } from "./history";
Expand All @@ -36,8 +42,6 @@ import * as Separator from "@radix-ui/react-separator";
import * as ContextMenu from "@radix-ui/react-context-menu";
import { DotsVerticalIcon, TrashIcon, CopyIcon } from "@radix-ui/react-icons";

type ParsedKey = { key: string; index?: number };

export type ConnectionData = Array<{
edge_id: string;
source: string;
Expand Down Expand Up @@ -181,7 +185,7 @@ export function CustomNode({
nodeId={id}
propKey={propKey}
propSchema={propSchema}
currentValue={getValue(propKey)}
currentValue={getValue(propKey, data.hardcodedValues)}
connections={data.connections}
handleInputChange={handleInputChange}
handleInputClick={handleInputClick}
Expand All @@ -204,7 +208,7 @@ export function CustomNode({
className=""
selfKey={noteKey}
schema={noteSchema as BlockIOStringSubSchema}
value={getValue(noteKey)}
value={getValue(noteKey, data.hardcodedValues)}
handleInputChange={handleInputChange}
handleInputClick={handleInputClick}
error={data.errors?.[noteKey] ?? ""}
Expand Down Expand Up @@ -240,7 +244,7 @@ export function CustomNode({
nodeId={id}
propKey={propKey}
propSchema={propSchema}
currentValue={getValue(propKey)}
currentValue={getValue(propKey, data.hardcodedValues)}
connections={data.connections}
handleInputChange={handleInputChange}
handleInputClick={handleInputClick}
Expand All @@ -261,11 +265,7 @@ export function CustomNode({
return (
(isRequired || isAdvancedOpen || isConnected || !isAdvanced) && (
<div key={propKey} data-id={`input-handle-${propKey}`}>
{"credentials_provider" in propSchema ? (
<span className="text-m green mb-0 text-gray-900">
Credentials
</span>
) : (
{!("credentials_provider" in propSchema) && (
<NodeHandle
keyName={propKey}
isConnected={isConnected}
Expand All @@ -279,7 +279,7 @@ export function CustomNode({
nodeId={id}
propKey={propKey}
propSchema={propSchema}
currentValue={getValue(propKey)}
currentValue={getValue(propKey, data.hardcodedValues)}
connections={data.connections}
handleInputChange={handleInputChange}
handleInputClick={handleInputClick}
Expand Down Expand Up @@ -336,48 +336,6 @@ export function CustomNode({
setErrors({ ...errors });
};

// Helper function to parse keys with array indices
//TODO move to utils
const parseKeys = (key: string): ParsedKey[] => {
const splits = key.split(/_@_|_#_|_\$_|\./);
const keys: ParsedKey[] = [];
let currentKey: string | null = null;

splits.forEach((split) => {
const isInteger = /^\d+$/.test(split);
if (!isInteger) {
if (currentKey !== null) {
keys.push({ key: currentKey });
}
currentKey = split;
} else {
if (currentKey !== null) {
keys.push({ key: currentKey, index: parseInt(split, 10) });
currentKey = null;
} else {
throw new Error("Invalid key format: array index without a key");
}
}
});

if (currentKey !== null) {
keys.push({ key: currentKey });
}

return keys;
};

const getValue = (key: string) => {
const keys = parseKeys(key);
return keys.reduce((acc, k) => {
if (acc === undefined) return undefined;
if (k.index !== undefined) {
return Array.isArray(acc[k.key]) ? acc[k.key][k.index] : undefined;
}
return acc[k.key];
}, data.hardcodedValues as any);
};

const isHandleConnected = (key: string) => {
return (
data.connections &&
Expand All @@ -400,7 +358,7 @@ export function CustomNode({
const handleInputClick = (key: string) => {
console.debug(`Opening modal for key: ${key}`);
setActiveKey(key);
const value = getValue(key);
const value = getValue(key, data.hardcodedValues);
setInputModalValue(
typeof value === "object" ? JSON.stringify(value, null, 2) : value,
);
Expand Down
Loading
Loading