Skip to content

Commit

Permalink
test: fix the test case
Browse files Browse the repository at this point in the history
  • Loading branch information
imotai committed Oct 27, 2023
1 parent a1172be commit 212950c
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 18 deletions.
27 changes: 17 additions & 10 deletions agent/src/og_agent/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,19 @@

""" """
import json
import io
import logging
import time
from typing import List
from pydantic import BaseModel, Field
from og_proto.kernel_server_pb2 import ExecuteResponse
from og_proto.agent_server_pb2 import TaskResponse, ContextState
from og_sdk.utils import parse_image_filename, process_char_stream
from og_proto.agent_server_pb2 import OnStepActionStart, TaskResponse, OnStepActionEnd, FinalAnswer, TypingContent
from .tokenizer import tokenize
import tiktoken

encoding = tiktoken.encoding_for_model("gpt-3.5-turbo")
logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -116,7 +121,7 @@ def _get_message_token_count(self, message):
return response_token_count

async def _read_function_call_message(
self, message, old_text_content, old_code_content, task_context, task_opt
self, message, queue, old_text_content, old_code_content, task_context, task_opt
):
typing_language = "text"
if message["function_call"].get("name", "") in [
Expand All @@ -126,26 +131,23 @@ async def _read_function_call_message(
typing_language = "python"
elif message["function_call"].get("name", "") == "execute_bash_code":
typing_language = "bash"

logger.debug(
f"argument explanation:{explanation_str} code:{code_str} text_content:{old_text_content}"
)
is_code = False
if message["function_call"].get("name", "") == "python":
is_code = True
arguments = message["function_call"].get("arguments", "")
return await self._send_typing_message(
arguments,
queue,
old_text_content,
old_code_content,
typing_language,
queue,
task_context,
task_opt,
is_code=is_code,
)

async def _read_json_message(
self, message, old_text_content, old_code_content, task_context, task_opt
self, message, queue, old_text_content, old_code_content, task_context, task_opt
):
arguments = messages.get("content", "")
typing_language = "text"
Expand All @@ -160,16 +162,18 @@ async def _read_json_message(
old_code_content,
typing_language,
queue,
task_context,
task_opt,
)

async def _send_typing_message(
self,
arguments,
queue,
old_text_content,
old_code_content,
language,
queue,
task_context,
task_opt,
is_code=False,
):
Expand All @@ -190,8 +194,8 @@ async def _send_typing_message(
typing_content=TypingContent(content=typed_chars, language="text"),
)
await queue.put(task_response)
return new_text_content, old_code_context
if code_str and old_code_context != code_str:
return new_text_content, old_code_content
if code_str and old_code_content != code_str:
typed_chars = code_str[len(old_code_content) :]
code_content = code_str
if task_opt.streaming and len(typed_chars) > 0:
Expand All @@ -205,6 +209,7 @@ async def _send_typing_message(
)
)
return old_text_content, code_content
return old_text_content, old_code_content

async def extract_message(
self,
Expand Down Expand Up @@ -248,6 +253,7 @@ async def extract_message(
new_code_content,
) = await self._read_function_call_message(
message,
queue,
text_content,
code_content,
task_context,
Expand All @@ -269,6 +275,7 @@ async def extract_message(
if is_json_format:
await self._read_json_message(
message,
queue,
text_content,
code_content,
task_context,
Expand Down
3 changes: 1 addition & 2 deletions agent/src/og_agent/base_stream_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,11 @@ def __init__(self, endpoint, key):

async def arun(self, request_data):
logging.debug(f"{request_data}")
data = json.dumps(request_data)
headers = {"Authorization": self.key}
async with aiohttp.ClientSession(
headers=headers, raise_for_status=True
) as session:
async with session.post(self.endpoint, data=data) as r:
async with session.post(self.endpoint, json=request_data) as r:
async for line in r.content:
if line:
yield line
4 changes: 2 additions & 2 deletions agent/src/og_agent/llama_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
logger = logging.getLogger(__name__)


class LlamaChatClient(BaseStreamClient):
class LlamaClient(BaseStreamClient):

def __init__(self, endpoint, key, grammar):
super().__init__(endpoint + "/v1/chat/completions", key)
Expand All @@ -25,8 +25,8 @@ async def chat(
):
data = {
"messages": messages,
"grammar": self.grammar,
"temperature": temperature,
"grammar": self.grammar,
"stream": True,
"model": model,
"max_tokens": max_tokens,
Expand Down
2 changes: 1 addition & 1 deletion serving/src/og_serving/http_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,6 @@
def run_serving():
app = create_app(settings)
host = config.get("host", "localhost")
port = int(config.get("port", "9517"))
port = int(config.get("port", "8080"))
logger.info(f"Starting serving at {host}:{port}")
uvicorn.run(app, host=host, port=port)
5 changes: 2 additions & 3 deletions serving/src/og_serving/server_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,7 +568,6 @@ class CreateCompletionRequest(BaseModel):
top_k: int = top_k_field
repeat_penalty: float = repeat_penalty_field
logit_bias_type: Optional[Literal["input_ids", "tokens"]] = Field(None)
grammar: str = Field(default=None)
model_config = {
"json_schema_extra": {
"examples": [{
Expand Down Expand Up @@ -727,7 +726,7 @@ class CreateChatCompletionRequest(BaseModel):
model: Optional[str] = model_field
n: Optional[int] = 1
user: Optional[str] = Field(None)

grammar: str = Field(default=None)
# llama.cpp specific parameters
top_k: int = top_k_field
repeat_penalty: float = repeat_penalty_field
Expand Down Expand Up @@ -767,7 +766,7 @@ async def create_chat_completion(
"user",
}
kwargs = body.model_dump(exclude=exclude)
if "grammar" in kwargs["grammar"] and kwargs["grammar"]:
if "grammar" in kwargs and kwargs["grammar"]:
kwargs["grammar"] = LlamaGrammar.from_string(kwargs["grammar"])
if body.logit_bias is not None:
kwargs["logits_processor"] = llama_cpp.LogitsProcessorList([
Expand Down

0 comments on commit 212950c

Please sign in to comment.