Repairing Partial JSON Parsing in Streaming

61 阅读5分钟

hermes_tool_parser.py

OR 08-13 22:13:17 [hermes_tool_parser.py:370] Error trying to handle streaming tool call. ERROR 08-13 22:13:17 [hermes_tool_parser.py:370] Traceback (most recent call last): ERROR 08-13 22:13:17 [hermes_tool_parser.py:370] File "/home/hzw/miniconda3/envs/vllm_env/lib/python3.11/site-packages/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py", line 241, in extract_tool_calls_streaming ERROR 08-13 22:13:17 [hermes_tool_parser.py:370] current_tool_call = partial_json_parser.loads( ERROR 08-13 22:13:17 [hermes_tool_parser.py:370] ^^^^^^^^^^^^^^^^^^^^^^^^^^ ERROR 08-13 22:13:17 [hermes_tool_parser.py:370] File "/home/hzw/miniconda3/envs/vllm_env/lib/python3.11/site-packages/partial_json_parser/core/api.py", line 15, in parse_json ERROR 08-13 22:13:17 [hermes_tool_parser.py:370] return parser(ensure_json(json_string, allow_partial, use_fast_fix)) ERROR 08-13 22:13:17 [hermes_tool_parser.py:370] Based on the error logs, the issue occurs in the extract_tool_calls_streaming method when trying to parse partial JSON using the partial_json_parser library. The error suggests that the parser is failing when trying to handle incomplete JSON strings during streaming.

Here's a repaired version of the method with better error handling:

# SPDX-License-Identifier: Apache-2.0

# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

  


import json

from collections.abc import Sequence

from typing import Union

  


import partial_json_parser

import regex as re

from partial_json_parser.core.options import Allow

  


from vllm.entrypoints.chat_utils import random_tool_call_id

from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,

DeltaFunctionCall, DeltaMessage,

DeltaToolCall,

ExtractedToolCallInformation,

FunctionCall, ToolCall)

from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (

ToolParser, ToolParserManager)

from vllm.logger import init_logger

from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer

  


logger = init_logger(__name__)

  
  


@ToolParserManager.register_module("hermes")

class Hermes2ProToolParser(ToolParser):

  


def __init__(self, tokenizer: AnyTokenizer):

super().__init__(tokenizer)

  


if isinstance(self.model_tokenizer, MistralTokenizer):

logger.error(

"Detected Mistral tokenizer when using a Hermes model")

self.model_tokenizer = self.model_tokenizer.tokenizer

  


self.current_tool_name_sent: bool = False

self.prev_tool_call_arr: list[dict] = []

self.current_tool_id: int = -1

self.streamed_args_for_tool: list[str] = [

] # map what has been streamed for each tool so far to a list

  


self.tool_call_start_token: str = "<tool_call>"

self.tool_call_end_token: str = "</tool_call>"

  


self.tool_call_regex = re.compile(

r"<tool_call>(.*?)</tool_call>|<tool_call>(.*)", re.DOTALL)

self.scratch_pad_regex = re.compile(

r"<scratch_pad>(.*?)</scratch_pad>", re.DOTALL)

  


if not self.model_tokenizer:

raise ValueError(

"The model tokenizer must be passed to the ToolParser "

"constructor during construction.")

self.tool_call_start_token_id = self.vocab.get(

self.tool_call_start_token)

self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token)

if (self.tool_call_start_token_id is None

or self.tool_call_end_token_id is None):

raise RuntimeError(

"Hermes 2 Pro Tool parser could not locate tool call start/end "

"tokens in the tokenizer!")

  


def extract_tool_calls(

self,

model_output: str,

request: ChatCompletionRequest,

) -> ExtractedToolCallInformation:

  


# sanity check; avoid unnecessary processing

if self.tool_call_start_token not in model_output:

return ExtractedToolCallInformation(tools_called=False,

tool_calls=[],

content=model_output)

  


else:

  


try:

# there are two possible captures - between tags, or between a

# tag and end-of-string so the result of

# findall is an array of tuples where one is a function call and

# the other is None

function_call_tuples = (

self.tool_call_regex.findall(model_output))

  


# load the JSON, and then use it to build the Function and

# Tool Call

raw_function_calls = [

json.loads(match[0] if match[0] else match[1])

for match in function_call_tuples

]

