Skip to content

Commit

Permalink
[BugFix] Fix hermes tool parser output error stream arguments in some…
Browse files Browse the repository at this point in the history
… cases (vllm-project#10395) (vllm-project#10398)

Signed-off-by: xiyuan lee <[email protected]>
  • Loading branch information
xiyuan-lee authored and weilong.yu committed Dec 13, 2024
1 parent c7724c9 commit 93dfd65
Showing 1 changed file with 12 additions and 13 deletions.
25 changes: 12 additions & 13 deletions vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
FunctionCall, ToolCall)
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser, ToolParserManager)
from vllm.entrypoints.openai.tool_parsers.utils import (
extract_intermediate_diff)
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.utils import random_uuid
Expand Down Expand Up @@ -190,8 +188,11 @@ def extract_tool_calls_streaming(
diff = self.prev_tool_call_arr[self.current_tool_id].get(
"arguments")
if diff:
diff = json.dumps(diff).replace(
self.streamed_args_for_tool[self.current_tool_id], "")
diff = diff.encode('utf-8').decode(
'unicode_escape') if diff is str else diff
diff = json.dumps(
diff, ensure_ascii=False
)[len(self.streamed_args_for_tool[self.current_tool_id]):]
logger.debug(
"Finishing tool and found diff that had not "
"been streamed yet: %s", diff)
Expand Down Expand Up @@ -307,22 +308,20 @@ def extract_tool_calls_streaming(

# last case -- we have an update to existing arguments.
elif cur_arguments and prev_arguments:
if isinstance(delta_text, str) and len(delta_text.rstrip(
)) >= 1 and delta_text.rstrip()[-1] == '}':
delta_text = delta_text.rstrip()[:-1]

logger.debug("got diff %s", delta_text)

cur_args_json = json.dumps(cur_arguments)
prev_args_json = json.dumps(prev_arguments)
logger.debug("Searching for diff between\n%s", cur_args_json)
logger.debug("and\n%s", prev_args_json)
argument_diff = extract_intermediate_diff(
cur_args_json, prev_args_json)
logger.debug("got argument diff %s", argument_diff)
delta = DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id,
function=DeltaFunctionCall(
arguments=argument_diff).model_dump(
arguments=delta_text).model_dump(
exclude_none=True))
])
self.streamed_args_for_tool[self.current_tool_id] \
+= argument_diff
+= delta_text

# handle saving the state for the current tool into
# the "prev" list for use in diffing for the next iteration
Expand Down

0 comments on commit 93dfd65

Please sign in to comment.