Skip to content

Commit

Permalink
chore(platform):Refactor GraphExecution naming clash and remove unuse…
Browse files Browse the repository at this point in the history
…d Graph Execution functions (#8939)

This is a follow-up of
#8752

There are several APIs and functions related to graph execution that are
unused now.
There is also confusion about the name of `GraphExecution` that exists
in graph.py & execution.py.

### Changes 🏗️

* Renamed `GraphExecution` in `execution.py` to `GraphExecutionEntry`,
this is only used as a queue entry for execution.
* Removed unused `get_graph_execution` & `list_executions` in
`execution.py`.
* Removed `with_run` option on `get_graph` function in `graph.py`.
* Removed `GraphMetaWithRuns`
* Removed exposed functions only for testing.
* Removed `executions` fields in Graph model.

### Checklist 📋

#### For code changes:
- [ ] I have clearly listed my changes in the PR description
- [ ] I have made a test plan
- [ ] I have tested my changes according to the test plan:
  <!-- Put your test plan here: -->
  - [ ] ...

<details>
  <summary>Example test plan</summary>
  
  - [ ] Create from scratch and execute an agent with at least 3 blocks
- [ ] Import an agent from file upload, and confirm it executes
correctly
  - [ ] Upload agent to marketplace
- [ ] Import an agent from marketplace and confirm it executes correctly
  - [ ] Edit an agent from monitor, and confirm it executes correctly
</details>

#### For configuration changes:
- [ ] `.env.example` is updated or already compatible with my changes
- [ ] `docker-compose.yml` is updated or already compatible with my
changes
- [ ] I have included a list of my configuration changes in the PR
description (under **Changes**)

<details>
  <summary>Examples of configuration changes</summary>

  - Changing ports
  - Adding new services that need to communicate with each other
  - Secrets or environment variable changes
  - New or infrastructure changes such as databases
</details>

---------

Co-authored-by: Krzysztof Czerwinski <[email protected]>
  • Loading branch information
majdyz and kcze authored Dec 11, 2024
1 parent 7a9115d commit 6490b4e
Show file tree
Hide file tree
Showing 11 changed files with 49 additions and 164 deletions.
35 changes: 3 additions & 32 deletions autogpt_platform/backend/backend/data/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
AgentNodeExecution,
AgentNodeExecutionInputOutput,
)
from prisma.types import AgentGraphExecutionWhereInput
from pydantic import BaseModel

from backend.data.block import BlockData, BlockInput, CompletedBlockOutput
Expand All @@ -19,14 +18,14 @@
from backend.util.settings import Config


class GraphExecution(BaseModel):
class GraphExecutionEntry(BaseModel):
user_id: str
graph_exec_id: str
graph_id: str
start_node_execs: list["NodeExecution"]
start_node_execs: list["NodeExecutionEntry"]