tool_calls = [

ToolCall(

type="function",

function=FunctionCall(

name=function_call["name"],

# function call args are JSON but as a string

arguments=json.dumps(function_call["arguments"],

ensure_ascii=False)))

for function_call in raw_function_calls

]

  


content = model_output[:model_output.

find(self.tool_call_start_token)]

return ExtractedToolCallInformation(

tools_called=True,

tool_calls=tool_calls,

content=content if content else None)

  


except Exception:

logger.exception(

"Error in extracting tool call from response.")

return ExtractedToolCallInformation(tools_called=False,

tool_calls=[],

content=model_output)

  


def extract_tool_calls_streaming(

self,

previous_text: str,

current_text: str,

delta_text: str,

previous_token_ids: Sequence[int],

current_token_ids: Sequence[int],

delta_token_ids: Sequence[int],

request: ChatCompletionRequest,

) -> Union[DeltaMessage, None]:

  


logger.debug("delta_text: %s", delta_text)

logger.debug("delta_token_ids: %s", delta_token_ids)

# check to see if we should be streaming a tool call

if self.tool_call_start_token_id not in current_token_ids:

logger.debug("No tool call tokens found!")

return DeltaMessage(content=delta_text)

  


try:

# figure out where we are in the parsing by counting tool call start & end tags

prev_tool_start_count = previous_token_ids.count(self.tool_call_start_token_id)

prev_tool_end_count = previous_token_ids.count(self.tool_call_end_token_id)

cur_tool_start_count = current_token_ids.count(self.tool_call_start_token_id)

cur_tool_end_count = current_token_ids.count(self.tool_call_end_token_id)

tool_call_portion = None

text_portion = None

  


# case: if we're generating text, OR rounding out a tool call

if (cur_tool_start_count == cur_tool_end_count

and prev_tool_end_count == cur_tool_end_count

and self.tool_call_end_token not in delta_text):

logger.debug("Generating text content! skipping tool parsing.")

return DeltaMessage(content=delta_text)

  


if self.tool_call_end_token in delta_text:

logger.debug("tool_call_end_token in delta_text")

try:

full_text = current_text + delta_text

tool_call_portion = full_text.split(self.tool_call_start_token)[-1].split(

self.tool_call_end_token)[0].rstrip()

delta_text = delta_text.split(self.tool_call_end_token)[0].rstrip()

text_portion = delta_text.split(self.tool_call_end_token)[-1].lstrip()

except IndexError:

logger.debug("Incomplete tool call tags, skipping")

return DeltaMessage(content=delta_text)

  


# flags for partial JSON parsing

flags = Allow.ALL if self.current_tool_name_sent else Allow.ALL & ~Allow.STR

  


# case -- we're starting a new tool call

if (cur_tool_start_count > cur_tool_end_count

and cur_tool_start_count > prev_tool_start_count):

if len(delta_token_ids) > 1:

try:

tool_call_portion = current_text.split(self.tool_call_start_token)[-1]

except IndexError:

tool_call_portion = None

else:

tool_call_portion = None

  


text_portion = None

  


# set cursors and state appropriately

self.current_tool_id += 1

self.current_tool_name_sent = False

self.streamed_args_for_tool.append("")

logger.debug("Starting on a new tool %s", self.current_tool_id)

  


# case -- we're updating an existing tool call

elif (cur_tool_start_count > cur_tool_end_count

and cur_tool_start_count == prev_tool_start_count):

try:

# get the portion of the text that's the tool call

tool_call_portion = current_text.split(self.tool_call_start_token)[-1]

except IndexError:

tool_call_portion = None

text_portion = None

  


# case -- the current tool call is being closed.

elif (cur_tool_start_count == cur_tool_end_count

and cur_tool_end_count >= prev_tool_end_count):

if (not self.prev_tool_call_arr or len(self.prev_tool_call_arr) == 0):

logger.debug("attempting to close tool call, but no tool call")

return None

diff = self.prev_tool_call_arr[self.current_tool_id].get("arguments")

if diff:

try:

diff = diff.encode('utf-8').decode('unicode_escape') if isinstance(diff, str) else diff

if ('"}' not in delta_text):

return None

end_loc = delta_text.rindex('"}')

diff = delta_text[:end_loc] + '"}'

logger.debug("Finishing tool and found diff that had not been streamed yet: %s", diff)

self.streamed_args_for_tool[self.current_tool_id] += diff

