Skip to content

Commit

Permalink
feat(store): Generate AI images for store submissions (#9090)
Browse files Browse the repository at this point in the history
Allow generating ai images for store submissions
  • Loading branch information
Swiftyos authored Dec 19, 2024
1 parent d028f5b commit 8e634d7
Show file tree
Hide file tree
Showing 6 changed files with 275 additions and 25 deletions.
94 changes: 94 additions & 0 deletions autogpt_platform/backend/backend/server/v2/store/image_gen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import io
import logging
from enum import Enum

import replicate
import replicate.exceptions
import requests
from replicate.helpers import FileOutput

from backend.data.graph import Graph
from backend.util.settings import Settings

logger = logging.getLogger(__name__)


class ImageSize(str, Enum):
LANDSCAPE = "1024x768"


class ImageStyle(str, Enum):
DIGITAL_ART = "digital art"


async def generate_agent_image(agent: Graph) -> io.BytesIO:
"""
Generate an image for an agent using Flux model via Replicate API.
Args:
agent (Graph): The agent to generate an image for
Returns:
io.BytesIO: The generated image as bytes
"""
try:
settings = Settings()

if not settings.secrets.replicate_api_key:
raise ValueError("Missing Replicate API key in settings")

# Construct prompt from agent details
prompt = f"App store image for AI agent that gives a cool visual representation of what the agent does: - {agent.name} - {agent.description}"

# Set up Replicate client
client = replicate.Client(api_token=settings.secrets.replicate_api_key)

# Model parameters
input_data = {
"prompt": prompt,
"width": 1024,
"height": 768,
"aspect_ratio": "4:3",
"output_format": "jpg",
"output_quality": 90,
"num_inference_steps": 30,
"guidance": 3.5,
"negative_prompt": "blurry, low quality, distorted, deformed",
"disable_safety_checker": True,
}

try:
# Run model
output = client.run("black-forest-labs/flux-pro", input=input_data)

# Depending on the model output, extract the image URL or bytes
# If the output is a list of FileOutput or URLs
if isinstance(output, list) and output:
if isinstance(output[0], FileOutput):
image_bytes = output[0].read()
else:
# If it's a URL string, fetch the image bytes
result_url = output[0]
response = requests.get(result_url)
response.raise_for_status()
image_bytes = response.content
elif isinstance(output, FileOutput):
image_bytes = output.read()
elif isinstance(output, str):
# Output is a URL
response = requests.get(output)
response.raise_for_status()
image_bytes = response.content
else:
raise RuntimeError("Unexpected output format from the model.")

return io.BytesIO(image_bytes)

except replicate.exceptions.ReplicateError as e:
if e.status == 401:
raise RuntimeError("Invalid Replicate API token") from e
raise RuntimeError(f"Replicate API error: {str(e)}") from e

except Exception as e:
logger.exception("Failed to generate agent image")
raise RuntimeError(f"Image generation failed: {str(e)}")
48 changes: 46 additions & 2 deletions autogpt_platform/backend/backend/server/v2/store/media.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,45 @@
MAX_FILE_SIZE = 50 * 1024 * 1024 # 50MB


async def upload_media(user_id: str, file: fastapi.UploadFile) -> str:
async def check_media_exists(user_id: str, filename: str) -> str | None:
"""
Check if a media file exists in storage for the given user.
Tries both images and videos directories.
Args:
user_id (str): ID of the user who uploaded the file
filename (str): Name of the file to check
Returns:
str | None: URL of the blob if it exists, None otherwise
"""
try:
settings = Settings()
storage_client = storage.Client()
bucket = storage_client.bucket(settings.config.media_gcs_bucket_name)

# Check images
image_path = f"users/{user_id}/images/{filename}"
image_blob = bucket.blob(image_path)
if image_blob.exists():
return image_blob.public_url

# Check videos
video_path = f"users/{user_id}/videos/{filename}"

video_blob = bucket.blob(video_path)
if video_blob.exists():
return video_blob.public_url

return None
except Exception as e:
logger.error(f"Error checking if media file exists: {str(e)}")
return None


async def upload_media(
user_id: str, file: fastapi.UploadFile, use_file_name: bool = False
) -> str:

# Get file content for deeper validation
try:
Expand Down Expand Up @@ -84,6 +122,9 @@ async def upload_media(user_id: str, file: fastapi.UploadFile) -> str:
try:
# Validate file type
content_type = file.content_type
if content_type is None:
content_type = "image/jpeg"

if (
content_type not in ALLOWED_IMAGE_TYPES
and content_type not in ALLOWED_VIDEO_TYPES
Expand Down Expand Up @@ -119,7 +160,10 @@ async def upload_media(user_id: str, file: fastapi.UploadFile) -> str:
# Generate unique filename
filename = file.filename or ""
file_ext = os.path.splitext(filename)[1].lower()
unique_filename = f"{uuid.uuid4()}{file_ext}"
if use_file_name:
unique_filename = filename
else:
unique_filename = f"{uuid.uuid4()}{file_ext}"

# Construct storage path
media_type = "images" if content_type in ALLOWED_IMAGE_TYPES else "videos"
Expand Down
62 changes: 62 additions & 0 deletions autogpt_platform/backend/backend/server/v2/store/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
import fastapi
import fastapi.responses

import backend.data.graph
import backend.server.v2.store.db
import backend.server.v2.store.image_gen
import backend.server.v2.store.media
import backend.server.v2.store.model

Expand Down Expand Up @@ -439,3 +441,63 @@ async def upload_submission_media(
raise fastapi.HTTPException(
status_code=500, detail=f"Failed to upload media file: {str(e)}"
)


@router.post(
"/submissions/generate_image",
tags=["store", "private"],
dependencies=[fastapi.Depends(autogpt_libs.auth.middleware.auth_middleware)],
)
async def generate_image(
agent_id: str,
user_id: typing.Annotated[
str, fastapi.Depends(autogpt_libs.auth.depends.get_user_id)
],
) -> fastapi.responses.Response:
"""
Generate an image for a store listing submission.
Args:
agent_id (str): ID of the agent to generate an image for
user_id (str): ID of the authenticated user
Returns:
JSONResponse: JSON containing the URL of the generated image
"""
try:
agent = await backend.data.graph.get_graph(agent_id, user_id=user_id)

if not agent:
raise fastapi.HTTPException(
status_code=404, detail=f"Agent with ID {agent_id} not found"
)
# Use .jpeg here since we are generating JPEG images
filename = f"agent_{agent_id}.jpeg"

existing_url = await backend.server.v2.store.media.check_media_exists(
user_id, filename
)
if existing_url:
logger.info(f"Using existing image for agent {agent_id}")
return fastapi.responses.JSONResponse(content={"image_url": existing_url})
# Generate agent image as JPEG
image = await backend.server.v2.store.image_gen.generate_agent_image(
agent=agent
)

# Create UploadFile with the correct filename and content_type
image_file = fastapi.UploadFile(
file=image,
filename=filename,
)

image_url = await backend.server.v2.store.media.upload_media(
user_id=user_id, file=image_file, use_file_name=True
)

return fastapi.responses.JSONResponse(content={"image_url": image_url})
except Exception as e:
logger.exception("Exception occurred whilst generating submission image")
raise fastapi.HTTPException(
status_code=500, detail=f"Failed to generate image: {str(e)}"
)
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ interface PublishAgentInfoProps {
) => void;
onClose: () => void;
initialData?: {
agent_id: string;
title: string;
subheader: string;
slug: string;
Expand All @@ -36,6 +37,7 @@ export const PublishAgentInfo: React.FC<PublishAgentInfoProps> = ({
onClose,
initialData,
}) => {
const [agentId, setAgentId] = React.useState<string | null>(null);
const [images, setImages] = React.useState<string[]>(
initialData?.additionalImages
? [initialData.thumbnailSrc, ...initialData.additionalImages]
Expand All @@ -59,10 +61,10 @@ export const PublishAgentInfo: React.FC<PublishAgentInfoProps> = ({
);
const [slug, setSlug] = React.useState(initialData?.slug || "");
const thumbnailsContainerRef = React.useRef<HTMLDivElement | null>(null);

React.useEffect(() => {
if (initialData) {
setImages(initialData.additionalImages || []);
setAgentId(initialData.agent_id);
setImagesWithValidation(initialData.additionalImages || []);
setSelectedImage(initialData.thumbnailSrc || null);
setTitle(initialData.title);
setSubheader(initialData.subheader);
Expand All @@ -73,10 +75,18 @@ export const PublishAgentInfo: React.FC<PublishAgentInfoProps> = ({
}
}, [initialData]);

const setImagesWithValidation = (newImages: string[]) => {
// Remove duplicates
const uniqueImages = Array.from(new Set(newImages));
// Keep only first 5 images
const limitedImages = uniqueImages.slice(0, 5);
setImages(limitedImages);
};

const handleRemoveImage = (indexToRemove: number) => {
const newImages = [...images];
newImages.splice(indexToRemove, 1);
setImages(newImages);
setImagesWithValidation(newImages);
if (newImages[indexToRemove] === selectedImage) {
setSelectedImage(newImages[0] || null);
}
Expand All @@ -88,6 +98,8 @@ export const PublishAgentInfo: React.FC<PublishAgentInfoProps> = ({
};

const handleAddImage = async () => {
if (images.length >= 5) return;

const input = document.createElement("input");
input.type = "file";
input.accept = "image/*";
Expand Down Expand Up @@ -115,11 +127,7 @@ export const PublishAgentInfo: React.FC<PublishAgentInfoProps> = ({
"$1",
);

setImages((prev) => {
const newImages = [...prev, imageUrl];
console.log("Added image. Images now:", newImages);
return newImages;
});
setImagesWithValidation([...images, imageUrl]);
if (!selectedImage) {
setSelectedImage(imageUrl);
}
Expand All @@ -128,6 +136,27 @@ export const PublishAgentInfo: React.FC<PublishAgentInfoProps> = ({
}
};

const [isGenerating, setIsGenerating] = React.useState(false);

const handleGenerateImage = async () => {
if (isGenerating || images.length >= 5) return;

setIsGenerating(true);
try {
const api = new BackendAPI();
if (!agentId) {
throw new Error("Agent ID is required");
}
const { image_url } = await api.generateStoreSubmissionImage(agentId);
console.log("image_url", image_url);
setImagesWithValidation([...images, image_url]);
} catch (error) {
console.error("Failed to generate image:", error);
} finally {
setIsGenerating(false);
}
};

const handleSubmit = (e: React.MouseEvent<HTMLButtonElement>) => {
e.preventDefault();
onSubmit(title, subheader, slug, description, images, youtubeLink, [
Expand Down Expand Up @@ -284,19 +313,21 @@ export const PublishAgentInfo: React.FC<PublishAgentInfoProps> = ({
</button>
</div>
))}
<Button
onClick={handleAddImage}
variant="ghost"
className="flex h-[70px] w-[100px] flex-col items-center justify-center rounded-md bg-neutral-200 hover:bg-neutral-300 dark:bg-neutral-700 dark:hover:bg-neutral-600"
>
<IconPlus
size="lg"
className="text-neutral-600 dark:text-neutral-300"
/>
<span className="mt-1 font-['Geist'] text-xs font-normal text-neutral-600 dark:text-neutral-300">
Add image
</span>
</Button>
{images.length < 5 && (
<Button
onClick={handleAddImage}
variant="ghost"
className="flex h-[70px] w-[100px] flex-col items-center justify-center rounded-md bg-neutral-200 hover:bg-neutral-300 dark:bg-neutral-700 dark:hover:bg-neutral-600"
>
<IconPlus
size="lg"
className="text-neutral-600 dark:text-neutral-300"
/>
<span className="mt-1 font-['Geist'] text-xs font-normal text-neutral-600 dark:text-neutral-300">
Add image
</span>
</Button>
)}
</>
)}
</div>
Expand All @@ -313,9 +344,17 @@ export const PublishAgentInfo: React.FC<PublishAgentInfoProps> = ({
<Button
variant="default"
size="sm"
className="bg-neutral-800 text-white hover:bg-neutral-900 dark:bg-neutral-600 dark:hover:bg-neutral-500"
className={`bg-neutral-800 text-white hover:bg-neutral-900 dark:bg-neutral-600 dark:hover:bg-neutral-500 ${
images.length >= 5 ? "cursor-not-allowed opacity-50" : ""
}`}
onClick={handleGenerateImage}
disabled={isGenerating || images.length >= 5}
>
Generate
{isGenerating
? "Generating..."
: images.length >= 5
? "Max images reached"
: "Generate"}
</Button>
</div>
</div>
Expand Down
Loading

0 comments on commit 8e634d7

Please sign in to comment.