Skip to content

Commit

Permalink
Merge pull request #70 from bespokelabsai/ryanm/batch_size_arg
Browse files Browse the repository at this point in the history
Add Prompter arg for batch size
  • Loading branch information
RyanMarten authored Nov 12, 2024
2 parents e9abb83 + 1f78e61 commit 41f4ed8
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 5 deletions.
Empty file added diff-file.patch
Empty file.
8 changes: 6 additions & 2 deletions examples/distill.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import logging

dataset = load_dataset("allenai/WildChat", split="train")
dataset = dataset.select(range(3_000))
dataset = dataset.select(range(300))

# To see more detail about how batches are being processed
logger = logging.getLogger("bespokelabs.curator")
Expand All @@ -20,7 +20,11 @@ def parse_func(row, response):


distill_prompter = curator.Prompter(
prompt_func=prompt_func, parse_func=parse_func, model_name="gpt-4o-mini", batch=True
prompt_func=prompt_func,
parse_func=parse_func,
model_name="gpt-4o-mini",
batch=True,
batch_size=100,
)

distilled_dataset = distill_prompter(dataset)
Expand Down
14 changes: 13 additions & 1 deletion src/bespokelabs/curator/prompter/prompter.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from datasets import Dataset
from pydantic import BaseModel
from xxhash import xxh64
import logging

from bespokelabs.curator.db import MetadataDB
from bespokelabs.curator.prompter.prompt_formatter import PromptFormatter
Expand All @@ -24,6 +25,8 @@

T = TypeVar("T")

logger = logging.getLogger(__name__)


class Prompter:
"""Interface for prompting LLMs."""
Expand All @@ -39,6 +42,7 @@ def __init__(
] = None,
response_format: Optional[Type[BaseModel]] = None,
batch: bool = False,
batch_size: Optional[int] = None,
):
"""Initialize a Prompter.
Expand All @@ -50,6 +54,8 @@ def __init__(
response object and returns the parsed output
response_format (Optional[Type[BaseModel]]): A Pydantic model specifying the
response format from the LLM.
batch (bool): Whether to use batch processing
batch_size (Optional[int]): The size of the batch to use, only used if batch is True
"""
prompt_sig = inspect.signature(prompt_func)
if len(prompt_sig.parameters) > 1:
Expand All @@ -69,8 +75,14 @@ def __init__(
)

if batch:
self._request_processor = OpenAIBatchRequestProcessor(model=model_name)
self._request_processor = OpenAIBatchRequestProcessor(
model=model_name, batch_size=batch_size
)
else:
if batch_size is not None:
logger.warning(
f"Prompter argument `batch_size` {batch_size} is ignored because `batch` is False"
)
self._request_processor = OpenAIOnlineRequestProcessor(model=model_name)

def __call__(self, dataset: Optional[Iterable] = None) -> Dataset:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,11 @@
class OpenAIOnlineRequestProcessor(BaseRequestProcessor):
def __init__(
self,
batch_size: Optional[int] = None,
model: str = "gpt-4o-mini",
api_key: str = os.getenv("OPENAI_API_KEY"),
url: str = "https://api.openai.com/v1/chat/completions",
):
super().__init__(batch_size)
super().__init__(batch_size=None)
self.model: str = model
self.url: str = url
self.api_key: str = api_key
Expand Down
Empty file added tests/test_cache.py
Empty file.

0 comments on commit 41f4ed8

Please sign in to comment.