return DeltaMessage(tool_calls=[

DeltaToolCall(index=self.current_tool_id,

function=DeltaFunctionCall(

arguments=diff).model_dump(

exclude_none=True))

])

except (ValueError, AttributeError) as e:

logger.debug(f"Error processing tool call arguments: {e}")

return None

  


# case -- otherwise we're just generating text

else:

text = delta_text.replace(self.tool_call_start_token, "")

text = text.replace(self.tool_call_end_token, "")

delta = DeltaMessage(tool_calls=[], content=text)

return delta

  


try:

if tool_call_portion:

# More robust JSON parsing with error handling

try:

current_tool_call = partial_json_parser.loads(

tool_call_portion, flags) if tool_call_portion else None

logger.debug("Parsed tool call %s", current_tool_call)

except (partial_json_parser.core.exceptions.MalformedJSON,

json.decoder.JSONDecodeError) as e:

logger.debug(f'Partial JSON parsing failed: {e}')

return None

else:

current_tool_call = None

except Exception as e:

logger.debug(f"Error in tool call processing: {e}")

return None

  


# case - we haven't sent the tool name yet

if not self.current_tool_name_sent:

if current_tool_call is None:

return None

function_name = current_tool_call.get("name")

if function_name:

self.current_tool_name_sent = True

return DeltaMessage(tool_calls=[

DeltaToolCall(index=self.current_tool_id,

type="function",

id=random_tool_call_id(),

function=DeltaFunctionCall(

name=function_name).model_dump(

exclude_none=True))

])

else:

return None

  


# if the tool call portion is None, send the delta as text

if tool_call_portion is None:

return DeltaMessage(content=delta_text) if text_portion is not None else None

  


logger.debug("Trying to parse current tool call with ID %s", self.current_tool_id)

  


# if we're starting a new tool call, push an empty object as a placeholder

if len(self.prev_tool_call_arr) <= self.current_tool_id:

self.prev_tool_call_arr.append({})

  


# main logic for tool parsing

prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get("arguments")

cur_arguments = current_tool_call.get("arguments") if current_tool_call else None

  


logger.debug("diffing old arguments: %s", prev_arguments)

logger.debug("against new ones: %s", cur_arguments)

  


# Handle different cases for argument processing

try:

if not cur_arguments and not prev_arguments:

logger.debug("Skipping text %s - no arguments", delta_text)

delta = None

elif not cur_arguments and prev_arguments:

logger.error("Arguments reset mid-call. Skipping streaming.")

delta = None

elif cur_arguments and not prev_arguments:

try:

cur_arguments_json = json.dumps(cur_arguments, ensure_ascii=False)

if delta_text not in cur_arguments_json[:-2]:

return None

args_delta_start_loc = cur_arguments_json[:-2].rindex(delta_text) + len(delta_text)

arguments_delta = cur_arguments_json[:args_delta_start_loc]

logger.debug("First tokens in arguments received: %s", arguments_delta)

delta = DeltaMessage(tool_calls=[

DeltaToolCall(index=self.current_tool_id,

function=DeltaFunctionCall(

arguments=arguments_delta).model_dump(

exclude_none=True))

])

self.streamed_args_for_tool[self.current_tool_id] += arguments_delta

except (ValueError, TypeError) as e:

logger.debug(f"Error processing initial arguments: {e}")

return None

elif cur_arguments and prev_arguments:

try:

if isinstance(delta_text, str) and len(delta_text.rstrip()) >= 1 and delta_text.rstrip()[-1] == '}':

delta_text = delta_text.rstrip()[:-1]

logger.debug("got diff %s", delta_text)

delta = DeltaMessage(tool_calls=[

DeltaToolCall(index=self.current_tool_id,

function=DeltaFunctionCall(

arguments=delta_text).model_dump(

exclude_none=True))

])

self.streamed_args_for_tool[self.current_tool_id] += delta_text

except Exception as e:

logger.debug(f"Error processing argument delta: {e}")

return None

except Exception as e:

logger.debug(f"Error in argument processing: {e}")

return None

  


# handle saving the state for the current tool

try:

if self.current_tool_id == len(self.prev_tool_call_arr) - 1:

self.prev_tool_call_arr[self.current_tool_id] = current_tool_call

else:

self.prev_tool_call_arr.append(current_tool_call)

except Exception as e:

logger.debug(f"Error saving tool call state: {e}")

  


return delta

  


except Exception as e:

logger.exception(f"Error in streaming tool call processing: {e}")

return None