Skip to content

Commit

Permalink
feat(block): Add AI video generator block with Fal txt 2 vid (#8528)
Browse files Browse the repository at this point in the history
### Background

Implements an AI Video Generator Block for text to image models hosted
on Fal


![image](https://github.com/user-attachments/assets/9cb70015-4174-4419-8c1a-4144f324442f)

---------

Co-authored-by: Aarushi <[email protected]>
Co-authored-by: Aarushi <[email protected]>
  • Loading branch information
3 people committed Dec 1, 2024
1 parent a50da78 commit 02f6a82
Show file tree
Hide file tree
Showing 6 changed files with 240 additions and 0 deletions.
36 changes: 36 additions & 0 deletions autogpt_platform/backend/backend/blocks/fal/_auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from typing import Literal

from pydantic import SecretStr

from backend.data.model import APIKeyCredentials, CredentialsField, CredentialsMetaInput

FalCredentials = APIKeyCredentials
FalCredentialsInput = CredentialsMetaInput[
Literal["fal"],
Literal["api_key"],
]

TEST_CREDENTIALS = APIKeyCredentials(
id="01234567-89ab-cdef-0123-456789abcdef",
provider="fal",
api_key=SecretStr("mock-fal-api-key"),
title="Mock FAL API key",
expires_at=None,
)
TEST_CREDENTIALS_INPUT = {
"provider": TEST_CREDENTIALS.provider,
"id": TEST_CREDENTIALS.id,
"type": TEST_CREDENTIALS.type,
"title": TEST_CREDENTIALS.title,
}


def FalCredentialsField() -> FalCredentialsInput:
"""
Creates a FAL credentials input on a block.
"""
return CredentialsField(
provider="fal",
supported_credential_types={"api_key"},
description="The FAL integration can be used with an API Key.",
)
199 changes: 199 additions & 0 deletions autogpt_platform/backend/backend/blocks/fal/ai_video_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
import logging
import time
from enum import Enum
from typing import Any, Dict

import httpx

from backend.blocks.fal._auth import (
TEST_CREDENTIALS,
TEST_CREDENTIALS_INPUT,
FalCredentials,
FalCredentialsField,
FalCredentialsInput,
)
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import SchemaField

logger = logging.getLogger(__name__)


class FalModel(str, Enum):
MOCHI = "fal-ai/mochi-v1"
LUMA = "fal-ai/luma-dream-machine"


class AIVideoGeneratorBlock(Block):
class Input(BlockSchema):
prompt: str = SchemaField(
description="Description of the video to generate.",
placeholder="A dog running in a field.",
)
model: FalModel = SchemaField(
title="FAL Model",
default=FalModel.MOCHI,
description="The FAL model to use for video generation.",
)
credentials: FalCredentialsInput = FalCredentialsField()

class Output(BlockSchema):
video_url: str = SchemaField(description="The URL of the generated video.")
error: str = SchemaField(
description="Error message if video generation failed."
)
logs: list[str] = SchemaField(
description="Generation progress logs.", optional=True
)

def __init__(self):
super().__init__(
id="530cf046-2ce0-4854-ae2c-659db17c7a46",
description="Generate videos using FAL AI models.",
categories={BlockCategory.AI},
input_schema=self.Input,
output_schema=self.Output,
test_input={
"prompt": "A dog running in a field.",
"model": FalModel.MOCHI,
"credentials": TEST_CREDENTIALS_INPUT,
},
test_credentials=TEST_CREDENTIALS,
test_output=[("video_url", "https://fal.media/files/example/video.mp4")],
test_mock={
"generate_video": lambda *args, **kwargs: "https://fal.media/files/example/video.mp4"
},
)

def _get_headers(self, api_key: str) -> Dict[str, str]:
"""Get headers for FAL API requests."""
return {
"Authorization": f"Key {api_key}",
"Content-Type": "application/json",
}

def _submit_request(
self, url: str, headers: Dict[str, str], data: Dict[str, Any]
) -> Dict[str, Any]:
"""Submit a request to the FAL API."""
try:
response = httpx.post(url, headers=headers, json=data)
response.raise_for_status()
return response.json()
except httpx.HTTPError as e:
logger.error(f"FAL API request failed: {str(e)}")
raise RuntimeError(f"Failed to submit request: {str(e)}")

def _poll_status(self, status_url: str, headers: Dict[str, str]) -> Dict[str, Any]:
"""Poll the status endpoint until completion or failure."""
try:
response = httpx.get(status_url, headers=headers)
response.raise_for_status()
return response.json()
except httpx.HTTPError as e:
logger.error(f"Failed to get status: {str(e)}")
raise RuntimeError(f"Failed to get status: {str(e)}")

def generate_video(self, input_data: Input, credentials: FalCredentials) -> str:
"""Generate video using the specified FAL model."""
base_url = "https://queue.fal.run"
api_key = credentials.api_key.get_secret_value()
headers = self._get_headers(api_key)

# Submit generation request
submit_url = f"{base_url}/{input_data.model.value}"
submit_data = {"prompt": input_data.prompt}

seen_logs = set()

try:
# Submit request to queue
submit_response = httpx.post(submit_url, headers=headers, json=submit_data)
submit_response.raise_for_status()
request_data = submit_response.json()

# Get request_id and urls from initial response
request_id = request_data.get("request_id")
status_url = request_data.get("status_url")
result_url = request_data.get("response_url")

if not all([request_id, status_url, result_url]):
raise ValueError("Missing required data in submission response")

# Poll for status with exponential backoff
max_attempts = 30
attempt = 0
base_wait_time = 5

while attempt < max_attempts:
status_response = httpx.get(f"{status_url}?logs=1", headers=headers)
status_response.raise_for_status()
status_data = status_response.json()

# Process new logs only
logs = status_data.get("logs", [])
if logs and isinstance(logs, list):
for log in logs:
if isinstance(log, dict):
# Create a unique key for this log entry
log_key = (
f"{log.get('timestamp', '')}-{log.get('message', '')}"
)
if log_key not in seen_logs:
seen_logs.add(log_key)
message = log.get("message", "")
if message:
logger.debug(
f"[FAL Generation] [{log.get('level', 'INFO')}] [{log.get('source', '')}] [{log.get('timestamp', '')}] {message}"
)

status = status_data.get("status")
if status == "COMPLETED":
# Get the final result
result_response = httpx.get(result_url, headers=headers)
result_response.raise_for_status()
result_data = result_response.json()

if "video" not in result_data or not isinstance(
result_data["video"], dict
):
raise ValueError("Invalid response format - missing video data")

video_url = result_data["video"].get("url")
if not video_url:
raise ValueError("No video URL in response")

return video_url

elif status == "FAILED":
error_msg = status_data.get("error", "No error details provided")
raise RuntimeError(f"Video generation failed: {error_msg}")
elif status == "IN_QUEUE":
position = status_data.get("queue_position", "unknown")
logger.debug(
f"[FAL Generation] Status: In queue, position: {position}"
)
elif status == "IN_PROGRESS":
logger.debug(
"[FAL Generation] Status: Request is being processed..."
)
else:
logger.info(f"[FAL Generation] Status: Unknown status: {status}")

wait_time = min(base_wait_time * (2**attempt), 60) # Cap at 60 seconds
time.sleep(wait_time)
attempt += 1

raise RuntimeError("Maximum polling attempts reached")

except httpx.HTTPError as e:
raise RuntimeError(f"API request failed: {str(e)}")

def run(
self, input_data: Input, *, credentials: FalCredentials, **kwargs
) -> BlockOutput:
try:
video_url = self.generate_video(input_data, credentials)
yield "video_url", video_url
except Exception as e:
error_message = str(e)
yield "error", error_message
2 changes: 2 additions & 0 deletions autogpt_platform/backend/backend/util/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,8 @@ class Secrets(UpdateTrackingModel["Secrets"], BaseSettings):
jina_api_key: str = Field(default="", description="Jina API Key")
unreal_speech_api_key: str = Field(default="", description="Unreal Speech API Key")

fal_key: str = Field(default="", description="FAL API key")

# Add more secret fields as needed

model_config = SettingsConfigDict(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ export const providerIcons: Record<
open_router: fallbackIcon,
pinecone: fallbackIcon,
replicate: fallbackIcon,
fal: fallbackIcon,
revid: fallbackIcon,
unreal_speech: fallbackIcon,
hubspot: fallbackIcon,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ const providerDisplayNames: Record<CredentialsProviderName, string> = {
open_router: "Open Router",
pinecone: "Pinecone",
replicate: "Replicate",
fal: "FAL",
revid: "Rev.ID",
unreal_speech: "Unreal Speech",
hubspot: "Hubspot",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ export const PROVIDER_NAMES = {
OPEN_ROUTER: "open_router",
PINECONE: "pinecone",
REPLICATE: "replicate",
FAL: "fal",
REVID: "revid",
UNREAL_SPEECH: "unreal_speech",
HUBSPOT: "hubspot",
Expand Down

0 comments on commit 02f6a82

Please sign in to comment.