diff --git a/src/bespokelabs/curator/request_processor/base_request_processor.py b/src/bespokelabs/curator/request_processor/base_request_processor.py index 3ef2e104..e5cb5d0d 100644 --- a/src/bespokelabs/curator/request_processor/base_request_processor.py +++ b/src/bespokelabs/curator/request_processor/base_request_processor.py @@ -13,6 +13,9 @@ from pydantic import BaseModel from bespokelabs.curator.prompter.prompt_formatter import PromptFormatter +from bespokelabs.curator.request_processor.event_loop import ( + get_or_create_event_loop, +) from bespokelabs.curator.request_processor.generic_request import GenericRequest from bespokelabs.curator.request_processor.generic_response import ( GenericResponse, @@ -171,9 +174,11 @@ async def create_all_request_files(): ] await asyncio.gather(*tasks) - asyncio.run(create_all_request_files()) + loop = get_or_create_event_loop() + loop.run_until_complete(create_all_request_files()) else: - asyncio.run( + loop = get_or_create_event_loop() + loop.run_until_complete( self.acreate_request_file( dataset, prompt_formatter, requests_file ) diff --git a/src/bespokelabs/curator/request_processor/event_loop.py b/src/bespokelabs/curator/request_processor/event_loop.py new file mode 100644 index 00000000..0cf2960c --- /dev/null +++ b/src/bespokelabs/curator/request_processor/event_loop.py @@ -0,0 +1,16 @@ +import asyncio + + +def get_or_create_event_loop(): + """ + Get the current event loop or create a new one if there isn't one. + """ + try: + return asyncio.get_running_loop() + except RuntimeError as e: + # If no event loop is running, asyncio will + # return a RuntimeError (https://docs.python.org/3/library/asyncio-eventloop.html#asyncio.get_running_loop). + # In that case, we can create a new event loop. + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + return loop diff --git a/src/bespokelabs/curator/request_processor/openai_batch_request_processor.py b/src/bespokelabs/curator/request_processor/openai_batch_request_processor.py index 95404978..1c21db9e 100644 --- a/src/bespokelabs/curator/request_processor/openai_batch_request_processor.py +++ b/src/bespokelabs/curator/request_processor/openai_batch_request_processor.py @@ -3,16 +3,21 @@ import logging import os from typing import Callable, Dict, Optional, TypeVar -from openai import AsyncOpenAI + import aiofiles +from openai import AsyncOpenAI +from tqdm import tqdm + from bespokelabs.curator.dataset import Dataset +from bespokelabs.curator.prompter.prompt_formatter import PromptFormatter from bespokelabs.curator.request_processor.base_request_processor import ( BaseRequestProcessor, GenericRequest, GenericResponse, ) -from bespokelabs.curator.prompter.prompt_formatter import PromptFormatter -from tqdm import tqdm +from bespokelabs.curator.request_processor.event_loop import ( + get_or_create_event_loop, +) T = TypeVar("T") logger = logging.getLogger(__name__) @@ -232,7 +237,8 @@ async def submit_all_batches(): ] return await asyncio.gather(*tasks) - batch_objects = asyncio.run(submit_all_batches()) + loop = get_or_create_event_loop() + batch_objects = loop.run_until_complete(submit_all_batches()) with open(batch_objects_file, "w") as f: # NOTE(Ryan): we can also store the request_file_name in this object here, instead of in the metadata during batch submission. Can find a nice abstraction across other batch APIs (e.g. claude) @@ -254,7 +260,8 @@ async def submit_all_batches(): working_dir, check_interval=self.check_interval ) - asyncio.run( + loop = get_or_create_event_loop() + loop.run_until_complete( batch_watcher.watch( prompt_formatter, self.get_generic_response, dataset ) diff --git a/src/bespokelabs/curator/request_processor/openai_online_request_processor.py b/src/bespokelabs/curator/request_processor/openai_online_request_processor.py index 37469b9a..d80c7d88 100644 --- a/src/bespokelabs/curator/request_processor/openai_online_request_processor.py +++ b/src/bespokelabs/curator/request_processor/openai_online_request_processor.py @@ -5,21 +5,24 @@ import re import time from dataclasses import dataclass, field +from functools import partial from typing import Any, Callable, Dict, Optional, Set, Tuple, TypeVar import aiohttp import requests import tiktoken from tqdm import tqdm -from functools import partial from bespokelabs.curator.dataset import Dataset +from bespokelabs.curator.prompter.prompter import PromptFormatter from bespokelabs.curator.request_processor.base_request_processor import ( BaseRequestProcessor, GenericRequest, GenericResponse, ) -from bespokelabs.curator.prompter.prompter import PromptFormatter +from bespokelabs.curator.request_processor.event_loop import ( + get_or_create_event_loop, +) T = TypeVar("T") logger = logging.getLogger(__name__) @@ -187,7 +190,8 @@ def run( for requests_file, responses_file in zip( requests_files, responses_files ): - asyncio.run( + loop = get_or_create_event_loop() + loop.run_until_complete( self.process_api_requests_from_file( requests_filepath=requests_file, save_filepath=responses_file,