Skip to content

Commit

Permalink
Merge pull request #85 from bespokelabsai/ryanm/invalid-json-from-model
Browse files Browse the repository at this point in the history
Fix JSON parsing from model
  • Loading branch information
RyanMarten authored Nov 13, 2024
2 parents 1c7412c + c7a50f3 commit e7bb89e
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
from typing import Optional

import aiofiles
from datasets import Dataset
from datasets.arrow_writer import ArrowWriter, SchemaInferenceError
from pydantic import BaseModel
import pyarrow
from datasets import Dataset
from datasets.arrow_writer import ArrowWriter
from pydantic import BaseModel, ValidationError

from bespokelabs.curator.prompter.prompt_formatter import PromptFormatter
from bespokelabs.curator.request_processor.event_loop import run_in_event_loop
Expand Down Expand Up @@ -259,11 +259,23 @@ def create_dataset_files(
if prompt_formatter.response_format:
# Response message is a string, which is converted to a dict
# The dict is then used to construct the response_format Pydantic model
response.response_message = (
prompt_formatter.response_format(
**response.response_message
try:
response.response_message = (
prompt_formatter.response_format(
**response.response_message
)
)
)
except ValidationError as e:
warning_msg = (
f"Pydantic failed to parse response message {response.response_message} with `response_format` {prompt_formatter.response_format}."
f"The model likely returned a JSON that does not match the schema of the `response_format`. Will skip this response."
)

logger.warning(warning_msg)
response.response_message = None
response.response_errors = [
f"{warning_msg}. Original error: {str(e)}"
]

# parse_func can return a single row or a list of rows
if prompt_formatter.parse_func:
Expand Down Expand Up @@ -317,3 +329,21 @@ def create_dataset_files(
output_dataset = Dataset.from_file(dataset_file)

return output_dataset


def parse_response_message(
response_message: str, response_format: Optional[BaseModel]
) -> tuple[Optional[dict | str], Optional[list[str]]]:
response_errors = None
if response_format:
try:
response_message = json.loads(response_message)
except json.JSONDecodeError:
logger.warning(
f"Failed to parse response as JSON: {response_message}, skipping this response."
)
response_message = None
response_errors = [
f"Failed to parse response as JSON: {response_message}"
]
return response_message, response_errors
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
import logging
import os
from typing import Optional, Type, TypeVar
from pydantic import BaseModel
from openai import AsyncOpenAI

import aiofiles
from openai import AsyncOpenAI
from pydantic import BaseModel
from tqdm import tqdm

from bespokelabs.curator.dataset import Dataset
Expand All @@ -15,6 +15,7 @@
BaseRequestProcessor,
GenericRequest,
GenericResponse,
parse_response_message,
)
from bespokelabs.curator.request_processor.event_loop import run_in_event_loop

Expand Down Expand Up @@ -413,15 +414,14 @@ async def download_batch_to_generic_responses_file(
else:
# NOTE(Ryan): can we actually parse the response into a an OpenAI ChatCompletions object? Easier to access fields?
# TODO(Ryan): if you add token tokens to generic response
content = raw_response["response"]["body"]["choices"][0][
"message"
]["content"]

if response_format:
content = json.loads(content)

generic_response.response_message = content

choices = raw_response["response"]["body"]["choices"]
# Assuming N = 1
response_message = choices[0]["message"]["content"]
response_message, response_errors = parse_response_message(
response_message, response_format
)
generic_response.response_message = response_message
generic_response.response_errors = response_errors
f.write(
json.dumps(generic_response.model_dump(), default=str)
+ "\n"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
BaseRequestProcessor,
GenericRequest,
GenericResponse,
parse_response_message,
)
from bespokelabs.curator.request_processor.event_loop import run_in_event_loop

Expand Down Expand Up @@ -104,7 +105,6 @@ def create_api_specific_request(
request["response_format"] = {
"type": "json_schema",
"json_schema": {
# TODO(ryan): not sure if we should use strict: True or have name: be something else.
"name": "output_schema",
"schema": generic_request.response_format,
},
Expand Down Expand Up @@ -561,11 +561,12 @@ async def call_api(
)
else:
response_message = response["choices"][0]["message"]["content"]
if self.generic_request.response_format:
response_message = json.loads(response_message)
response_message, response_errors = parse_response_message(
response_message, self.generic_request.response_format
)
generic_response = GenericResponse(
response_message=response_message,
response_errors=None,
response_errors=response_errors,
raw_request=self.api_specific_request_json,
raw_response=response,
generic_request=self.generic_request,
Expand Down

0 comments on commit e7bb89e

Please sign in to comment.