Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Explicitly close AsyncClient to avoid getting asyncio event loop is closed issues #101

Merged
merged 2 commits into from
Nov 14, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ def create_api_specific_request(
return request

async def asubmit_batch(self, batch_file: str) -> dict:
async_client = AsyncOpenAI()
# Create a list to store API-specific requests
api_specific_requests = []

Expand All @@ -147,13 +148,13 @@ async def asubmit_batch(self, batch_file: str) -> dict:
# Join requests with newlines and encode to bytes for upload
file_content = "\n".join(api_specific_requests).encode()

batch_file_upload = await self.async_client.files.create(
batch_file_upload = await async_client.files.create(
file=file_content, purpose="batch"
)

logger.info(f"File uploaded: {batch_file_upload}")

batch_object = await self.async_client.batches.create(
batch_object = await async_client.batches.create(
input_file_id=batch_file_upload.id,
endpoint="/v1/chat/completions",
completion_window="24h",
Expand All @@ -164,6 +165,9 @@ async def asubmit_batch(self, batch_file: str) -> dict:
logger.info(
f"Batch request submitted, received batch object: {batch_object}"
)
# Explicitly close the client. Otherwise we get something like
# future: <Task finished name='Task-46' coro=<AsyncClient.aclose() done ... >>
await async_client.close()

return batch_object

Expand Down Expand Up @@ -198,8 +202,6 @@ def run(
)
else:
# upload requests files and submit batches
self.async_client = AsyncOpenAI()

# asyncio gather preserves order
async def submit_all_batches():
tasks = [
Expand All @@ -226,18 +228,21 @@ async def submit_all_batches():
# TODO(Ryan): This creates responses_0.jsonl, responses_1.jsonl, etc. errors named same way? or errors_0.jsonl, errors_1.jsonl?
# TODO(Ryan): retries, resubmits on lagging batches - need to study this a little closer
# TODO(Ryan): likely can add some logic for smarter check_interval based on batch size and if the batch has started or not, fine to do a dumb ping for now
batch_watcher = BatchWatcher(
working_dir, check_interval=self.check_interval
)

# NOTE(Ryan): If we allow for multiple heterogeneous requests per dataset row, we will need to update this.
total_requests = 1 if dataset is None else len(dataset)

run_in_event_loop(
batch_watcher.watch(
async def watch_batches():
batch_watcher = BatchWatcher(
working_dir, check_interval=self.check_interval
)
await batch_watcher.watch(
prompt_formatter.response_format, total_requests
)
)
# Explicitly close the client. Otherwise we get something like
# future: <Task finished name='Task-46' coro=<AsyncClient.aclose() done ... >>
await batch_watcher.close_client()

run_in_event_loop(watch_batches())

dataset = self.create_dataset_files(
working_dir, parse_func_hash, prompt_formatter
Expand Down Expand Up @@ -266,6 +271,9 @@ def __init__(self, working_dir: str, check_interval) -> None:
self.check_interval = check_interval
self.working_dir = working_dir

async def close_client(self):
await self.client.close()

async def check_batch_status(self, batch_id: str) -> tuple[str, str]:
"""Check the status of a batch by its ID.

Expand Down