class NodeExecution(BaseModel):
class NodeExecutionEntry(BaseModel):
user_id: str
graph_exec_id: str
graph_id: str
Expand Down Expand Up @@ -325,34 +324,6 @@ async def update_execution_status(
return ExecutionResult.from_db(res)


async def get_graph_execution(
graph_exec_id: str, user_id: str
) -> AgentGraphExecution | None:
"""
Retrieve a specific graph execution by its ID.
Args:
graph_exec_id (str): The ID of the graph execution to retrieve.
user_id (str): The ID of the user to whom the graph (execution) belongs.
Returns:
AgentGraphExecution | None: The graph execution if found, None otherwise.
"""
execution = await AgentGraphExecution.prisma().find_first(
where={"id": graph_exec_id, "userId": user_id},
include=GRAPH_EXECUTION_INCLUDE,
)
return execution


async def list_executions(graph_id: str, graph_version: int | None = None) -> list[str]:
where: AgentGraphExecutionWhereInput = {"agentGraphId": graph_id}
if graph_version is not None:
where["agentGraphVersion"] = graph_version
executions = await AgentGraphExecution.prisma().find_many(where=where)
return [execution.id for execution in executions]


async def get_execution_results(graph_exec_id: str) -> list[ExecutionResult]:
executions = await AgentNodeExecution.prisma().find_many(
where={"agentGraphExecutionId": graph_exec_id},
Expand Down
28 changes: 10 additions & 18 deletions autogpt_platform/backend/backend/data/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import prisma
from prisma.models import AgentGraph, AgentGraphExecution, AgentNode, AgentNodeLink
from prisma.types import AgentGraphExecutionWhereInput, AgentGraphWhereInput
from prisma.types import AgentGraphWhereInput
from pydantic.fields import computed_field

from backend.blocks.agent import AgentExecutorBlock
Expand Down Expand Up @@ -143,7 +143,6 @@ class Graph(BaseDbModel):
is_template: bool = False
name: str
description: str
executions: list[GraphExecution] = []
nodes: list[Node] = []
links: list[Link] = []

Expand Down Expand Up @@ -329,11 +328,6 @@ def is_static_output_block(nid: str) -> bool:

@staticmethod
def from_db(graph: AgentGraph, hide_credentials: bool = False):
executions = [
GraphExecution.from_db(execution)
for execution in graph.AgentGraphExecution or []
]

return GraphModel(
id=graph.id,
user_id=graph.userId,
Expand All @@ -342,7 +336,6 @@ def from_db(graph: AgentGraph, hide_credentials: bool = False):
is_template=graph.isTemplate,
name=graph.name or "",
description=graph.description or "",
executions=executions,
nodes=[
GraphModel._process_node(node, hide_credentials)
for node in graph.AgentNodes or []
Expand Down Expand Up @@ -412,15 +405,13 @@ async def set_node_webhook(node_id: str, webhook_id: str | None) -> NodeModel:

async def get_graphs(
user_id: str,
include_executions: bool = False,
filter_by: Literal["active", "template"] | None = "active",
) -> list[GraphModel]:
"""
Retrieves graph metadata objects.
Default behaviour is to get all currently active graphs.
Args:
include_executions: Whether to include executions in the graph metadata.
filter_by: An optional filter to either select templates or active graphs.
user_id: The ID of the user that owns the graph.
Expand All @@ -434,30 +425,31 @@ async def get_graphs(
elif filter_by == "template":
where_clause["isTemplate"] = True

graph_include = AGENT_GRAPH_INCLUDE
graph_include["AgentGraphExecution"] = include_executions

graphs = await AgentGraph.prisma().find_many(
where=where_clause,
distinct=["id"],
order={"version": "desc"},
include=graph_include,
include=AGENT_GRAPH_INCLUDE,
)

return [GraphModel.from_db(graph) for graph in graphs]


async def get_executions(user_id: str) -> list[GraphExecution]:
where_clause: AgentGraphExecutionWhereInput = {"userId": user_id}

executions = await AgentGraphExecution.prisma().find_many(
where=where_clause,
where={"userId": user_id},
order={"createdAt": "desc"},
)

return [GraphExecution.from_db(execution) for execution in executions]


async def get_execution(user_id: str, execution_id: str) -> GraphExecution | None:
execution = await AgentGraphExecution.prisma().find_first(
where={"id": execution_id, "userId": user_id}
)
return GraphExecution.from_db(execution) if execution else None


async def get_graph(
graph_id: str,
version: int | None = None,
Expand Down
42 changes: 22 additions & 20 deletions autogpt_platform/backend/backend/executor/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@
ExecutionQueue,
ExecutionResult,
ExecutionStatus,
GraphExecution,
NodeExecution,
GraphExecutionEntry,
NodeExecutionEntry,
merge_execution_input,
parse_execution_output,
)
Expand Down Expand Up @@ -96,13 +96,13 @@ def _wrap(self, msg: str, **extra):


T = TypeVar("T")
ExecutionStream = Generator[NodeExecution, None, None]
ExecutionStream = Generator[NodeExecutionEntry, None, None]


def execute_node(
db_client: "DatabaseManager",
creds_manager: IntegrationCredentialsManager,
data: NodeExecution,
data: NodeExecutionEntry,
execution_stats: dict[str, Any] | None = None,
) -> ExecutionStream:
"""
Expand Down Expand Up @@ -252,15 +252,15 @@ def _enqueue_next_nodes(
graph_exec_id: str,
graph_id: str,
log_metadata: LogMetadata,
) -> list[NodeExecution]:
) -> list[NodeExecutionEntry]:
def add_enqueued_execution(
node_exec_id: str, node_id: str, data: BlockInput
) -> NodeExecution:
) -> NodeExecutionEntry:
exec_update = db_client.update_execution_status(
node_exec_id, ExecutionStatus.QUEUED, data
)
db_client.send_execution_update(exec_update)
return NodeExecution(
return NodeExecutionEntry(
user_id=user_id,
graph_exec_id=graph_exec_id,
graph_id=graph_id,
Expand All @@ -269,7 +269,7 @@ def add_enqueued_execution(
data=data,
)

def register_next_executions(node_link: Link) -> list[NodeExecution]:
def register_next_executions(node_link: Link) -> list[NodeExecutionEntry]:
enqueued_executions = []
next_output_name = node_link.source_name
next_input_name = node_link.sink_name
Expand Down Expand Up @@ -501,8 +501,8 @@ def on_node_executor_sigterm(cls):
@error_logged
def on_node_execution(
cls,
q: ExecutionQueue[NodeExecution],
node_exec: NodeExecution,
q: ExecutionQueue[NodeExecutionEntry],
node_exec: NodeExecutionEntry,
) -> dict[str, Any]:
log_metadata = LogMetadata(
user_id=node_exec.user_id,
Expand All @@ -529,8 +529,8 @@ def on_node_execution(
@time_measured
def _on_node_execution(
cls,
q: ExecutionQueue[NodeExecution],
node_exec: NodeExecution,
q: ExecutionQueue[NodeExecutionEntry],
node_exec: NodeExecutionEntry,
log_metadata: LogMetadata,
stats: dict[str, Any] | None = None,
):
Expand Down Expand Up @@ -580,7 +580,9 @@ def _init_node_executor_pool(cls):

@classmethod
@error_logged
def on_graph_execution(cls, graph_exec: GraphExecution, cancel: threading.Event):
def on_graph_execution(
cls, graph_exec: GraphExecutionEntry, cancel: threading.Event
):
log_metadata = LogMetadata(
user_id=graph_exec.user_id,
graph_eid=graph_exec.graph_exec_id,
Expand All @@ -605,7 +607,7 @@ def on_graph_execution(cls, graph_exec: GraphExecution, cancel: threading.Event)
@time_measured
def _on_graph_execution(
cls,
graph_exec: GraphExecution,
graph_exec: GraphExecutionEntry,
cancel: threading.Event,
log_metadata: LogMetadata,
) -> tuple[dict[str, Any], Exception | None]:
Expand Down Expand Up @@ -636,13 +638,13 @@ def cancel_handler():
cancel_thread.start()

try:
queue = ExecutionQueue[NodeExecution]()
queue = ExecutionQueue[NodeExecutionEntry]()
for node_exec in graph_exec.start_node_execs:
queue.add(node_exec)

running_executions: dict[str, AsyncResult] = {}

def make_exec_callback(exec_data: NodeExecution):
def make_exec_callback(exec_data: NodeExecutionEntry):
node_id = exec_data.node_id

def callback(result: object):
Expand Down Expand Up @@ -717,7 +719,7 @@ def __init__(self):
self.use_redis = True
self.use_supabase = True
self.pool_size = settings.config.num_graph_workers
self.queue = ExecutionQueue[GraphExecution]()
self.queue = ExecutionQueue[GraphExecutionEntry]()
self.active_graph_runs: dict[str, tuple[Future, threading.Event]] = {}

@classmethod
Expand Down Expand Up @@ -768,7 +770,7 @@ def add_execution(
data: BlockInput,
user_id: str,
graph_version: int | None = None,
) -> GraphExecution:
) -> GraphExecutionEntry:
graph: GraphModel | None = self.db_client.get_graph(
graph_id=graph_id, user_id=user_id, version=graph_version
)
Expand Down Expand Up @@ -818,7 +820,7 @@ def add_execution(
starting_node_execs = []
for node_exec in node_execs:
starting_node_execs.append(
NodeExecution(
NodeExecutionEntry(
user_id=user_id,
graph_exec_id=node_exec.graph_exec_id,
graph_id=node_exec.graph_id,
Expand All @@ -832,7 +834,7 @@ def add_execution(
)
self.db_client.send_execution_update(exec_update)

graph_exec = GraphExecution(
graph_exec = GraphExecutionEntry(
user_id=user_id,
graph_id=graph_id,
graph_exec_id=graph_exec_id,
Expand Down
12 changes: 6 additions & 6 deletions autogpt_platform/backend/backend/server/rest_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,17 +117,17 @@ async def test_execute_graph(
async def test_create_graph(
create_graph: backend.server.routers.v1.CreateGraph,
user_id: str,
is_template=False,
):
return await backend.server.routers.v1.create_new_graph(create_graph, user_id)

@staticmethod
async def test_get_graph_run_status(
graph_id: str, graph_exec_id: str, user_id: str
):
return await backend.server.routers.v1.get_graph_run_status(
graph_id, graph_exec_id, user_id
async def test_get_graph_run_status(graph_exec_id: str, user_id: str):
execution = await backend.data.graph.get_execution(
user_id=user_id, execution_id=graph_exec_id
)
if not execution:
raise ValueError(f"Execution {graph_exec_id} not found")
return execution.status

@staticmethod
async def test_get_graph_run_node_execution_results(
Expand Down
Loading

0 comments on commit 6490b4e

Please sign in to comment.