Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix JSON parsing from model #85

Merged
merged 9 commits into from
Nov 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
from typing import Optional

import aiofiles
from datasets import Dataset
from datasets.arrow_writer import ArrowWriter, SchemaInferenceError
from pydantic import BaseModel
import pyarrow
from datasets import Dataset
from datasets.arrow_writer import ArrowWriter
from pydantic import BaseModel, ValidationError

from bespokelabs.curator.prompter.prompt_formatter import PromptFormatter
from bespokelabs.curator.request_processor.event_loop import run_in_event_loop
Expand Down Expand Up @@ -259,11 +259,23 @@ def create_dataset_files(
if prompt_formatter.response_format:
# Response message is a string, which is converted to a dict
# The dict is then used to construct the response_format Pydantic model
response.response_message = (
prompt_formatter.response_format(
**response.response_message
try:
response.response_message = (
prompt_formatter.response_format(
**response.response_message
)
)
)
except ValidationError as e:
warning_msg = (
f"Pydantic failed to parse response message {response.response_message} with `response_format` {prompt_formatter.response_format}."
f"The model likely returned a JSON that does not match the schema of the `response_format`. Will skip this response."
)

logger.warning(warning_msg)
response.response_message = None
response.response_errors = [
f"{warning_msg}. Original error: {str(e)}"
]

# parse_func can return a single row or a list of rows
if prompt_formatter.parse_func:
Expand Down Expand Up @@ -306,3 +318,21 @@ def create_dataset_files(
output_dataset = Dataset.from_file(dataset_file)

return output_dataset


def parse_response_message(
response_message: str, response_format: Optional[BaseModel]
) -> tuple[Optional[dict | str], Optional[list[str]]]:
response_errors = None
if response_format:
try:
response_message = json.loads(response_message)
except json.JSONDecodeError:
logger.warning(
f"Failed to parse response as JSON: {response_message}, skipping this response."
)
response_message = None
response_errors = [
f"Failed to parse response as JSON: {response_message}"
]
return response_message, response_errors
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
import logging
import os
from typing import Optional, Type, TypeVar
from pydantic import BaseModel
from openai import AsyncOpenAI

import aiofiles
from openai import AsyncOpenAI
from pydantic import BaseModel
from tqdm import tqdm

from bespokelabs.curator.dataset import Dataset
Expand All @@ -15,6 +15,7 @@
BaseRequestProcessor,
GenericRequest,
GenericResponse,
parse_response_message,
)
from bespokelabs.curator.request_processor.event_loop import run_in_event_loop

Expand Down Expand Up @@ -413,15 +414,14 @@ async def download_batch_to_generic_responses_file(
else:
# NOTE(Ryan): can we actually parse the response into a an OpenAI ChatCompletions object? Easier to access fields?
# TODO(Ryan): if you add token tokens to generic response
content = raw_response["response"]["body"]["choices"][0][
"message"
]["content"]

if response_format:
content = json.loads(content)

generic_response.response_message = content

choices = raw_response["response"]["body"]["choices"]
# Assuming N = 1
response_message = choices[0]["message"]["content"]
response_message, response_errors = parse_response_message(
response_message, response_format
)
generic_response.response_message = response_message
generic_response.response_errors = response_errors
f.write(
json.dumps(generic_response.model_dump(), default=str)
+ "\n"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
BaseRequestProcessor,
GenericRequest,
GenericResponse,
parse_response_message,
)
from bespokelabs.curator.request_processor.event_loop import run_in_event_loop

Expand Down Expand Up @@ -104,7 +105,6 @@ def create_api_specific_request(
request["response_format"] = {
"type": "json_schema",
"json_schema": {
# TODO(ryan): not sure if we should use strict: True or have name: be something else.
"name": "output_schema",
"schema": generic_request.response_format,
},
Expand Down Expand Up @@ -561,11 +561,12 @@ async def call_api(
)
else:
response_message = response["choices"][0]["message"]["content"]
if self.generic_request.response_format:
response_message = json.loads(response_message)
response_message, response_errors = parse_response_message(
response_message, self.generic_request.response_format
)
generic_response = GenericResponse(
response_message=response_message,
response_errors=None,
response_errors=response_errors,
raw_request=self.api_specific_request_json,
raw_response=response,
generic_request=self.generic_request,
Expand Down