Skip to content

Commit

Permalink
Merge pull request #164 from dbpunk-labs/feat/refactor
Browse files Browse the repository at this point in the history
refactor:adjust the openai and codellama agent
  • Loading branch information
imotai authored Oct 20, 2023
2 parents bacede7 + 1ece762 commit 25f317e
Show file tree
Hide file tree
Showing 21 changed files with 1,168 additions and 950 deletions.
2 changes: 2 additions & 0 deletions agent/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@
"aiohttp>=3.8.5",
"pydantic",
"tiktoken",
"fastapi",
"uvicorn",
],
package_data={"og_agent": ["*.bnf"]},
entry_points={
Expand Down
64 changes: 34 additions & 30 deletions agent/src/og_agent/agent_api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,23 +47,21 @@ class StepResponseType(str, Enum):


class ContextState(BaseModel):
generated_token_count: int
iteration_count: int
llm_model_name: str
output_token_count: int
llm_name: str
total_duration: int
sent_token_count: int
llm_model_response_duration: int
output_token_count: int
llm_response_duration: int
context_id: str | None = None

@classmethod
def new_from(cls, state):
return cls(
generated_token_count=state.generated_token_count,
iteration_count=state.iteration_count,
llm_model_name=state.model_name,
output_token_count=state.output_token_count,
llm_name=state.llm_name,
total_duration=state.total_duration,
sent_token_count=state.sent_token_count,
llm_model_response_duration=state.model_respond_duration,
input_token_count=state.input_token_count,
llm_response_duration=state.llm_response_duration,
)


Expand All @@ -73,7 +71,7 @@ class StepActionEnd(BaseModel):
has_error: bool

@classmethod
def new_from(cls, step_action_end: agent_server_pb2.OnAgentActionEnd):
def new_from(cls, step_action_end: agent_server_pb2.OnStepActionEnd):
return cls(
output=step_action_end.output,
output_files=step_action_end.output_files,
Expand All @@ -85,7 +83,7 @@ class FinalAnswer(BaseModel):
answer: str

@classmethod
def new_from(cls, final_answer: agent_server_pb2.FinalRespond):
def new_from(cls, final_answer: agent_server_pb2.FinalAnswer):
return cls(answer=final_answer.answer)


Expand All @@ -94,7 +92,7 @@ class StepActionStart(BaseModel):
tool: str

@classmethod
def new_from(cls, step_action_start: agent_server_pb2.OnAgentAction):
def new_from(cls, step_action_start: agent_server_pb2.OnStepActionStart):
return cls(input=step_action_start.input, tool=step_action_start.tool)


Expand All @@ -109,49 +107,57 @@ class StepResponse(BaseModel):
final_answer: FinalAnswer | None = None

@classmethod
def new_from(cls, response: agent_server_pb2.TaskRespond):
if response.respond_type == agent_server_pb2.TaskRespond.OnAgentActionType:
def new_from(cls, response: agent_server_pb2.TaskResponse):
if response.response_type == agent_server_pb2.TaskResponse.OnStepActionStart:
return cls(
step_type=StepResponseType.OnStepActionStart,
step_state=ContextState.new_from(response.state),
step_action_start=StepActionStart.new_from(response.on_agent_action),
step_action_start=StepActionStart.new_from(
response.on_step_action_start
),
)
elif response.respond_type == agent_server_pb2.TaskRespond.OnAgentCodeTyping:
elif response.response_type == agent_server_pb2.TaskResponse.OnModelTypeCode:
return cls(
step_type=StepResponseType.OnStepCodeTyping,
step_state=ContextState.new_from(response.state),
typing_content=response.typing_content,
typing_content=response.typing_content.content,
)

elif response.respond_type == agent_server_pb2.TaskRespond.OnAgentTextTyping:
elif response.response_type == agent_server_pb2.TaskResponse.OnModelTypeText:
return cls(
step_type=StepResponseType.OnStepTextTyping,
step_state=ContextState.new_from(response.state),
typing_content=response.typing_content,
typing_content=response.typing_content.content,
)
elif response.respond_type == agent_server_pb2.TaskRespond.OnAgentActionStdout:
elif (
response.response_type
== agent_server_pb2.TaskResponse.OnStepActionStreamStdout
):
return cls(
step_type=StepResponseType.OnStepActionStdout,
step_state=ContextState.new_from(response.state),
step_action_stdout=response.console_stdout,
)
elif response.respond_type == agent_server_pb2.TaskRespond.OnAgentActionStderr:
elif (
response.response_type
== agent_server_pb2.TaskResponse.OnStepActionStreamStderr
):
return cls(
step_type=StepResponseType.OnStepActionStderr,
step_state=ContextState.new_from(response.state),
step_action_stderr=response.console_stderr,
)
elif response.respond_type == agent_server_pb2.TaskRespond.OnAgentActionEndType:
elif response.response_type == agent_server_pb2.TaskResponse.OnStepActionEnd:
return cls(
step_type=StepResponseType.OnStepActionEnd,
step_state=ContextState.new_from(response.state),
step_action_end=StepActionEnd.new_from(response.on_agent_action_end),
step_action_end=StepActionEnd.new_from(response.on_step_action_end),
)
elif response.respond_type == agent_server_pb2.TaskRespond.OnFinalAnswerType:
elif response.response_type == agent_server_pb2.TaskResponse.OnFinalAnswer:
return cls(
step_type=StepResponseType.OnFinalAnswer,
step_state=ContextState.new_from(response.state),
final_answer=FinalAnswer.new_from(response.final_respond),
final_answer=FinalAnswer.new_from(response.final_answer),
)


Expand All @@ -164,11 +170,9 @@ class TaskRequest(BaseModel):


async def run_task(task: TaskRequest, key):
index = 0
async for respond in agent_sdk.prompt(task.prompt, key, files=task.input_files):
response = StepResponse.new_from(respond).model_dump(exclude_none=True)
yield "\n" + json.dumps(response) if index > 0 else json.dumps(response)
index += 1
yield "data: %s\n" % json.dumps(response)


@app.post("/process")
Expand All @@ -181,7 +185,7 @@ async def process_task(
response.status_code = status.HTTP_401_UNAUTHORIZED
return
response.status_code = status.HTTP_200_OK
response.media_type = "application/json"
response.media_type = "text/event-stream"
agent_sdk.connect()
return StreamingResponse(run_task(task, api_token))

Expand Down
Loading

0 comments on commit 25f317e

Please sign in to comment.