From 855a3b9beca07deadf3f977e1c288695b3d6574a Mon Sep 17 00:00:00 2001 From: Ryan Marten Date: Thu, 7 Nov 2024 12:29:07 -0800 Subject: [PATCH 01/18] remove example from readme that is no longer applicable --- README.md | 1 - 1 file changed, 1 deletion(-) diff --git a/README.md b/README.md index 0663636c..8ec54feb 100644 --- a/README.md +++ b/README.md @@ -14,7 +14,6 @@ To run the camel example: ```bash poetry run python examples/camel.py -poetry run python examples/openhermesv3.py ``` Run the tests: From c1f0a25a2d589fdd8f110a6b6aca6d99d7a0d777 Mon Sep 17 00:00:00 2001 From: Ryan Marten Date: Thu, 7 Nov 2024 12:48:46 -0800 Subject: [PATCH 02/18] remove prompt_formatter from BaseRequestProcessor init --- poetry.lock | 9 +++++---- pyproject.toml | 1 + .../request_processor/base_request_processor.py | 17 ++++++++--------- 3 files changed, 14 insertions(+), 13 deletions(-) diff --git a/poetry.lock b/poetry.lock index 1e3ab82d..4d0a3c57 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2967,13 +2967,13 @@ files = [ [[package]] name = "tqdm" -version = "4.66.6" +version = "4.67.0" description = "Fast, Extensible Progress Meter" optional = false python-versions = ">=3.7" files = [ - {file = "tqdm-4.66.6-py3-none-any.whl", hash = "sha256:223e8b5359c2efc4b30555531f09e9f2f3589bcd7fdd389271191031b49b7a63"}, - {file = "tqdm-4.66.6.tar.gz", hash = "sha256:4bdd694238bef1485ce839d67967ab50af8f9272aab687c0d7702a01da0be090"}, + {file = "tqdm-4.67.0-py3-none-any.whl", hash = "sha256:0cd8af9d56911acab92182e88d763100d4788bdf421d251616040cc4d44863be"}, + {file = "tqdm-4.67.0.tar.gz", hash = "sha256:fe5a6f95e6fe0b9755e9469b77b9c3cf850048224ecaa8293d7d2d31f97d869a"}, ] [package.dependencies] @@ -2981,6 +2981,7 @@ colorama = {version = "*", markers = "platform_system == \"Windows\""} [package.extras] dev = ["pytest (>=6)", "pytest-cov", "pytest-timeout", "pytest-xdist"] +discord = ["requests"] notebook = ["ipywidgets (>=6)"] slack = ["slack-sdk"] telegram = ["requests"] @@ -3339,4 +3340,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "ff68e5d9c7907a62260570558c0c0238b48810b71748691e77639928a8c449c0" +content-hash = "20e3c109fa84799358c28bde59cdc894bc1dcc28a47736dbf30d6ba63f8db705" diff --git a/pyproject.toml b/pyproject.toml index f1fb426e..ea0e6c20 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,6 +21,7 @@ pytest = "^8.3.3" pytest-asyncio = "^0.24.0" pandas = "^2.2.3" xxhash = "^3.5.0" +tqdm = "^4.67.0" [tool.poetry.group.dev.dependencies] black = "^24.2.0" diff --git a/src/bespokelabs/curator/request_processor/base_request_processor.py b/src/bespokelabs/curator/request_processor/base_request_processor.py index 70a8cffc..833ff28e 100644 --- a/src/bespokelabs/curator/request_processor/base_request_processor.py +++ b/src/bespokelabs/curator/request_processor/base_request_processor.py @@ -1,11 +1,13 @@ import json import logging import os +import asyncio +import glob +import aiofiles + from abc import ABC, abstractmethod from typing import Dict, List, Optional - -import tqdm -from tqdm import tqdm +from math import ceil from bespokelabs.curator.dataset import Dataset from bespokelabs.curator.prompter.prompt_formatter import PromptFormatter @@ -18,9 +20,6 @@ class BaseRequestProcessor(ABC): Base class for all request processors. """ - def __init__(self, prompt_formatter: PromptFormatter): - self.prompt_formatter = prompt_formatter - @abstractmethod def get_rate_limits(self) -> dict: """ @@ -78,7 +77,7 @@ def run(self, dataset: Dataset, working_dir: str) -> Dataset: """ pass - def create_request_files(self, dataset: Optional[Dataset], working_dir: str) -> str: + def create_request_files(self, dataset: Optional[Dataset], working_dir: str, prompt_formatter: PromptFormatter) -> str: """ Creates a request file if they don't already exist or use existing. @@ -123,12 +122,12 @@ def create_request_files(self, dataset: Optional[Dataset], working_dir: str) -> # Create new requests file with open(requests_file, "w") as f: if dataset is None: - request = self.prompt_formatter.get_generic_request(dict(), 0) + request = prompt_formatter.get_generic_request(dict(), 0) api_request = self.create_api_specific_request(request) f.write(json.dumps(api_request) + "\n") else: for dataset_row_idx, dataset_row in enumerate(dataset): - request = self.prompt_formatter.get_generic_request( + request = prompt_formatter.get_generic_request( dataset_row, dataset_row_idx ) # Convert the generic request to an API-specific request From 8d9d955482ee31088b765000eb6834eed59f9d0f Mon Sep 17 00:00:00 2001 From: Ryan Marten Date: Thu, 7 Nov 2024 13:13:32 -0800 Subject: [PATCH 03/18] added request file batch creation --- src/bespokelabs/curator/dataset.py | 2 +- src/bespokelabs/curator/db.py | 12 +- src/bespokelabs/curator/prompter/prompter.py | 16 +- .../base_request_processor.py | 63 +- ...r.py => openai_batch_request_processor.py} | 0 .../openai_online_request_processor.py | 697 ++++++++++++++++++ tests/test_prompt.py | 4 +- 7 files changed, 767 insertions(+), 27 deletions(-) rename src/bespokelabs/curator/request_processor/{openai_request_processor.py => openai_batch_request_processor.py} (100%) create mode 100644 src/bespokelabs/curator/request_processor/openai_online_request_processor.py diff --git a/src/bespokelabs/curator/dataset.py b/src/bespokelabs/curator/dataset.py index a0d42c81..25255e95 100644 --- a/src/bespokelabs/curator/dataset.py +++ b/src/bespokelabs/curator/dataset.py @@ -93,7 +93,7 @@ def to_huggingface(self, in_memory: bool = False) -> None: response.response = self.prompt_formatter.response_format( **response.response ) - + # TODO(Ryan): We can make this more sophisticated by making response_generic a class if response is None: failed_responses_count += 1 diff --git a/src/bespokelabs/curator/db.py b/src/bespokelabs/curator/db.py index 2877c3c3..cdcd39d3 100644 --- a/src/bespokelabs/curator/db.py +++ b/src/bespokelabs/curator/db.py @@ -41,8 +41,7 @@ def store_metadata(self, metadata: dict): # Check if run_hash exists cursor.execute( - "SELECT run_hash FROM runs WHERE run_hash = ?", - (metadata["run_hash"],) + "SELECT run_hash FROM runs WHERE run_hash = ?", (metadata["run_hash"],) ) existing_run = cursor.fetchone() @@ -54,10 +53,7 @@ def store_metadata(self, metadata: dict): SET last_edited_time = ? WHERE run_hash = ? """, - ( - metadata["timestamp"], - metadata["run_hash"] - ) + (metadata["timestamp"], metadata["run_hash"]), ) else: # Insert new entry @@ -75,7 +71,7 @@ def store_metadata(self, metadata: dict): metadata["model_name"], metadata["response_format"], metadata["timestamp"], - '-' - ) + "-", + ), ) conn.commit() diff --git a/src/bespokelabs/curator/prompter/prompter.py b/src/bespokelabs/curator/prompter/prompter.py index c219321b..f4c62dce 100644 --- a/src/bespokelabs/curator/prompter/prompter.py +++ b/src/bespokelabs/curator/prompter/prompter.py @@ -17,8 +17,12 @@ from bespokelabs.curator.db import MetadataDB from bespokelabs.curator.prompter.prompt_formatter import PromptFormatter from bespokelabs.curator.request_processor.generic_request import GenericRequest -from bespokelabs.curator.request_processor.openai_request_processor import ( +from bespokelabs.curator.request_processor.openai_batch_request_processor import ( OpenAIRequestProcessor, + OpenAIOnlineRequestProcessor, +) +from bespokelabs.curator.request_processor.base_request_processor import ( + BaseRequestProcessor, ) T = TypeVar("T") @@ -37,6 +41,9 @@ def __init__( ] ] = None, response_format: Optional[Type[BaseModel]] = None, + request_processor: Optional[ + BaseRequestProcessor + ] = OpenAIOnlineRequestProcessor, ): """Initialize a Prompter. @@ -66,6 +73,8 @@ def __init__( model_name, prompt_func, parse_func, response_format ) + self.request_processor = request_processor + def __call__(self, dataset: Optional[Iterable] = None): """Run completions on a dataset.""" return self._completions(dataset) @@ -145,8 +154,9 @@ def _completions( } metadata_db.store_metadata(metadata_dict) - request_processor = OpenAIRequestProcessor(self.prompt_formatter) - return request_processor.run(dataset, f"{curator_cache_dir}/{fingerprint}") + return self.request_processor.run( + dataset, f"{curator_cache_dir}/{fingerprint}", self.prompt_formatter + ) def _hash_chunk(chunks: list) -> list: diff --git a/src/bespokelabs/curator/request_processor/base_request_processor.py b/src/bespokelabs/curator/request_processor/base_request_processor.py index 833ff28e..b8e91ffc 100644 --- a/src/bespokelabs/curator/request_processor/base_request_processor.py +++ b/src/bespokelabs/curator/request_processor/base_request_processor.py @@ -64,7 +64,9 @@ def get_generic_response(self, response: dict) -> GenericResponse: pass @abstractmethod - def run(self, dataset: Dataset, working_dir: str) -> Dataset: + def run( + self, dataset: Dataset, working_dir: str, prompt_formatter: PromptFormatter + ) -> Dataset: """ Uses the API to completing the specific map by calling the LLM. @@ -77,7 +79,13 @@ def run(self, dataset: Dataset, working_dir: str) -> Dataset: """ pass - def create_request_files(self, dataset: Optional[Dataset], working_dir: str, prompt_formatter: PromptFormatter) -> str: + def create_request_files( + self, + dataset: Optional[Dataset], + working_dir: str, + prompt_formatter: PromptFormatter, + batch_size: Optional[int] = None, + ) -> str: """ Creates a request file if they don't already exist or use existing. @@ -89,16 +97,16 @@ def create_request_files(self, dataset: Optional[Dataset], working_dir: str, pro str: Path to the request file that was created. """ os.makedirs(working_dir, exist_ok=True) - requests_file = f"{working_dir}/requests.jsonl" + requests_files = glob.glob(f"{working_dir}/requests_*.jsonl") # By default use existing requests in working_dir - if os.path.exists(requests_file): + if len(requests_files) > 0: logging.info( - f"Using existing requests in {working_dir} by default. " + f"Using existing requests in {working_dir} by default. Found {len(requests_files)} request files." f"If this is not what you want, delete the directory or specify a new one and re-run." ) # count existing jobs in file and print first job - with open(requests_file, "r") as f: + with open(requests_files[0], "r") as f: # Count lines and store first job first_job = None num_jobs = 0 @@ -109,16 +117,30 @@ def create_request_files(self, dataset: Optional[Dataset], working_dir: str, pro if num_jobs > 0: logging.info( - f"There are {num_jobs} existing requests in {requests_file}" + f"There are {num_jobs} existing requests in {requests_files[0]}" ) - logging.info(f"Example request:\n{json.dumps(first_job, indent=2)}") - return requests_file - else: - logging.warning( - f"No requests found in {requests_file}. Will delete the file and start over." + logging.info( + f"Example request in {requests_files[0]}:\n{json.dumps(first_job, indent=2)}" ) - os.remove(requests_file) + # Some simple sanity checks for the user + if batch_size is not None: + if batch_size != num_jobs: + logging.warning( + f"Batch size is {batch_size}, but there are {num_jobs} requests in {requests_files[0]}. " + f"If you want to run with new batch size, you will have to delete the working directory and re-run (looses progress)" + ) + if len(requests_files) == 1 and len(dataset) > batch_size: + logging.warning( + f"Only one request file was found, but batch size is specified and dataset is larger than batch size." + f"You might be resuming from a different dataset or weren't using batching before." + f"If you want to run with batching, you will have to delete working directory and re-run (looses progress)" + ) + return requests_files + + request_count = 0 + request_file_idx = 0 + requests_file = f"{working_dir}/requests_{request_file_idx}.jsonl" # Create new requests file with open(requests_file, "w") as f: if dataset is None: @@ -134,5 +156,18 @@ def create_request_files(self, dataset: Optional[Dataset], working_dir: str, pro api_request = self.create_api_specific_request(request) # Write the API-specific request to file f.write(json.dumps(api_request) + "\n") - logging.info(f"Requests file {requests_file} written to disk.") + request_count += 1 + + # Batches could be created in parallel, but dataset is iterated sequentially + if batch_size is not None and request_count == batch_size: + request_count = 0 + request_file_idx += 1 + requests_file = ( + f"{working_dir}/requests_{request_file_idx}.jsonl" + ) + logging.info( + f"Wrote {request_count:,} requests to {requests_file}." + ) + if request_count > 0: + logging.info(f"Wrote {request_count:,} requests to {requests_file}.") return requests_file diff --git a/src/bespokelabs/curator/request_processor/openai_request_processor.py b/src/bespokelabs/curator/request_processor/openai_batch_request_processor.py similarity index 100% rename from src/bespokelabs/curator/request_processor/openai_request_processor.py rename to src/bespokelabs/curator/request_processor/openai_batch_request_processor.py diff --git a/src/bespokelabs/curator/request_processor/openai_online_request_processor.py b/src/bespokelabs/curator/request_processor/openai_online_request_processor.py new file mode 100644 index 00000000..32d495c0 --- /dev/null +++ b/src/bespokelabs/curator/request_processor/openai_online_request_processor.py @@ -0,0 +1,697 @@ +import asyncio +import json +import logging +import os +import re +import time +from dataclasses import dataclass, field +from typing import Any, Callable, Dict, Optional, Set, Tuple, TypeVar + +import aiohttp +import requests +import tiktoken +from tqdm import tqdm + +from bespokelabs.curator.dataset import Dataset +from bespokelabs.curator.request_processor.base_request_processor import ( + BaseRequestProcessor, + GenericRequest, + GenericResponse, +) + +T = TypeVar("T") + + +class OpenAIRequestProcessor(BaseRequestProcessor): + model: str + url: str = "https://api.openai.com/v1/chat/completions" + api_key: str = os.getenv("OPENAI_API_KEY") + + def get_rate_limits(self) -> dict: + """ + Function to get rate limits for a given annotator. Makes a single request to openAI API + and gets the rate limits from the response headers. These rate limits vary per model + and are determined by your organization's usage tier. View the following: + https://platform.openai.com/docs/guides/rate-limits/usage-tiers + https://platform.openai.com/settings/organization/limits + + Args: + model (str): The model for which to get the rate limits. + request_url (str): The request URL for which to get the rate limits. + + Returns: + tuple[int, int]: A tuple containing the maximum number of requests and tokens per minute. + """ + # Send a dummy request to get rate limit information + response = requests.post( + self.url, + headers={"Authorization": f"Bearer {self.api_key}"}, + json={"model": self.prompt_formatter.model_name, "messages": []}, + ) + + rpm = int(response.headers.get("x-ratelimit-limit-requests", 0)) + tpm = int(response.headers.get("x-ratelimit-limit-tokens", 0)) + + if not rpm or not tpm: + logging.warning( + "Failed to get rate limits from OpenAI API, using default values" + ) + rpm = 30_000 + tpm = 150_000_000 + + logging.info(f"Automatically set max_requests_per_minute to {rpm}") + logging.info(f"Automatically set max_tokens_per_minute to {tpm}") + + rate_limits = { + "max_requests_per_minute": rpm, + "max_tokens_per_minute": tpm, + } + + return rate_limits + + def create_api_specific_request(self, generic_request: GenericRequest) -> dict: + """ + Creates a API-specific request body from a generic request body. + + Using the api_parallel_processor, we can store whatever we want in the metadata. We will store both the row and the index. + This is so we can later construct the new dataset row. + + Returns: + dict: API specific request body + """ + if generic_request.response_format: + request = { + "model": generic_request.model, + "messages": generic_request.messages, + "response_format": { + "type": "json_schema", + "json_schema": { + # TODO(ryan): not sure if this should be something else. + # TODO(ryan): also not sure if we should use strict: True + "name": "output_schema", + "schema": generic_request.response_format.model_json_schema(), + }, + }, + "metadata": { + "request_idx": generic_request.row_idx, + "sample": generic_request.row, + }, + } + else: + request = { + "model": generic_request.model, + "messages": generic_request.messages, + "metadata": { + "request_idx": generic_request.row_idx, + "sample": generic_request.row, + }, + } + + return request + + def get_generic_response(self, response: Dict) -> GenericResponse: + """ + Parses a API-specific response into a generic response body. + Does error handling on the response. + If there is an error, return None. + + IMPORTANT: In the generic response body you need to provide either the original dataset row OR the index of the row in the original dataset. + + Args: + response: API-specific response + + Returns: + dict: Generic response body with an extra field "metadata" which contains the original dataset row or the index of the row in the original dataset + """ + content = response["response"]["choices"][0]["message"]["content"] + if self.prompt_formatter.response_format: + content = json.loads(content) + + return GenericResponse( + response=content, + row=response["metadata"]["sample"], + row_idx=response["metadata"]["request_idx"], + ) + + def run( + self, + dataset: Optional[Dataset], + working_dir: str, + ) -> Dataset: + """ + Uses the API to completing the specific map by calling the LLM. + + Args: + dataset (Dataset): Dataset that is being mapped over + working_dir (str): Working directory to save files (requests.jsonl, responses.jsonl, dataset.arrow) + + Returns: + Dataset: Completed dataset + """ + requests_file = self.create_request_files(dataset, working_dir) + responses_file = f"{working_dir}/responses.jsonl" + + rate_limits = self.get_rate_limits() + rpm = rate_limits["max_requests_per_minute"] + tpm = rate_limits["max_tokens_per_minute"] + + token_encoding_name = get_token_encoding_name(self.prompt_formatter.model_name) + + # NOTE(Ryan): If you wanted to do this on batches, you could run a for loop here about request_files. Although I don't recommend it because you are waiting for straggler requests to finish for each batch. + # NOTE(Ryan): And if you wanted to do batches in parallel, you would have to divide rpm and tpm by the number of parallel batches. + # TODO(Ryan): Can we abstract retries from process_api_requests_from_file so you can use it even if you use liteLLM. + i = 0 + asyncio.run( + self.process_api_requests_from_file( + requests_filepath=requests_file, + save_filepath=responses_file, + request_url=self.url, + max_requests_per_minute=rpm, + max_tokens_per_minute=tpm, + token_encoding_name=token_encoding_name, + max_attempts=5, + resume=True, # detects existing jobs and resume from there + ) + ) + + return Dataset.from_working_dir(working_dir, self.prompt_formatter) + + async def process_api_requests_from_file( + self, + requests_filepath: str, + save_filepath: str, + request_url: str, + max_requests_per_minute: float, + max_tokens_per_minute: float, + token_encoding_name: str, + max_attempts: int, + resume: bool, + resume_no_retry: bool = False, + ) -> None: + """Processes API requests in parallel, throttling to stay under rate limits.""" + # constants + seconds_to_pause_after_rate_limit_error = 15 + seconds_to_sleep_each_loop = ( + 0.001 # 1 ms limits max throughput to 1,000 requests per second + ) + + # infer API endpoint and construct request header + api_endpoint = api_endpoint_from_url(self.url) + request_header = {"Authorization": f"Bearer {self.api_key}"} + # use api-key header for Azure deployments + if "/deployments" in self.url: + request_header = {"api-key": f"{self.api_key}"} + + # initialize trackers + queue_of_requests_to_retry = asyncio.Queue() + task_id_generator = ( + task_id_generator_function() + ) # generates integer IDs of 0, 1, 2, ... + status_tracker = ( + StatusTracker() + ) # single instance to track a collection of variables + next_request = None # variable to hold the next request to call + + # initialize available capacity counts + available_request_capacity = max_requests_per_minute + available_token_capacity = max_tokens_per_minute + last_update_time = time.time() + + # initialize flags + file_not_finished = True # after file is empty, we'll skip reading it + logging.debug(f"Initialization complete.") + + completed_request_ids: Set[int] = set() + if os.path.exists(save_filepath): + if resume: + # save all successfully completed requests to a temporary file, then overwrite the original file with the temporary file + logging.debug(f"Resuming progress from existing file: {save_filepath}") + logging.debug( + f"Removing all failed requests from {save_filepath} so they can be retried" + ) + temp_filepath = f"{save_filepath}.temp" + num_previously_failed_requests = 0 + with open(save_filepath, "r") as input_file, open( + temp_filepath, "w" + ) as output_file: + for line in input_file: + response = GenericResponse.model_validate_json(line) + if response.errors: + # this means that the request failed and we have a list of errors + logging.debug( + f"Request {response.row_idx} previously failed due to errors: {response.errors}, removing from output and will retry" + ) + num_previously_failed_requests += 1 + else: + completed_request_ids.add(response.row_idx) + output_file.write(line) + logging.info( + f"Found {len(completed_request_ids)} completed requests and {num_previously_failed_requests} previously failed requests" + ) + logging.info( + "Failed requests and remaining requests will now be processed." + ) + os.replace(temp_filepath, save_filepath) + elif resume_no_retry: + logging.warning( + f"Resuming progress from existing file: {save_filepath}, without retrying failed requests" + ) + num_previously_failed_requests = 0 + with open(save_filepath, "r") as input_file, open( + temp_filepath, "w" + ) as output_file: + for line in tqdm(input_file, desc="Processing existing requests"): + data = json.loads(line) + if isinstance(data[1], list): + # this means that the request failed and we have a list of errors + logging.debug( + f"Request {data[2].get('request_idx')} previously failed due to errors: {data[1]}, will NOT retry" + ) + num_previously_failed_requests += 1 + completed_request_ids.add(data[2].get("request_idx")) + logging.info( + f"Found {len(completed_request_ids)} total requests and {num_previously_failed_requests} previously failed requests" + ) + logging.info("Remaining requests will now be processed.") + else: + user_input = input( + f"File {save_filepath} already exists.\nTo resume if there are remaining requests without responses, run with --resume flag.\nOverwrite? (Y/n): " + ) + if user_input.lower() != "y" and user_input.lower() != "": + logging.info("Aborting operation.") + return + + # initialize file reading + with open(requests_filepath) as file: + # `requests` will provide requests one at a time + requests = file.__iter__() + logging.debug(f"File opened. Entering main loop") + + # Count total number of requests + total_requests = sum(1 for _ in open(requests_filepath)) + if total_requests == len(completed_request_ids): + logging.debug( + "All requests have already been completed so will just reuse cache." + ) + return + + # Create progress bar + pbar = tqdm( + total=total_requests, desc="Processing parallel requests to OpenAI" + ) + + connector = aiohttp.TCPConnector(limit=10 * max_requests_per_minute) + async with aiohttp.ClientSession( + connector=connector + ) as session: # Initialize ClientSession here + while True: + # get next request (if one is not already waiting for capacity) + if next_request is None: + if not queue_of_requests_to_retry.empty(): + next_request = queue_of_requests_to_retry.get_nowait() + logging.debug( + f"Retrying request {next_request.task_id}: {next_request}" + ) + elif file_not_finished: + try: + # get new request + request_json = json.loads(next(requests)) + request_idx = request_json["metadata"]["request_idx"] + if resume and request_idx in completed_request_ids: + logging.debug( + f"Skipping already completed request {request_idx}" + ) + status_tracker.num_tasks_already_completed += 1 + continue + next_request = APIRequest( + task_id=next(task_id_generator), + request_json=request_json, + token_consumption=num_tokens_consumed_from_request( + request_json, api_endpoint, token_encoding_name + ), + attempts_left=max_attempts, + metadata=request_json.pop("metadata", None), + ) + status_tracker.num_tasks_started += 1 + status_tracker.num_tasks_in_progress += 1 + logging.debug( + f"Reading request {next_request.task_id}: {next_request}" + ) + except StopIteration: + # if file runs out, set flag to stop reading it + logging.debug("Read file exhausted") + file_not_finished = False + + # update available capacity + current_time = time.time() + seconds_since_update = current_time - last_update_time + available_request_capacity = min( + available_request_capacity + + max_requests_per_minute * seconds_since_update / 60.0, + max_requests_per_minute, + ) + available_token_capacity = min( + available_token_capacity + + max_tokens_per_minute * seconds_since_update / 60.0, + max_tokens_per_minute, + ) + last_update_time = current_time + + # if enough capacity available, call API + if next_request: + next_request_tokens = next_request.token_consumption + if ( + available_request_capacity >= 1 + and available_token_capacity >= next_request_tokens + ): + # update counters + available_request_capacity -= 1 + available_token_capacity -= next_request_tokens + next_request.attempts_left -= 1 + + # call API + asyncio.create_task( + next_request.call_api( + session=session, + request_url=request_url, + request_header=request_header, + retry_queue=queue_of_requests_to_retry, + save_filepath=save_filepath, + status_tracker=status_tracker, + get_generic_response=self.get_generic_response, + ) + ) + next_request = None # reset next_request to empty + else: + logging.debug( + f"Not Enough Capacity: Request tokens: {next_request_tokens}, Available request capacity: {available_request_capacity}, Available token capacity: {available_token_capacity}" + ) + + # Update progress bar when a task is completed + total_completed = ( + status_tracker.num_tasks_succeeded + + status_tracker.num_tasks_failed + + status_tracker.num_tasks_already_completed + ) + if total_completed > pbar.n: + pbar.update(total_completed - pbar.n) + + # if all tasks are finished, break + if status_tracker.num_tasks_in_progress == 0: + break + + # main loop sleeps briefly so concurrent tasks can run + await asyncio.sleep(seconds_to_sleep_each_loop) + + # if a rate limit error was hit recently, pause to cool down + seconds_since_rate_limit_error = ( + time.time() - status_tracker.time_of_last_rate_limit_error + ) + if ( + seconds_since_rate_limit_error + < seconds_to_pause_after_rate_limit_error + ): + remaining_seconds_to_pause = ( + seconds_to_pause_after_rate_limit_error + - seconds_since_rate_limit_error + ) + await asyncio.sleep(remaining_seconds_to_pause) + # ^e.g., if pause is 15 seconds and final limit was hit 5 seconds ago + logging.warn( + f"Pausing to cool down until {time.ctime(status_tracker.time_of_last_rate_limit_error + seconds_to_pause_after_rate_limit_error)}" + ) + + # Close the progress bar + pbar.close() + + # after finishing, log final status + logging.info( + f"""Parallel processing complete. Results saved to {save_filepath}""" + ) + + logging.info(f"Status tracker: {status_tracker}") + + if status_tracker.num_tasks_failed > 0: + logging.warning( + f"{status_tracker.num_tasks_failed} / {status_tracker.num_tasks_started} requests failed. Errors logged to {save_filepath}." + ) + if status_tracker.num_rate_limit_errors > 0: + logging.warning( + f"{status_tracker.num_rate_limit_errors} rate limit errors received. Consider running at a lower rate." + ) + + +@dataclass +class StatusTracker: + """Stores metadata about the script's progress. Only one instance is created.""" + + num_tasks_already_completed: int = 0 + num_tasks_started: int = 0 + num_tasks_in_progress: int = 0 # script ends when this reaches 0 + num_tasks_succeeded: int = 0 + num_tasks_failed: int = 0 + num_rate_limit_errors: int = 0 + num_api_errors: int = 0 # excluding rate limit errors, counted above + num_other_errors: int = 0 + time_of_last_rate_limit_error: int = 0 # used to cool off after hitting rate limits + + +@dataclass +class APIRequest: + """Stores an API request's inputs, outputs, and other metadata. Contains a method to make an API call.""" + + task_id: int + request_json: dict + token_consumption: int + attempts_left: int + metadata: dict + result: list = field(default_factory=list) + + async def call_api( + self, + session: aiohttp.ClientSession, + request_url: str, + request_header: dict, + retry_queue: asyncio.Queue, + save_filepath: str, + status_tracker: StatusTracker, + get_generic_response: Callable[[list], dict], + ) -> None: + """Calls the OpenAI API and saves results.""" + logging.debug(f"Starting request #{self.task_id}") + error = None + try: + async with session.post( + url=request_url, headers=request_header, json=self.request_json + ) as response: + response = await response.json() + if "error" in response: + logging.warning( + f"Request {self.task_id} failed with error {response['error']}" + ) + status_tracker.num_api_errors += 1 + error = response + if "rate limit" in response["error"].get("message", "").lower(): + status_tracker.time_of_last_rate_limit_error = time.time() + status_tracker.num_rate_limit_errors += 1 + status_tracker.num_api_errors -= ( + 1 # rate limit errors are counted separately + ) + + except ( + Exception + ) as e: # catching naked exceptions is bad practice, but in this case we'll log & save them + logging.warning( + f"Request {self.task_id} failed with Exception {e}, attempts left {self.attempts_left}" + ) + status_tracker.num_other_errors += 1 + error = e + if error: + self.result.append(error) + if self.attempts_left: + retry_queue.put_nowait(self) + else: + logging.error( + f"Request {self.request_json} failed after all attempts. Saving errors: {self.result}" + ) + data = GenericResponse( + request=self.request_json, + errors=[str(e) for e in self.result], + row=self.metadata["sample"], + row_idx=self.metadata["request_idx"], + ) + append_generic_response(data, save_filepath) + status_tracker.num_tasks_in_progress -= 1 + status_tracker.num_tasks_failed += 1 + else: + data = get_generic_response( + {"response": response, "metadata": self.metadata} + ) + data.raw_response = response + data.request = self.request_json + append_generic_response(data, save_filepath) + status_tracker.num_tasks_in_progress -= 1 + status_tracker.num_tasks_succeeded += 1 + logging.debug(f"Request {self.task_id} saved to {save_filepath}") + + +def get_token_encoding_name(model: str) -> str: + """Get the token encoding name for a given model.""" + if "gpt" in model: + return tiktoken.encoding_for_model(model).name + else: + logging.warning( + f'Token encoding name for model "{model}" not implemented, using cl100k_base for token counting' + ) + return "cl100k_base" + + +def get_rate_limits(model: str, request_url: str, api_key: str) -> Tuple[int, int]: + """ + Function to get rate limits for a given annotator. Makes a single request to openAI API + and gets the rate limits from the response headers. These rate limits vary per model + and are determined by your organization's usage tier. View the following: + https://platform.openai.com/docs/guides/rate-limits/usage-tiers + https://platform.openai.com/settings/organization/limits + + Args: + model (str): The model for which to get the rate limits. + request_url (str): The request URL for which to get the rate limits. + + Returns: + Tuple[int, int]: The maximum number of requests and tokens per minute. + """ + if "api.openai.com" in request_url: + # Send a dummy request to get rate limit information + response = requests.post( + request_url, + headers={"Authorization": f"Bearer {api_key}"}, + json={"model": model, "messages": []}, + ) + # Extract rate limit information from headers + max_requests = int(response.headers.get("x-ratelimit-limit-requests", 30_000)) + max_tokens = int(response.headers.get("x-ratelimit-limit-tokens", 150_000_000)) + elif "api.sambanova.ai" in request_url: + # Send a dummy request to get rate limit information + max_requests = 50 + max_tokens = 100_000_000 + else: + raise NotImplementedError( + f'Rate limits for API endpoint "{request_url}" not implemented' + ) + + return max_requests, max_tokens + + +def get_api_key(request_url: str) -> str: + """Get the API key for a given request URL.""" + if "api.openai.com" in request_url: + return os.getenv("OPENAI_API_KEY") + elif "api.sambanova.ai" in request_url: + return os.getenv("SAMBANOVA_API_KEY") + else: + raise NotImplementedError( + f'Default API key environment variable for API endpoint "{request_url}" not implemented' + ) + + +def api_endpoint_from_url(request_url: str) -> str: + """Extract the API endpoint from the request URL. + This is used to determine the number of tokens consumed by the request. + """ + + # OpenAI API + match = re.search("^https://[^/]+/v\\d+/(.+)$", request_url) + if match: + return match[1] + + # for Azure OpenAI deployment urls + match = re.search( + r"^https://[^/]+/openai/deployments/[^/]+/(.+?)(\?|$)", request_url + ) + if match: + return match[1] + + # Catch all for other API endpoints using OpenAI OpenAPI format + if "chat/completions" in request_url: + return "chat/completions" + elif "completions" in request_url: + return "completions" + else: + raise NotImplementedError( + f'API endpoint "{request_url}" not implemented in this script' + ) + + +def append_generic_response(data: GenericResponse, filename: str) -> None: + """Append a json payload to the end of a jsonl file.""" + json_string = json.dumps(data.model_dump()) + with open(filename, "a") as f: + f.write(json_string + "\n") + + +def num_tokens_consumed_from_request( + request_json: dict, + api_endpoint: str, + token_encoding_name: str, +): + """Count the number of tokens in the request. Only supports completion and embedding requests.""" + encoding = tiktoken.get_encoding(token_encoding_name) + # if completions request, tokens = prompt + n * max_tokens + if api_endpoint.endswith("completions"): + max_tokens = request_json.get("max_tokens", 15) + n = request_json.get("n", 1) + completion_tokens = n * max_tokens + + # chat completions + if api_endpoint.startswith("chat/"): + num_tokens = 0 + for message in request_json["messages"]: + num_tokens += 4 # every message follows {role/name}\n{content}\n + for key, value in message.items(): + num_tokens += len(encoding.encode(value)) + if key == "name": # if there's a name, the role is omitted + num_tokens -= 1 # role is always required and always 1 token + num_tokens += 2 # every reply is primed with assistant + return num_tokens + completion_tokens + # normal completions + else: + prompt = request_json["prompt"] + if isinstance(prompt, str): # single prompt + prompt_tokens = len(encoding.encode(prompt)) + num_tokens = prompt_tokens + completion_tokens + return num_tokens + elif isinstance(prompt, list): # multiple prompts + prompt_tokens = sum([len(encoding.encode(p)) for p in prompt]) + num_tokens = prompt_tokens + completion_tokens * len(prompt) + return num_tokens + else: + raise TypeError( + 'Expecting either string or list of strings for "prompt" field in completion request' + ) + # if embeddings request, tokens = input tokens + elif api_endpoint == "embeddings": + input = request_json["input"] + if isinstance(input, str): # single input + num_tokens = len(encoding.encode(input)) + return num_tokens + elif isinstance(input, list): # multiple inputs + num_tokens = sum([len(encoding.encode(i)) for i in input]) + return num_tokens + else: + raise TypeError( + 'Expecting either string or list of strings for "inputs" field in embedding request' + ) + # more logic needed to support other API calls (e.g., edits, inserts, DALL-E) + else: + raise NotImplementedError( + f'API endpoint "{api_endpoint}" not implemented in this script' + ) + + +def task_id_generator_function(): + """Generate integers 0, 1, 2, and so on.""" + task_id = 0 + while True: + yield task_id + task_id += 1 diff --git a/tests/test_prompt.py b/tests/test_prompt.py index 7338e5da..848c7623 100644 --- a/tests/test_prompt.py +++ b/tests/test_prompt.py @@ -22,11 +22,13 @@ def prompter() -> Prompter: Returns: PromptCaller: A configured prompt caller instance. """ + def prompt_func(row): return { "user_prompt": f"Context: {row['context']} Answer this question: {row['question']}", "system_prompt": "You are a helpful assistant.", } + return Prompter( model_name="gpt-4o-mini", prompt_func=prompt_func, @@ -54,7 +56,7 @@ def test_completions(prompter: Prompter, tmp_path): result_dataset = prompter(dataset) result_dataset = result_dataset.to_huggingface() - + # Assertions assert len(result_dataset) == len(dataset) assert "message" in result_dataset.column_names From e101f48083e677d7625381c35af9f14f2e4bc41e Mon Sep 17 00:00:00 2001 From: Ryan Marten Date: Thu, 7 Nov 2024 13:26:13 -0800 Subject: [PATCH 04/18] update Dataset.to_huggingface to read multiple response files for batches --- src/bespokelabs/curator/dataset.py | 41 +++++++++---------- .../base_request_processor.py | 5 +-- 2 files changed, 21 insertions(+), 25 deletions(-) diff --git a/src/bespokelabs/curator/dataset.py b/src/bespokelabs/curator/dataset.py index 25255e95..24d9d616 100644 --- a/src/bespokelabs/curator/dataset.py +++ b/src/bespokelabs/curator/dataset.py @@ -1,12 +1,14 @@ import json import logging import os -from typing import Any, Dict, Iterable, Iterator, List, TypeVar +import glob import pandas as pd -from datasets import Dataset as HFDataset -from datasets.arrow_writer import ArrowWriter + from pydantic import BaseModel +from datasets import Dataset as HFDataset +from datasets.arrow_writer import ArrowWriter, SchemaInferenceError +from typing import Any, Dict, Iterable, Iterator, List, TypeVar from bespokelabs.curator.prompter.prompt_formatter import PromptFormatter from bespokelabs.curator.request_processor.generic_response import GenericResponse @@ -77,24 +79,24 @@ def to_huggingface(self, in_memory: bool = False) -> None: os.makedirs(self.working_dir, exist_ok=True) dataset_file = f"{self.working_dir}/dataset.arrow" - response_file = f"{self.working_dir}/responses.jsonl" + responses_files = glob.glob(f"{self.working_dir}/responses_*.jsonl") + if len(responses_files) == 0: + raise ValueError( + f"No responses files found in {self.working_dir}, can't construct dataset" + ) # Process all response files with ArrowWriter(path=dataset_file) as writer: - if not os.path.exists(response_file): - raise ValueError(f"Responses file {response_file} does not exist") - - with open(response_file, "r") as f_in: - for line in f_in: - total_responses_count += 1 - try: + for responses_file in responses_files: + with open(responses_file, "r") as f_in: + for line in f_in: + total_responses_count += 1 response = GenericResponse.model_validate_json(line) if self.prompt_formatter.response_format: response.response = self.prompt_formatter.response_format( **response.response ) - # TODO(Ryan): We can make this more sophisticated by making response_generic a class if response is None: failed_responses_count += 1 continue @@ -109,15 +111,8 @@ def to_huggingface(self, in_memory: bool = False) -> None: for row in dataset_rows: if isinstance(row, BaseModel): row = row.model_dump() - # NOTE(Ryan): This throws a strange error if there are null values in the row writer.write(row) - # NOTE(Ryan): Catching naked exceptions is bad practice, but this prevents the program from crashing - # TODO(Ryan): Add in handling for specific exceptions as they come up - except Exception as e: - logging.warning(f"Error: {e}\nFull response: {response}") - continue - logging.info( f"Read {total_responses_count} responses, {failed_responses_count} failed" ) @@ -126,8 +121,12 @@ def to_huggingface(self, in_memory: bool = False) -> None: if failed_responses_count == total_responses_count: raise ValueError("All requests failed") - # NOTE(Ryan): This throws an error if all rows were None # TODO(Ryan): Look at what this file looks like before finalize. What happens during finalize? - writer.finalize() + try: + writer.finalize() + except SchemaInferenceError as e: + raise ValueError( + "Arrow writer is complaining about the schema: likely all of your parsed rows were None and writer.write only wrote None objects." + ) from e return HFDataset.from_file(dataset_file, in_memory=in_memory) diff --git a/src/bespokelabs/curator/request_processor/base_request_processor.py b/src/bespokelabs/curator/request_processor/base_request_processor.py index b8e91ffc..b9d496b1 100644 --- a/src/bespokelabs/curator/request_processor/base_request_processor.py +++ b/src/bespokelabs/curator/request_processor/base_request_processor.py @@ -1,13 +1,10 @@ import json import logging import os -import asyncio import glob -import aiofiles from abc import ABC, abstractmethod -from typing import Dict, List, Optional -from math import ceil +from typing import Optional from bespokelabs.curator.dataset import Dataset from bespokelabs.curator.prompter.prompt_formatter import PromptFormatter From d289ca908060c265cd5dc14c744c2d6ded0b2a7b Mon Sep 17 00:00:00 2001 From: Ryan Marten Date: Thu, 7 Nov 2024 15:39:39 -0800 Subject: [PATCH 05/18] need to have random access to map back from idx, assuming dataset fits on disk --- src/bespokelabs/__init__.py | 9 ++ src/bespokelabs/curator/prompter/prompter.py | 30 +++-- .../request_processor/generic_response.py | 2 +- .../openai_batch_request_processor.py | 122 ++++++++++-------- .../openai_online_request_processor.py | 36 +++--- tests/test_batch.py | 97 ++++++++++++++ 6 files changed, 217 insertions(+), 79 deletions(-) create mode 100644 tests/test_batch.py diff --git a/src/bespokelabs/__init__.py b/src/bespokelabs/__init__.py index e69de29b..7f7ae402 100644 --- a/src/bespokelabs/__init__.py +++ b/src/bespokelabs/__init__.py @@ -0,0 +1,9 @@ +from bespokelabs.curator.request_processor import ( + OpenAIBatchRequestProcessor, + OpenAIOnlineRequestProcessor, +) + +__all__ = [ + "OpenAIBatchRequestProcessor", + "OpenAIOnlineRequestProcessor", +] diff --git a/src/bespokelabs/curator/prompter/prompter.py b/src/bespokelabs/curator/prompter/prompter.py index f4c62dce..4ffba335 100644 --- a/src/bespokelabs/curator/prompter/prompter.py +++ b/src/bespokelabs/curator/prompter/prompter.py @@ -18,7 +18,7 @@ from bespokelabs.curator.prompter.prompt_formatter import PromptFormatter from bespokelabs.curator.request_processor.generic_request import GenericRequest from bespokelabs.curator.request_processor.openai_batch_request_processor import ( - OpenAIRequestProcessor, + OpenAIBatchRequestProcessor, OpenAIOnlineRequestProcessor, ) from bespokelabs.curator.request_processor.base_request_processor import ( @@ -41,9 +41,6 @@ def __init__( ] ] = None, response_format: Optional[Type[BaseModel]] = None, - request_processor: Optional[ - BaseRequestProcessor - ] = OpenAIOnlineRequestProcessor, ): """Initialize a Prompter. @@ -75,12 +72,23 @@ def __init__( self.request_processor = request_processor - def __call__(self, dataset: Optional[Iterable] = None): + def __call__( + self, + dataset: Optional[Iterable] = None, + request_processor: Optional[ + BaseRequestProcessor + ] = OpenAIOnlineRequestProcessor, + ): """Run completions on a dataset.""" - return self._completions(dataset) + return self._completions(dataset, request_processor) def _completions( - self, dataset: Optional[Iterable] = None, name: Optional[str] = None + self, + dataset: Optional[Iterable] = None, + name: Optional[str] = None, + request_processor: Optional[ + BaseRequestProcessor + ] = OpenAIOnlineRequestProcessor, ) -> "Dataset": """ Apply structured completions in parallel to a dataset using specified model and @@ -154,10 +162,16 @@ def _completions( } metadata_db.store_metadata(metadata_dict) - return self.request_processor.run( + # TODO(Ryan): do the response processing, while context of original dataset is available and need random access via row_idx) + response_files = request_processor.run( dataset, f"{curator_cache_dir}/{fingerprint}", self.prompt_formatter ) + # NOTE(Ryan): If we decide to allow user to provide any iterable as input dataset (and doens't have random access via row_idx), we can do it differently. This might be a little slower. + # https://huggingface.co/docs/datasets/v3.1.0/about_mapstyle_vs_iterable + # What we can do is sort the generic responses_i.jsonl files by row_idx (in parallel across patches), then just iterate over the dataset and create the new rows in order with the responses. + # To do this, we need to write generic responses instead API_specific responses. Requests can be api-specific, we don't care since we are not monitoring there. + def _hash_chunk(chunks: list) -> list: """Hash a chunk of data.""" diff --git a/src/bespokelabs/curator/request_processor/generic_response.py b/src/bespokelabs/curator/request_processor/generic_response.py index 5ce7971c..551d8622 100644 --- a/src/bespokelabs/curator/request_processor/generic_response.py +++ b/src/bespokelabs/curator/request_processor/generic_response.py @@ -7,6 +7,6 @@ class GenericResponse(BaseModel): response: Optional[Dict[str, Any]] | str = None request: Optional[Dict[str, Any]] = None errors: Optional[List[str]] = None - row: Dict[str, Any] + row: Optional[Dict[str, Any]] = None row_idx: int raw_response: Optional[Dict[str, Any]] = None diff --git a/src/bespokelabs/curator/request_processor/openai_batch_request_processor.py b/src/bespokelabs/curator/request_processor/openai_batch_request_processor.py index 32d495c0..ab5e757d 100644 --- a/src/bespokelabs/curator/request_processor/openai_batch_request_processor.py +++ b/src/bespokelabs/curator/request_processor/openai_batch_request_processor.py @@ -22,15 +22,19 @@ T = TypeVar("T") -class OpenAIRequestProcessor(BaseRequestProcessor): +class OpenAIBatchRequestProcessor(BaseRequestProcessor): model: str url: str = "https://api.openai.com/v1/chat/completions" api_key: str = os.getenv("OPENAI_API_KEY") + batch_size: int = 1000 + check_interval: int = 60 def get_rate_limits(self) -> dict: """ - Function to get rate limits for a given annotator. Makes a single request to openAI API - and gets the rate limits from the response headers. These rate limits vary per model + Function to get rate limits for a given annotator. Not available via response headers, so + the following is based on tier 5 limits on Nov 6th, 2024. + + These rate limits vary per model and are determined by your organization's usage tier. View the following: https://platform.openai.com/docs/guides/rate-limits/usage-tiers https://platform.openai.com/settings/organization/limits @@ -42,30 +46,30 @@ def get_rate_limits(self) -> dict: Returns: tuple[int, int]: A tuple containing the maximum number of requests and tokens per minute. """ - # Send a dummy request to get rate limit information - response = requests.post( - self.url, - headers={"Authorization": f"Bearer {self.api_key}"}, - json={"model": self.prompt_formatter.model_name, "messages": []}, - ) - - rpm = int(response.headers.get("x-ratelimit-limit-requests", 0)) - tpm = int(response.headers.get("x-ratelimit-limit-tokens", 0)) + model_tpd = { + "gpt-3.5-turbo": 5_000_000_000, + "gpt-3.5-turbo-0125": 5_000_000_000, + "gpt-3.5-turbo-1106": 5_000_000_000, + "gpt-3.5-turbo-16k": 5_000_000_000, + "gpt-3.5-turbo-instruct": 200_000, + "gpt-3.5-turbo-instruct-0914": 200_000, + "gpt-4": 150_000_000, + "gpt-4-0613": 150_000_000, + "gpt-4-turbo": 300_000_000, + "gpt-4o": 10_000_000_000, + "gpt-4o-mini": 15_000_000_000, + } - if not rpm or not tpm: - logging.warning( - "Failed to get rate limits from OpenAI API, using default values" - ) - rpm = 30_000 - tpm = 150_000_000 + if self.model not in model_tpd: + tpd = 1_000_000_000 + else: + tpd = model_tpd[self.model] - logging.info(f"Automatically set max_requests_per_minute to {rpm}") - logging.info(f"Automatically set max_tokens_per_minute to {tpm}") + logging.info( + f"Automatically set max_tokens_per_day to {tpd}, model: {self.model} " + ) - rate_limits = { - "max_requests_per_minute": rpm, - "max_tokens_per_minute": tpm, - } + rate_limits = {"max_tokens_per_day": tpd} return rate_limits @@ -79,8 +83,9 @@ def create_api_specific_request(self, generic_request: GenericRequest) -> dict: Returns: dict: API specific request body """ + # NOTE(Ryan): We can have a shared place that creates the body (since it is the same for both online and batch). if generic_request.response_format: - request = { + body = { "model": generic_request.model, "messages": generic_request.messages, "response_format": { @@ -92,21 +97,20 @@ def create_api_specific_request(self, generic_request: GenericRequest) -> dict: "schema": generic_request.response_format.model_json_schema(), }, }, - "metadata": { - "request_idx": generic_request.row_idx, - "sample": generic_request.row, - }, } else: - request = { + body = { "model": generic_request.model, "messages": generic_request.messages, - "metadata": { - "request_idx": generic_request.row_idx, - "sample": generic_request.row, - }, } + request = { + "custom_id": str(generic_request.row_idx), + "method": "POST", + "url": "/v1/chat/completions", + "body": body, + } + return request def get_generic_response(self, response: Dict) -> GenericResponse: @@ -123,14 +127,28 @@ def get_generic_response(self, response: Dict) -> GenericResponse: Returns: dict: Generic response body with an extra field "metadata" which contains the original dataset row or the index of the row in the original dataset """ - content = response["response"]["choices"][0]["message"]["content"] + + request_id = response["id"] + status_code = response["response"]["status_code"] + + # TODO(Ryan): Add error handling + if status_code != 200: + logging.warning( + f"Request {request_id} failed with status code {status_code}" + ) + return None + + # NOTE(Ryan): can we actually parse the response into a an OpenAI ChatCompletions object? Easier to access fields? + # TODO(Ryan): if you add token tokens to generic response, we can parse that here too, similar to my comment above we can do that in the shared place. + content = response["response"]["body"]["choices"][0]["message"]["content"] + if self.prompt_formatter.response_format: content = json.loads(content) return GenericResponse( response=content, - row=response["metadata"]["sample"], - row_idx=response["metadata"]["request_idx"], + row_idx=int(response["custom_id"]), + raw_response=response, ) def run( @@ -148,8 +166,10 @@ def run( Returns: Dataset: Completed dataset """ - requests_file = self.create_request_files(dataset, working_dir) - responses_file = f"{working_dir}/responses.jsonl" + requests_files = self.create_request_files(dataset, working_dir) + responses_files = [ + f"{working_dir}/responses_{i}.jsonl" for i in range(len(requests_files)) + ] rate_limits = self.get_rate_limits() rpm = rate_limits["max_requests_per_minute"] @@ -160,19 +180,19 @@ def run( # NOTE(Ryan): If you wanted to do this on batches, you could run a for loop here about request_files. Although I don't recommend it because you are waiting for straggler requests to finish for each batch. # NOTE(Ryan): And if you wanted to do batches in parallel, you would have to divide rpm and tpm by the number of parallel batches. # TODO(Ryan): Can we abstract retries from process_api_requests_from_file so you can use it even if you use liteLLM. - i = 0 - asyncio.run( - self.process_api_requests_from_file( - requests_filepath=requests_file, - save_filepath=responses_file, - request_url=self.url, - max_requests_per_minute=rpm, - max_tokens_per_minute=tpm, - token_encoding_name=token_encoding_name, - max_attempts=5, - resume=True, # detects existing jobs and resume from there + for i in range(len(requests_files)): + asyncio.run( + self.process_api_requests_from_file( + requests_filepath=requests_files[i], + save_filepath=responses_files[i], + request_url=self.url, + max_requests_per_minute=rpm, + max_tokens_per_minute=tpm, + token_encoding_name=token_encoding_name, + max_attempts=5, + resume=True, # detects existing jobs and resume from there + ) ) - ) return Dataset.from_working_dir(working_dir, self.prompt_formatter) diff --git a/src/bespokelabs/curator/request_processor/openai_online_request_processor.py b/src/bespokelabs/curator/request_processor/openai_online_request_processor.py index 32d495c0..5306b7d2 100644 --- a/src/bespokelabs/curator/request_processor/openai_online_request_processor.py +++ b/src/bespokelabs/curator/request_processor/openai_online_request_processor.py @@ -22,7 +22,7 @@ T = TypeVar("T") -class OpenAIRequestProcessor(BaseRequestProcessor): +class OpenAIOnlineRequestProcessor(BaseRequestProcessor): model: str url: str = "https://api.openai.com/v1/chat/completions" api_key: str = os.getenv("OPENAI_API_KEY") @@ -86,8 +86,7 @@ def create_api_specific_request(self, generic_request: GenericRequest) -> dict: "response_format": { "type": "json_schema", "json_schema": { - # TODO(ryan): not sure if this should be something else. - # TODO(ryan): also not sure if we should use strict: True + # TODO(ryan): not sure if we should use strict: True or have name: be something else. "name": "output_schema", "schema": generic_request.response_format.model_json_schema(), }, @@ -148,8 +147,8 @@ def run( Returns: Dataset: Completed dataset """ - requests_file = self.create_request_files(dataset, working_dir) - responses_file = f"{working_dir}/responses.jsonl" + requests_files = self.create_request_files(dataset, working_dir) + responses_files = [] rate_limits = self.get_rate_limits() rpm = rate_limits["max_requests_per_minute"] @@ -160,21 +159,20 @@ def run( # NOTE(Ryan): If you wanted to do this on batches, you could run a for loop here about request_files. Although I don't recommend it because you are waiting for straggler requests to finish for each batch. # NOTE(Ryan): And if you wanted to do batches in parallel, you would have to divide rpm and tpm by the number of parallel batches. # TODO(Ryan): Can we abstract retries from process_api_requests_from_file so you can use it even if you use liteLLM. - i = 0 - asyncio.run( - self.process_api_requests_from_file( - requests_filepath=requests_file, - save_filepath=responses_file, - request_url=self.url, - max_requests_per_minute=rpm, - max_tokens_per_minute=tpm, - token_encoding_name=token_encoding_name, - max_attempts=5, - resume=True, # detects existing jobs and resume from there + for requests_file in requests_files: + asyncio.run( + self.process_api_requests_from_file( + requests_filepath=requests_file, + save_filepath=responses_file, + request_url=self.url, + max_requests_per_minute=rpm, + max_tokens_per_minute=tpm, + token_encoding_name=token_encoding_name, + max_attempts=5, + resume=True, # detects existing jobs and resume from there + ) ) - ) - - return Dataset.from_working_dir(working_dir, self.prompt_formatter) + return responses_files async def process_api_requests_from_file( self, diff --git a/tests/test_batch.py b/tests/test_batch.py new file mode 100644 index 00000000..44251008 --- /dev/null +++ b/tests/test_batch.py @@ -0,0 +1,97 @@ +from bespokelabs import curator +from bespokelabs.curator import OpenAIBatchRequestProcessor +from datasets import load_dataset, Dataset +import argparse + + +def convert_ShareGPT_to_IT_format(dataset: Dataset) -> Dataset: + def it_from_sharegpt(sample): + if sample["conversations"][0]["from"] == "human": + instruction = sample["conversations"][0]["value"] + assert sample["conversations"][1]["from"] == "gpt" + response = sample["conversations"][1]["value"] + elif sample["conversations"][1]["from"] == "human": + instruction = sample["conversations"][1]["value"] + assert sample["conversations"][2]["from"] == "gpt" + response = sample["conversations"][2]["value"] + else: + raise ValueError("Invalid conversation format") + return {"instruction": instruction, "original_response": response} + + dataset = dataset.map(it_from_sharegpt) + dataset = dataset.remove_columns(["conversations"]) + dataset = dataset.select_columns(["instruction", "original_response"]) + return dataset + + +def load_ShareGPT_dataset_as_IT(dataset_name: str, truncate: int = None) -> Dataset: + dataset = load_dataset(dataset_name, split="train") + if truncate is not None: + dataset = dataset.select(range(truncate)) + return convert_ShareGPT_to_IT_format(dataset) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--working_dir", + type=str, + required=True, + help="Where requests, responses, and dataset will be locally written to save intermediate results", + ) + parser.add_argument( + "--num_samples", + type=int, + default=10, + help="The number of samples to use from the dataset", + ) + parser.add_argument( + "--batch_size", + type=int, + default=None, + help="The number of samples to use per batch", + ) + parser.add_argument( + "--type", + type=str, + default="online", + help="The type of API to use", + ) + parser.add_argument( + "--check_interval", + type=int, + default=10, + help="The interval (in seconds) to check the status of the batch", + ) + args = parser.parse_args() + + # Load the dataset to instruction, response columns + dataset = load_ShareGPT_dataset_as_IT("teknium/OpenHermes-2.5") + print(dataset) + + dataset = dataset.select(range(args.num_samples)) + + # if args.type == "online": + # api = OpenAIOnlineAPI(model="gpt-4o-mini") + # elif args.type == "batch": + # api = OpenAIBatchAPI(model="gpt-4o-mini", check_interval=args.check_interval) + reannotate_prompter = curator.Prompter( + prompt_func=lambda row: {"user_prompt": row["instruction"]}, + parse_func=lambda row, response: {**row, "model_response": response}, + model_name="gpt-4o-mini", + ) + + request_processor = OpenAIBatchRequestProcessor( + model="gpt-4o-mini", + batch_size=args.batch_size, + check_interval=args.check_interval, + ) + reannotated_dataset = reannotate_prompter(dataset, request_processor) + + dataset = reannotated_dataset.to_huggingface() + + # Upload dataset to Hugging Face + print(dataset) + dataset_name = "mlfoundations-dev/rewrite-test-gpt-4o-mini" + dataset.push_to_hub(dataset_name) + print(f"https://huggingface.co/datasets/{dataset_name}") From b4311f8b2d32f4df89d145cf350d591e15ee99d4 Mon Sep 17 00:00:00 2001 From: Ryan Marten Date: Thu, 7 Nov 2024 16:34:45 -0800 Subject: [PATCH 06/18] add in BatchWater class for OpenAIBatchRequestProcessor --- poetry.lock | 464 ++++++++++- pyproject.toml | 1 + .../openai_batch_request_processor.py | 746 ++++++------------ tests/test_batch.py | 1 + 4 files changed, 685 insertions(+), 527 deletions(-) diff --git a/poetry.lock b/poetry.lock index 4d0a3c57..0f3fdb73 100644 --- a/poetry.lock +++ b/poetry.lock @@ -508,6 +508,90 @@ files = [ {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"}, ] +[[package]] +name = "contourpy" +version = "1.3.0" +description = "Python library for calculating contours of 2D quadrilateral grids" +optional = false +python-versions = ">=3.9" +files = [ + {file = "contourpy-1.3.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:880ea32e5c774634f9fcd46504bf9f080a41ad855f4fef54f5380f5133d343c7"}, + {file = "contourpy-1.3.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:76c905ef940a4474a6289c71d53122a4f77766eef23c03cd57016ce19d0f7b42"}, + {file = "contourpy-1.3.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:92f8557cbb07415a4d6fa191f20fd9d2d9eb9c0b61d1b2f52a8926e43c6e9af7"}, + {file = "contourpy-1.3.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:36f965570cff02b874773c49bfe85562b47030805d7d8360748f3eca570f4cab"}, + {file = "contourpy-1.3.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:cacd81e2d4b6f89c9f8a5b69b86490152ff39afc58a95af002a398273e5ce589"}, + {file = "contourpy-1.3.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:69375194457ad0fad3a839b9e29aa0b0ed53bb54db1bfb6c3ae43d111c31ce41"}, + {file = "contourpy-1.3.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:7a52040312b1a858b5e31ef28c2e865376a386c60c0e248370bbea2d3f3b760d"}, + {file = "contourpy-1.3.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:3faeb2998e4fcb256542e8a926d08da08977f7f5e62cf733f3c211c2a5586223"}, + {file = "contourpy-1.3.0-cp310-cp310-win32.whl", hash = "sha256:36e0cff201bcb17a0a8ecc7f454fe078437fa6bda730e695a92f2d9932bd507f"}, + {file = "contourpy-1.3.0-cp310-cp310-win_amd64.whl", hash = "sha256:87ddffef1dbe5e669b5c2440b643d3fdd8622a348fe1983fad7a0f0ccb1cd67b"}, + {file = "contourpy-1.3.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:0fa4c02abe6c446ba70d96ece336e621efa4aecae43eaa9b030ae5fb92b309ad"}, + {file = "contourpy-1.3.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:834e0cfe17ba12f79963861e0f908556b2cedd52e1f75e6578801febcc6a9f49"}, + {file = "contourpy-1.3.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:dbc4c3217eee163fa3984fd1567632b48d6dfd29216da3ded3d7b844a8014a66"}, + {file = "contourpy-1.3.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4865cd1d419e0c7a7bf6de1777b185eebdc51470800a9f42b9e9decf17762081"}, + {file = "contourpy-1.3.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:303c252947ab4b14c08afeb52375b26781ccd6a5ccd81abcdfc1fafd14cf93c1"}, + {file = "contourpy-1.3.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:637f674226be46f6ba372fd29d9523dd977a291f66ab2a74fbeb5530bb3f445d"}, + {file = "contourpy-1.3.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:76a896b2f195b57db25d6b44e7e03f221d32fe318d03ede41f8b4d9ba1bff53c"}, + {file = "contourpy-1.3.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:e1fd23e9d01591bab45546c089ae89d926917a66dceb3abcf01f6105d927e2cb"}, + {file = "contourpy-1.3.0-cp311-cp311-win32.whl", hash = "sha256:d402880b84df3bec6eab53cd0cf802cae6a2ef9537e70cf75e91618a3801c20c"}, + {file = "contourpy-1.3.0-cp311-cp311-win_amd64.whl", hash = "sha256:6cb6cc968059db9c62cb35fbf70248f40994dfcd7aa10444bbf8b3faeb7c2d67"}, + {file = "contourpy-1.3.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:570ef7cf892f0afbe5b2ee410c507ce12e15a5fa91017a0009f79f7d93a1268f"}, + {file = "contourpy-1.3.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:da84c537cb8b97d153e9fb208c221c45605f73147bd4cadd23bdae915042aad6"}, + {file = "contourpy-1.3.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0be4d8425bfa755e0fd76ee1e019636ccc7c29f77a7c86b4328a9eb6a26d0639"}, + {file = "contourpy-1.3.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9c0da700bf58f6e0b65312d0a5e695179a71d0163957fa381bb3c1f72972537c"}, + {file = "contourpy-1.3.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:eb8b141bb00fa977d9122636b16aa67d37fd40a3d8b52dd837e536d64b9a4d06"}, + {file = "contourpy-1.3.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3634b5385c6716c258d0419c46d05c8aa7dc8cb70326c9a4fb66b69ad2b52e09"}, + {file = "contourpy-1.3.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:0dce35502151b6bd35027ac39ba6e5a44be13a68f55735c3612c568cac3805fd"}, + {file = "contourpy-1.3.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:aea348f053c645100612b333adc5983d87be69acdc6d77d3169c090d3b01dc35"}, + {file = "contourpy-1.3.0-cp312-cp312-win32.whl", hash = "sha256:90f73a5116ad1ba7174341ef3ea5c3150ddf20b024b98fb0c3b29034752c8aeb"}, + {file = "contourpy-1.3.0-cp312-cp312-win_amd64.whl", hash = "sha256:b11b39aea6be6764f84360fce6c82211a9db32a7c7de8fa6dd5397cf1d079c3b"}, + {file = "contourpy-1.3.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:3e1c7fa44aaae40a2247e2e8e0627f4bea3dd257014764aa644f319a5f8600e3"}, + {file = "contourpy-1.3.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:364174c2a76057feef647c802652f00953b575723062560498dc7930fc9b1cb7"}, + {file = "contourpy-1.3.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:32b238b3b3b649e09ce9aaf51f0c261d38644bdfa35cbaf7b263457850957a84"}, + {file = "contourpy-1.3.0-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d51fca85f9f7ad0b65b4b9fe800406d0d77017d7270d31ec3fb1cc07358fdea0"}, + {file = "contourpy-1.3.0-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:732896af21716b29ab3e988d4ce14bc5133733b85956316fb0c56355f398099b"}, + {file = "contourpy-1.3.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d73f659398a0904e125280836ae6f88ba9b178b2fed6884f3b1f95b989d2c8da"}, + {file = "contourpy-1.3.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:c6c7c2408b7048082932cf4e641fa3b8ca848259212f51c8c59c45aa7ac18f14"}, + {file = "contourpy-1.3.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:f317576606de89da6b7e0861cf6061f6146ead3528acabff9236458a6ba467f8"}, + {file = "contourpy-1.3.0-cp313-cp313-win32.whl", hash = "sha256:31cd3a85dbdf1fc002280c65caa7e2b5f65e4a973fcdf70dd2fdcb9868069294"}, + {file = "contourpy-1.3.0-cp313-cp313-win_amd64.whl", hash = "sha256:4553c421929ec95fb07b3aaca0fae668b2eb5a5203d1217ca7c34c063c53d087"}, + {file = "contourpy-1.3.0-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:345af746d7766821d05d72cb8f3845dfd08dd137101a2cb9b24de277d716def8"}, + {file = "contourpy-1.3.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:3bb3808858a9dc68f6f03d319acd5f1b8a337e6cdda197f02f4b8ff67ad2057b"}, + {file = "contourpy-1.3.0-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:420d39daa61aab1221567b42eecb01112908b2cab7f1b4106a52caaec8d36973"}, + {file = "contourpy-1.3.0-cp313-cp313t-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4d63ee447261e963af02642ffcb864e5a2ee4cbfd78080657a9880b8b1868e18"}, + {file = "contourpy-1.3.0-cp313-cp313t-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:167d6c890815e1dac9536dca00828b445d5d0df4d6a8c6adb4a7ec3166812fa8"}, + {file = "contourpy-1.3.0-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:710a26b3dc80c0e4febf04555de66f5fd17e9cf7170a7b08000601a10570bda6"}, + {file = "contourpy-1.3.0-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:75ee7cb1a14c617f34a51d11fa7524173e56551646828353c4af859c56b766e2"}, + {file = "contourpy-1.3.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:33c92cdae89ec5135d036e7218e69b0bb2851206077251f04a6c4e0e21f03927"}, + {file = "contourpy-1.3.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:a11077e395f67ffc2c44ec2418cfebed032cd6da3022a94fc227b6faf8e2acb8"}, + {file = "contourpy-1.3.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:e8134301d7e204c88ed7ab50028ba06c683000040ede1d617298611f9dc6240c"}, + {file = "contourpy-1.3.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e12968fdfd5bb45ffdf6192a590bd8ddd3ba9e58360b29683c6bb71a7b41edca"}, + {file = "contourpy-1.3.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:fd2a0fc506eccaaa7595b7e1418951f213cf8255be2600f1ea1b61e46a60c55f"}, + {file = "contourpy-1.3.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4cfb5c62ce023dfc410d6059c936dcf96442ba40814aefbfa575425a3a7f19dc"}, + {file = "contourpy-1.3.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:68a32389b06b82c2fdd68276148d7b9275b5f5cf13e5417e4252f6d1a34f72a2"}, + {file = "contourpy-1.3.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:94e848a6b83da10898cbf1311a815f770acc9b6a3f2d646f330d57eb4e87592e"}, + {file = "contourpy-1.3.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:d78ab28a03c854a873787a0a42254a0ccb3cb133c672f645c9f9c8f3ae9d0800"}, + {file = "contourpy-1.3.0-cp39-cp39-win32.whl", hash = "sha256:81cb5ed4952aae6014bc9d0421dec7c5835c9c8c31cdf51910b708f548cf58e5"}, + {file = "contourpy-1.3.0-cp39-cp39-win_amd64.whl", hash = "sha256:14e262f67bd7e6eb6880bc564dcda30b15e351a594657e55b7eec94b6ef72843"}, + {file = "contourpy-1.3.0-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:fe41b41505a5a33aeaed2a613dccaeaa74e0e3ead6dd6fd3a118fb471644fd6c"}, + {file = "contourpy-1.3.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:eca7e17a65f72a5133bdbec9ecf22401c62bcf4821361ef7811faee695799779"}, + {file = "contourpy-1.3.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:1ec4dc6bf570f5b22ed0d7efba0dfa9c5b9e0431aeea7581aa217542d9e809a4"}, + {file = "contourpy-1.3.0-pp39-pypy39_pp73-macosx_10_15_x86_64.whl", hash = "sha256:00ccd0dbaad6d804ab259820fa7cb0b8036bda0686ef844d24125d8287178ce0"}, + {file = "contourpy-1.3.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8ca947601224119117f7c19c9cdf6b3ab54c5726ef1d906aa4a69dfb6dd58102"}, + {file = "contourpy-1.3.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:c6ec93afeb848a0845a18989da3beca3eec2c0f852322efe21af1931147d12cb"}, + {file = "contourpy-1.3.0.tar.gz", hash = "sha256:7ffa0db17717a8ffb127efd0c95a4362d996b892c2904db72428d5b52e1938a4"}, +] + +[package.dependencies] +numpy = ">=1.23" + +[package.extras] +bokeh = ["bokeh", "selenium"] +docs = ["furo", "sphinx (>=7.2)", "sphinx-copybutton"] +mypy = ["contourpy[bokeh,docs]", "docutils-stubs", "mypy (==1.11.1)", "types-Pillow"] +test = ["Pillow", "contourpy[test-no-images]", "matplotlib"] +test-no-images = ["pytest", "pytest-cov", "pytest-rerunfailures", "pytest-xdist", "wurlitzer"] + [[package]] name = "cryptography" version = "43.0.3" @@ -557,6 +641,21 @@ ssh = ["bcrypt (>=3.1.5)"] test = ["certifi", "cryptography-vectors (==43.0.3)", "pretend", "pytest (>=6.2.0)", "pytest-benchmark", "pytest-cov", "pytest-xdist"] test-randomorder = ["pytest-randomly"] +[[package]] +name = "cycler" +version = "0.12.1" +description = "Composable style cycles" +optional = false +python-versions = ">=3.8" +files = [ + {file = "cycler-0.12.1-py3-none-any.whl", hash = "sha256:85cef7cff222d8644161529808465972e51340599459b8ac3ccbac5a854e0d30"}, + {file = "cycler-0.12.1.tar.gz", hash = "sha256:88bb128f02ba341da8ef447245a9e138fae777f6a23943da4540077d3601eb1c"}, +] + +[package.extras] +docs = ["ipython", "matplotlib", "numpydoc", "sphinx"] +tests = ["pytest", "pytest-cov", "pytest-xdist"] + [[package]] name = "datasets" version = "3.0.2" @@ -702,6 +801,77 @@ docs = ["furo (>=2024.8.6)", "sphinx (>=8.0.2)", "sphinx-autodoc-typehints (>=2. testing = ["covdefaults (>=2.3)", "coverage (>=7.6.1)", "diff-cover (>=9.2)", "pytest (>=8.3.3)", "pytest-asyncio (>=0.24)", "pytest-cov (>=5)", "pytest-mock (>=3.14)", "pytest-timeout (>=2.3.1)", "virtualenv (>=20.26.4)"] typing = ["typing-extensions (>=4.12.2)"] +[[package]] +name = "fonttools" +version = "4.54.1" +description = "Tools to manipulate font files" +optional = false +python-versions = ">=3.8" +files = [ + {file = "fonttools-4.54.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:7ed7ee041ff7b34cc62f07545e55e1468808691dddfd315d51dd82a6b37ddef2"}, + {file = "fonttools-4.54.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:41bb0b250c8132b2fcac148e2e9198e62ff06f3cc472065dff839327945c5882"}, + {file = "fonttools-4.54.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7965af9b67dd546e52afcf2e38641b5be956d68c425bef2158e95af11d229f10"}, + {file = "fonttools-4.54.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:278913a168f90d53378c20c23b80f4e599dca62fbffae4cc620c8eed476b723e"}, + {file = "fonttools-4.54.1-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:0e88e3018ac809b9662615072dcd6b84dca4c2d991c6d66e1970a112503bba7e"}, + {file = "fonttools-4.54.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:4aa4817f0031206e637d1e685251ac61be64d1adef111060df84fdcbc6ab6c44"}, + {file = "fonttools-4.54.1-cp310-cp310-win32.whl", hash = "sha256:7e3b7d44e18c085fd8c16dcc6f1ad6c61b71ff463636fcb13df7b1b818bd0c02"}, + {file = "fonttools-4.54.1-cp310-cp310-win_amd64.whl", hash = "sha256:dd9cc95b8d6e27d01e1e1f1fae8559ef3c02c76317da650a19047f249acd519d"}, + {file = "fonttools-4.54.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:5419771b64248484299fa77689d4f3aeed643ea6630b2ea750eeab219588ba20"}, + {file = "fonttools-4.54.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:301540e89cf4ce89d462eb23a89464fef50915255ece765d10eee8b2bf9d75b2"}, + {file = "fonttools-4.54.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:76ae5091547e74e7efecc3cbf8e75200bc92daaeb88e5433c5e3e95ea8ce5aa7"}, + {file = "fonttools-4.54.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:82834962b3d7c5ca98cb56001c33cf20eb110ecf442725dc5fdf36d16ed1ab07"}, + {file = "fonttools-4.54.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:d26732ae002cc3d2ecab04897bb02ae3f11f06dd7575d1df46acd2f7c012a8d8"}, + {file = "fonttools-4.54.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:58974b4987b2a71ee08ade1e7f47f410c367cdfc5a94fabd599c88165f56213a"}, + {file = "fonttools-4.54.1-cp311-cp311-win32.whl", hash = "sha256:ab774fa225238986218a463f3fe151e04d8c25d7de09df7f0f5fce27b1243dbc"}, + {file = "fonttools-4.54.1-cp311-cp311-win_amd64.whl", hash = "sha256:07e005dc454eee1cc60105d6a29593459a06321c21897f769a281ff2d08939f6"}, + {file = "fonttools-4.54.1-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:54471032f7cb5fca694b5f1a0aaeba4af6e10ae989df408e0216f7fd6cdc405d"}, + {file = "fonttools-4.54.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:8fa92cb248e573daab8d032919623cc309c005086d743afb014c836636166f08"}, + {file = "fonttools-4.54.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0a911591200114969befa7f2cb74ac148bce5a91df5645443371aba6d222e263"}, + {file = "fonttools-4.54.1-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:93d458c8a6a354dc8b48fc78d66d2a8a90b941f7fec30e94c7ad9982b1fa6bab"}, + {file = "fonttools-4.54.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:5eb2474a7c5be8a5331146758debb2669bf5635c021aee00fd7c353558fc659d"}, + {file = "fonttools-4.54.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:c9c563351ddc230725c4bdf7d9e1e92cbe6ae8553942bd1fb2b2ff0884e8b714"}, + {file = "fonttools-4.54.1-cp312-cp312-win32.whl", hash = "sha256:fdb062893fd6d47b527d39346e0c5578b7957dcea6d6a3b6794569370013d9ac"}, + {file = "fonttools-4.54.1-cp312-cp312-win_amd64.whl", hash = "sha256:e4564cf40cebcb53f3dc825e85910bf54835e8a8b6880d59e5159f0f325e637e"}, + {file = "fonttools-4.54.1-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:6e37561751b017cf5c40fce0d90fd9e8274716de327ec4ffb0df957160be3bff"}, + {file = "fonttools-4.54.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:357cacb988a18aace66e5e55fe1247f2ee706e01debc4b1a20d77400354cddeb"}, + {file = "fonttools-4.54.1-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f8e953cc0bddc2beaf3a3c3b5dd9ab7554677da72dfaf46951e193c9653e515a"}, + {file = "fonttools-4.54.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:58d29b9a294573d8319f16f2f79e42428ba9b6480442fa1836e4eb89c4d9d61c"}, + {file = "fonttools-4.54.1-cp313-cp313-win32.whl", hash = "sha256:9ef1b167e22709b46bf8168368b7b5d3efeaaa746c6d39661c1b4405b6352e58"}, + {file = "fonttools-4.54.1-cp313-cp313-win_amd64.whl", hash = "sha256:262705b1663f18c04250bd1242b0515d3bbae177bee7752be67c979b7d47f43d"}, + {file = "fonttools-4.54.1-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:ed2f80ca07025551636c555dec2b755dd005e2ea8fbeb99fc5cdff319b70b23b"}, + {file = "fonttools-4.54.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:9dc080e5a1c3b2656caff2ac2633d009b3a9ff7b5e93d0452f40cd76d3da3b3c"}, + {file = "fonttools-4.54.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1d152d1be65652fc65e695e5619e0aa0982295a95a9b29b52b85775243c06556"}, + {file = "fonttools-4.54.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8583e563df41fdecef31b793b4dd3af8a9caa03397be648945ad32717a92885b"}, + {file = "fonttools-4.54.1-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:0d1d353ef198c422515a3e974a1e8d5b304cd54a4c2eebcae708e37cd9eeffb1"}, + {file = "fonttools-4.54.1-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:fda582236fee135d4daeca056c8c88ec5f6f6d88a004a79b84a02547c8f57386"}, + {file = "fonttools-4.54.1-cp38-cp38-win32.whl", hash = "sha256:e7d82b9e56716ed32574ee106cabca80992e6bbdcf25a88d97d21f73a0aae664"}, + {file = "fonttools-4.54.1-cp38-cp38-win_amd64.whl", hash = "sha256:ada215fd079e23e060157aab12eba0d66704316547f334eee9ff26f8c0d7b8ab"}, + {file = "fonttools-4.54.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:f5b8a096e649768c2f4233f947cf9737f8dbf8728b90e2771e2497c6e3d21d13"}, + {file = "fonttools-4.54.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:4e10d2e0a12e18f4e2dd031e1bf7c3d7017be5c8dbe524d07706179f355c5dac"}, + {file = "fonttools-4.54.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:31c32d7d4b0958600eac75eaf524b7b7cb68d3a8c196635252b7a2c30d80e986"}, + {file = "fonttools-4.54.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c39287f5c8f4a0c5a55daf9eaf9ccd223ea59eed3f6d467133cc727d7b943a55"}, + {file = "fonttools-4.54.1-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:a7a310c6e0471602fe3bf8efaf193d396ea561486aeaa7adc1f132e02d30c4b9"}, + {file = "fonttools-4.54.1-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:d3b659d1029946f4ff9b6183984578041b520ce0f8fb7078bb37ec7445806b33"}, + {file = "fonttools-4.54.1-cp39-cp39-win32.whl", hash = "sha256:e96bc94c8cda58f577277d4a71f51c8e2129b8b36fd05adece6320dd3d57de8a"}, + {file = "fonttools-4.54.1-cp39-cp39-win_amd64.whl", hash = "sha256:e8a4b261c1ef91e7188a30571be6ad98d1c6d9fa2427244c545e2fa0a2494dd7"}, + {file = "fonttools-4.54.1-py3-none-any.whl", hash = "sha256:37cddd62d83dc4f72f7c3f3c2bcf2697e89a30efb152079896544a93907733bd"}, + {file = "fonttools-4.54.1.tar.gz", hash = "sha256:957f669d4922f92c171ba01bef7f29410668db09f6c02111e22b2bce446f3285"}, +] + +[package.extras] +all = ["brotli (>=1.0.1)", "brotlicffi (>=0.8.0)", "fs (>=2.2.0,<3)", "lxml (>=4.0)", "lz4 (>=1.7.4.2)", "matplotlib", "munkres", "pycairo", "scipy", "skia-pathops (>=0.5.0)", "sympy", "uharfbuzz (>=0.23.0)", "unicodedata2 (>=15.1.0)", "xattr", "zopfli (>=0.1.4)"] +graphite = ["lz4 (>=1.7.4.2)"] +interpolatable = ["munkres", "pycairo", "scipy"] +lxml = ["lxml (>=4.0)"] +pathops = ["skia-pathops (>=0.5.0)"] +plot = ["matplotlib"] +repacker = ["uharfbuzz (>=0.23.0)"] +symfont = ["sympy"] +type1 = ["xattr"] +ufo = ["fs (>=2.2.0,<3)"] +unicode = ["unicodedata2 (>=15.1.0)"] +woff = ["brotli (>=1.0.1)", "brotlicffi (>=0.8.0)", "zopfli (>=0.1.4)"] + [[package]] name = "frozenlist" version = "1.5.0" @@ -1300,6 +1470,129 @@ enabler = ["pytest-enabler (>=2.2)"] test = ["pyfakefs", "pytest (>=6,!=8.1.*)"] type = ["pygobject-stubs", "pytest-mypy", "shtab", "types-pywin32"] +[[package]] +name = "kiwisolver" +version = "1.4.7" +description = "A fast implementation of the Cassowary constraint solver" +optional = false +python-versions = ">=3.8" +files = [ + {file = "kiwisolver-1.4.7-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:8a9c83f75223d5e48b0bc9cb1bf2776cf01563e00ade8775ffe13b0b6e1af3a6"}, + {file = "kiwisolver-1.4.7-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:58370b1ffbd35407444d57057b57da5d6549d2d854fa30249771775c63b5fe17"}, + {file = "kiwisolver-1.4.7-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:aa0abdf853e09aff551db11fce173e2177d00786c688203f52c87ad7fcd91ef9"}, + {file = "kiwisolver-1.4.7-cp310-cp310-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:8d53103597a252fb3ab8b5845af04c7a26d5e7ea8122303dd7a021176a87e8b9"}, + {file = "kiwisolver-1.4.7-cp310-cp310-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:88f17c5ffa8e9462fb79f62746428dd57b46eb931698e42e990ad63103f35e6c"}, + {file = "kiwisolver-1.4.7-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:88a9ca9c710d598fd75ee5de59d5bda2684d9db36a9f50b6125eaea3969c2599"}, + {file = "kiwisolver-1.4.7-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f4d742cb7af1c28303a51b7a27aaee540e71bb8e24f68c736f6f2ffc82f2bf05"}, + {file = "kiwisolver-1.4.7-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e28c7fea2196bf4c2f8d46a0415c77a1c480cc0724722f23d7410ffe9842c407"}, + {file = "kiwisolver-1.4.7-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:e968b84db54f9d42046cf154e02911e39c0435c9801681e3fc9ce8a3c4130278"}, + {file = "kiwisolver-1.4.7-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:0c18ec74c0472de033e1bebb2911c3c310eef5649133dd0bedf2a169a1b269e5"}, + {file = "kiwisolver-1.4.7-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:8f0ea6da6d393d8b2e187e6a5e3fb81f5862010a40c3945e2c6d12ae45cfb2ad"}, + {file = "kiwisolver-1.4.7-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:f106407dda69ae456dd1227966bf445b157ccc80ba0dff3802bb63f30b74e895"}, + {file = "kiwisolver-1.4.7-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:84ec80df401cfee1457063732d90022f93951944b5b58975d34ab56bb150dfb3"}, + {file = "kiwisolver-1.4.7-cp310-cp310-win32.whl", hash = "sha256:71bb308552200fb2c195e35ef05de12f0c878c07fc91c270eb3d6e41698c3bcc"}, + {file = "kiwisolver-1.4.7-cp310-cp310-win_amd64.whl", hash = "sha256:44756f9fd339de0fb6ee4f8c1696cfd19b2422e0d70b4cefc1cc7f1f64045a8c"}, + {file = "kiwisolver-1.4.7-cp310-cp310-win_arm64.whl", hash = "sha256:78a42513018c41c2ffd262eb676442315cbfe3c44eed82385c2ed043bc63210a"}, + {file = "kiwisolver-1.4.7-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:d2b0e12a42fb4e72d509fc994713d099cbb15ebf1103545e8a45f14da2dfca54"}, + {file = "kiwisolver-1.4.7-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:2a8781ac3edc42ea4b90bc23e7d37b665d89423818e26eb6df90698aa2287c95"}, + {file = "kiwisolver-1.4.7-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:46707a10836894b559e04b0fd143e343945c97fd170d69a2d26d640b4e297935"}, + {file = "kiwisolver-1.4.7-cp311-cp311-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ef97b8df011141c9b0f6caf23b29379f87dd13183c978a30a3c546d2c47314cb"}, + {file = "kiwisolver-1.4.7-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3ab58c12a2cd0fc769089e6d38466c46d7f76aced0a1f54c77652446733d2d02"}, + {file = "kiwisolver-1.4.7-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:803b8e1459341c1bb56d1c5c010406d5edec8a0713a0945851290a7930679b51"}, + {file = "kiwisolver-1.4.7-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f9a9e8a507420fe35992ee9ecb302dab68550dedc0da9e2880dd88071c5fb052"}, + {file = "kiwisolver-1.4.7-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:18077b53dc3bb490e330669a99920c5e6a496889ae8c63b58fbc57c3d7f33a18"}, + {file = "kiwisolver-1.4.7-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:6af936f79086a89b3680a280c47ea90b4df7047b5bdf3aa5c524bbedddb9e545"}, + {file = "kiwisolver-1.4.7-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:3abc5b19d24af4b77d1598a585b8a719beb8569a71568b66f4ebe1fb0449460b"}, + {file = "kiwisolver-1.4.7-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:933d4de052939d90afbe6e9d5273ae05fb836cc86c15b686edd4b3560cc0ee36"}, + {file = "kiwisolver-1.4.7-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:65e720d2ab2b53f1f72fb5da5fb477455905ce2c88aaa671ff0a447c2c80e8e3"}, + {file = "kiwisolver-1.4.7-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:3bf1ed55088f214ba6427484c59553123fdd9b218a42bbc8c6496d6754b1e523"}, + {file = "kiwisolver-1.4.7-cp311-cp311-win32.whl", hash = "sha256:4c00336b9dd5ad96d0a558fd18a8b6f711b7449acce4c157e7343ba92dd0cf3d"}, + {file = "kiwisolver-1.4.7-cp311-cp311-win_amd64.whl", hash = "sha256:929e294c1ac1e9f615c62a4e4313ca1823ba37326c164ec720a803287c4c499b"}, + {file = "kiwisolver-1.4.7-cp311-cp311-win_arm64.whl", hash = "sha256:e33e8fbd440c917106b237ef1a2f1449dfbb9b6f6e1ce17c94cd6a1e0d438376"}, + {file = "kiwisolver-1.4.7-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:5360cc32706dab3931f738d3079652d20982511f7c0ac5711483e6eab08efff2"}, + {file = "kiwisolver-1.4.7-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:942216596dc64ddb25adb215c3c783215b23626f8d84e8eff8d6d45c3f29f75a"}, + {file = "kiwisolver-1.4.7-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:48b571ecd8bae15702e4f22d3ff6a0f13e54d3d00cd25216d5e7f658242065ee"}, + {file = "kiwisolver-1.4.7-cp312-cp312-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ad42ba922c67c5f219097b28fae965e10045ddf145d2928bfac2eb2e17673640"}, + {file = "kiwisolver-1.4.7-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:612a10bdae23404a72941a0fc8fa2660c6ea1217c4ce0dbcab8a8f6543ea9e7f"}, + {file = "kiwisolver-1.4.7-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9e838bba3a3bac0fe06d849d29772eb1afb9745a59710762e4ba3f4cb8424483"}, + {file = "kiwisolver-1.4.7-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:22f499f6157236c19f4bbbd472fa55b063db77a16cd74d49afe28992dff8c258"}, + {file = "kiwisolver-1.4.7-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:693902d433cf585133699972b6d7c42a8b9f8f826ebcaf0132ff55200afc599e"}, + {file = "kiwisolver-1.4.7-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:4e77f2126c3e0b0d055f44513ed349038ac180371ed9b52fe96a32aa071a5107"}, + {file = "kiwisolver-1.4.7-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:657a05857bda581c3656bfc3b20e353c232e9193eb167766ad2dc58b56504948"}, + {file = "kiwisolver-1.4.7-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:4bfa75a048c056a411f9705856abfc872558e33c055d80af6a380e3658766038"}, + {file = "kiwisolver-1.4.7-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:34ea1de54beef1c104422d210c47c7d2a4999bdecf42c7b5718fbe59a4cac383"}, + {file = "kiwisolver-1.4.7-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:90da3b5f694b85231cf93586dad5e90e2d71b9428f9aad96952c99055582f520"}, + {file = "kiwisolver-1.4.7-cp312-cp312-win32.whl", hash = "sha256:18e0cca3e008e17fe9b164b55735a325140a5a35faad8de92dd80265cd5eb80b"}, + {file = "kiwisolver-1.4.7-cp312-cp312-win_amd64.whl", hash = "sha256:58cb20602b18f86f83a5c87d3ee1c766a79c0d452f8def86d925e6c60fbf7bfb"}, + {file = "kiwisolver-1.4.7-cp312-cp312-win_arm64.whl", hash = "sha256:f5a8b53bdc0b3961f8b6125e198617c40aeed638b387913bf1ce78afb1b0be2a"}, + {file = "kiwisolver-1.4.7-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:2e6039dcbe79a8e0f044f1c39db1986a1b8071051efba3ee4d74f5b365f5226e"}, + {file = "kiwisolver-1.4.7-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:a1ecf0ac1c518487d9d23b1cd7139a6a65bc460cd101ab01f1be82ecf09794b6"}, + {file = "kiwisolver-1.4.7-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:7ab9ccab2b5bd5702ab0803676a580fffa2aa178c2badc5557a84cc943fcf750"}, + {file = "kiwisolver-1.4.7-cp313-cp313-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f816dd2277f8d63d79f9c8473a79fe54047bc0467754962840782c575522224d"}, + {file = "kiwisolver-1.4.7-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cf8bcc23ceb5a1b624572a1623b9f79d2c3b337c8c455405ef231933a10da379"}, + {file = "kiwisolver-1.4.7-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:dea0bf229319828467d7fca8c7c189780aa9ff679c94539eed7532ebe33ed37c"}, + {file = "kiwisolver-1.4.7-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7c06a4c7cf15ec739ce0e5971b26c93638730090add60e183530d70848ebdd34"}, + {file = "kiwisolver-1.4.7-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:913983ad2deb14e66d83c28b632fd35ba2b825031f2fa4ca29675e665dfecbe1"}, + {file = "kiwisolver-1.4.7-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:5337ec7809bcd0f424c6b705ecf97941c46279cf5ed92311782c7c9c2026f07f"}, + {file = "kiwisolver-1.4.7-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:4c26ed10c4f6fa6ddb329a5120ba3b6db349ca192ae211e882970bfc9d91420b"}, + {file = "kiwisolver-1.4.7-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:c619b101e6de2222c1fcb0531e1b17bbffbe54294bfba43ea0d411d428618c27"}, + {file = "kiwisolver-1.4.7-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:073a36c8273647592ea332e816e75ef8da5c303236ec0167196793eb1e34657a"}, + {file = "kiwisolver-1.4.7-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:3ce6b2b0231bda412463e152fc18335ba32faf4e8c23a754ad50ffa70e4091ee"}, + {file = "kiwisolver-1.4.7-cp313-cp313-win32.whl", hash = "sha256:f4c9aee212bc89d4e13f58be11a56cc8036cabad119259d12ace14b34476fd07"}, + {file = "kiwisolver-1.4.7-cp313-cp313-win_amd64.whl", hash = "sha256:8a3ec5aa8e38fc4c8af308917ce12c536f1c88452ce554027e55b22cbbfbff76"}, + {file = "kiwisolver-1.4.7-cp313-cp313-win_arm64.whl", hash = "sha256:76c8094ac20ec259471ac53e774623eb62e6e1f56cd8690c67ce6ce4fcb05650"}, + {file = "kiwisolver-1.4.7-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:5d5abf8f8ec1f4e22882273c423e16cae834c36856cac348cfbfa68e01c40f3a"}, + {file = "kiwisolver-1.4.7-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:aeb3531b196ef6f11776c21674dba836aeea9d5bd1cf630f869e3d90b16cfade"}, + {file = "kiwisolver-1.4.7-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:b7d755065e4e866a8086c9bdada157133ff466476a2ad7861828e17b6026e22c"}, + {file = "kiwisolver-1.4.7-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:08471d4d86cbaec61f86b217dd938a83d85e03785f51121e791a6e6689a3be95"}, + {file = "kiwisolver-1.4.7-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:7bbfcb7165ce3d54a3dfbe731e470f65739c4c1f85bb1018ee912bae139e263b"}, + {file = "kiwisolver-1.4.7-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5d34eb8494bea691a1a450141ebb5385e4b69d38bb8403b5146ad279f4b30fa3"}, + {file = "kiwisolver-1.4.7-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:9242795d174daa40105c1d86aba618e8eab7bf96ba8c3ee614da8302a9f95503"}, + {file = "kiwisolver-1.4.7-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:a0f64a48bb81af7450e641e3fe0b0394d7381e342805479178b3d335d60ca7cf"}, + {file = "kiwisolver-1.4.7-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:8e045731a5416357638d1700927529e2b8ab304811671f665b225f8bf8d8f933"}, + {file = "kiwisolver-1.4.7-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:4322872d5772cae7369f8351da1edf255a604ea7087fe295411397d0cfd9655e"}, + {file = "kiwisolver-1.4.7-cp38-cp38-musllinux_1_2_ppc64le.whl", hash = "sha256:e1631290ee9271dffe3062d2634c3ecac02c83890ada077d225e081aca8aab89"}, + {file = "kiwisolver-1.4.7-cp38-cp38-musllinux_1_2_s390x.whl", hash = "sha256:edcfc407e4eb17e037bca59be0e85a2031a2ac87e4fed26d3e9df88b4165f92d"}, + {file = "kiwisolver-1.4.7-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:4d05d81ecb47d11e7f8932bd8b61b720bf0b41199358f3f5e36d38e28f0532c5"}, + {file = "kiwisolver-1.4.7-cp38-cp38-win32.whl", hash = "sha256:b38ac83d5f04b15e515fd86f312479d950d05ce2368d5413d46c088dda7de90a"}, + {file = "kiwisolver-1.4.7-cp38-cp38-win_amd64.whl", hash = "sha256:d83db7cde68459fc803052a55ace60bea2bae361fc3b7a6d5da07e11954e4b09"}, + {file = "kiwisolver-1.4.7-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:3f9362ecfca44c863569d3d3c033dbe8ba452ff8eed6f6b5806382741a1334bd"}, + {file = "kiwisolver-1.4.7-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:e8df2eb9b2bac43ef8b082e06f750350fbbaf2887534a5be97f6cf07b19d9583"}, + {file = "kiwisolver-1.4.7-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:f32d6edbc638cde7652bd690c3e728b25332acbadd7cad670cc4a02558d9c417"}, + {file = "kiwisolver-1.4.7-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:e2e6c39bd7b9372b0be21456caab138e8e69cc0fc1190a9dfa92bd45a1e6e904"}, + {file = "kiwisolver-1.4.7-cp39-cp39-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:dda56c24d869b1193fcc763f1284b9126550eaf84b88bbc7256e15028f19188a"}, + {file = "kiwisolver-1.4.7-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:79849239c39b5e1fd906556c474d9b0439ea6792b637511f3fe3a41158d89ca8"}, + {file = "kiwisolver-1.4.7-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5e3bc157fed2a4c02ec468de4ecd12a6e22818d4f09cde2c31ee3226ffbefab2"}, + {file = "kiwisolver-1.4.7-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3da53da805b71e41053dc670f9a820d1157aae77b6b944e08024d17bcd51ef88"}, + {file = "kiwisolver-1.4.7-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:8705f17dfeb43139a692298cb6637ee2e59c0194538153e83e9ee0c75c2eddde"}, + {file = "kiwisolver-1.4.7-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:82a5c2f4b87c26bb1a0ef3d16b5c4753434633b83d365cc0ddf2770c93829e3c"}, + {file = "kiwisolver-1.4.7-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:ce8be0466f4c0d585cdb6c1e2ed07232221df101a4c6f28821d2aa754ca2d9e2"}, + {file = "kiwisolver-1.4.7-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:409afdfe1e2e90e6ee7fc896f3df9a7fec8e793e58bfa0d052c8a82f99c37abb"}, + {file = "kiwisolver-1.4.7-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:5b9c3f4ee0b9a439d2415012bd1b1cc2df59e4d6a9939f4d669241d30b414327"}, + {file = "kiwisolver-1.4.7-cp39-cp39-win32.whl", hash = "sha256:a79ae34384df2b615eefca647a2873842ac3b596418032bef9a7283675962644"}, + {file = "kiwisolver-1.4.7-cp39-cp39-win_amd64.whl", hash = "sha256:cf0438b42121a66a3a667de17e779330fc0f20b0d97d59d2f2121e182b0505e4"}, + {file = "kiwisolver-1.4.7-cp39-cp39-win_arm64.whl", hash = "sha256:764202cc7e70f767dab49e8df52c7455e8de0df5d858fa801a11aa0d882ccf3f"}, + {file = "kiwisolver-1.4.7-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:94252291e3fe68001b1dd747b4c0b3be12582839b95ad4d1b641924d68fd4643"}, + {file = "kiwisolver-1.4.7-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:5b7dfa3b546da08a9f622bb6becdb14b3e24aaa30adba66749d38f3cc7ea9706"}, + {file = "kiwisolver-1.4.7-pp310-pypy310_pp73-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:bd3de6481f4ed8b734da5df134cd5a6a64fe32124fe83dde1e5b5f29fe30b1e6"}, + {file = "kiwisolver-1.4.7-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a91b5f9f1205845d488c928e8570dcb62b893372f63b8b6e98b863ebd2368ff2"}, + {file = "kiwisolver-1.4.7-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:40fa14dbd66b8b8f470d5fc79c089a66185619d31645f9b0773b88b19f7223c4"}, + {file = "kiwisolver-1.4.7-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:eb542fe7933aa09d8d8f9d9097ef37532a7df6497819d16efe4359890a2f417a"}, + {file = "kiwisolver-1.4.7-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:bfa1acfa0c54932d5607e19a2c24646fb4c1ae2694437789129cf099789a3b00"}, + {file = "kiwisolver-1.4.7-pp38-pypy38_pp73-macosx_11_0_arm64.whl", hash = "sha256:eee3ea935c3d227d49b4eb85660ff631556841f6e567f0f7bda972df6c2c9935"}, + {file = "kiwisolver-1.4.7-pp38-pypy38_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:f3160309af4396e0ed04db259c3ccbfdc3621b5559b5453075e5de555e1f3a1b"}, + {file = "kiwisolver-1.4.7-pp38-pypy38_pp73-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:a17f6a29cf8935e587cc8a4dbfc8368c55edc645283db0ce9801016f83526c2d"}, + {file = "kiwisolver-1.4.7-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:10849fb2c1ecbfae45a693c070e0320a91b35dd4bcf58172c023b994283a124d"}, + {file = "kiwisolver-1.4.7-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:ac542bf38a8a4be2dc6b15248d36315ccc65f0743f7b1a76688ffb6b5129a5c2"}, + {file = "kiwisolver-1.4.7-pp39-pypy39_pp73-macosx_10_15_x86_64.whl", hash = "sha256:8b01aac285f91ca889c800042c35ad3b239e704b150cfd3382adfc9dcc780e39"}, + {file = "kiwisolver-1.4.7-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:48be928f59a1f5c8207154f935334d374e79f2b5d212826307d072595ad76a2e"}, + {file = "kiwisolver-1.4.7-pp39-pypy39_pp73-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f37cfe618a117e50d8c240555331160d73d0411422b59b5ee217843d7b693608"}, + {file = "kiwisolver-1.4.7-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:599b5c873c63a1f6ed7eead644a8a380cfbdf5db91dcb6f85707aaab213b1674"}, + {file = "kiwisolver-1.4.7-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:801fa7802e5cfabe3ab0c81a34c323a319b097dfb5004be950482d882f3d7225"}, + {file = "kiwisolver-1.4.7-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:0c6c43471bc764fad4bc99c5c2d6d16a676b1abf844ca7c8702bdae92df01ee0"}, + {file = "kiwisolver-1.4.7.tar.gz", hash = "sha256:9893ff81bd7107f7b685d3017cc6583daadb4fc26e4a888350df530e41980a60"}, +] + [[package]] name = "litellm" version = "1.51.1" @@ -1422,6 +1715,69 @@ files = [ {file = "markupsafe-3.0.2.tar.gz", hash = "sha256:ee55d3edf80167e48ea11a923c7386f4669df67d7994554387f84e7d8b0a2bf0"}, ] +[[package]] +name = "matplotlib" +version = "3.9.2" +description = "Python plotting package" +optional = false +python-versions = ">=3.9" +files = [ + {file = "matplotlib-3.9.2-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:9d78bbc0cbc891ad55b4f39a48c22182e9bdaea7fc0e5dbd364f49f729ca1bbb"}, + {file = "matplotlib-3.9.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c375cc72229614632c87355366bdf2570c2dac01ac66b8ad048d2dabadf2d0d4"}, + {file = "matplotlib-3.9.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1d94ff717eb2bd0b58fe66380bd8b14ac35f48a98e7c6765117fe67fb7684e64"}, + {file = "matplotlib-3.9.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ab68d50c06938ef28681073327795c5db99bb4666214d2d5f880ed11aeaded66"}, + {file = "matplotlib-3.9.2-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:65aacf95b62272d568044531e41de26285d54aec8cb859031f511f84bd8b495a"}, + {file = "matplotlib-3.9.2-cp310-cp310-win_amd64.whl", hash = "sha256:3fd595f34aa8a55b7fc8bf9ebea8aa665a84c82d275190a61118d33fbc82ccae"}, + {file = "matplotlib-3.9.2-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:d8dd059447824eec055e829258ab092b56bb0579fc3164fa09c64f3acd478772"}, + {file = "matplotlib-3.9.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:c797dac8bb9c7a3fd3382b16fe8f215b4cf0f22adccea36f1545a6d7be310b41"}, + {file = "matplotlib-3.9.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d719465db13267bcef19ea8954a971db03b9f48b4647e3860e4bc8e6ed86610f"}, + {file = "matplotlib-3.9.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8912ef7c2362f7193b5819d17dae8629b34a95c58603d781329712ada83f9447"}, + {file = "matplotlib-3.9.2-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:7741f26a58a240f43bee74965c4882b6c93df3e7eb3de160126d8c8f53a6ae6e"}, + {file = "matplotlib-3.9.2-cp311-cp311-win_amd64.whl", hash = "sha256:ae82a14dab96fbfad7965403c643cafe6515e386de723e498cf3eeb1e0b70cc7"}, + {file = "matplotlib-3.9.2-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:ac43031375a65c3196bee99f6001e7fa5bdfb00ddf43379d3c0609bdca042df9"}, + {file = "matplotlib-3.9.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:be0fc24a5e4531ae4d8e858a1a548c1fe33b176bb13eff7f9d0d38ce5112a27d"}, + {file = "matplotlib-3.9.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bf81de2926c2db243c9b2cbc3917619a0fc85796c6ba4e58f541df814bbf83c7"}, + {file = "matplotlib-3.9.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f6ee45bc4245533111ced13f1f2cace1e7f89d1c793390392a80c139d6cf0e6c"}, + {file = "matplotlib-3.9.2-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:306c8dfc73239f0e72ac50e5a9cf19cc4e8e331dd0c54f5e69ca8758550f1e1e"}, + {file = "matplotlib-3.9.2-cp312-cp312-win_amd64.whl", hash = "sha256:5413401594cfaff0052f9d8b1aafc6d305b4bd7c4331dccd18f561ff7e1d3bd3"}, + {file = "matplotlib-3.9.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:18128cc08f0d3cfff10b76baa2f296fc28c4607368a8402de61bb3f2eb33c7d9"}, + {file = "matplotlib-3.9.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:4876d7d40219e8ae8bb70f9263bcbe5714415acfdf781086601211335e24f8aa"}, + {file = "matplotlib-3.9.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6d9f07a80deab4bb0b82858a9e9ad53d1382fd122be8cde11080f4e7dfedb38b"}, + {file = "matplotlib-3.9.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f7c0410f181a531ec4e93bbc27692f2c71a15c2da16766f5ba9761e7ae518413"}, + {file = "matplotlib-3.9.2-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:909645cce2dc28b735674ce0931a4ac94e12f5b13f6bb0b5a5e65e7cea2c192b"}, + {file = "matplotlib-3.9.2-cp313-cp313-win_amd64.whl", hash = "sha256:f32c7410c7f246838a77d6d1eff0c0f87f3cb0e7c4247aebea71a6d5a68cab49"}, + {file = "matplotlib-3.9.2-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:37e51dd1c2db16ede9cfd7b5cabdfc818b2c6397c83f8b10e0e797501c963a03"}, + {file = "matplotlib-3.9.2-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:b82c5045cebcecd8496a4d694d43f9cc84aeeb49fe2133e036b207abe73f4d30"}, + {file = "matplotlib-3.9.2-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f053c40f94bc51bc03832a41b4f153d83f2062d88c72b5e79997072594e97e51"}, + {file = "matplotlib-3.9.2-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dbe196377a8248972f5cede786d4c5508ed5f5ca4a1e09b44bda889958b33f8c"}, + {file = "matplotlib-3.9.2-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:5816b1e1fe8c192cbc013f8f3e3368ac56fbecf02fb41b8f8559303f24c5015e"}, + {file = "matplotlib-3.9.2-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:cef2a73d06601437be399908cf13aee74e86932a5ccc6ccdf173408ebc5f6bb2"}, + {file = "matplotlib-3.9.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:e0830e188029c14e891fadd99702fd90d317df294c3298aad682739c5533721a"}, + {file = "matplotlib-3.9.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:03ba9c1299c920964e8d3857ba27173b4dbb51ca4bab47ffc2c2ba0eb5e2cbc5"}, + {file = "matplotlib-3.9.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1cd93b91ab47a3616b4d3c42b52f8363b88ca021e340804c6ab2536344fad9ca"}, + {file = "matplotlib-3.9.2-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:6d1ce5ed2aefcdce11904fc5bbea7d9c21fff3d5f543841edf3dea84451a09ea"}, + {file = "matplotlib-3.9.2-cp39-cp39-win_amd64.whl", hash = "sha256:b2696efdc08648536efd4e1601b5fd491fd47f4db97a5fbfd175549a7365c1b2"}, + {file = "matplotlib-3.9.2-pp39-pypy39_pp73-macosx_10_15_x86_64.whl", hash = "sha256:d52a3b618cb1cbb769ce2ee1dcdb333c3ab6e823944e9a2d36e37253815f9556"}, + {file = "matplotlib-3.9.2-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:039082812cacd6c6bec8e17a9c1e6baca230d4116d522e81e1f63a74d01d2e21"}, + {file = "matplotlib-3.9.2-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6758baae2ed64f2331d4fd19be38b7b4eae3ecec210049a26b6a4f3ae1c85dcc"}, + {file = "matplotlib-3.9.2-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:050598c2b29e0b9832cde72bcf97627bf00262adbc4a54e2b856426bb2ef0697"}, + {file = "matplotlib-3.9.2.tar.gz", hash = "sha256:96ab43906269ca64a6366934106fa01534454a69e471b7bf3d79083981aaab92"}, +] + +[package.dependencies] +contourpy = ">=1.0.1" +cycler = ">=0.10" +fonttools = ">=4.22.0" +kiwisolver = ">=1.3.1" +numpy = ">=1.23" +packaging = ">=20.0" +pillow = ">=8" +pyparsing = ">=2.3.1" +python-dateutil = ">=2.7" + +[package.extras] +dev = ["meson-python (>=0.13.1)", "numpy (>=1.25)", "pybind11 (>=2.6)", "setuptools (>=64)", "setuptools_scm (>=7)"] + [[package]] name = "matplotlib-inline" version = "0.1.7" @@ -1845,6 +2201,98 @@ files = [ [package.dependencies] ptyprocess = ">=0.5" +[[package]] +name = "pillow" +version = "11.0.0" +description = "Python Imaging Library (Fork)" +optional = false +python-versions = ">=3.9" +files = [ + {file = "pillow-11.0.0-cp310-cp310-macosx_10_10_x86_64.whl", hash = "sha256:6619654954dc4936fcff82db8eb6401d3159ec6be81e33c6000dfd76ae189947"}, + {file = "pillow-11.0.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:b3c5ac4bed7519088103d9450a1107f76308ecf91d6dabc8a33a2fcfb18d0fba"}, + {file = "pillow-11.0.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a65149d8ada1055029fcb665452b2814fe7d7082fcb0c5bed6db851cb69b2086"}, + {file = "pillow-11.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:88a58d8ac0cc0e7f3a014509f0455248a76629ca9b604eca7dc5927cc593c5e9"}, + {file = "pillow-11.0.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:c26845094b1af3c91852745ae78e3ea47abf3dbcd1cf962f16b9a5fbe3ee8488"}, + {file = "pillow-11.0.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:1a61b54f87ab5786b8479f81c4b11f4d61702830354520837f8cc791ebba0f5f"}, + {file = "pillow-11.0.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:674629ff60030d144b7bca2b8330225a9b11c482ed408813924619c6f302fdbb"}, + {file = "pillow-11.0.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:598b4e238f13276e0008299bd2482003f48158e2b11826862b1eb2ad7c768b97"}, + {file = "pillow-11.0.0-cp310-cp310-win32.whl", hash = "sha256:9a0f748eaa434a41fccf8e1ee7a3eed68af1b690e75328fd7a60af123c193b50"}, + {file = "pillow-11.0.0-cp310-cp310-win_amd64.whl", hash = "sha256:a5629742881bcbc1f42e840af185fd4d83a5edeb96475a575f4da50d6ede337c"}, + {file = "pillow-11.0.0-cp310-cp310-win_arm64.whl", hash = "sha256:ee217c198f2e41f184f3869f3e485557296d505b5195c513b2bfe0062dc537f1"}, + {file = "pillow-11.0.0-cp311-cp311-macosx_10_10_x86_64.whl", hash = "sha256:1c1d72714f429a521d8d2d018badc42414c3077eb187a59579f28e4270b4b0fc"}, + {file = "pillow-11.0.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:499c3a1b0d6fc8213519e193796eb1a86a1be4b1877d678b30f83fd979811d1a"}, + {file = "pillow-11.0.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c8b2351c85d855293a299038e1f89db92a2f35e8d2f783489c6f0b2b5f3fe8a3"}, + {file = "pillow-11.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6f4dba50cfa56f910241eb7f883c20f1e7b1d8f7d91c750cd0b318bad443f4d5"}, + {file = "pillow-11.0.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:5ddbfd761ee00c12ee1be86c9c0683ecf5bb14c9772ddbd782085779a63dd55b"}, + {file = "pillow-11.0.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:45c566eb10b8967d71bf1ab8e4a525e5a93519e29ea071459ce517f6b903d7fa"}, + {file = "pillow-11.0.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:b4fd7bd29610a83a8c9b564d457cf5bd92b4e11e79a4ee4716a63c959699b306"}, + {file = "pillow-11.0.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:cb929ca942d0ec4fac404cbf520ee6cac37bf35be479b970c4ffadf2b6a1cad9"}, + {file = "pillow-11.0.0-cp311-cp311-win32.whl", hash = "sha256:006bcdd307cc47ba43e924099a038cbf9591062e6c50e570819743f5607404f5"}, + {file = "pillow-11.0.0-cp311-cp311-win_amd64.whl", hash = "sha256:52a2d8323a465f84faaba5236567d212c3668f2ab53e1c74c15583cf507a0291"}, + {file = "pillow-11.0.0-cp311-cp311-win_arm64.whl", hash = "sha256:16095692a253047fe3ec028e951fa4221a1f3ed3d80c397e83541a3037ff67c9"}, + {file = "pillow-11.0.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:d2c0a187a92a1cb5ef2c8ed5412dd8d4334272617f532d4ad4de31e0495bd923"}, + {file = "pillow-11.0.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:084a07ef0821cfe4858fe86652fffac8e187b6ae677e9906e192aafcc1b69903"}, + {file = "pillow-11.0.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8069c5179902dcdce0be9bfc8235347fdbac249d23bd90514b7a47a72d9fecf4"}, + {file = "pillow-11.0.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f02541ef64077f22bf4924f225c0fd1248c168f86e4b7abdedd87d6ebaceab0f"}, + {file = "pillow-11.0.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:fcb4621042ac4b7865c179bb972ed0da0218a076dc1820ffc48b1d74c1e37fe9"}, + {file = "pillow-11.0.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:00177a63030d612148e659b55ba99527803288cea7c75fb05766ab7981a8c1b7"}, + {file = "pillow-11.0.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:8853a3bf12afddfdf15f57c4b02d7ded92c7a75a5d7331d19f4f9572a89c17e6"}, + {file = "pillow-11.0.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:3107c66e43bda25359d5ef446f59c497de2b5ed4c7fdba0894f8d6cf3822dafc"}, + {file = "pillow-11.0.0-cp312-cp312-win32.whl", hash = "sha256:86510e3f5eca0ab87429dd77fafc04693195eec7fd6a137c389c3eeb4cfb77c6"}, + {file = "pillow-11.0.0-cp312-cp312-win_amd64.whl", hash = "sha256:8ec4a89295cd6cd4d1058a5e6aec6bf51e0eaaf9714774e1bfac7cfc9051db47"}, + {file = "pillow-11.0.0-cp312-cp312-win_arm64.whl", hash = "sha256:27a7860107500d813fcd203b4ea19b04babe79448268403172782754870dac25"}, + {file = "pillow-11.0.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:bcd1fb5bb7b07f64c15618c89efcc2cfa3e95f0e3bcdbaf4642509de1942a699"}, + {file = "pillow-11.0.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:0e038b0745997c7dcaae350d35859c9715c71e92ffb7e0f4a8e8a16732150f38"}, + {file = "pillow-11.0.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0ae08bd8ffc41aebf578c2af2f9d8749d91f448b3bfd41d7d9ff573d74f2a6b2"}, + {file = "pillow-11.0.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d69bfd8ec3219ae71bcde1f942b728903cad25fafe3100ba2258b973bd2bc1b2"}, + {file = "pillow-11.0.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:61b887f9ddba63ddf62fd02a3ba7add935d053b6dd7d58998c630e6dbade8527"}, + {file = "pillow-11.0.0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:c6a660307ca9d4867caa8d9ca2c2658ab685de83792d1876274991adec7b93fa"}, + {file = "pillow-11.0.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:73e3a0200cdda995c7e43dd47436c1548f87a30bb27fb871f352a22ab8dcf45f"}, + {file = "pillow-11.0.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:fba162b8872d30fea8c52b258a542c5dfd7b235fb5cb352240c8d63b414013eb"}, + {file = "pillow-11.0.0-cp313-cp313-win32.whl", hash = "sha256:f1b82c27e89fffc6da125d5eb0ca6e68017faf5efc078128cfaa42cf5cb38798"}, + {file = "pillow-11.0.0-cp313-cp313-win_amd64.whl", hash = "sha256:8ba470552b48e5835f1d23ecb936bb7f71d206f9dfeee64245f30c3270b994de"}, + {file = "pillow-11.0.0-cp313-cp313-win_arm64.whl", hash = "sha256:846e193e103b41e984ac921b335df59195356ce3f71dcfd155aa79c603873b84"}, + {file = "pillow-11.0.0-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:4ad70c4214f67d7466bea6a08061eba35c01b1b89eaa098040a35272a8efb22b"}, + {file = "pillow-11.0.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:6ec0d5af64f2e3d64a165f490d96368bb5dea8b8f9ad04487f9ab60dc4bb6003"}, + {file = "pillow-11.0.0-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c809a70e43c7977c4a42aefd62f0131823ebf7dd73556fa5d5950f5b354087e2"}, + {file = "pillow-11.0.0-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:4b60c9520f7207aaf2e1d94de026682fc227806c6e1f55bba7606d1c94dd623a"}, + {file = "pillow-11.0.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:1e2688958a840c822279fda0086fec1fdab2f95bf2b717b66871c4ad9859d7e8"}, + {file = "pillow-11.0.0-cp313-cp313t-win32.whl", hash = "sha256:607bbe123c74e272e381a8d1957083a9463401f7bd01287f50521ecb05a313f8"}, + {file = "pillow-11.0.0-cp313-cp313t-win_amd64.whl", hash = "sha256:5c39ed17edea3bc69c743a8dd3e9853b7509625c2462532e62baa0732163a904"}, + {file = "pillow-11.0.0-cp313-cp313t-win_arm64.whl", hash = "sha256:75acbbeb05b86bc53cbe7b7e6fe00fbcf82ad7c684b3ad82e3d711da9ba287d3"}, + {file = "pillow-11.0.0-cp39-cp39-macosx_10_10_x86_64.whl", hash = "sha256:2e46773dc9f35a1dd28bd6981332fd7f27bec001a918a72a79b4133cf5291dba"}, + {file = "pillow-11.0.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:2679d2258b7f1192b378e2893a8a0a0ca472234d4c2c0e6bdd3380e8dfa21b6a"}, + {file = "pillow-11.0.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:eda2616eb2313cbb3eebbe51f19362eb434b18e3bb599466a1ffa76a033fb916"}, + {file = "pillow-11.0.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:20ec184af98a121fb2da42642dea8a29ec80fc3efbaefb86d8fdd2606619045d"}, + {file = "pillow-11.0.0-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:8594f42df584e5b4bb9281799698403f7af489fba84c34d53d1c4bfb71b7c4e7"}, + {file = "pillow-11.0.0-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:c12b5ae868897c7338519c03049a806af85b9b8c237b7d675b8c5e089e4a618e"}, + {file = "pillow-11.0.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:70fbbdacd1d271b77b7721fe3cdd2d537bbbd75d29e6300c672ec6bb38d9672f"}, + {file = "pillow-11.0.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:5178952973e588b3f1360868847334e9e3bf49d19e169bbbdfaf8398002419ae"}, + {file = "pillow-11.0.0-cp39-cp39-win32.whl", hash = "sha256:8c676b587da5673d3c75bd67dd2a8cdfeb282ca38a30f37950511766b26858c4"}, + {file = "pillow-11.0.0-cp39-cp39-win_amd64.whl", hash = "sha256:94f3e1780abb45062287b4614a5bc0874519c86a777d4a7ad34978e86428b8dd"}, + {file = "pillow-11.0.0-cp39-cp39-win_arm64.whl", hash = "sha256:290f2cc809f9da7d6d622550bbf4c1e57518212da51b6a30fe8e0a270a5b78bd"}, + {file = "pillow-11.0.0-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:1187739620f2b365de756ce086fdb3604573337cc28a0d3ac4a01ab6b2d2a6d2"}, + {file = "pillow-11.0.0-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:fbbcb7b57dc9c794843e3d1258c0fbf0f48656d46ffe9e09b63bbd6e8cd5d0a2"}, + {file = "pillow-11.0.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5d203af30149ae339ad1b4f710d9844ed8796e97fda23ffbc4cc472968a47d0b"}, + {file = "pillow-11.0.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:21a0d3b115009ebb8ac3d2ebec5c2982cc693da935f4ab7bb5c8ebe2f47d36f2"}, + {file = "pillow-11.0.0-pp310-pypy310_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:73853108f56df97baf2bb8b522f3578221e56f646ba345a372c78326710d3830"}, + {file = "pillow-11.0.0-pp310-pypy310_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:e58876c91f97b0952eb766123bfef372792ab3f4e3e1f1a2267834c2ab131734"}, + {file = "pillow-11.0.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:224aaa38177597bb179f3ec87eeefcce8e4f85e608025e9cfac60de237ba6316"}, + {file = "pillow-11.0.0-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:5bd2d3bdb846d757055910f0a59792d33b555800813c3b39ada1829c372ccb06"}, + {file = "pillow-11.0.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:375b8dd15a1f5d2feafff536d47e22f69625c1aa92f12b339ec0b2ca40263273"}, + {file = "pillow-11.0.0-pp39-pypy39_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:daffdf51ee5db69a82dd127eabecce20729e21f7a3680cf7cbb23f0829189790"}, + {file = "pillow-11.0.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:7326a1787e3c7b0429659e0a944725e1b03eeaa10edd945a86dead1913383944"}, + {file = "pillow-11.0.0.tar.gz", hash = "sha256:72bacbaf24ac003fea9bff9837d1eedb6088758d41e100c1552930151f677739"}, +] + +[package.extras] +docs = ["furo", "olefile", "sphinx (>=8.1)", "sphinx-copybutton", "sphinx-inline-tabs", "sphinxext-opengraph"] +fpx = ["olefile"] +mic = ["olefile"] +tests = ["check-manifest", "coverage", "defusedxml", "markdown2", "olefile", "packaging", "pyroma", "pytest", "pytest-cov", "pytest-timeout"] +typing = ["typing-extensions"] +xmp = ["defusedxml"] + [[package]] name = "pkginfo" version = "1.10.0" @@ -2239,6 +2687,20 @@ files = [ [package.extras] windows-terminal = ["colorama (>=0.4.6)"] +[[package]] +name = "pyparsing" +version = "3.2.0" +description = "pyparsing module - Classes and methods to define and execute parsing grammars" +optional = false +python-versions = ">=3.9" +files = [ + {file = "pyparsing-3.2.0-py3-none-any.whl", hash = "sha256:93d9577b88da0bbea8cc8334ee8b918ed014968fd2ec383e868fb8afb1ccef84"}, + {file = "pyparsing-3.2.0.tar.gz", hash = "sha256:cbf74e27246d595d9a74b186b810f6fbb86726dbf3b9532efb343f6d7294fe9c"}, +] + +[package.extras] +diagrams = ["jinja2", "railroad-diagrams"] + [[package]] name = "pytest" version = "8.3.3" @@ -3340,4 +3802,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "20e3c109fa84799358c28bde59cdc894bc1dcc28a47736dbf30d6ba63f8db705" +content-hash = "f7cb5ef17d0f570775881d92e959a2643fdf594183aa6d3eb63cb34307ab06dc" diff --git a/pyproject.toml b/pyproject.toml index ea0e6c20..8d834bc2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,6 +22,7 @@ pytest-asyncio = "^0.24.0" pandas = "^2.2.3" xxhash = "^3.5.0" tqdm = "^4.67.0" +matplotlib = "^3.9.2" [tool.poetry.group.dev.dependencies] black = "^24.2.0" diff --git a/src/bespokelabs/curator/request_processor/openai_batch_request_processor.py b/src/bespokelabs/curator/request_processor/openai_batch_request_processor.py index ab5e757d..eabeb424 100644 --- a/src/bespokelabs/curator/request_processor/openai_batch_request_processor.py +++ b/src/bespokelabs/curator/request_processor/openai_batch_request_processor.py @@ -2,15 +2,11 @@ import json import logging import os -import re -import time from dataclasses import dataclass, field from typing import Any, Callable, Dict, Optional, Set, Tuple, TypeVar - -import aiohttp -import requests -import tiktoken -from tqdm import tqdm +from openai import AsyncOpenAI +import pandas as pd +import matplotlib.pyplot as plt from bespokelabs.curator.dataset import Dataset from bespokelabs.curator.request_processor.base_request_processor import ( @@ -166,552 +162,250 @@ def run( Returns: Dataset: Completed dataset """ - requests_files = self.create_request_files(dataset, working_dir) - responses_files = [ - f"{working_dir}/responses_{i}.jsonl" for i in range(len(requests_files)) - ] - - rate_limits = self.get_rate_limits() - rpm = rate_limits["max_requests_per_minute"] - tpm = rate_limits["max_tokens_per_minute"] - - token_encoding_name = get_token_encoding_name(self.prompt_formatter.model_name) - - # NOTE(Ryan): If you wanted to do this on batches, you could run a for loop here about request_files. Although I don't recommend it because you are waiting for straggler requests to finish for each batch. - # NOTE(Ryan): And if you wanted to do batches in parallel, you would have to divide rpm and tpm by the number of parallel batches. - # TODO(Ryan): Can we abstract retries from process_api_requests_from_file so you can use it even if you use liteLLM. - for i in range(len(requests_files)): - asyncio.run( - self.process_api_requests_from_file( - requests_filepath=requests_files[i], - save_filepath=responses_files[i], - request_url=self.url, - max_requests_per_minute=rpm, - max_tokens_per_minute=tpm, - token_encoding_name=token_encoding_name, - max_attempts=5, - resume=True, # detects existing jobs and resume from there - ) - ) - - return Dataset.from_working_dir(working_dir, self.prompt_formatter) - - async def process_api_requests_from_file( - self, - requests_filepath: str, - save_filepath: str, - request_url: str, - max_requests_per_minute: float, - max_tokens_per_minute: float, - token_encoding_name: str, - max_attempts: int, - resume: bool, - resume_no_retry: bool = False, - ) -> None: - """Processes API requests in parallel, throttling to stay under rate limits.""" - # constants - seconds_to_pause_after_rate_limit_error = 15 - seconds_to_sleep_each_loop = ( - 0.001 # 1 ms limits max throughput to 1,000 requests per second + requests_files = self.create_request_files( + dataset, map, working_dir, batch_size ) + batch_objects_file = f"{working_dir}/batch_objects.jsonl" - # infer API endpoint and construct request header - api_endpoint = api_endpoint_from_url(self.url) - request_header = {"Authorization": f"Bearer {self.api_key}"} - # use api-key header for Azure deployments - if "/deployments" in self.url: - request_header = {"api-key": f"{self.api_key}"} - - # initialize trackers - queue_of_requests_to_retry = asyncio.Queue() - task_id_generator = ( - task_id_generator_function() - ) # generates integer IDs of 0, 1, 2, ... - status_tracker = ( - StatusTracker() - ) # single instance to track a collection of variables - next_request = None # variable to hold the next request to call - - # initialize available capacity counts - available_request_capacity = max_requests_per_minute - available_token_capacity = max_tokens_per_minute - last_update_time = time.time() - - # initialize flags - file_not_finished = True # after file is empty, we'll skip reading it - logging.debug(f"Initialization complete.") - - completed_request_ids: Set[int] = set() - if os.path.exists(save_filepath): - if resume: - # save all successfully completed requests to a temporary file, then overwrite the original file with the temporary file - logging.debug(f"Resuming progress from existing file: {save_filepath}") - logging.debug( - f"Removing all failed requests from {save_filepath} so they can be retried" - ) - temp_filepath = f"{save_filepath}.temp" - num_previously_failed_requests = 0 - with open(save_filepath, "r") as input_file, open( - temp_filepath, "w" - ) as output_file: - for line in input_file: - response = GenericResponse.model_validate_json(line) - if response.errors: - # this means that the request failed and we have a list of errors - logging.debug( - f"Request {response.row_idx} previously failed due to errors: {response.errors}, removing from output and will retry" - ) - num_previously_failed_requests += 1 - else: - completed_request_ids.add(response.row_idx) - output_file.write(line) - logging.info( - f"Found {len(completed_request_ids)} completed requests and {num_previously_failed_requests} previously failed requests" - ) - logging.info( - "Failed requests and remaining requests will now be processed." - ) - os.replace(temp_filepath, save_filepath) - elif resume_no_retry: - logging.warning( - f"Resuming progress from existing file: {save_filepath}, without retrying failed requests" - ) - num_previously_failed_requests = 0 - with open(save_filepath, "r") as input_file, open( - temp_filepath, "w" - ) as output_file: - for line in tqdm(input_file, desc="Processing existing requests"): - data = json.loads(line) - if isinstance(data[1], list): - # this means that the request failed and we have a list of errors - logging.debug( - f"Request {data[2].get('request_idx')} previously failed due to errors: {data[1]}, will NOT retry" - ) - num_previously_failed_requests += 1 - completed_request_ids.add(data[2].get("request_idx")) - logging.info( - f"Found {len(completed_request_ids)} total requests and {num_previously_failed_requests} previously failed requests" - ) - logging.info("Remaining requests will now be processed.") - else: - user_input = input( - f"File {save_filepath} already exists.\nTo resume if there are remaining requests without responses, run with --resume flag.\nOverwrite? (Y/n): " - ) - if user_input.lower() != "y" and user_input.lower() != "": - logging.info("Aborting operation.") - return - - # initialize file reading - with open(requests_filepath) as file: - # `requests` will provide requests one at a time - requests = file.__iter__() - logging.debug(f"File opened. Entering main loop") - - # Count total number of requests - total_requests = sum(1 for _ in open(requests_filepath)) - if total_requests == len(completed_request_ids): - logging.debug( - "All requests have already been completed so will just reuse cache." - ) - return - - # Create progress bar - pbar = tqdm( - total=total_requests, desc="Processing parallel requests to OpenAI" + # TODO(Ryan): we should have an easy way to cancel all batches in batch_objects.jsonl if the user realized they made a mistake + if os.path.exists(batch_objects_file): + logging.warning( + f"Batch objects file already exists, skipping batch submission and resuming: {batch_objects_file}" ) + else: + # upload requests files and submit batches + self.async_client = AsyncOpenAI() - connector = aiohttp.TCPConnector(limit=10 * max_requests_per_minute) - async with aiohttp.ClientSession( - connector=connector - ) as session: # Initialize ClientSession here - while True: - # get next request (if one is not already waiting for capacity) - if next_request is None: - if not queue_of_requests_to_retry.empty(): - next_request = queue_of_requests_to_retry.get_nowait() - logging.debug( - f"Retrying request {next_request.task_id}: {next_request}" - ) - elif file_not_finished: - try: - # get new request - request_json = json.loads(next(requests)) - request_idx = request_json["metadata"]["request_idx"] - if resume and request_idx in completed_request_ids: - logging.debug( - f"Skipping already completed request {request_idx}" - ) - status_tracker.num_tasks_already_completed += 1 - continue - next_request = APIRequest( - task_id=next(task_id_generator), - request_json=request_json, - token_consumption=num_tokens_consumed_from_request( - request_json, api_endpoint, token_encoding_name - ), - attempts_left=max_attempts, - metadata=request_json.pop("metadata", None), - ) - status_tracker.num_tasks_started += 1 - status_tracker.num_tasks_in_progress += 1 - logging.debug( - f"Reading request {next_request.task_id}: {next_request}" - ) - except StopIteration: - # if file runs out, set flag to stop reading it - logging.debug("Read file exhausted") - file_not_finished = False - - # update available capacity - current_time = time.time() - seconds_since_update = current_time - last_update_time - available_request_capacity = min( - available_request_capacity - + max_requests_per_minute * seconds_since_update / 60.0, - max_requests_per_minute, - ) - available_token_capacity = min( - available_token_capacity - + max_tokens_per_minute * seconds_since_update / 60.0, - max_tokens_per_minute, - ) - last_update_time = current_time - - # if enough capacity available, call API - if next_request: - next_request_tokens = next_request.token_consumption - if ( - available_request_capacity >= 1 - and available_token_capacity >= next_request_tokens - ): - # update counters - available_request_capacity -= 1 - available_token_capacity -= next_request_tokens - next_request.attempts_left -= 1 - - # call API - asyncio.create_task( - next_request.call_api( - session=session, - request_url=request_url, - request_header=request_header, - retry_queue=queue_of_requests_to_retry, - save_filepath=save_filepath, - status_tracker=status_tracker, - get_generic_response=self.get_generic_response, - ) - ) - next_request = None # reset next_request to empty - else: - logging.debug( - f"Not Enough Capacity: Request tokens: {next_request_tokens}, Available request capacity: {available_request_capacity}, Available token capacity: {available_token_capacity}" - ) - - # Update progress bar when a task is completed - total_completed = ( - status_tracker.num_tasks_succeeded - + status_tracker.num_tasks_failed - + status_tracker.num_tasks_already_completed - ) - if total_completed > pbar.n: - pbar.update(total_completed - pbar.n) + # asyncio gather preserves order + async def submit_all_batches(): + tasks = [ + self.asubmit_batch(requests_files[i]) + for i in range(len(requests_files)) + ] + return await asyncio.gather(*tasks) - # if all tasks are finished, break - if status_tracker.num_tasks_in_progress == 0: - break + batch_objects = asyncio.run(submit_all_batches()) - # main loop sleeps briefly so concurrent tasks can run - await asyncio.sleep(seconds_to_sleep_each_loop) + with open(batch_objects_file, "w") as f: + # NOTE(Ryan): we can also store the request_file_name in this object here, instead of in the metadata during batch submission. Can find a nice abstraction across other batch APIs (e.g. claude) + for obj in batch_objects: + f.write(json.dumps(obj.model_dump()) + "\n") + logging.info(f"Batch objects written to {batch_objects_file}") - # if a rate limit error was hit recently, pause to cool down - seconds_since_rate_limit_error = ( - time.time() - status_tracker.time_of_last_rate_limit_error - ) - if ( - seconds_since_rate_limit_error - < seconds_to_pause_after_rate_limit_error - ): - remaining_seconds_to_pause = ( - seconds_to_pause_after_rate_limit_error - - seconds_since_rate_limit_error - ) - await asyncio.sleep(remaining_seconds_to_pause) - # ^e.g., if pause is 15 seconds and final limit was hit 5 seconds ago - logging.warn( - f"Pausing to cool down until {time.ctime(status_tracker.time_of_last_rate_limit_error + seconds_to_pause_after_rate_limit_error)}" - ) - - # Close the progress bar - pbar.close() - - # after finishing, log final status - logging.info( - f"""Parallel processing complete. Results saved to {save_filepath}""" - ) + # TODO(Ryan): Actually do accounting for tokens, so rate limits enforced locally. + # NOTE(Ryan): Although this isn't really practical since the limits are for an entire day and an entire organization. Maybe skip this and just recognize what a rate limit error for batching looks like (need to try this on a low tier account). + # rate_limits = self.get_rate_limits() + # tpd = rate_limits["max_tokens_per_day"] + # token_encoding_name = get_token_encoding_name(self.model) - logging.info(f"Status tracker: {status_tracker}") + # TODO(Ryan): based on the files that are downloaded, update completed_ids. If any are errors, try to resubmit (depending on error type). + # TODO(Ryan): This creates responses_0.jsonl, responses_1.jsonl, etc. errors named same way? or errors_0.jsonl, errors_1.jsonl? + # TODO(Ryan): retries, resubmits on lagging batches - need to study this a little closer + # TODO(Ryan): likely can add some logic for smarter check_interval based on batch size and if the batch has started or not, fine to do a dumb ping for now + batch_watcher = BatchWatcher(working_dir, check_interval=self.check_interval) - if status_tracker.num_tasks_failed > 0: - logging.warning( - f"{status_tracker.num_tasks_failed} / {status_tracker.num_tasks_started} requests failed. Errors logged to {save_filepath}." - ) - if status_tracker.num_rate_limit_errors > 0: - logging.warning( - f"{status_tracker.num_rate_limit_errors} rate limit errors received. Consider running at a lower rate." - ) + asyncio.run(batch_watcher.watch()) + # TODO(Ryan): Add back in here the dataset creation + return dataset -@dataclass -class StatusTracker: - """Stores metadata about the script's progress. Only one instance is created.""" - num_tasks_already_completed: int = 0 - num_tasks_started: int = 0 - num_tasks_in_progress: int = 0 # script ends when this reaches 0 - num_tasks_succeeded: int = 0 - num_tasks_failed: int = 0 - num_rate_limit_errors: int = 0 - num_api_errors: int = 0 # excluding rate limit errors, counted above - num_other_errors: int = 0 - time_of_last_rate_limit_error: int = 0 # used to cool off after hitting rate limits +@Dataset +class BatchWatcher: + def __init__(self, working_dir: str, check_interval: int = 60) -> None: + """Initialize BatchWatcher with batch objects file and check interval. + Args: + batch_objects_file (str): Path to the batch objects JSON file. + check_interval (int): Time interval (in seconds) to check batch status. + """ + self.client = AsyncOpenAI() + with open(f"{working_dir}/batch_objects.jsonl", "r") as f: + self.batch_objects = [json.loads(line) for line in f] + self.batch_ids = [obj["id"] for obj in self.batch_objects] + self.batch_id_to_request_file_name = { + obj["id"]: obj["metadata"]["request_file_name"] + for obj in self.batch_objects + } + self.batches = [] -@dataclass -class APIRequest: - """Stores an API request's inputs, outputs, and other metadata. Contains a method to make an API call.""" + async def check_batch_status(self, batch_id: str) -> tuple[str, str]: + """Check the status of a batch by its ID. - task_id: int - request_json: dict - token_consumption: int - attempts_left: int - metadata: dict - result: list = field(default_factory=list) + Args: + batch_id (str): The ID of the batch to check. - async def call_api( - self, - session: aiohttp.ClientSession, - request_url: str, - request_header: dict, - retry_queue: asyncio.Queue, - save_filepath: str, - status_tracker: StatusTracker, - get_generic_response: Callable[[list], dict], - ) -> None: - """Calls the OpenAI API and saves results.""" - logging.debug(f"Starting request #{self.task_id}") - error = None - try: - async with session.post( - url=request_url, headers=request_header, json=self.request_json - ) as response: - response = await response.json() - if "error" in response: - logging.warning( - f"Request {self.task_id} failed with error {response['error']}" - ) - status_tracker.num_api_errors += 1 - error = response - if "rate limit" in response["error"].get("message", "").lower(): - status_tracker.time_of_last_rate_limit_error = time.time() - status_tracker.num_rate_limit_errors += 1 - status_tracker.num_api_errors -= ( - 1 # rate limit errors are counted separately + Returns: + tuple[str, str]: The batch ID and its status. + """ + batch = await self.client.batches.retrieve(batch_id) + logging.info( + f"Batch {batch_id} status: {batch.status} requests: {batch.request_counts.completed}/{batch.request_counts.failed}/{batch.request_counts.total} completed/failed/total" + ) + return batch_id, batch + + async def watch(self) -> None: + """Monitor the status of batches until all are completed (includes successfully, failed, expired or cancelled).""" + completed_batches = {} + while len(completed_batches) < len(self.batch_ids): + status_tasks = [] + for batch_id in self.batch_ids: + if batch_id not in completed_batches: + status_tasks.append(self.check_batch_status(batch_id)) + + batches = await asyncio.gather(*status_tasks) + newly_completed_batches = [] + for batch_id, batch in batches: + if batch.status in ["completed", "failed", "expired", "cancelled"]: + logging.info( + f"Batch {batch_id} processing finished with status: {batch.status}" ) + completed_batches[batch_id] = batch + newly_completed_batches.append(batch) - except ( - Exception - ) as e: # catching naked exceptions is bad practice, but in this case we'll log & save them - logging.warning( - f"Request {self.task_id} failed with Exception {e}, attempts left {self.attempts_left}" - ) - status_tracker.num_other_errors += 1 - error = e - if error: - self.result.append(error) - if self.attempts_left: - retry_queue.put_nowait(self) - else: - logging.error( - f"Request {self.request_json} failed after all attempts. Saving errors: {self.result}" - ) - data = GenericResponse( - request=self.request_json, - errors=[str(e) for e in self.result], - row=self.metadata["sample"], - row_idx=self.metadata["request_idx"], + # NOTE(Ryan): Now downloading after each check, instead of waiting until all are completed + tasks = [ + self.download_batch_result_file(batch) + for batch in newly_completed_batches + ] + await asyncio.gather(*tasks) + + if len(completed_batches) < len(self.batch_ids): + logging.info( + f"Remaining batches processing {len(self.batch_ids) - len(completed_batches)}/{len(self.batch_ids)}" ) - append_generic_response(data, save_filepath) - status_tracker.num_tasks_in_progress -= 1 - status_tracker.num_tasks_failed += 1 - else: - data = get_generic_response( - {"response": response, "metadata": self.metadata} - ) - data.raw_response = response - data.request = self.request_json - append_generic_response(data, save_filepath) - status_tracker.num_tasks_in_progress -= 1 - status_tracker.num_tasks_succeeded += 1 - logging.debug(f"Request {self.task_id} saved to {save_filepath}") - - -def get_token_encoding_name(model: str) -> str: - """Get the token encoding name for a given model.""" - if "gpt" in model: - return tiktoken.encoding_for_model(model).name - else: - logging.warning( - f'Token encoding name for model "{model}" not implemented, using cl100k_base for token counting' - ) - return "cl100k_base" - - -def get_rate_limits(model: str, request_url: str, api_key: str) -> Tuple[int, int]: - """ - Function to get rate limits for a given annotator. Makes a single request to openAI API - and gets the rate limits from the response headers. These rate limits vary per model - and are determined by your organization's usage tier. View the following: - https://platform.openai.com/docs/guides/rate-limits/usage-tiers - https://platform.openai.com/settings/organization/limits - - Args: - model (str): The model for which to get the rate limits. - request_url (str): The request URL for which to get the rate limits. - - Returns: - Tuple[int, int]: The maximum number of requests and tokens per minute. - """ - if "api.openai.com" in request_url: - # Send a dummy request to get rate limit information - response = requests.post( - request_url, - headers={"Authorization": f"Bearer {api_key}"}, - json={"model": model, "messages": []}, - ) - # Extract rate limit information from headers - max_requests = int(response.headers.get("x-ratelimit-limit-requests", 30_000)) - max_tokens = int(response.headers.get("x-ratelimit-limit-tokens", 150_000_000)) - elif "api.sambanova.ai" in request_url: - # Send a dummy request to get rate limit information - max_requests = 50 - max_tokens = 100_000_000 - else: - raise NotImplementedError( - f'Rate limits for API endpoint "{request_url}" not implemented' - ) + logging.info(f"Sleeping for {self.check_interval} seconds...") + await asyncio.sleep(self.check_interval) - return max_requests, max_tokens + self.batches = completed_batches.values() + async def download_batch_result_file(self, batch) -> str: + """Download the result of a completed batch to file. -def get_api_key(request_url: str) -> str: - """Get the API key for a given request URL.""" - if "api.openai.com" in request_url: - return os.getenv("OPENAI_API_KEY") - elif "api.sambanova.ai" in request_url: - return os.getenv("SAMBANOVA_API_KEY") - else: - raise NotImplementedError( - f'Default API key environment variable for API endpoint "{request_url}" not implemented' - ) + Args: + batch: The batch object to download results from. + Returns: + str: Path to the downloaded result file. + """ + if batch.status == "completed" and batch.output_file_id: + file_content = await self.client.files.content(batch.output_file_id) + elif batch.status == "failed" and batch.error_file_id: + file_content = await self.client.files.content(batch.error_file_id) + elif batch.status == "cancelled" or batch.status == "expired": + logging.warning(f"Batch {batch.id} was cancelled or expired") + return None -def api_endpoint_from_url(request_url: str) -> str: - """Extract the API endpoint from the request URL. - This is used to determine the number of tokens consumed by the request. - """ - - # OpenAI API - match = re.search("^https://[^/]+/v\\d+/(.+)$", request_url) - if match: - return match[1] - - # for Azure OpenAI deployment urls - match = re.search( - r"^https://[^/]+/openai/deployments/[^/]+/(.+?)(\?|$)", request_url - ) - if match: - return match[1] - - # Catch all for other API endpoints using OpenAI OpenAPI format - if "chat/completions" in request_url: - return "chat/completions" - elif "completions" in request_url: - return "completions" - else: - raise NotImplementedError( - f'API endpoint "{request_url}" not implemented in this script' + # NOTE(Ryan): This is so the naming is consistent with the request file naming + request_file_id = ( + self.batch_id_to_request_file_name[batch.id].split("/")[-1].split("_", 1)[1] ) + output_path = f"{self.working_dir}/responses_{request_file_id}" + with open(output_path, "wb") as f: + f.write(file_content.content) + return output_path + # NOTE(Ryan): This could be useful for very small batches and overall total requests not too large + async def download_batch_result_in_memory(self, batch) -> list[str]: + """Download the result of a completed batch. -def append_generic_response(data: GenericResponse, filename: str) -> None: - """Append a json payload to the end of a jsonl file.""" - json_string = json.dumps(data.model_dump()) - with open(filename, "a") as f: - f.write(json_string + "\n") - - -def num_tokens_consumed_from_request( - request_json: dict, - api_endpoint: str, - token_encoding_name: str, -): - """Count the number of tokens in the request. Only supports completion and embedding requests.""" - encoding = tiktoken.get_encoding(token_encoding_name) - # if completions request, tokens = prompt + n * max_tokens - if api_endpoint.endswith("completions"): - max_tokens = request_json.get("max_tokens", 15) - n = request_json.get("n", 1) - completion_tokens = n * max_tokens - - # chat completions - if api_endpoint.startswith("chat/"): - num_tokens = 0 - for message in request_json["messages"]: - num_tokens += 4 # every message follows {role/name}\n{content}\n - for key, value in message.items(): - num_tokens += len(encoding.encode(value)) - if key == "name": # if there's a name, the role is omitted - num_tokens -= 1 # role is always required and always 1 token - num_tokens += 2 # every reply is primed with assistant - return num_tokens + completion_tokens - # normal completions - else: - prompt = request_json["prompt"] - if isinstance(prompt, str): # single prompt - prompt_tokens = len(encoding.encode(prompt)) - num_tokens = prompt_tokens + completion_tokens - return num_tokens - elif isinstance(prompt, list): # multiple prompts - prompt_tokens = sum([len(encoding.encode(p)) for p in prompt]) - num_tokens = prompt_tokens + completion_tokens * len(prompt) - return num_tokens - else: - raise TypeError( - 'Expecting either string or list of strings for "prompt" field in completion request' - ) - # if embeddings request, tokens = input tokens - elif api_endpoint == "embeddings": - input = request_json["input"] - if isinstance(input, str): # single input - num_tokens = len(encoding.encode(input)) - return num_tokens - elif isinstance(input, list): # multiple inputs - num_tokens = sum([len(encoding.encode(i)) for i in input]) - return num_tokens + Args: + batch: The batch object to download results from. + + Returns: + list[str]: Lines of the downloaded result. + """ + if batch.status == "completed" and batch.output_file_id: + file_content = await self.client.files.content(batch.output_file_id) + return file_content.text.splitlines() + return [] + + # NOTE(Ryan): This could be useful for very small batches and overall total requests not too large + async def download_results_in_memory(self, output_path: str) -> None: + """Download results of all batches and save to a specified path. + + Args: + output_path (str): Path to save the downloaded results. + """ + tasks = [self.download_batch_result(batch) for batch in self.batches] + results = await asyncio.gather(*tasks) + + all_results = [ + item for sublist in results for item in sublist + ] # Flatten the list of lists + + with open(output_path, "w") as f: + for result in all_results: + f.write(result + "\n") + logging.info(f"All batch results downloaded and saved to: {output_path}") + + async def download_errors(self, error_path: str, batch_id: str) -> None: + """Download error file for a specific batch if available. + + Args: + error_path (str): Path to save the error file. + batch_id (str): The ID of the batch to download errors from. + """ + batch = await self.client.batches.retrieve(batch_id) + if batch.error_file_id: + file_content = await self.client.files.content(batch.error_file_id) + with open(error_path, "wb") as f: + f.write(file_content.content) + logging.info(f"Batch errors downloaded and saved to: {error_path}") else: - raise TypeError( - 'Expecting either string or list of strings for "inputs" field in embedding request' - ) - # more logic needed to support other API calls (e.g., edits, inserts, DALL-E) - else: - raise NotImplementedError( - f'API endpoint "{api_endpoint}" not implemented in this script' - ) + logging.info(f"No error file available for batch {batch_id}.") + async def plot_completion_data(self, output_dir: str) -> None: + """Save plots visualizing completion times for the batches. -def task_id_generator_function(): - """Generate integers 0, 1, 2, and so on.""" - task_id = 0 - while True: - yield task_id - task_id += 1 + Args: + output_dir (str): Directory to save the plots. + """ + completion_times = [] + completion_dates = [] + + for batch_id in self.batch_ids: + batch = await self.client.batches.retrieve(batch_id) + if batch.status == "completed": + duration = ( + batch.completed_at - batch.created_at + ) / 60 # Convert to minutes + completion_times.append(duration) + completion_dates.append(batch.completed_at) + + # Create a DataFrame for plotting + df = pd.DataFrame( + { + "Completion Time (min)": completion_times, # Update label to minutes + "Completion Date": pd.to_datetime(completion_dates, unit="s"), + } + ) + + # Histogram of completion durations + plt.figure(figsize=(12, 6)) + plt.hist(df["Completion Time (min)"], bins=20, color="blue", alpha=0.7) + plt.title("Histogram of Completion Durations") + plt.xlabel("Duration (minutes)") # Update label to minutes + plt.ylabel("Frequency") + plt.grid(axis="y") + plt.savefig( + os.path.join(output_dir, "completion_durations_histogram.png") + ) # Save the histogram + plt.close() # Close the plot + + # Cumulative plot of completed jobs over time + df.sort_values("Completion Date", inplace=True) + df["Cumulative Completed"] = range(1, len(df) + 1) + + plt.figure(figsize=(12, 6)) + plt.plot( + df["Completion Date"], df["Cumulative Completed"], marker="o", color="green" + ) + plt.title("Cumulative Completed Jobs Over Time") + plt.xlabel("Completion Date") + plt.ylabel("Cumulative Completed Jobs") + plt.grid() + plt.savefig( + os.path.join(output_dir, "cumulative_completed_jobs.png") + ) # Save the cumulative plot + plt.close() # Close the plot diff --git a/tests/test_batch.py b/tests/test_batch.py index 44251008..37a325ae 100644 --- a/tests/test_batch.py +++ b/tests/test_batch.py @@ -86,6 +86,7 @@ def load_ShareGPT_dataset_as_IT(dataset_name: str, truncate: int = None) -> Data batch_size=args.batch_size, check_interval=args.check_interval, ) + reannotated_dataset = reannotate_prompter(dataset, request_processor) dataset = reannotated_dataset.to_huggingface() From 263f76ea41a4c9f05f0f39263005436bcf2219aa Mon Sep 17 00:00:00 2001 From: Ryan Marten Date: Thu, 7 Nov 2024 18:23:45 -0800 Subject: [PATCH 07/18] update batch_size to use class attribute --- .../curator/request_processor/openai_batch_request_processor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/bespokelabs/curator/request_processor/openai_batch_request_processor.py b/src/bespokelabs/curator/request_processor/openai_batch_request_processor.py index eabeb424..97bf30db 100644 --- a/src/bespokelabs/curator/request_processor/openai_batch_request_processor.py +++ b/src/bespokelabs/curator/request_processor/openai_batch_request_processor.py @@ -163,7 +163,7 @@ def run( Dataset: Completed dataset """ requests_files = self.create_request_files( - dataset, map, working_dir, batch_size + dataset, map, working_dir, self.batch_size ) batch_objects_file = f"{working_dir}/batch_objects.jsonl" From bb87fcfd3591954ad5d99dd283c0428bdbcba8cd Mon Sep 17 00:00:00 2001 From: Ryan Marten Date: Thu, 7 Nov 2024 20:17:12 -0800 Subject: [PATCH 08/18] debug openai online --- src/bespokelabs/__init__.py | 9 -- src/bespokelabs/curator/prompter/prompter.py | 42 ++++----- .../base_request_processor.py | 90 ++++++++++++++++++- .../openai_batch_request_processor.py | 84 +++++------------ .../openai_online_request_processor.py | 48 +++++++--- tests/test_batch.py | 59 +++++++----- 6 files changed, 202 insertions(+), 130 deletions(-) diff --git a/src/bespokelabs/__init__.py b/src/bespokelabs/__init__.py index 7f7ae402..e69de29b 100644 --- a/src/bespokelabs/__init__.py +++ b/src/bespokelabs/__init__.py @@ -1,9 +0,0 @@ -from bespokelabs.curator.request_processor import ( - OpenAIBatchRequestProcessor, - OpenAIOnlineRequestProcessor, -) - -__all__ = [ - "OpenAIBatchRequestProcessor", - "OpenAIOnlineRequestProcessor", -] diff --git a/src/bespokelabs/curator/prompter/prompter.py b/src/bespokelabs/curator/prompter/prompter.py index 4ffba335..2dadd436 100644 --- a/src/bespokelabs/curator/prompter/prompter.py +++ b/src/bespokelabs/curator/prompter/prompter.py @@ -13,12 +13,14 @@ from pydantic import BaseModel from xxhash import xxh64 -from bespokelabs.curator.dataset import Dataset +from datasets import Dataset from bespokelabs.curator.db import MetadataDB from bespokelabs.curator.prompter.prompt_formatter import PromptFormatter from bespokelabs.curator.request_processor.generic_request import GenericRequest from bespokelabs.curator.request_processor.openai_batch_request_processor import ( OpenAIBatchRequestProcessor, +) +from bespokelabs.curator.request_processor.openai_online_request_processor import ( OpenAIOnlineRequestProcessor, ) from bespokelabs.curator.request_processor.base_request_processor import ( @@ -41,6 +43,7 @@ def __init__( ] ] = None, response_format: Optional[Type[BaseModel]] = None, + batch: bool = False, ): """Initialize a Prompter. @@ -70,26 +73,20 @@ def __init__( model_name, prompt_func, parse_func, response_format ) - self.request_processor = request_processor + if batch: + self._request_processor = OpenAIBatchRequestProcessor(model=model_name) + else: + self._request_processor = OpenAIOnlineRequestProcessor(model=model_name) - def __call__( - self, - dataset: Optional[Iterable] = None, - request_processor: Optional[ - BaseRequestProcessor - ] = OpenAIOnlineRequestProcessor, - ): + def __call__(self, dataset: Optional[Iterable] = None) -> Dataset: """Run completions on a dataset.""" - return self._completions(dataset, request_processor) + return self._completions(self._request_processor, dataset) def _completions( self, + request_processor: BaseRequestProcessor, dataset: Optional[Iterable] = None, - name: Optional[str] = None, - request_processor: Optional[ - BaseRequestProcessor - ] = OpenAIOnlineRequestProcessor, - ) -> "Dataset": + ) -> Dataset: """ Apply structured completions in parallel to a dataset using specified model and prompts. @@ -98,7 +95,6 @@ def _completions( dataset (Iterable): A dataset consisting of a list of items to apply completions prompter (Prompter): A Prompter that contains the logic for formatting each item in the dataset - name (str): Name of the task resume (bool): Whether to resume from the previous completions run. If True, we use a fingerprint from the input dataset and the prompter to resume from a previous run that matches the same fingerprint. @@ -106,8 +102,9 @@ def _completions( Returns: Iterable: A list of structured outputs from the completions """ - if dataset is not None: - dataset = Dataset.from_iterable(dataset) + # NOTE(Ryan): We convert from iterable to Dataset because Dataset has random access via row_idx + if not isinstance(dataset, Dataset): + dataset = Dataset.from_generator(dataset) if self is None: raise ValueError("Prompter must be provided") @@ -135,8 +132,6 @@ def _completions( ) fingerprint = xxh64(fingerprint_str.encode("utf-8")).hexdigest() - - name = f"{name.replace(' ', '-')}--{fingerprint}" if name else fingerprint metadata_db_path = os.path.join(curator_cache_dir, "metadata.db") metadata_db = MetadataDB(metadata_db_path) @@ -163,14 +158,11 @@ def _completions( metadata_db.store_metadata(metadata_dict) # TODO(Ryan): do the response processing, while context of original dataset is available and need random access via row_idx) - response_files = request_processor.run( + dataset = request_processor.run( dataset, f"{curator_cache_dir}/{fingerprint}", self.prompt_formatter ) - # NOTE(Ryan): If we decide to allow user to provide any iterable as input dataset (and doens't have random access via row_idx), we can do it differently. This might be a little slower. - # https://huggingface.co/docs/datasets/v3.1.0/about_mapstyle_vs_iterable - # What we can do is sort the generic responses_i.jsonl files by row_idx (in parallel across patches), then just iterate over the dataset and create the new rows in order with the responses. - # To do this, we need to write generic responses instead API_specific responses. Requests can be api-specific, we don't care since we are not monitoring there. + return dataset def _hash_chunk(chunks: list) -> list: diff --git a/src/bespokelabs/curator/request_processor/base_request_processor.py b/src/bespokelabs/curator/request_processor/base_request_processor.py index b9d496b1..234f7a65 100644 --- a/src/bespokelabs/curator/request_processor/base_request_processor.py +++ b/src/bespokelabs/curator/request_processor/base_request_processor.py @@ -6,10 +6,12 @@ from abc import ABC, abstractmethod from typing import Optional -from bespokelabs.curator.dataset import Dataset +from datasets import Dataset from bespokelabs.curator.prompter.prompt_formatter import PromptFormatter from bespokelabs.curator.request_processor.generic_request import GenericRequest from bespokelabs.curator.request_processor.generic_response import GenericResponse +from datasets.arrow_writer import ArrowWriter, SchemaInferenceError +from pydantic import BaseModel class BaseRequestProcessor(ABC): @@ -17,6 +19,9 @@ class BaseRequestProcessor(ABC): Base class for all request processors. """ + def __init__(self, batch_size: int = 1000): + self.batch_size = batch_size + @abstractmethod def get_rate_limits(self) -> dict: """ @@ -167,4 +172,85 @@ def create_request_files( ) if request_count > 0: logging.info(f"Wrote {request_count:,} requests to {requests_file}.") - return requests_file + return requests_files + + def create_dataset_files( + self, + dataset: Dataset, + working_dir: str, + prompt_formatter: PromptFormatter, + ) -> None: + """ + Creates the request files if they don't already exist or use existing. + A single request file (requests_0.jsonl) or multiple request files + (requests_0.jsonl, requests_1.jsonl, etc.) are created depending on + batch_size. + + Args: + dataset (Dataset): The dataset to be processed. + working_dir (str): The directory where request files will be saved. + prompt_formatter (PromptFormatter): The prompt formatter to use for parsing the responses. + + Returns: + Dataset: Completed dataset + """ + total_responses_count = 0 + failed_responses_count = 0 + + responses_files = glob.glob(f"{working_dir}/responses_*.jsonl") + if len(responses_files) == 0: + raise ValueError(f"No responses files found in {working_dir}") + dataset_file = f"{working_dir}/dataset.arrow" + + # Process all response files + with ArrowWriter(path=dataset_file) as writer: + for responses_file in responses_files: + with open(responses_file, "r") as f_in: + for generic_response_string in f_in: + total_responses_count += 1 + response = GenericResponse.model_validate_json( + generic_response_string + ) + if prompt_formatter.response_format: + response.response = prompt_formatter.response_format( + **response.response + ) + + if response is None: + failed_responses_count += 1 + continue + + # Requires dataset to be Dataset object with random access + if response.row is None: + response.row = dataset[response.row_idx] + + # parse_func can return a single row or a list of rows + if prompt_formatter.parse_func: + dataset_rows = prompt_formatter.parse_func( + response.row, response.response + ) + if not isinstance(dataset_rows, list): + dataset_rows = [dataset_rows] + else: + dataset_rows = [response.response] + + for row in dataset_rows: + if isinstance(row, BaseModel): + row = row.model_dump() + writer.write(row) + + logging.info( + f"Read {total_responses_count} responses, {failed_responses_count} failed" + ) + if failed_responses_count == total_responses_count: + raise ValueError("All requests failed") + + logging.info("Finalizing writer") + try: + writer.finalize() + except SchemaInferenceError as e: + raise ValueError( + "Arrow writer is complaining about the schema: likely all of your parsed rows were None and writer.write only wrote None objects." + ) from e + + return Dataset.from_file(dataset_file) diff --git a/src/bespokelabs/curator/request_processor/openai_batch_request_processor.py b/src/bespokelabs/curator/request_processor/openai_batch_request_processor.py index 97bf30db..72035757 100644 --- a/src/bespokelabs/curator/request_processor/openai_batch_request_processor.py +++ b/src/bespokelabs/curator/request_processor/openai_batch_request_processor.py @@ -19,11 +19,18 @@ class OpenAIBatchRequestProcessor(BaseRequestProcessor): - model: str - url: str = "https://api.openai.com/v1/chat/completions" - api_key: str = os.getenv("OPENAI_API_KEY") - batch_size: int = 1000 - check_interval: int = 60 + def __init__( + self, + batch_size: int = 1000, + model: str = "gpt-4o-mini", + check_interval: int = 60, + api_key: str = os.getenv("OPENAI_API_KEY"), + url: str = "https://api.openai.com/v1/chat/completions", + ): + super().__init__(batch_size) + self.url: str = url + self.api_key: str = api_key + self.check_interval: int = check_interval def get_rate_limits(self) -> dict: """ @@ -127,7 +134,7 @@ def get_generic_response(self, response: Dict) -> GenericResponse: request_id = response["id"] status_code = response["response"]["status_code"] - # TODO(Ryan): Add error handling + # TODO(Ryan): Add error handling. This should handle error files from BatchAPI. if status_code != 200: logging.warning( f"Request {request_id} failed with status code {status_code}" @@ -206,7 +213,6 @@ async def submit_all_batches(): asyncio.run(batch_watcher.watch()) - # TODO(Ryan): Add back in here the dataset creation return dataset @@ -279,7 +285,9 @@ async def watch(self) -> None: self.batches = completed_batches.values() - async def download_batch_result_file(self, batch) -> str: + async def download_batch_result_file( + self, batch, OpenAIBatchRequestProcessor + ) -> str: """Download the result of a completed batch to file. Args: @@ -297,64 +305,18 @@ async def download_batch_result_file(self, batch) -> str: return None # NOTE(Ryan): This is so the naming is consistent with the request file naming - request_file_id = ( + request_file_idx = ( self.batch_id_to_request_file_name[batch.id].split("/")[-1].split("_", 1)[1] ) - output_path = f"{self.working_dir}/responses_{request_file_id}" + output_path = f"{self.working_dir}/responses_{request_file_idx}" with open(output_path, "wb") as f: - f.write(file_content.content) + for raw_response in file_content.text.splitlines(): + generic_response = OpenAIBatchRequestProcessor.get_generic_response( + raw_response + ) + f.write(json.dumps(generic_response.model_dump()) + "\n") return output_path - # NOTE(Ryan): This could be useful for very small batches and overall total requests not too large - async def download_batch_result_in_memory(self, batch) -> list[str]: - """Download the result of a completed batch. - - Args: - batch: The batch object to download results from. - - Returns: - list[str]: Lines of the downloaded result. - """ - if batch.status == "completed" and batch.output_file_id: - file_content = await self.client.files.content(batch.output_file_id) - return file_content.text.splitlines() - return [] - - # NOTE(Ryan): This could be useful for very small batches and overall total requests not too large - async def download_results_in_memory(self, output_path: str) -> None: - """Download results of all batches and save to a specified path. - - Args: - output_path (str): Path to save the downloaded results. - """ - tasks = [self.download_batch_result(batch) for batch in self.batches] - results = await asyncio.gather(*tasks) - - all_results = [ - item for sublist in results for item in sublist - ] # Flatten the list of lists - - with open(output_path, "w") as f: - for result in all_results: - f.write(result + "\n") - logging.info(f"All batch results downloaded and saved to: {output_path}") - - async def download_errors(self, error_path: str, batch_id: str) -> None: - """Download error file for a specific batch if available. - - Args: - error_path (str): Path to save the error file. - batch_id (str): The ID of the batch to download errors from. - """ - batch = await self.client.batches.retrieve(batch_id) - if batch.error_file_id: - file_content = await self.client.files.content(batch.error_file_id) - with open(error_path, "wb") as f: - f.write(file_content.content) - logging.info(f"Batch errors downloaded and saved to: {error_path}") - else: - logging.info(f"No error file available for batch {batch_id}.") - async def plot_completion_data(self, output_dir: str) -> None: """Save plots visualizing completion times for the batches. diff --git a/src/bespokelabs/curator/request_processor/openai_online_request_processor.py b/src/bespokelabs/curator/request_processor/openai_online_request_processor.py index 5306b7d2..26f6b1d8 100644 --- a/src/bespokelabs/curator/request_processor/openai_online_request_processor.py +++ b/src/bespokelabs/curator/request_processor/openai_online_request_processor.py @@ -11,6 +11,7 @@ import requests import tiktoken from tqdm import tqdm +from functools import partial from bespokelabs.curator.dataset import Dataset from bespokelabs.curator.request_processor.base_request_processor import ( @@ -18,14 +19,23 @@ GenericRequest, GenericResponse, ) +from bespokelabs.curator.prompter.prompter import PromptFormatter T = TypeVar("T") class OpenAIOnlineRequestProcessor(BaseRequestProcessor): - model: str - url: str = "https://api.openai.com/v1/chat/completions" - api_key: str = os.getenv("OPENAI_API_KEY") + def __init__( + self, + batch_size: int = 1000, + model: str = "gpt-4o-mini", + api_key: str = os.getenv("OPENAI_API_KEY"), + url: str = "https://api.openai.com/v1/chat/completions", + ): + super().__init__(batch_size) + self.model: str = model + self.url: str = url + self.api_key: str = api_key def get_rate_limits(self) -> dict: """ @@ -46,7 +56,7 @@ def get_rate_limits(self) -> dict: response = requests.post( self.url, headers={"Authorization": f"Bearer {self.api_key}"}, - json={"model": self.prompt_formatter.model_name, "messages": []}, + json={"model": self.model, "messages": []}, ) rpm = int(response.headers.get("x-ratelimit-limit-requests", 0)) @@ -108,7 +118,9 @@ def create_api_specific_request(self, generic_request: GenericRequest) -> dict: return request - def get_generic_response(self, response: Dict) -> GenericResponse: + def get_generic_response( + self, response: Dict, prompt_formatter: PromptFormatter + ) -> GenericResponse: """ Parses a API-specific response into a generic response body. Does error handling on the response. @@ -123,7 +135,7 @@ def get_generic_response(self, response: Dict) -> GenericResponse: dict: Generic response body with an extra field "metadata" which contains the original dataset row or the index of the row in the original dataset """ content = response["response"]["choices"][0]["message"]["content"] - if self.prompt_formatter.response_format: + if prompt_formatter.response_format: content = json.loads(content) return GenericResponse( @@ -136,6 +148,7 @@ def run( self, dataset: Optional[Dataset], working_dir: str, + prompt_formatter: PromptFormatter, ) -> Dataset: """ Uses the API to completing the specific map by calling the LLM. @@ -147,19 +160,23 @@ def run( Returns: Dataset: Completed dataset """ - requests_files = self.create_request_files(dataset, working_dir) - responses_files = [] + requests_files = self.create_request_files( + dataset, working_dir, prompt_formatter + ) + responses_files = [ + f"{working_dir}/responses_{i}.jsonl" for i in range(len(requests_files)) + ] rate_limits = self.get_rate_limits() rpm = rate_limits["max_requests_per_minute"] tpm = rate_limits["max_tokens_per_minute"] - token_encoding_name = get_token_encoding_name(self.prompt_formatter.model_name) + token_encoding_name = get_token_encoding_name(prompt_formatter.model_name) # NOTE(Ryan): If you wanted to do this on batches, you could run a for loop here about request_files. Although I don't recommend it because you are waiting for straggler requests to finish for each batch. # NOTE(Ryan): And if you wanted to do batches in parallel, you would have to divide rpm and tpm by the number of parallel batches. # TODO(Ryan): Can we abstract retries from process_api_requests_from_file so you can use it even if you use liteLLM. - for requests_file in requests_files: + for requests_file, responses_file in zip(requests_files, responses_files): asyncio.run( self.process_api_requests_from_file( requests_filepath=requests_file, @@ -170,9 +187,12 @@ def run( token_encoding_name=token_encoding_name, max_attempts=5, resume=True, # detects existing jobs and resume from there + prompt_formatter=prompt_formatter, ) ) - return responses_files + + dataset = self.create_dataset_files(dataset, working_dir, prompt_formatter) + return dataset async def process_api_requests_from_file( self, @@ -184,6 +204,7 @@ async def process_api_requests_from_file( token_encoding_name: str, max_attempts: int, resume: bool, + prompt_formatter: PromptFormatter, resume_no_retry: bool = False, ) -> None: """Processes API requests in parallel, throttling to stay under rate limits.""" @@ -376,7 +397,10 @@ async def process_api_requests_from_file( retry_queue=queue_of_requests_to_retry, save_filepath=save_filepath, status_tracker=status_tracker, - get_generic_response=self.get_generic_response, + get_generic_response=partial( + self.get_generic_response, + prompt_formatter=prompt_formatter, + ), ) ) next_request = None # reset next_request to empty diff --git a/tests/test_batch.py b/tests/test_batch.py index 37a325ae..c834d7a8 100644 --- a/tests/test_batch.py +++ b/tests/test_batch.py @@ -1,5 +1,7 @@ from bespokelabs import curator -from bespokelabs.curator import OpenAIBatchRequestProcessor +from bespokelabs.curator.request_processor.openai_batch_request_processor import ( + OpenAIBatchRequestProcessor, +) from datasets import load_dataset, Dataset import argparse @@ -18,7 +20,7 @@ def it_from_sharegpt(sample): raise ValueError("Invalid conversation format") return {"instruction": instruction, "original_response": response} - dataset = dataset.map(it_from_sharegpt) + dataset = dataset.map(it_from_sharegpt, num_proc=8) dataset = dataset.remove_columns(["conversations"]) dataset = dataset.select_columns(["instruction", "original_response"]) return dataset @@ -33,12 +35,6 @@ def load_ShareGPT_dataset_as_IT(dataset_name: str, truncate: int = None) -> Data if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument( - "--working_dir", - type=str, - required=True, - help="Where requests, responses, and dataset will be locally written to save intermediate results", - ) parser.add_argument( "--num_samples", type=int, @@ -52,10 +48,10 @@ def load_ShareGPT_dataset_as_IT(dataset_name: str, truncate: int = None) -> Data help="The number of samples to use per batch", ) parser.add_argument( - "--type", - type=str, - default="online", - help="The type of API to use", + "--batch", + type=bool, + default=False, + help="Whether to use batch processing", ) parser.add_argument( "--check_interval", @@ -71,28 +67,49 @@ def load_ShareGPT_dataset_as_IT(dataset_name: str, truncate: int = None) -> Data dataset = dataset.select(range(args.num_samples)) + # python tests/test_batch.py --num_samples 10 + # if args.type == "online": # api = OpenAIOnlineAPI(model="gpt-4o-mini") # elif args.type == "batch": # api = OpenAIBatchAPI(model="gpt-4o-mini", check_interval=args.check_interval) + + # TODO(Ryan) messages as prompt_func output or if string default to user_prompt instruction + # def prompt_func(row): + # messages = [ + # {"role": "user", "content": row["instruction"]} + # ] + # return messages + + # def parse_func(row, response): + # row["model_response"] = response + # return row + + # reannotate_prompter = curator.Prompter( + # prompt_func=prompt_func, + # parse_func=parse_func, + # model_name="gpt-4o-mini", + # ) + reannotate_prompter = curator.Prompter( prompt_func=lambda row: {"user_prompt": row["instruction"]}, parse_func=lambda row, response: {**row, "model_response": response}, model_name="gpt-4o-mini", ) - request_processor = OpenAIBatchRequestProcessor( - model="gpt-4o-mini", - batch_size=args.batch_size, - check_interval=args.check_interval, - ) + # To set internal variables + # request_processor = OpenAIBatchRequestProcessor( + # model="gpt-4o-mini", + # batch_size=args.batch_size, + # check_interval=args.check_interval, + # ) - reannotated_dataset = reannotate_prompter(dataset, request_processor) + # reannotate_prompter._processor = request_processor - dataset = reannotated_dataset.to_huggingface() + reannotated_dataset = reannotate_prompter(dataset) # Upload dataset to Hugging Face - print(dataset) + print(reannotated_dataset) dataset_name = "mlfoundations-dev/rewrite-test-gpt-4o-mini" - dataset.push_to_hub(dataset_name) + reannotated_dataset.push_to_hub(dataset_name) print(f"https://huggingface.co/datasets/{dataset_name}") From d252717ccac50bd6b1b1f7e51ab93d55ded8f016 Mon Sep 17 00:00:00 2001 From: Ryan Marten Date: Thu, 7 Nov 2024 21:00:25 -0800 Subject: [PATCH 09/18] batch and online both working with small time test --- poetry.lock | 13 ++- pyproject.toml | 1 + .../base_request_processor.py | 83 ++++++++++++++----- .../openai_batch_request_processor.py | 61 +++++++++++--- tests/test_batch.py | 9 +- 5 files changed, 131 insertions(+), 36 deletions(-) diff --git a/poetry.lock b/poetry.lock index 0f3fdb73..6ebe0eff 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,5 +1,16 @@ # This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand. +[[package]] +name = "aiofiles" +version = "24.1.0" +description = "File support for asyncio." +optional = false +python-versions = ">=3.8" +files = [ + {file = "aiofiles-24.1.0-py3-none-any.whl", hash = "sha256:b4ec55f4195e3eb5d7abd1bf7e061763e864dd4954231fb8539a0ef8bb8260e5"}, + {file = "aiofiles-24.1.0.tar.gz", hash = "sha256:22a075c9e5a3810f0c2e48f3008c94d68c65d763b9b03857924c99e57355166c"}, +] + [[package]] name = "aiohappyeyeballs" version = "2.4.3" @@ -3802,4 +3813,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "f7cb5ef17d0f570775881d92e959a2643fdf594183aa6d3eb63cb34307ab06dc" +content-hash = "664f2dbe4be78f76d7f210e72c2b6b5dd307b7b9d7e2494032be992798d5dae7" diff --git a/pyproject.toml b/pyproject.toml index 8d834bc2..96382522 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,6 +23,7 @@ pandas = "^2.2.3" xxhash = "^3.5.0" tqdm = "^4.67.0" matplotlib = "^3.9.2" +aiofiles = "^24.1.0" [tool.poetry.group.dev.dependencies] black = "^24.2.0" diff --git a/src/bespokelabs/curator/request_processor/base_request_processor.py b/src/bespokelabs/curator/request_processor/base_request_processor.py index 234f7a65..76d60e6e 100644 --- a/src/bespokelabs/curator/request_processor/base_request_processor.py +++ b/src/bespokelabs/curator/request_processor/base_request_processor.py @@ -12,6 +12,9 @@ from bespokelabs.curator.request_processor.generic_response import GenericResponse from datasets.arrow_writer import ArrowWriter, SchemaInferenceError from pydantic import BaseModel +from math import ceil +import asyncio +import aiofiles class BaseRequestProcessor(ABC): @@ -87,7 +90,7 @@ def create_request_files( working_dir: str, prompt_formatter: PromptFormatter, batch_size: Optional[int] = None, - ) -> str: + ) -> list[str]: """ Creates a request file if they don't already exist or use existing. @@ -96,7 +99,7 @@ def create_request_files( working_dir (str): The directory where request files will be saved. Returns: - str: Path to the request file that was created. + list[str]: Paths to the request files that were created. """ os.makedirs(working_dir, exist_ok=True) requests_files = glob.glob(f"{working_dir}/requests_*.jsonl") @@ -143,6 +146,8 @@ def create_request_files( request_count = 0 request_file_idx = 0 requests_file = f"{working_dir}/requests_{request_file_idx}.jsonl" + requests_files = [] + # Create new requests file with open(requests_file, "w") as f: if dataset is None: @@ -150,30 +155,66 @@ def create_request_files( api_request = self.create_api_specific_request(request) f.write(json.dumps(api_request) + "\n") else: - for dataset_row_idx, dataset_row in enumerate(dataset): - request = prompt_formatter.get_generic_request( - dataset_row, dataset_row_idx - ) - # Convert the generic request to an API-specific request - api_request = self.create_api_specific_request(request) - # Write the API-specific request to file - f.write(json.dumps(api_request) + "\n") - request_count += 1 - - # Batches could be created in parallel, but dataset is iterated sequentially - if batch_size is not None and request_count == batch_size: - request_count = 0 - request_file_idx += 1 - requests_file = ( - f"{working_dir}/requests_{request_file_idx}.jsonl" - ) - logging.info( - f"Wrote {request_count:,} requests to {requests_file}." + if batch_size: + num_batches = ceil(len(dataset) / batch_size) + requests_files = [ + f"{working_dir}/requests_{i}.jsonl" for i in range(num_batches) + ] + + async def create_all_request_files(): + tasks = [ + self.acreate_request_file( + dataset, + prompt_formatter, + requests_files[i], + i * batch_size, + batch_size, + ) + for i in range(num_batches) + ] + await asyncio.gather(*tasks) + + asyncio.run(create_all_request_files()) + else: + requests_files = [f"{working_dir}/requests_0.jsonl"] + asyncio.run( + self.acreate_request_file( + dataset, prompt_formatter, requests_files[0] ) + ) + if request_count > 0: logging.info(f"Wrote {request_count:,} requests to {requests_file}.") + return requests_files + # NOTE(Ryan): Instead of doing this, just iterate over iterable and keep counter and change filename when hit batch_size, this will be slower but this whole thing is dominated by llm calls anyways + async def acreate_request_file( + self, + dataset: Dataset, + prompt_formatter: PromptFormatter, + request_file: str, + start_idx: int = 0, + batch_size: int = None, + ) -> str: + if batch_size is not None: + end_idx = min(start_idx + batch_size, len(dataset)) + dataset = dataset.select(range(start_idx, end_idx)) + + # NOTE(Ryan): For loops only for IterableDataset which allows for _very_ large datasets, when start_idx and batch_size are not specified + async with aiofiles.open(request_file, "w") as f: + for idx, dataset_row in enumerate(dataset): + dataset_row_idx = idx + start_idx + # Get the generic request from the map function + request = prompt_formatter.get_generic_request( + dataset_row, dataset_row_idx + ) + # Convert the generic request to an API-specific request + api_request = self.create_api_specific_request(request) + # Write the API-specific request to file + await f.write(json.dumps(api_request) + "\n") + logging.info(f"Requests file {request_file} written to disk.") + def create_dataset_files( self, dataset: Dataset, diff --git a/src/bespokelabs/curator/request_processor/openai_batch_request_processor.py b/src/bespokelabs/curator/request_processor/openai_batch_request_processor.py index 72035757..c2bdfef5 100644 --- a/src/bespokelabs/curator/request_processor/openai_batch_request_processor.py +++ b/src/bespokelabs/curator/request_processor/openai_batch_request_processor.py @@ -7,6 +7,7 @@ from openai import AsyncOpenAI import pandas as pd import matplotlib.pyplot as plt +import aiofiles from bespokelabs.curator.dataset import Dataset from bespokelabs.curator.request_processor.base_request_processor import ( @@ -14,6 +15,7 @@ GenericRequest, GenericResponse, ) +from bespokelabs.curator.prompter.prompt_formatter import PromptFormatter T = TypeVar("T") @@ -116,7 +118,9 @@ def create_api_specific_request(self, generic_request: GenericRequest) -> dict: return request - def get_generic_response(self, response: Dict) -> GenericResponse: + def get_generic_response( + self, response: Dict, prompt_formatter: PromptFormatter + ) -> GenericResponse: """ Parses a API-specific response into a generic response body. Does error handling on the response. @@ -130,7 +134,6 @@ def get_generic_response(self, response: Dict) -> GenericResponse: Returns: dict: Generic response body with an extra field "metadata" which contains the original dataset row or the index of the row in the original dataset """ - request_id = response["id"] status_code = response["response"]["status_code"] @@ -145,7 +148,7 @@ def get_generic_response(self, response: Dict) -> GenericResponse: # TODO(Ryan): if you add token tokens to generic response, we can parse that here too, similar to my comment above we can do that in the shared place. content = response["response"]["body"]["choices"][0]["message"]["content"] - if self.prompt_formatter.response_format: + if prompt_formatter.response_format: content = json.loads(content) return GenericResponse( @@ -154,10 +157,32 @@ def get_generic_response(self, response: Dict) -> GenericResponse: raw_response=response, ) + async def asubmit_batch(self, batch_file: str) -> dict: + async with aiofiles.open(batch_file, "rb") as file: + file_content = await file.read() + batch_file_upload = await self.async_client.files.create( + file=file_content, purpose="batch" + ) + + logging.info(f"File uploaded: {batch_file_upload}") + + batch_object = await self.async_client.batches.create( + input_file_id=batch_file_upload.id, + endpoint="/v1/chat/completions", + completion_window="24h", + metadata={ + "request_file_name": batch_file + }, # for easily mapping back later, NOTE(Ryan): can convert to the int or UUID later + ) + logging.info(f"Batch request submitted, received batch object: {batch_object}") + + return batch_object + def run( self, dataset: Optional[Dataset], working_dir: str, + prompt_formatter: PromptFormatter, ) -> Dataset: """ Uses the API to completing the specific map by calling the LLM. @@ -170,7 +195,7 @@ def run( Dataset: Completed dataset """ requests_files = self.create_request_files( - dataset, map, working_dir, self.batch_size + dataset, working_dir, prompt_formatter, self.batch_size ) batch_objects_file = f"{working_dir}/batch_objects.jsonl" @@ -211,12 +236,12 @@ async def submit_all_batches(): # TODO(Ryan): likely can add some logic for smarter check_interval based on batch size and if the batch has started or not, fine to do a dumb ping for now batch_watcher = BatchWatcher(working_dir, check_interval=self.check_interval) - asyncio.run(batch_watcher.watch()) + asyncio.run(batch_watcher.watch(prompt_formatter, self.get_generic_response)) + dataset = self.create_dataset_files(dataset, working_dir, prompt_formatter) return dataset -@Dataset class BatchWatcher: def __init__(self, working_dir: str, check_interval: int = 60) -> None: """Initialize BatchWatcher with batch objects file and check interval. @@ -234,6 +259,8 @@ def __init__(self, working_dir: str, check_interval: int = 60) -> None: for obj in self.batch_objects } self.batches = [] + self.check_interval = check_interval + self.working_dir = working_dir async def check_batch_status(self, batch_id: str) -> tuple[str, str]: """Check the status of a batch by its ID. @@ -250,7 +277,11 @@ async def check_batch_status(self, batch_id: str) -> tuple[str, str]: ) return batch_id, batch - async def watch(self) -> None: + async def watch( + self, + prompt_formatter: PromptFormatter, + get_generic_response: Callable[[Dict], GenericResponse], + ) -> None: """Monitor the status of batches until all are completed (includes successfully, failed, expired or cancelled).""" completed_batches = {} while len(completed_batches) < len(self.batch_ids): @@ -271,7 +302,9 @@ async def watch(self) -> None: # NOTE(Ryan): Now downloading after each check, instead of waiting until all are completed tasks = [ - self.download_batch_result_file(batch) + self.download_batch_result_file( + batch, prompt_formatter, get_generic_response + ) for batch in newly_completed_batches ] await asyncio.gather(*tasks) @@ -286,7 +319,10 @@ async def watch(self) -> None: self.batches = completed_batches.values() async def download_batch_result_file( - self, batch, OpenAIBatchRequestProcessor + self, + batch, + prompt_formatter: PromptFormatter, + get_generic_response: Callable[[Dict], GenericResponse], ) -> str: """Download the result of a completed batch to file. @@ -309,10 +345,11 @@ async def download_batch_result_file( self.batch_id_to_request_file_name[batch.id].split("/")[-1].split("_", 1)[1] ) output_path = f"{self.working_dir}/responses_{request_file_idx}" - with open(output_path, "wb") as f: + with open(output_path, "w") as f: for raw_response in file_content.text.splitlines(): - generic_response = OpenAIBatchRequestProcessor.get_generic_response( - raw_response + # TODO(Ryan): We should abstract this out + generic_response = get_generic_response( + json.loads(raw_response), prompt_formatter ) f.write(json.dumps(generic_response.model_dump()) + "\n") return output_path diff --git a/tests/test_batch.py b/tests/test_batch.py index c834d7a8..84034538 100644 --- a/tests/test_batch.py +++ b/tests/test_batch.py @@ -4,6 +4,7 @@ ) from datasets import load_dataset, Dataset import argparse +import logging def convert_ShareGPT_to_IT_format(dataset: Dataset) -> Dataset: @@ -49,8 +50,7 @@ def load_ShareGPT_dataset_as_IT(dataset_name: str, truncate: int = None) -> Data ) parser.add_argument( "--batch", - type=bool, - default=False, + action="store_true", help="Whether to use batch processing", ) parser.add_argument( @@ -91,10 +91,15 @@ def load_ShareGPT_dataset_as_IT(dataset_name: str, truncate: int = None) -> Data # model_name="gpt-4o-mini", # ) + logging.basicConfig(level=logging.INFO) + logger = logging.getLogger(__name__) + logger.setLevel(logging.INFO) + reannotate_prompter = curator.Prompter( prompt_func=lambda row: {"user_prompt": row["instruction"]}, parse_func=lambda row, response: {**row, "model_response": response}, model_name="gpt-4o-mini", + batch=args.batch, ) # To set internal variables From d1870cae74be2f03407aa6250f90ac7d59504737 Mon Sep 17 00:00:00 2001 From: Ryan Marten Date: Thu, 7 Nov 2024 21:14:36 -0800 Subject: [PATCH 10/18] debugged and testing batch at larger size --- .../base_request_processor.py | 23 ++++++++----------- .../openai_batch_request_processor.py | 2 +- tests/test_batch.py | 13 +++++------ 3 files changed, 17 insertions(+), 21 deletions(-) diff --git a/src/bespokelabs/curator/request_processor/base_request_processor.py b/src/bespokelabs/curator/request_processor/base_request_processor.py index 76d60e6e..ad29e033 100644 --- a/src/bespokelabs/curator/request_processor/base_request_processor.py +++ b/src/bespokelabs/curator/request_processor/base_request_processor.py @@ -89,7 +89,6 @@ def create_request_files( dataset: Optional[Dataset], working_dir: str, prompt_formatter: PromptFormatter, - batch_size: Optional[int] = None, ) -> list[str]: """ Creates a request file if they don't already exist or use existing. @@ -129,13 +128,13 @@ def create_request_files( ) # Some simple sanity checks for the user - if batch_size is not None: - if batch_size != num_jobs: + if self.batch_size is not None: + if self.batch_size != num_jobs: logging.warning( - f"Batch size is {batch_size}, but there are {num_jobs} requests in {requests_files[0]}. " + f"Batch size is {self.batch_size}, but there are {num_jobs} requests in {requests_files[0]}. " f"If you want to run with new batch size, you will have to delete the working directory and re-run (looses progress)" ) - if len(requests_files) == 1 and len(dataset) > batch_size: + if len(requests_files) == 1 and len(dataset) > self.batch_size: logging.warning( f"Only one request file was found, but batch size is specified and dataset is larger than batch size." f"You might be resuming from a different dataset or weren't using batching before." @@ -155,8 +154,8 @@ def create_request_files( api_request = self.create_api_specific_request(request) f.write(json.dumps(api_request) + "\n") else: - if batch_size: - num_batches = ceil(len(dataset) / batch_size) + if self.batch_size: + num_batches = ceil(len(dataset) / self.batch_size) requests_files = [ f"{working_dir}/requests_{i}.jsonl" for i in range(num_batches) ] @@ -167,8 +166,7 @@ async def create_all_request_files(): dataset, prompt_formatter, requests_files[i], - i * batch_size, - batch_size, + start_idx=i * self.batch_size, ) for i in range(num_batches) ] @@ -195,10 +193,9 @@ async def acreate_request_file( prompt_formatter: PromptFormatter, request_file: str, start_idx: int = 0, - batch_size: int = None, ) -> str: - if batch_size is not None: - end_idx = min(start_idx + batch_size, len(dataset)) + if self.batch_size is not None: + end_idx = min(start_idx + self.batch_size, len(dataset)) dataset = dataset.select(range(start_idx, end_idx)) # NOTE(Ryan): For loops only for IterableDataset which allows for _very_ large datasets, when start_idx and batch_size are not specified @@ -213,7 +210,7 @@ async def acreate_request_file( api_request = self.create_api_specific_request(request) # Write the API-specific request to file await f.write(json.dumps(api_request) + "\n") - logging.info(f"Requests file {request_file} written to disk.") + logging.info(f"Wrote {end_idx - start_idx} requests to {request_file}.") def create_dataset_files( self, diff --git a/src/bespokelabs/curator/request_processor/openai_batch_request_processor.py b/src/bespokelabs/curator/request_processor/openai_batch_request_processor.py index c2bdfef5..6f36e70d 100644 --- a/src/bespokelabs/curator/request_processor/openai_batch_request_processor.py +++ b/src/bespokelabs/curator/request_processor/openai_batch_request_processor.py @@ -195,7 +195,7 @@ def run( Dataset: Completed dataset """ requests_files = self.create_request_files( - dataset, working_dir, prompt_formatter, self.batch_size + dataset, working_dir, prompt_formatter ) batch_objects_file = f"{working_dir}/batch_objects.jsonl" diff --git a/tests/test_batch.py b/tests/test_batch.py index 84034538..09fa590b 100644 --- a/tests/test_batch.py +++ b/tests/test_batch.py @@ -103,14 +103,13 @@ def load_ShareGPT_dataset_as_IT(dataset_name: str, truncate: int = None) -> Data ) # To set internal variables - # request_processor = OpenAIBatchRequestProcessor( - # model="gpt-4o-mini", - # batch_size=args.batch_size, - # check_interval=args.check_interval, - # ) - - # reannotate_prompter._processor = request_processor + request_processor = OpenAIBatchRequestProcessor( + model="gpt-4o-mini", + batch_size=args.batch_size, + check_interval=args.check_interval, + ) + reannotate_prompter._request_processor = request_processor reannotated_dataset = reannotate_prompter(dataset) # Upload dataset to Hugging Face From 21a77bd15b88a51113c3dd76c6e3eb0af912622f Mon Sep 17 00:00:00 2001 From: Ryan Marten Date: Thu, 7 Nov 2024 21:36:20 -0800 Subject: [PATCH 11/18] added new advanced example for distilling --- examples/distill.py | 26 ++++ .../curator/prompter/prompt_formatter.py | 12 +- src/bespokelabs/curator/prompter/prompter.py | 50 +------- tests/test_batch.py | 119 ------------------ 4 files changed, 34 insertions(+), 173 deletions(-) create mode 100644 examples/distill.py delete mode 100644 tests/test_batch.py diff --git a/examples/distill.py b/examples/distill.py new file mode 100644 index 00000000..96769feb --- /dev/null +++ b/examples/distill.py @@ -0,0 +1,26 @@ +from bespokelabs import curator +from datasets import load_dataset +import logging + +logging.basicConfig(level=logging.INFO) + +dataset = load_dataset("allenai/WildChat", split="train") +dataset = dataset.select(range(3_000)) + + +def prompt_func(row): + return row["conversation"][0]["content"] + + +def parse_func(row, response): + instruction = row["conversation"][0]["content"] + return {"instruction": instruction, "new_response": response} + + +distill_prompter = curator.Prompter( + prompt_func=prompt_func, parse_func=parse_func, model_name="gpt-4o-mini", batch=True +) + +distilled_dataset = distill_prompter(dataset) +print(distilled_dataset) +print(distilled_dataset[0]) diff --git a/src/bespokelabs/curator/prompter/prompt_formatter.py b/src/bespokelabs/curator/prompter/prompt_formatter.py index 1f46fca2..a978311e 100644 --- a/src/bespokelabs/curator/prompter/prompt_formatter.py +++ b/src/bespokelabs/curator/prompter/prompt_formatter.py @@ -48,13 +48,11 @@ def get_generic_request( f"Prompting function {self.prompt_func} must have 0 or 1 arguments." ) - messages = [] - system_prompt = prompts.get("system_prompt", "You are a helpful AI assistant.") - messages.append({"role": "system", "content": system_prompt}) - - if "user_prompt" not in prompts: - raise ValueError("user_prompt is required") - messages.append({"role": "user", "content": prompts["user_prompt"]}) + if isinstance(prompts, str): + messages = [{"role": "user", "content": prompts}] + else: + # TODO(Ryan): Add validation here + messages = prompts # Convert BaseModel to dict for serialization if isinstance(row, BaseModel): diff --git a/src/bespokelabs/curator/prompter/prompter.py b/src/bespokelabs/curator/prompter/prompter.py index 2dadd436..9e6ee6c9 100644 --- a/src/bespokelabs/curator/prompter/prompter.py +++ b/src/bespokelabs/curator/prompter/prompter.py @@ -49,8 +49,8 @@ def __init__( Args: model_name (str): The name of the LLM to use - prompt_func (Callable[[Dict[str, Any]], Dict[str, str]]): A function that takes a single row - and returns a dict with "system_prompt" and "user_prompt" + prompt_func (Callable[[Dict[str, Any]], Union[str, List[Dict[str, Any]]]]): A function that takes a single row + and returns either a string (assumed to be a user prompt) or messages list parse_func (Callable[[Dict[str, Any], Any], T]): A function that takes the input row and response object and returns the parsed output response_format (Optional[Type[BaseModel]]): A Pydantic model specifying the @@ -113,7 +113,7 @@ def _completions( "CURATOR_CACHE_DIR", os.path.expanduser("~/.cache/curator") ) - dataset_hash = _hash_dataset(dataset) + dataset_hash = dataset._fingerprint prompt_func_hash = _get_function_hash(self.prompt_formatter.prompt_func) parse_func_hash = _get_function_hash(self.prompt_formatter.parse_func) @@ -165,50 +165,6 @@ def _completions( return dataset -def _hash_chunk(chunks: list) -> list: - """Hash a chunk of data.""" - - def _json_dumps_row(row): - if isinstance(row, BaseModel): - row = row.model_dump() - return json.dumps(row, sort_keys=True) - - chunks = [_json_dumps_row(row) for row in chunks] - chunk_str = "|||".join(chunks) - return xxh64(chunk_str).hexdigest() - - -def _hash_dataset(dataset: Optional[Iterable]): - """Hash a dataset to a consistent value using parallel processing.""" - start = time.perf_counter_ns() - if dataset is None: - return xxh64("").hexdigest() - - # Convert to list and determine chunking parameters - dataset_list = list(dataset) - if len(dataset_list) == 0: - return xxh64("").hexdigest() - - num_cores = 4 - total_size = len(dataset_list) - chunk_size = math.ceil(total_size / (num_cores * 4)) # 4 chunks per core - - chunks = [ - dataset_list[i : i + chunk_size] for i in range(0, total_size, chunk_size) - ] - - # Process chunks in parallel - with ProcessPoolExecutor(max_workers=num_cores) as executor: - chunk_hash = list(executor.map(_hash_chunk, chunks)) - chunk_hash_str = "|||".join(chunk_hash) - hash_value = xxh64(chunk_hash_str).hexdigest() - - logging.debug( - f"Dataset hash time: {(time.perf_counter_ns() - start) / 1e6:.6f} milliseconds" - ) - return hash_value - - def _get_function_hash(func) -> str: """Get a hash of a function's source code.""" if func is None: diff --git a/tests/test_batch.py b/tests/test_batch.py deleted file mode 100644 index 09fa590b..00000000 --- a/tests/test_batch.py +++ /dev/null @@ -1,119 +0,0 @@ -from bespokelabs import curator -from bespokelabs.curator.request_processor.openai_batch_request_processor import ( - OpenAIBatchRequestProcessor, -) -from datasets import load_dataset, Dataset -import argparse -import logging - - -def convert_ShareGPT_to_IT_format(dataset: Dataset) -> Dataset: - def it_from_sharegpt(sample): - if sample["conversations"][0]["from"] == "human": - instruction = sample["conversations"][0]["value"] - assert sample["conversations"][1]["from"] == "gpt" - response = sample["conversations"][1]["value"] - elif sample["conversations"][1]["from"] == "human": - instruction = sample["conversations"][1]["value"] - assert sample["conversations"][2]["from"] == "gpt" - response = sample["conversations"][2]["value"] - else: - raise ValueError("Invalid conversation format") - return {"instruction": instruction, "original_response": response} - - dataset = dataset.map(it_from_sharegpt, num_proc=8) - dataset = dataset.remove_columns(["conversations"]) - dataset = dataset.select_columns(["instruction", "original_response"]) - return dataset - - -def load_ShareGPT_dataset_as_IT(dataset_name: str, truncate: int = None) -> Dataset: - dataset = load_dataset(dataset_name, split="train") - if truncate is not None: - dataset = dataset.select(range(truncate)) - return convert_ShareGPT_to_IT_format(dataset) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument( - "--num_samples", - type=int, - default=10, - help="The number of samples to use from the dataset", - ) - parser.add_argument( - "--batch_size", - type=int, - default=None, - help="The number of samples to use per batch", - ) - parser.add_argument( - "--batch", - action="store_true", - help="Whether to use batch processing", - ) - parser.add_argument( - "--check_interval", - type=int, - default=10, - help="The interval (in seconds) to check the status of the batch", - ) - args = parser.parse_args() - - # Load the dataset to instruction, response columns - dataset = load_ShareGPT_dataset_as_IT("teknium/OpenHermes-2.5") - print(dataset) - - dataset = dataset.select(range(args.num_samples)) - - # python tests/test_batch.py --num_samples 10 - - # if args.type == "online": - # api = OpenAIOnlineAPI(model="gpt-4o-mini") - # elif args.type == "batch": - # api = OpenAIBatchAPI(model="gpt-4o-mini", check_interval=args.check_interval) - - # TODO(Ryan) messages as prompt_func output or if string default to user_prompt instruction - # def prompt_func(row): - # messages = [ - # {"role": "user", "content": row["instruction"]} - # ] - # return messages - - # def parse_func(row, response): - # row["model_response"] = response - # return row - - # reannotate_prompter = curator.Prompter( - # prompt_func=prompt_func, - # parse_func=parse_func, - # model_name="gpt-4o-mini", - # ) - - logging.basicConfig(level=logging.INFO) - logger = logging.getLogger(__name__) - logger.setLevel(logging.INFO) - - reannotate_prompter = curator.Prompter( - prompt_func=lambda row: {"user_prompt": row["instruction"]}, - parse_func=lambda row, response: {**row, "model_response": response}, - model_name="gpt-4o-mini", - batch=args.batch, - ) - - # To set internal variables - request_processor = OpenAIBatchRequestProcessor( - model="gpt-4o-mini", - batch_size=args.batch_size, - check_interval=args.check_interval, - ) - - reannotate_prompter._request_processor = request_processor - reannotated_dataset = reannotate_prompter(dataset) - - # Upload dataset to Hugging Face - print(reannotated_dataset) - dataset_name = "mlfoundations-dev/rewrite-test-gpt-4o-mini" - reannotated_dataset.push_to_hub(dataset_name) - print(f"https://huggingface.co/datasets/{dataset_name}") From 5ae98fc745f44e5f55088cc0bfebd5b1f27b0483 Mon Sep 17 00:00:00 2001 From: Ryan Marten Date: Thu, 7 Nov 2024 22:17:31 -0800 Subject: [PATCH 12/18] fix json serialization issue for datetime --- .../base_request_processor.py | 4 +++- .../openai_batch_request_processor.py | 23 +++++++++++++------ .../openai_online_request_processor.py | 6 +++-- 3 files changed, 23 insertions(+), 10 deletions(-) diff --git a/src/bespokelabs/curator/request_processor/base_request_processor.py b/src/bespokelabs/curator/request_processor/base_request_processor.py index ad29e033..8ca42383 100644 --- a/src/bespokelabs/curator/request_processor/base_request_processor.py +++ b/src/bespokelabs/curator/request_processor/base_request_processor.py @@ -52,7 +52,9 @@ def create_api_specific_request(self, generic_request: GenericRequest) -> dict: pass @abstractmethod - def get_generic_response(self, response: dict) -> GenericResponse: + def get_generic_response( + self, response: dict, prompt_formatter: PromptFormatter, dataset: Dataset + ) -> GenericResponse: """ Parses a API-specific response into a generic response body. Does error handling on the response. diff --git a/src/bespokelabs/curator/request_processor/openai_batch_request_processor.py b/src/bespokelabs/curator/request_processor/openai_batch_request_processor.py index 6f36e70d..48569fe9 100644 --- a/src/bespokelabs/curator/request_processor/openai_batch_request_processor.py +++ b/src/bespokelabs/curator/request_processor/openai_batch_request_processor.py @@ -8,7 +8,7 @@ import pandas as pd import matplotlib.pyplot as plt import aiofiles - +import io from bespokelabs.curator.dataset import Dataset from bespokelabs.curator.request_processor.base_request_processor import ( BaseRequestProcessor, @@ -16,6 +16,7 @@ GenericResponse, ) from bespokelabs.curator.prompter.prompt_formatter import PromptFormatter +from io import BytesIO T = TypeVar("T") @@ -119,7 +120,7 @@ def create_api_specific_request(self, generic_request: GenericRequest) -> dict: return request def get_generic_response( - self, response: Dict, prompt_formatter: PromptFormatter + self, response: Dict, prompt_formatter: PromptFormatter, dataset: Dataset ) -> GenericResponse: """ Parses a API-specific response into a generic response body. @@ -147,13 +148,17 @@ def get_generic_response( # NOTE(Ryan): can we actually parse the response into a an OpenAI ChatCompletions object? Easier to access fields? # TODO(Ryan): if you add token tokens to generic response, we can parse that here too, similar to my comment above we can do that in the shared place. content = response["response"]["body"]["choices"][0]["message"]["content"] + row_idx = int(response["custom_id"]) if prompt_formatter.response_format: content = json.loads(content) + # NOTE(Ryan): So dicts that have objects that are not JSON serializable will be converted to strings. + return GenericResponse( response=content, - row_idx=int(response["custom_id"]), + row_idx=row_idx, + row=dataset[row_idx], raw_response=response, ) @@ -236,7 +241,9 @@ async def submit_all_batches(): # TODO(Ryan): likely can add some logic for smarter check_interval based on batch size and if the batch has started or not, fine to do a dumb ping for now batch_watcher = BatchWatcher(working_dir, check_interval=self.check_interval) - asyncio.run(batch_watcher.watch(prompt_formatter, self.get_generic_response)) + asyncio.run( + batch_watcher.watch(prompt_formatter, self.get_generic_response, dataset) + ) dataset = self.create_dataset_files(dataset, working_dir, prompt_formatter) return dataset @@ -281,6 +288,7 @@ async def watch( self, prompt_formatter: PromptFormatter, get_generic_response: Callable[[Dict], GenericResponse], + dataset: Dataset, ) -> None: """Monitor the status of batches until all are completed (includes successfully, failed, expired or cancelled).""" completed_batches = {} @@ -303,7 +311,7 @@ async def watch( # NOTE(Ryan): Now downloading after each check, instead of waiting until all are completed tasks = [ self.download_batch_result_file( - batch, prompt_formatter, get_generic_response + batch, prompt_formatter, get_generic_response, dataset ) for batch in newly_completed_batches ] @@ -323,6 +331,7 @@ async def download_batch_result_file( batch, prompt_formatter: PromptFormatter, get_generic_response: Callable[[Dict], GenericResponse], + dataset: Dataset, ) -> str: """Download the result of a completed batch to file. @@ -349,9 +358,9 @@ async def download_batch_result_file( for raw_response in file_content.text.splitlines(): # TODO(Ryan): We should abstract this out generic_response = get_generic_response( - json.loads(raw_response), prompt_formatter + json.loads(raw_response), prompt_formatter, dataset ) - f.write(json.dumps(generic_response.model_dump()) + "\n") + f.write(json.dumps(generic_response.model_dump(), default=str) + "\n") return output_path async def plot_completion_data(self, output_dir: str) -> None: diff --git a/src/bespokelabs/curator/request_processor/openai_online_request_processor.py b/src/bespokelabs/curator/request_processor/openai_online_request_processor.py index 26f6b1d8..8dee4375 100644 --- a/src/bespokelabs/curator/request_processor/openai_online_request_processor.py +++ b/src/bespokelabs/curator/request_processor/openai_online_request_processor.py @@ -119,7 +119,7 @@ def create_api_specific_request(self, generic_request: GenericRequest) -> dict: return request def get_generic_response( - self, response: Dict, prompt_formatter: PromptFormatter + self, response: Dict, prompt_formatter: PromptFormatter, dataset: Dataset ) -> GenericResponse: """ Parses a API-specific response into a generic response body. @@ -140,7 +140,9 @@ def get_generic_response( return GenericResponse( response=content, - row=response["metadata"]["sample"], + row=response["metadata"][ + "sample" + ], # Or can do dataset[response["metadata"]["request_idx"]] row_idx=response["metadata"]["request_idx"], ) From d91163d7a55d7fef372d02e8fa50a09e8cd055c7 Mon Sep 17 00:00:00 2001 From: Ryan Marten Date: Thu, 7 Nov 2024 22:36:10 -0800 Subject: [PATCH 13/18] make poetry example work --- examples/poetry.py | 6 ++---- src/bespokelabs/curator/prompter/prompter.py | 6 ++++-- .../curator/request_processor/base_request_processor.py | 4 ++-- .../request_processor/openai_online_request_processor.py | 4 ++-- 4 files changed, 10 insertions(+), 10 deletions(-) diff --git a/examples/poetry.py b/examples/poetry.py index 46de2fe0..ad0300b1 100644 --- a/examples/poetry.py +++ b/examples/poetry.py @@ -1,11 +1,9 @@ from bespokelabs import curator poet = curator.Prompter( - prompt_func=lambda: { - "user_prompt": "Write a poem about the beauty of computer science" - }, + prompt_func=lambda: "Write a poem about the beauty of computer science", model_name="gpt-4o-mini", ) poem = poet() -print(poem.to_list()[0]) +print(poem["response"][0]) diff --git a/src/bespokelabs/curator/prompter/prompter.py b/src/bespokelabs/curator/prompter/prompter.py index 9e6ee6c9..66d6da7d 100644 --- a/src/bespokelabs/curator/prompter/prompter.py +++ b/src/bespokelabs/curator/prompter/prompter.py @@ -103,7 +103,7 @@ def _completions( Iterable: A list of structured outputs from the completions """ # NOTE(Ryan): We convert from iterable to Dataset because Dataset has random access via row_idx - if not isinstance(dataset, Dataset): + if not isinstance(dataset, Dataset) and dataset is not None: dataset = Dataset.from_generator(dataset) if self is None: @@ -113,7 +113,9 @@ def _completions( "CURATOR_CACHE_DIR", os.path.expanduser("~/.cache/curator") ) - dataset_hash = dataset._fingerprint + dataset_hash = ( + dataset._fingerprint if dataset is not None else xxh64("").hexdigest() + ) prompt_func_hash = _get_function_hash(self.prompt_formatter.prompt_func) parse_func_hash = _get_function_hash(self.prompt_formatter.parse_func) diff --git a/src/bespokelabs/curator/request_processor/base_request_processor.py b/src/bespokelabs/curator/request_processor/base_request_processor.py index 8ca42383..29ff5c5d 100644 --- a/src/bespokelabs/curator/request_processor/base_request_processor.py +++ b/src/bespokelabs/curator/request_processor/base_request_processor.py @@ -22,7 +22,7 @@ class BaseRequestProcessor(ABC): Base class for all request processors. """ - def __init__(self, batch_size: int = 1000): + def __init__(self, batch_size: Optional[int] = None): self.batch_size = batch_size @abstractmethod @@ -272,7 +272,7 @@ def create_dataset_files( if not isinstance(dataset_rows, list): dataset_rows = [dataset_rows] else: - dataset_rows = [response.response] + dataset_rows = [{"response": response.response}] for row in dataset_rows: if isinstance(row, BaseModel): diff --git a/src/bespokelabs/curator/request_processor/openai_online_request_processor.py b/src/bespokelabs/curator/request_processor/openai_online_request_processor.py index 8dee4375..261739d7 100644 --- a/src/bespokelabs/curator/request_processor/openai_online_request_processor.py +++ b/src/bespokelabs/curator/request_processor/openai_online_request_processor.py @@ -27,7 +27,7 @@ class OpenAIOnlineRequestProcessor(BaseRequestProcessor): def __init__( self, - batch_size: int = 1000, + batch_size: Optional[int] = None, model: str = "gpt-4o-mini", api_key: str = os.getenv("OPENAI_API_KEY"), url: str = "https://api.openai.com/v1/chat/completions", @@ -119,7 +119,7 @@ def create_api_specific_request(self, generic_request: GenericRequest) -> dict: return request def get_generic_response( - self, response: Dict, prompt_formatter: PromptFormatter, dataset: Dataset + self, response: Dict, prompt_formatter: PromptFormatter ) -> GenericResponse: """ Parses a API-specific response into a generic response body. From 688e3f1b05b4bb4459ecc99b47ac5f087f43a5b5 Mon Sep 17 00:00:00 2001 From: Ryan Marten Date: Thu, 7 Nov 2024 22:59:54 -0800 Subject: [PATCH 14/18] fix camel --- examples/camel.py | 19 ++--- .../base_request_processor.py | 69 +++++++++---------- 2 files changed, 38 insertions(+), 50 deletions(-) diff --git a/examples/camel.py b/examples/camel.py index 15b69f6d..13e4119f 100644 --- a/examples/camel.py +++ b/examples/camel.py @@ -1,6 +1,4 @@ from typing import List - -import pandas as pd from pydantic import BaseModel, Field from bespokelabs import curator @@ -24,18 +22,14 @@ class QAs(BaseModel): subject_prompter = curator.Prompter( - prompt_func=lambda: { - "user_prompt": f"Generate a diverse list of 3 subjects. Keep it high-level (e.g. Math, Science)." - }, + prompt_func=lambda: f"Generate a diverse list of 3 subjects. Keep it high-level (e.g. Math, Science).", parse_func=lambda _, subjects: [subject for subject in subjects.subjects], model_name="gpt-4o-mini", response_format=Subjects, ) subject_dataset = subject_prompter() subsubject_prompter = curator.Prompter( - prompt_func=lambda subject: { - "user_prompt": f"For the given subject {subject}. Generate 3 diverse subsubjects. No explanation." - }, + prompt_func=lambda subject: f"For the given subject {subject}. Generate 3 diverse subsubjects. No explanation.", parse_func=lambda subject, subsubjects: [ {"subject": subject["subject"], "subsubject": subsubject.subject} for subsubject in subsubjects.subjects @@ -46,9 +40,7 @@ class QAs(BaseModel): subsubject_dataset = subsubject_prompter(subject_dataset) qa_prompter = curator.Prompter( - prompt_func=lambda subsubject: { - "user_prompt": f"For the given subsubject {subsubject}. Generate 3 diverse questions and answers. No explanation." - }, + prompt_func=lambda subsubject: f"For the given subsubject {subsubject}. Generate 3 diverse questions and answers. No explanation.", model_name="gpt-4o-mini", response_format=QAs, parse_func=lambda subsubject, qas: [ @@ -63,6 +55,5 @@ class QAs(BaseModel): ) qa_dataset = qa_prompter(subsubject_dataset) -qa_hf_dataset = qa_dataset.to_huggingface() -qa_hf_dataset.map(lambda row: {"answer": row["answer"].strip()}, num_proc=2) -print(qa_hf_dataset) +qa_dataset.map(lambda row: {"answer": row["answer"].strip()}, num_proc=2) +print(qa_dataset) diff --git a/src/bespokelabs/curator/request_processor/base_request_processor.py b/src/bespokelabs/curator/request_processor/base_request_processor.py index 29ff5c5d..4ac2f3c8 100644 --- a/src/bespokelabs/curator/request_processor/base_request_processor.py +++ b/src/bespokelabs/curator/request_processor/base_request_processor.py @@ -144,47 +144,42 @@ def create_request_files( ) return requests_files - request_count = 0 - request_file_idx = 0 - requests_file = f"{working_dir}/requests_{request_file_idx}.jsonl" - requests_files = [] - # Create new requests file - with open(requests_file, "w") as f: - if dataset is None: + requests_file = f"{working_dir}/requests_0.jsonl" + requests_files = [requests_file] + + if dataset is None: + with open(requests_file, "w") as f: request = prompt_formatter.get_generic_request(dict(), 0) api_request = self.create_api_specific_request(request) f.write(json.dumps(api_request) + "\n") - else: - if self.batch_size: - num_batches = ceil(len(dataset) / self.batch_size) - requests_files = [ - f"{working_dir}/requests_{i}.jsonl" for i in range(num_batches) - ] - - async def create_all_request_files(): - tasks = [ - self.acreate_request_file( - dataset, - prompt_formatter, - requests_files[i], - start_idx=i * self.batch_size, - ) - for i in range(num_batches) - ] - await asyncio.gather(*tasks) - - asyncio.run(create_all_request_files()) - else: - requests_files = [f"{working_dir}/requests_0.jsonl"] - asyncio.run( - self.acreate_request_file( - dataset, prompt_formatter, requests_files[0] - ) + return requests_files + + if self.batch_size: + num_batches = ceil(len(dataset) / self.batch_size) + requests_files = [ + f"{working_dir}/requests_{i}.jsonl" for i in range(num_batches) + ] + + async def create_all_request_files(): + tasks = [ + self.acreate_request_file( + dataset, + prompt_formatter, + requests_files[i], + start_idx=i * self.batch_size, ) - - if request_count > 0: - logging.info(f"Wrote {request_count:,} requests to {requests_file}.") + for i in range(num_batches) + ] + await asyncio.gather(*tasks) + + asyncio.run(create_all_request_files()) + else: + asyncio.run( + self.acreate_request_file( + dataset, prompt_formatter, requests_file + ) + ) return requests_files @@ -199,6 +194,8 @@ async def acreate_request_file( if self.batch_size is not None: end_idx = min(start_idx + self.batch_size, len(dataset)) dataset = dataset.select(range(start_idx, end_idx)) + else: + end_idx = len(dataset) # NOTE(Ryan): For loops only for IterableDataset which allows for _very_ large datasets, when start_idx and batch_size are not specified async with aiofiles.open(request_file, "w") as f: From 44f0a8191e31107e97a6b4d2de544f1731d3c917 Mon Sep 17 00:00:00 2001 From: Trung Vu Date: Fri, 8 Nov 2024 07:11:44 +0000 Subject: [PATCH 15/18] fix name conflict & remove matplotlib --- build.py => build_pkg.py | 4 +- .../openai_batch_request_processor.py | 62 +------------------ 2 files changed, 3 insertions(+), 63 deletions(-) rename build.py => build_pkg.py (99%) diff --git a/build.py b/build_pkg.py similarity index 99% rename from build.py rename to build_pkg.py index 2e60f838..5a6cd635 100644 --- a/build.py +++ b/build_pkg.py @@ -1,9 +1,9 @@ -import os -import subprocess import shutil +import subprocess import sys from pathlib import Path + def run_command(command, cwd=None): result = subprocess.run(command, shell=True, cwd=cwd, check=True) return result diff --git a/src/bespokelabs/curator/request_processor/openai_batch_request_processor.py b/src/bespokelabs/curator/request_processor/openai_batch_request_processor.py index 48569fe9..20afcaf1 100644 --- a/src/bespokelabs/curator/request_processor/openai_batch_request_processor.py +++ b/src/bespokelabs/curator/request_processor/openai_batch_request_processor.py @@ -2,13 +2,9 @@ import json import logging import os -from dataclasses import dataclass, field -from typing import Any, Callable, Dict, Optional, Set, Tuple, TypeVar +from typing import Callable, Dict, Optional, TypeVar from openai import AsyncOpenAI -import pandas as pd -import matplotlib.pyplot as plt import aiofiles -import io from bespokelabs.curator.dataset import Dataset from bespokelabs.curator.request_processor.base_request_processor import ( BaseRequestProcessor, @@ -16,7 +12,6 @@ GenericResponse, ) from bespokelabs.curator.prompter.prompt_formatter import PromptFormatter -from io import BytesIO T = TypeVar("T") @@ -362,58 +357,3 @@ async def download_batch_result_file( ) f.write(json.dumps(generic_response.model_dump(), default=str) + "\n") return output_path - - async def plot_completion_data(self, output_dir: str) -> None: - """Save plots visualizing completion times for the batches. - - Args: - output_dir (str): Directory to save the plots. - """ - completion_times = [] - completion_dates = [] - - for batch_id in self.batch_ids: - batch = await self.client.batches.retrieve(batch_id) - if batch.status == "completed": - duration = ( - batch.completed_at - batch.created_at - ) / 60 # Convert to minutes - completion_times.append(duration) - completion_dates.append(batch.completed_at) - - # Create a DataFrame for plotting - df = pd.DataFrame( - { - "Completion Time (min)": completion_times, # Update label to minutes - "Completion Date": pd.to_datetime(completion_dates, unit="s"), - } - ) - - # Histogram of completion durations - plt.figure(figsize=(12, 6)) - plt.hist(df["Completion Time (min)"], bins=20, color="blue", alpha=0.7) - plt.title("Histogram of Completion Durations") - plt.xlabel("Duration (minutes)") # Update label to minutes - plt.ylabel("Frequency") - plt.grid(axis="y") - plt.savefig( - os.path.join(output_dir, "completion_durations_histogram.png") - ) # Save the histogram - plt.close() # Close the plot - - # Cumulative plot of completed jobs over time - df.sort_values("Completion Date", inplace=True) - df["Cumulative Completed"] = range(1, len(df) + 1) - - plt.figure(figsize=(12, 6)) - plt.plot( - df["Completion Date"], df["Cumulative Completed"], marker="o", color="green" - ) - plt.title("Cumulative Completed Jobs Over Time") - plt.xlabel("Completion Date") - plt.ylabel("Cumulative Completed Jobs") - plt.grid() - plt.savefig( - os.path.join(output_dir, "cumulative_completed_jobs.png") - ) # Save the cumulative plot - plt.close() # Close the plot From 10a063e5e995d3059e50ca3735979d0cc9505db5 Mon Sep 17 00:00:00 2001 From: Trung Vu Date: Fri, 8 Nov 2024 07:24:30 +0000 Subject: [PATCH 16/18] fix cache read logic from arrow --- .../request_processor/base_request_processor.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/src/bespokelabs/curator/request_processor/base_request_processor.py b/src/bespokelabs/curator/request_processor/base_request_processor.py index 4ac2f3c8..43e9c80a 100644 --- a/src/bespokelabs/curator/request_processor/base_request_processor.py +++ b/src/bespokelabs/curator/request_processor/base_request_processor.py @@ -1,20 +1,20 @@ +import asyncio +import glob import json import logging import os -import glob - from abc import ABC, abstractmethod +from math import ceil from typing import Optional +import aiofiles from datasets import Dataset +from datasets.arrow_writer import ArrowWriter, SchemaInferenceError +from pydantic import BaseModel + from bespokelabs.curator.prompter.prompt_formatter import PromptFormatter from bespokelabs.curator.request_processor.generic_request import GenericRequest from bespokelabs.curator.request_processor.generic_response import GenericResponse -from datasets.arrow_writer import ArrowWriter, SchemaInferenceError -from pydantic import BaseModel -from math import ceil -import asyncio -import aiofiles class BaseRequestProcessor(ABC): @@ -238,6 +238,9 @@ def create_dataset_files( if len(responses_files) == 0: raise ValueError(f"No responses files found in {working_dir}") dataset_file = f"{working_dir}/dataset.arrow" + if os.path.exists(dataset_file): + logging.info(f"Using existing dataset file {dataset_file}") + return Dataset.from_file(dataset_file) # Process all response files with ArrowWriter(path=dataset_file) as writer: From 6d31b05dd46ef05a6e51220076ba600a89eb3f34 Mon Sep 17 00:00:00 2001 From: Trung Vu Date: Fri, 8 Nov 2024 07:24:46 +0000 Subject: [PATCH 17/18] fix qa_dataset print --- examples/camel.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/camel.py b/examples/camel.py index 13e4119f..bffa0507 100644 --- a/examples/camel.py +++ b/examples/camel.py @@ -1,4 +1,5 @@ from typing import List + from pydantic import BaseModel, Field from bespokelabs import curator @@ -56,4 +57,4 @@ class QAs(BaseModel): qa_dataset = qa_prompter(subsubject_dataset) qa_dataset.map(lambda row: {"answer": row["answer"].strip()}, num_proc=2) -print(qa_dataset) +print(qa_dataset.to_pandas()) From 8171a9b9eced2e13bde8278b5f37de01e78b3983 Mon Sep 17 00:00:00 2001 From: Trung Vu Date: Fri, 8 Nov 2024 07:25:13 +0000 Subject: [PATCH 18/18] black --- build_pkg.py | 43 +++++----- src/bespokelabs/curator/prompter/prompter.py | 14 ++-- .../base_request_processor.py | 4 +- src/bespokelabs/curator/viewer/__main__.py | 80 +++++++++++-------- 4 files changed, 77 insertions(+), 64 deletions(-) diff --git a/build_pkg.py b/build_pkg.py index 5a6cd635..86c886c8 100644 --- a/build_pkg.py +++ b/build_pkg.py @@ -8,19 +8,21 @@ def run_command(command, cwd=None): result = subprocess.run(command, shell=True, cwd=cwd, check=True) return result + def npm_install(): print("Running npm install") run_command("npm install", cwd="bespoke-dataset-viewer") + def nextjs_build(): print("Running Next.js build") run_command("npm run build", cwd="bespoke-dataset-viewer") print("Copying build artifacts to static folder") - + # Source and target directories source_base = Path("bespoke-dataset-viewer") target_base = Path("src/bespokelabs/curator/viewer/static") - + # Ensure target directory exists if target_base.exists(): shutil.rmtree(target_base) @@ -28,26 +30,26 @@ def nextjs_build(): # Copy only the necessary files, excluding node_modules files_to_copy = [ - '.next', - 'app', - 'components', - 'lib', - 'public', - 'types', - 'package.json', - 'package-lock.json', - 'next.config.ts', - 'next-env.d.ts', - 'tsconfig.json', - 'postcss.config.mjs', - 'tailwind.config.ts', - 'components.json' + ".next", + "app", + "components", + "lib", + "public", + "types", + "package.json", + "package-lock.json", + "next.config.ts", + "next-env.d.ts", + "tsconfig.json", + "postcss.config.mjs", + "tailwind.config.ts", + "components.json", ] - + for item in files_to_copy: source = source_base / item target = target_base / item - + if source.exists(): if source.is_file(): shutil.copy2(source, target) @@ -60,6 +62,7 @@ def nextjs_build(): else: print(f"Warning: {source} not found") + def run_pytest(): print("Running pytest") try: @@ -68,11 +71,13 @@ def run_pytest(): print("Pytest failed. Aborting build.") sys.exit(1) + def main(): npm_install() nextjs_build() run_pytest() print("Build completed successfully.") + if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/src/bespokelabs/curator/prompter/prompter.py b/src/bespokelabs/curator/prompter/prompter.py index 66d6da7d..171ddad5 100644 --- a/src/bespokelabs/curator/prompter/prompter.py +++ b/src/bespokelabs/curator/prompter/prompter.py @@ -1,21 +1,19 @@ """Curator: Bespoke Labs Synthetic Data Generation Library.""" import inspect -import json -import logging -import math import os -import time -from concurrent.futures import ProcessPoolExecutor from datetime import datetime from typing import Any, Callable, Dict, Iterable, Optional, Type, TypeVar, Union +from datasets import Dataset from pydantic import BaseModel from xxhash import xxh64 -from datasets import Dataset from bespokelabs.curator.db import MetadataDB from bespokelabs.curator.prompter.prompt_formatter import PromptFormatter +from bespokelabs.curator.request_processor.base_request_processor import ( + BaseRequestProcessor, +) from bespokelabs.curator.request_processor.generic_request import GenericRequest from bespokelabs.curator.request_processor.openai_batch_request_processor import ( OpenAIBatchRequestProcessor, @@ -23,9 +21,6 @@ from bespokelabs.curator.request_processor.openai_online_request_processor import ( OpenAIOnlineRequestProcessor, ) -from bespokelabs.curator.request_processor.base_request_processor import ( - BaseRequestProcessor, -) T = TypeVar("T") @@ -116,6 +111,7 @@ def _completions( dataset_hash = ( dataset._fingerprint if dataset is not None else xxh64("").hexdigest() ) + prompt_func_hash = _get_function_hash(self.prompt_formatter.prompt_func) parse_func_hash = _get_function_hash(self.prompt_formatter.parse_func) diff --git a/src/bespokelabs/curator/request_processor/base_request_processor.py b/src/bespokelabs/curator/request_processor/base_request_processor.py index 43e9c80a..083cca64 100644 --- a/src/bespokelabs/curator/request_processor/base_request_processor.py +++ b/src/bespokelabs/curator/request_processor/base_request_processor.py @@ -176,9 +176,7 @@ async def create_all_request_files(): asyncio.run(create_all_request_files()) else: asyncio.run( - self.acreate_request_file( - dataset, prompt_formatter, requests_file - ) + self.acreate_request_file(dataset, prompt_formatter, requests_file) ) return requests_files diff --git a/src/bespokelabs/curator/viewer/__main__.py b/src/bespokelabs/curator/viewer/__main__.py index 3bf229f1..145241ab 100644 --- a/src/bespokelabs/curator/viewer/__main__.py +++ b/src/bespokelabs/curator/viewer/__main__.py @@ -12,60 +12,74 @@ import tempfile import shutil + def get_viewer_path(): return str(Path(__file__).parent) + def ensure_dependencies(): """Ensure npm dependencies are installed""" - static_dir = os.path.join(get_viewer_path(), 'static') - node_modules = os.path.join(static_dir, 'node_modules') - + static_dir = os.path.join(get_viewer_path(), "static") + node_modules = os.path.join(static_dir, "node_modules") + if not os.path.exists(node_modules): print("First run: Installing Node.js dependencies...") print("Your node_modules path: ", node_modules) try: - subprocess.run( - ["npm", "install"], - cwd=static_dir, - check=True - ) + subprocess.run(["npm", "install"], cwd=static_dir, check=True) print("Dependencies installed successfully.") except subprocess.CalledProcessError as e: print(f"Error installing dependencies: {e}") sys.exit(1) except FileNotFoundError: - print("Error: Node.js is not installed. Please install Node.js to run the viewer.") + print( + "Error: Node.js is not installed. Please install Node.js to run the viewer." + ) sys.exit(1) + def _setup_logging(level): logging.basicConfig( - format='%(asctime)s %(levelname)-8s] %(message)s', + format="%(asctime)s %(levelname)-8s] %(message)s", level=level, - datefmt='%Y-%m-%d %H:%M:%S' + datefmt="%Y-%m-%d %H:%M:%S", ) + def check_node_installed(): """Check if Node.js is installed and return version if found""" try: result = subprocess.run( - ["node", "--version"], - capture_output=True, - text=True, - check=True + ["node", "--version"], capture_output=True, text=True, check=True ) return result.stdout.strip() except (subprocess.CalledProcessError, FileNotFoundError): return None + def main(): parser = ArgumentParser(description="Curator Viewer") - parser.add_argument("--host", default="127.0.0.1", help="Host to run the server on (default: localhost)") - parser.add_argument("--port", type=int, default=3000, help="Port to run the server on (default: 3000)") - parser.add_argument("--verbose", "-v", action="store_true", help="Enables debug logging for more verbose output") + parser.add_argument( + "--host", + default="127.0.0.1", + help="Host to run the server on (default: localhost)", + ) + parser.add_argument( + "--port", + type=int, + default=3000, + help="Port to run the server on (default: 3000)", + ) + parser.add_argument( + "--verbose", + "-v", + action="store_true", + help="Enables debug logging for more verbose output", + ) args = parser.parse_args() _setup_logging(logging.DEBUG if args.verbose else logging.INFO) - + # Check if Node.js is installed node_version = check_node_installed() if not node_version: @@ -76,7 +90,7 @@ def main(): print("2. Verify installation by running: node --version") print("3. Run curator-viewer again") sys.exit(1) - + ensure_dependencies() # Set environment variables for the Next.js server @@ -87,26 +101,26 @@ def main(): # Start the Next.js server viewer_path = get_viewer_path() - static_dir = os.path.join(viewer_path, 'static') - server_file = os.path.join(viewer_path, 'server.js') - - if not os.path.exists(os.path.join(static_dir, '.next')): - print("Error: Next.js build artifacts not found. The package may not be built correctly.") + static_dir = os.path.join(viewer_path, "static") + server_file = os.path.join(viewer_path, "server.js") + + if not os.path.exists(os.path.join(static_dir, ".next")): + print( + "Error: Next.js build artifacts not found. The package may not be built correctly." + ) sys.exit(1) - + try: - subprocess.run( - ["node", server_file], - cwd=viewer_path, - env=env, - check=True - ) + subprocess.run(["node", server_file], cwd=viewer_path, env=env, check=True) except subprocess.CalledProcessError as e: print(f"Error starting Next.js server: {e}") sys.exit(1) except FileNotFoundError: - print("Error: Node.js is not installed. Please install Node.js to run the viewer.") + print( + "Error: Node.js is not installed. Please install Node.js to run the viewer." + ) sys.exit(1) + if __name__ == "__main__": main()