Skip to content

Commit

Permalink
feat(backend): Enable json parsing with typing & conversion (#8578)
Browse files Browse the repository at this point in the history
  • Loading branch information
majdyz authored Nov 15, 2024
1 parent 6a1cea4 commit 8987fdd
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 12 deletions.
2 changes: 1 addition & 1 deletion autogpt_platform/backend/backend/data/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def from_graph(graph: AgentGraphExecution):
def from_db(execution: AgentNodeExecution):
if execution.executionData:
# Execution that has been queued for execution will persist its data.
input_data = json.loads(execution.executionData)
input_data = json.loads(execution.executionData, target_type=dict[str, Any])
else:
# For incomplete execution, executionData will not be yet available.
input_data: BlockInput = defaultdict()
Expand Down
19 changes: 12 additions & 7 deletions autogpt_platform/backend/backend/data/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ def from_db(node: AgentNode):
obj = Node(
id=node.id,
block_id=node.AgentBlock.id,
input_default=json.loads(node.constantInput),
metadata=json.loads(node.metadata),
input_default=json.loads(node.constantInput, target_type=dict[str, Any]),
metadata=json.loads(node.metadata, target_type=dict[str, Any]),
)
obj.input_links = [Link.from_db(link) for link in node.Input or []]
obj.output_links = [Link.from_db(link) for link in node.Output or []]
Expand All @@ -80,10 +80,13 @@ def from_db(execution: AgentGraphExecution):
duration = (end_time - start_time).total_seconds()
total_run_time = duration

if execution.stats:
stats = json.loads(execution.stats)
duration = stats.get("walltime", duration)
total_run_time = stats.get("nodes_walltime", total_run_time)
try:
stats = json.loads(execution.stats or "{}", target_type=dict[str, Any])
except ValueError:
stats = {}

duration = stats.get("walltime", duration)
total_run_time = stats.get("nodes_walltime", total_run_time)

return GraphExecution(
id=execution.id,
Expand Down Expand Up @@ -311,7 +314,9 @@ def from_db(graph: AgentGraph, hide_credentials: bool = False):
def _process_node(node: AgentNode, hide_credentials: bool) -> Node:
node_dict = node.model_dump()
if hide_credentials and "constantInput" in node_dict:
constant_input = json.loads(node_dict["constantInput"])
constant_input = json.loads(
node_dict["constantInput"], target_type=dict[str, Any]
)
constant_input = Graph._hide_credentials_in_input(constant_input)
node_dict["constantInput"] = json.dumps(constant_input)
return Node.from_db(AgentNode(**node_dict))
Expand Down
20 changes: 19 additions & 1 deletion autogpt_platform/backend/backend/util/json.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import json
from typing import Any, Type, TypeVar, overload

from fastapi.encoders import jsonable_encoder

from .type import type_match


def to_dict(data) -> dict:
return jsonable_encoder(data)
Expand All @@ -11,4 +14,19 @@ def dumps(data) -> str:
return json.dumps(jsonable_encoder(data))


loads = json.loads
T = TypeVar("T")


@overload
def loads(data: str, *args, target_type: Type[T], **kwargs) -> T: ...


@overload
def loads(data: str, *args, **kwargs) -> Any: ...


def loads(data: str, *args, target_type: Type[T] | None = None, **kwargs) -> Any:
parsed = json.loads(data, *args, **kwargs)
if target_type:
return type_match(parsed, target_type)
return parsed
22 changes: 19 additions & 3 deletions autogpt_platform/backend/backend/util/type.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import json
from typing import Any, Type, TypeVar, get_args, get_origin
from typing import Any, Type, TypeVar, cast, get_args, get_origin


class ConversionError(Exception):
class ConversionError(ValueError):
pass


Expand Down Expand Up @@ -102,7 +102,7 @@ def __convert_bool(value: Any) -> bool:
return bool(value)


def convert(value: Any, target_type: Type):
def _try_convert(value: Any, target_type: Type, raise_on_mismatch: bool) -> Any:
origin = get_origin(target_type)
args = get_args(target_type)
if origin is None:
Expand Down Expand Up @@ -133,6 +133,8 @@ def convert(value: Any, target_type: Type):
return {convert(v, args[0]) for v in value}
else:
return value
elif raise_on_mismatch:
raise TypeError(f"Value {value} is not of expected type {target_type}")
else:
# Need to convert value to the origin type
if origin is list:
Expand Down Expand Up @@ -175,3 +177,17 @@ def convert(value: Any, target_type: Type):
return __convert_bool(value)
else:
return value


T = TypeVar("T")


def type_match(value: Any, target_type: Type[T]) -> T:
return cast(T, _try_convert(value, target_type, raise_on_mismatch=True))


def convert(value: Any, target_type: Type[T]) -> T:
try:
return cast(T, _try_convert(value, target_type, raise_on_mismatch=False))
except Exception as e:
raise ConversionError(f"Failed to convert {value} to {target_type}") from e

0 comments on commit 8987fdd

Please sign in to comment.