Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove the use of asyncio.run to make asyncio work in colab #69

Merged
merged 5 commits into from
Nov 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
from pydantic import BaseModel

from bespokelabs.curator.prompter.prompt_formatter import PromptFormatter
from bespokelabs.curator.request_processor.event_loop import (
get_or_create_event_loop,
)
from bespokelabs.curator.request_processor.generic_request import GenericRequest
from bespokelabs.curator.request_processor.generic_response import (
GenericResponse,
Expand Down Expand Up @@ -171,9 +174,11 @@ async def create_all_request_files():
]
await asyncio.gather(*tasks)

asyncio.run(create_all_request_files())
loop = get_or_create_event_loop()
loop.run_until_complete(create_all_request_files())
else:
asyncio.run(
loop = get_or_create_event_loop()
loop.run_until_complete(
self.acreate_request_file(
dataset, prompt_formatter, requests_file
)
Expand Down
16 changes: 16 additions & 0 deletions src/bespokelabs/curator/request_processor/event_loop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import asyncio


def get_or_create_event_loop():
"""
Get the current event loop or create a new one if there isn't one.
"""
try:
return asyncio.get_running_loop()
except RuntimeError as e:
# If no event loop is running, asyncio will
# return a RuntimeError (https://docs.python.org/3/library/asyncio-eventloop.html#asyncio.get_running_loop).
# In that case, we can create a new event loop.
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
return loop
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,21 @@
import logging
import os
from typing import Callable, Dict, Optional, TypeVar
from openai import AsyncOpenAI

import aiofiles
from openai import AsyncOpenAI
from tqdm import tqdm

from bespokelabs.curator.dataset import Dataset
from bespokelabs.curator.prompter.prompt_formatter import PromptFormatter
from bespokelabs.curator.request_processor.base_request_processor import (
BaseRequestProcessor,
GenericRequest,
GenericResponse,
)
from bespokelabs.curator.prompter.prompt_formatter import PromptFormatter
from tqdm import tqdm
from bespokelabs.curator.request_processor.event_loop import (
get_or_create_event_loop,
)

T = TypeVar("T")
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -232,7 +237,8 @@ async def submit_all_batches():
]
return await asyncio.gather(*tasks)

batch_objects = asyncio.run(submit_all_batches())
loop = get_or_create_event_loop()
batch_objects = loop.run_until_complete(submit_all_batches())

with open(batch_objects_file, "w") as f:
# NOTE(Ryan): we can also store the request_file_name in this object here, instead of in the metadata during batch submission. Can find a nice abstraction across other batch APIs (e.g. claude)
Expand All @@ -254,7 +260,8 @@ async def submit_all_batches():
working_dir, check_interval=self.check_interval
)

asyncio.run(
loop = get_or_create_event_loop()
loop.run_until_complete(
batch_watcher.watch(
prompt_formatter, self.get_generic_response, dataset
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,24 @@
import re
import time
from dataclasses import dataclass, field
from functools import partial
from typing import Any, Callable, Dict, Optional, Set, Tuple, TypeVar

import aiohttp
import requests
import tiktoken
from tqdm import tqdm
from functools import partial

from bespokelabs.curator.dataset import Dataset
from bespokelabs.curator.prompter.prompter import PromptFormatter
from bespokelabs.curator.request_processor.base_request_processor import (
BaseRequestProcessor,
GenericRequest,
GenericResponse,
)
from bespokelabs.curator.prompter.prompter import PromptFormatter
from bespokelabs.curator.request_processor.event_loop import (
get_or_create_event_loop,
)

T = TypeVar("T")
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -187,7 +190,8 @@ def run(
for requests_file, responses_file in zip(
requests_files, responses_files
):
asyncio.run(
loop = get_or_create_event_loop()
loop.run_until_complete(
self.process_api_requests_from_file(
requests_filepath=requests_file,
save_filepath=responses_file,
Expand Down