Skip to content

Commit

Permalink
feat(platform): Simplify Credentials UX (#8524)
Browse files Browse the repository at this point in the history
- Change `provider` of default credentials to actual provider names (e.g. `anthropic`), remove `llm` provider
- Add `discriminator` and `discriminator_mapping` to `CredentialsField` that allows to filter credentials input to only allow  providers for matching models in `useCredentials` hook (thanks @ntindle for the idea!); e.g. user chooses `GPT4_TURBO` so then only OpenAI credentials are allowed
- Choose credentials automatically and hide credentials input on the node completely if there's only one possible option
- Move `getValue` and `parseKeys` to utils
- Add `ANTHROPIC`, `GROQ` and `OLLAMA` to providers in frontend `types.ts`
- Add `hidden` field to credentials that is used for default system keys to hide them in user profile
- Now `provider` field in `CredentialsField` can accept multiple providers as a list

-----------------
Co-authored-by: Nicholas Tindle <[email protected]>
Co-authored-by: Reinier van der Leer <[email protected]>
  • Loading branch information
kcze authored Nov 12, 2024
1 parent ef7e504 commit e907ffd
Show file tree
Hide file tree
Showing 11 changed files with 168 additions and 82 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -46,21 +46,21 @@
)
openai_credentials = APIKeyCredentials(
id="53c25cb8-e3ee-465c-a4d1-e75a4c899c2a",
provider="llm",
provider="openai",
api_key=SecretStr(settings.secrets.openai_api_key),
title="Use Credits for OpenAI",
expires_at=None,
)
anthropic_credentials = APIKeyCredentials(
id="24e5d942-d9e3-4798-8151-90143ee55629",
provider="llm",
provider="anthropic",
api_key=SecretStr(settings.secrets.anthropic_api_key),
title="Use Credits for Anthropic",
expires_at=None,
)
groq_credentials = APIKeyCredentials(
id="4ec22295-8f97-4dd1-b42b-2c6957a02545",
provider="llm",
provider="groq",
api_key=SecretStr(settings.secrets.groq_api_key),
title="Use Credits for Groq",
expires_at=None,
Expand Down
11 changes: 8 additions & 3 deletions autogpt_platform/backend/backend/blocks/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,12 @@
# "ollama": BlockSecret(value=""),
# }

AICredentials = CredentialsMetaInput[Literal["llm"], Literal["api_key"]]
LLMProviderName = Literal["anthropic", "groq", "openai", "ollama"]
AICredentials = CredentialsMetaInput[LLMProviderName, Literal["api_key"]]

TEST_CREDENTIALS = APIKeyCredentials(
id="ed55ac19-356e-4243-a6cb-bc599e9b716f",
provider="llm",
provider="openai",
api_key=SecretStr("mock-openai-api-key"),
title="Mock OpenAI API key",
expires_at=None,
Expand All @@ -50,8 +51,12 @@
def AICredentialsField() -> AICredentials:
return CredentialsField(
description="API key for the LLM provider.",
provider="llm",
provider=["anthropic", "groq", "openai", "ollama"],
supported_credential_types={"api_key"},
discriminator="model",
discriminator_mapping={
model.value: model.metadata.provider for model in LlmModel
},
)


Expand Down
10 changes: 8 additions & 2 deletions autogpt_platform/backend/backend/data/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,10 +152,12 @@ class CredentialsMetaInput(BaseModel, Generic[CP, CT]):


def CredentialsField(
provider: CP,
provider: CP | list[CP],
supported_credential_types: set[CT],
required_scopes: set[str] = set(),
*,
discriminator: Optional[str] = None,
discriminator_mapping: Optional[dict[str, Any]] = None,
title: Optional[str] = None,
description: Optional[str] = None,
**kwargs,
Expand All @@ -167,9 +169,13 @@ def CredentialsField(
json_extra = {
k: v
for k, v in {
"credentials_provider": provider,
"credentials_provider": (
[provider] if isinstance(provider, str) else provider
),
"credentials_scopes": list(required_scopes) or None, # omit if empty
"credentials_types": list(supported_credential_types),
"discriminator": discriminator,
"discriminator_mapping": discriminator_mapping,
}.items()
if v is not None
}
Expand Down
33 changes: 23 additions & 10 deletions autogpt_platform/frontend/src/app/profile/page.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,14 @@ import { useSupabase } from "@/components/SupabaseProvider";
import { Button } from "@/components/ui/button";
import useUser from "@/hooks/useUser";
import { useRouter } from "next/navigation";
import { useCallback, useContext } from "react";
import { useCallback, useContext, useMemo } from "react";
import { FaSpinner } from "react-icons/fa";
import { Separator } from "@/components/ui/separator";
import { useToast } from "@/components/ui/use-toast";
import { IconKey, IconUser } from "@/components/ui/icons";
import { LogOutIcon, Trash2Icon } from "lucide-react";
import { providerIcons } from "@/components/integrations/credentials-input";
import {
CredentialsProviderName,
CredentialsProvidersContext,
} from "@/components/integrations/credentials-provider";
import { CredentialsProvidersContext } from "@/components/integrations/credentials-provider";
import {
Table,
TableBody,
Expand All @@ -23,6 +20,7 @@ import {
TableHeader,
TableRow,
} from "@/components/ui/table";
import { CredentialsProviderName } from "@/lib/autogpt-server-api";

export default function PrivatePage() {
const { user, isLoading, error } = useUser();
Expand Down Expand Up @@ -62,7 +60,22 @@ export default function PrivatePage() {
[providers, toast],
);

if (isLoading || !providers || !providers) {
//TODO: remove when the way system credentials are handled is updated
// This contains ids for built-in "Use Credits for X" credentials
const hiddenCredentials = useMemo(
() => [
"fdb7f412-f519-48d1-9b5f-d2f73d0e01fe", // Revid
"760f84fc-b270-42de-91f6-08efe1b512d0", // Ideogram
"6b9fc200-4726-4973-86c9-cd526f5ce5db", // Replicate
"53c25cb8-e3ee-465c-a4d1-e75a4c899c2a", // OpenAI
"24e5d942-d9e3-4798-8151-90143ee55629", // Anthropic
"4ec22295-8f97-4dd1-b42b-2c6957a02545", // Groq
"7f7b0654-c36b-4565-8fa7-9a52575dfae2", // D-ID
],
[],
);

if (isLoading || !providers) {
return (
<div className="flex h-[80vh] items-center justify-center">
<FaSpinner className="mr-2 h-16 w-16 animate-spin" />
Expand All @@ -76,15 +89,15 @@ export default function PrivatePage() {
}

const allCredentials = Object.values(providers).flatMap((provider) =>
[...provider.savedOAuthCredentials, ...provider.savedApiKeys].map(
(credentials) => ({
[...provider.savedOAuthCredentials, ...provider.savedApiKeys]
.filter((cred) => !hiddenCredentials.includes(cred.id))
.map((credentials) => ({
...credentials,
provider: provider.provider,
providerName: provider.providerName,
ProviderIcon: providerIcons[provider.provider],
TypeIcon: { oauth2: IconUser, api_key: IconKey }[credentials.type],
}),
),
})),
);

return (
Expand Down
58 changes: 10 additions & 48 deletions autogpt_platform/frontend/src/components/CustomNode.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,13 @@ import {
BlockUIType,
BlockCost,
} from "@/lib/autogpt-server-api/types";
import { beautifyString, cn, setNestedProperty } from "@/lib/utils";
import {
beautifyString,
cn,
getValue,
parseKeys,
setNestedProperty,
} from "@/lib/utils";
import { Button } from "@/components/ui/button";
import { Switch } from "@/components/ui/switch";
import { history } from "./history";
Expand All @@ -36,8 +42,6 @@ import * as Separator from "@radix-ui/react-separator";
import * as ContextMenu from "@radix-ui/react-context-menu";
import { DotsVerticalIcon, TrashIcon, CopyIcon } from "@radix-ui/react-icons";

type ParsedKey = { key: string; index?: number };

export type ConnectionData = Array<{
edge_id: string;
source: string;
Expand Down Expand Up @@ -178,7 +182,7 @@ export function CustomNode({
className=""
selfKey={noteKey}
schema={noteSchema as BlockIOStringSubSchema}
value={getValue(noteKey)}
value={getValue(noteKey, data.hardcodedValues)}
handleInputChange={handleInputChange}
handleInputClick={handleInputClick}
error={data.errors?.[noteKey] ?? ""}
Expand Down Expand Up @@ -228,7 +232,7 @@ export function CustomNode({
nodeId={id}
propKey={getInputPropKey(propKey)}
propSchema={propSchema}
currentValue={getValue(getInputPropKey(propKey))}
currentValue={getValue(propKey, data.hardcodedValues)}
connections={data.connections}
handleInputChange={handleInputChange}
handleInputClick={handleInputClick}
Expand Down Expand Up @@ -283,48 +287,6 @@ export function CustomNode({
setErrors({ ...errors });
};

// Helper function to parse keys with array indices
//TODO move to utils
const parseKeys = (key: string): ParsedKey[] => {
const splits = key.split(/_@_|_#_|_\$_|\./);
const keys: ParsedKey[] = [];
let currentKey: string | null = null;

splits.forEach((split) => {
const isInteger = /^\d+$/.test(split);
if (!isInteger) {
if (currentKey !== null) {
keys.push({ key: currentKey });
}
currentKey = split;
} else {
if (currentKey !== null) {
keys.push({ key: currentKey, index: parseInt(split, 10) });
currentKey = null;
} else {
throw new Error("Invalid key format: array index without a key");
}
}
});

if (currentKey !== null) {
keys.push({ key: currentKey });
}

return keys;
};

const getValue = (key: string) => {
const keys = parseKeys(key);
return keys.reduce((acc, k) => {
if (acc === undefined) return undefined;
if (k.index !== undefined) {
return Array.isArray(acc[k.key]) ? acc[k.key][k.index] : undefined;
}
return acc[k.key];
}, data.hardcodedValues as any);
};

const isHandleConnected = (key: string) => {
return (
data.connections &&
Expand All @@ -347,7 +309,7 @@ export function CustomNode({
const handleInputClick = (key: string) => {
console.debug(`Opening modal for key: ${key}`);
setActiveKey(key);
const value = getValue(key);
const value = getValue(key, data.hardcodedValues);
setInputModalValue(
typeof value === "object" ? JSON.stringify(value, null, 2) : value,
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,16 +46,18 @@ export const providerIcons: Record<
CredentialsProviderName,
React.FC<{ className?: string }>
> = {
anthropic: fallbackIcon,
github: FaGithub,
google: FaGoogle,
groq: fallbackIcon,
notion: NotionLogoIcon,
discord: FaDiscord,
d_id: fallbackIcon,
google_maps: FaGoogle,
jina: fallbackIcon,
ideogram: fallbackIcon,
llm: fallbackIcon,
medium: FaMedium,
ollama: fallbackIcon,
openai: fallbackIcon,
openweathermap: fallbackIcon,
pinecone: fallbackIcon,
Expand All @@ -80,7 +82,7 @@ export type OAuthPopupResultMessage = { message_type: "oauth_popup_result" } & (
export const CredentialsInput: FC<{
className?: string;
selectedCredentials?: CredentialsMetaInput;
onSelectCredentials: (newValue: CredentialsMetaInput) => void;
onSelectCredentials: (newValue?: CredentialsMetaInput) => void;
}> = ({ className, selectedCredentials, onSelectCredentials }) => {
const api = useMemo(() => new AutoGPTServerAPI(), []);
const credentials = useCredentials();
Expand All @@ -91,14 +93,10 @@ export const CredentialsInput: FC<{
useState<AbortController | null>(null);
const [oAuthError, setOAuthError] = useState<string | null>(null);

if (!credentials) {
if (!credentials || credentials.isLoading) {
return null;
}

if (credentials.isLoading) {
return <div>Loading...</div>;
}

const {
schema,
provider,
Expand Down Expand Up @@ -222,10 +220,21 @@ export const CredentialsInput: FC<{
</>
);

// Deselect credentials if they do not exist (e.g. provider was changed)
if (
selectedCredentials &&
!savedApiKeys
.concat(savedOAuthCredentials)
.some((c) => c.id === selectedCredentials.id)
) {
onSelectCredentials(undefined);
}

// No saved credentials yet
if (savedApiKeys.length === 0 && savedOAuthCredentials.length === 0) {
return (
<>
<span className="text-m green mb-0 text-gray-900">Credentials</span>
<div className={cn("flex flex-row space-x-2", className)}>
{supportsOAuth2 && (
<Button onClick={handleOAuthLogin}>
Expand All @@ -248,6 +257,25 @@ export const CredentialsInput: FC<{
);
}

const singleCredential =
savedApiKeys.length === 1 && savedOAuthCredentials.length === 0
? savedApiKeys[0]
: savedOAuthCredentials.length === 1 && savedApiKeys.length === 0
? savedOAuthCredentials[0]
: null;

if (singleCredential) {
if (!selectedCredentials) {
onSelectCredentials({
id: singleCredential.id,
type: singleCredential.type,
provider,
title: singleCredential.title,
});
}
return null;
}

function handleValueChange(newValue: string) {
if (newValue === "sign-in") {
// Trigger OAuth2 sign in flow
Expand All @@ -263,7 +291,7 @@ export const CredentialsInput: FC<{
onSelectCredentials({
id: selectedCreds.id,
type: selectedCreds.type,
provider: schema.credentials_provider,
provider: provider,
// title: customTitle, // TODO: add input for title
});
}
Expand All @@ -272,6 +300,7 @@ export const CredentialsInput: FC<{
// Saved credentials exist
return (
<>
<span className="text-m green mb-0 text-gray-900">Credentials</span>
<Select value={selectedCredentials?.id} onValueChange={handleValueChange}>
<SelectTrigger>
<SelectValue placeholder={schema.placeholder} />
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,18 @@ const CREDENTIALS_PROVIDER_NAMES = Object.values(

// --8<-- [start:CredentialsProviderNames]
const providerDisplayNames: Record<CredentialsProviderName, string> = {
anthropic: "Anthropic",
discord: "Discord",
d_id: "D-ID",
github: "GitHub",
google: "Google",
google_maps: "Google Maps",
groq: "Groq",
ideogram: "Ideogram",
jina: "Jina",
medium: "Medium",
llm: "LLM",
notion: "Notion",
ollama: "Ollama",
openai: "OpenAI",
openweathermap: "OpenWeatherMap",
pinecone: "Pinecone",
Expand Down
Loading

0 comments on commit e907ffd

Please sign in to comment.