diff --git a/pyproject.toml b/pyproject.toml index 49a814d..80d0d21 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,6 +35,15 @@ classifiers = [ "License :: OSI Approved :: GNU General Public License v3 (GPLv3)", ] +[project.optional-dependencies] +tests = [ + "tox", + "pytest", + "pytest-sugar", + "pytest-asyncio", + "pytest-httpx" +] + [project.scripts] shor = "shelloracle.__main__:main" diff --git a/src/shelloracle/providers/localai.py b/src/shelloracle/providers/localai.py index 7c83846..539904c 100644 --- a/src/shelloracle/providers/localai.py +++ b/src/shelloracle/providers/localai.py @@ -1,7 +1,7 @@ from collections.abc import AsyncIterator from openai import APIError -from openai import AsyncOpenAI as OpenAIClient +from openai import AsyncOpenAI from . import Provider, ProviderError, Setting, system_prompt @@ -19,7 +19,7 @@ def endpoint(self) -> str: def __init__(self): # Use a placeholder API key so the client will work - self.client = OpenAIClient(api_key="sk-xxx", base_url=self.endpoint) + self.client = AsyncOpenAI(api_key="sk-xxx", base_url=self.endpoint) async def generate(self, prompt: str) -> AsyncIterator[str]: try: diff --git a/src/shelloracle/providers/openai.py b/src/shelloracle/providers/openai.py index a1d8461..4e79ce8 100644 --- a/src/shelloracle/providers/openai.py +++ b/src/shelloracle/providers/openai.py @@ -1,7 +1,6 @@ from collections.abc import AsyncIterator -from openai import APIError -from openai import AsyncOpenAI as OpenAIClient +from openai import AsyncOpenAI, APIError from . import Provider, ProviderError, Setting, system_prompt @@ -15,7 +14,7 @@ class OpenAI(Provider): def __init__(self): if not self.api_key: raise ProviderError("No API key provided") - self.client = OpenAIClient(api_key=self.api_key) + self.client = AsyncOpenAI(api_key=self.api_key) async def generate(self, prompt: str) -> AsyncIterator[str]: try: diff --git a/tests/providers/conftest.py b/tests/providers/conftest.py index 9acd629..d97d9ca 100644 --- a/tests/providers/conftest.py +++ b/tests/providers/conftest.py @@ -1,17 +1,45 @@ -import textwrap +from unittest.mock import MagicMock +import openai +import openai.resources import pytest +def split_with_delimiter(string, delim): + result = [] + last_split = 0 + for index, character in enumerate(string): + if character == delim: + result.append(string[last_split:index + 1]) + last_split = index + 1 + if last_split != len(string): + result.append(string[last_split:]) + return result + + @pytest.fixture -def ollama_config(set_config): - config = textwrap.dedent("""\ - [shelloracle] - provider = "Ollama" - - [provider.Ollama] - host = "localhost" - port = 11434 - model = "dolphin-mistral" - """) - set_config(config) +def mock_asyncopenai(monkeypatch): + class AsyncChatCompletionIterator: + def __init__(self, answer: str): + self.answer_index = 0 + self.answer_deltas = split_with_delimiter(answer, " ") + + def __aiter__(self): + return self + + async def __anext__(self): + if self.answer_index < len(self.answer_deltas): + answer_chunk = self.answer_deltas[self.answer_index] + self.answer_index += 1 + chunk = MagicMock() + chunk.delta.content = answer_chunk + answer = MagicMock() + answer.choices = [chunk] + return answer + else: + raise StopAsyncIteration + + async def mock_acreate(*args, **kwargs): + return AsyncChatCompletionIterator("cat test.py") + + monkeypatch.setattr(openai.resources.chat.AsyncCompletions, "create", mock_acreate) diff --git a/tests/providers/test_localai.py b/tests/providers/test_localai.py index 6667db1..e25ab3b 100644 --- a/tests/providers/test_localai.py +++ b/tests/providers/test_localai.py @@ -1,10 +1,28 @@ import pytest +from shelloracle.providers.localai import LocalAI -class TestLocalAI: + +class TestOpenAI: @pytest.fixture def localai_config(self, set_config): - config = {'shelloracle': {'provider': 'LocalAI'}, - 'provider': {'LocalAI': {'host': 'localhost', 'port': 8080, 'model': 'gpt-3.5-turbo'}}} + config = {'shelloracle': {'provider': 'LocalAI'}, 'provider': { + 'LocalAI': {'host': 'localhost', 'port': 8080, 'model': 'mistral-openorca'}}} + set_config(config) + + @pytest.fixture + def localai_instance(self, localai_config): + return LocalAI() + + def test_name(self): + assert LocalAI.name == "LocalAI" + + def test_model(self, localai_instance): + assert localai_instance.model == "mistral-openorca" - return set_config(config) + @pytest.mark.asyncio + async def test_generate(self, mock_asyncopenai, localai_instance): + result = "" + async for response in localai_instance.generate(""): + result += response + assert result == "cat test.py" diff --git a/tests/providers/test_ollama.py b/tests/providers/test_ollama.py index 115b020..0571a5d 100644 --- a/tests/providers/test_ollama.py +++ b/tests/providers/test_ollama.py @@ -9,7 +9,7 @@ class TestOllama: def ollama_config(self, set_config): config = {'shelloracle': {'provider': 'Ollama'}, 'provider': {'Ollama': {'host': 'localhost', 'port': 11434, 'model': 'dolphin-mistral'}}} - return set_config(config) + set_config(config) @pytest.fixture def ollama_instance(self, ollama_config): diff --git a/tests/providers/test_openai.py b/tests/providers/test_openai.py index 3dbef19..fa459c6 100644 --- a/tests/providers/test_openai.py +++ b/tests/providers/test_openai.py @@ -8,7 +8,7 @@ class TestOpenAI: def openai_config(self, set_config): config = {'shelloracle': {'provider': 'OpenAI'}, 'provider': { 'OpenAI': {'api_key': 'sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx', 'model': 'gpt-3.5-turbo'}}} - return set_config(config) + set_config(config) @pytest.fixture def openai_instance(self, openai_config): @@ -16,3 +16,16 @@ def openai_instance(self, openai_config): def test_name(self): assert OpenAI.name == "OpenAI" + + def test_api_key(self, openai_instance): + assert openai_instance.api_key == "sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx" + + def test_model(self, openai_instance): + assert openai_instance.model == "gpt-3.5-turbo" + + @pytest.mark.asyncio + async def test_generate(self, mock_asyncopenai, openai_instance): + result = "" + async for response in openai_instance.generate(""): + result += response + assert result == "cat test.py"