From e9f4e12729e2b00a34613934a95d4fd2e7da0b5a Mon Sep 17 00:00:00 2001 From: Daniel Date: Mon, 13 May 2024 16:56:00 -0400 Subject: [PATCH] Add foundational tests (#60) --- pyproject.toml | 9 +++++ src/shelloracle/providers/localai.py | 4 +- src/shelloracle/providers/openai.py | 5 +-- tests/conftest.py | 26 +++++++++++++ tests/providers/conftest.py | 43 +++++++++++++++++++++ tests/providers/test_localai.py | 28 ++++++++++++++ tests/providers/test_ollama.py | 42 +++++++++++++++++++- tests/providers/test_openai.py | 30 ++++++++++++++- tests/test_config.py | 57 ++++++++++++++++++++++++++++ tests/test_shelloracle.py | 2 +- tox.ini | 2 + 11 files changed, 238 insertions(+), 10 deletions(-) create mode 100644 tests/conftest.py create mode 100644 tests/providers/conftest.py create mode 100644 tests/providers/test_localai.py create mode 100644 tests/test_config.py diff --git a/pyproject.toml b/pyproject.toml index 49a814d..80d0d21 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,6 +35,15 @@ classifiers = [ "License :: OSI Approved :: GNU General Public License v3 (GPLv3)", ] +[project.optional-dependencies] +tests = [ + "tox", + "pytest", + "pytest-sugar", + "pytest-asyncio", + "pytest-httpx" +] + [project.scripts] shor = "shelloracle.__main__:main" diff --git a/src/shelloracle/providers/localai.py b/src/shelloracle/providers/localai.py index 7c83846..539904c 100644 --- a/src/shelloracle/providers/localai.py +++ b/src/shelloracle/providers/localai.py @@ -1,7 +1,7 @@ from collections.abc import AsyncIterator from openai import APIError -from openai import AsyncOpenAI as OpenAIClient +from openai import AsyncOpenAI from . import Provider, ProviderError, Setting, system_prompt @@ -19,7 +19,7 @@ def endpoint(self) -> str: def __init__(self): # Use a placeholder API key so the client will work - self.client = OpenAIClient(api_key="sk-xxx", base_url=self.endpoint) + self.client = AsyncOpenAI(api_key="sk-xxx", base_url=self.endpoint) async def generate(self, prompt: str) -> AsyncIterator[str]: try: diff --git a/src/shelloracle/providers/openai.py b/src/shelloracle/providers/openai.py index a1d8461..4e79ce8 100644 --- a/src/shelloracle/providers/openai.py +++ b/src/shelloracle/providers/openai.py @@ -1,7 +1,6 @@ from collections.abc import AsyncIterator -from openai import APIError -from openai import AsyncOpenAI as OpenAIClient +from openai import AsyncOpenAI, APIError from . import Provider, ProviderError, Setting, system_prompt @@ -15,7 +14,7 @@ class OpenAI(Provider): def __init__(self): if not self.api_key: raise ProviderError("No API key provided") - self.client = OpenAIClient(api_key=self.api_key) + self.client = AsyncOpenAI(api_key=self.api_key) async def generate(self, prompt: str) -> AsyncIterator[str]: try: diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..9c8b3d2 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,26 @@ +import pytest +import tomlkit + +from shelloracle.config import Configuration + + +@pytest.fixture(autouse=True) +def tmp_shelloracle_home(monkeypatch, tmp_path): + monkeypatch.setattr("shelloracle.settings.Settings.shelloracle_home", tmp_path) + return tmp_path + + +@pytest.fixture +def set_config(monkeypatch, tmp_shelloracle_home): + config_path = tmp_shelloracle_home / "config.toml" + + def _set_config(config: dict) -> Configuration: + with config_path.open("w") as f: + tomlkit.dump(config, f) + configuration = Configuration(config_path) + monkeypatch.setattr("shelloracle.config._config", configuration) + + yield _set_config + + config_path.unlink() + diff --git a/tests/providers/conftest.py b/tests/providers/conftest.py new file mode 100644 index 0000000..9686a2e --- /dev/null +++ b/tests/providers/conftest.py @@ -0,0 +1,43 @@ +from unittest.mock import MagicMock + +import pytest + + +def split_with_delimiter(string, delim): + result = [] + last_split = 0 + for index, character in enumerate(string): + if character == delim: + result.append(string[last_split:index + 1]) + last_split = index + 1 + if last_split != len(string): + result.append(string[last_split:]) + return result + + +@pytest.fixture +def mock_asyncopenai(monkeypatch): + class AsyncChatCompletionIterator: + def __init__(self, answer: str): + self.answer_index = 0 + self.answer_deltas = split_with_delimiter(answer, " ") + + def __aiter__(self): + return self + + async def __anext__(self): + if self.answer_index < len(self.answer_deltas): + answer_chunk = self.answer_deltas[self.answer_index] + self.answer_index += 1 + choice = MagicMock() + choice.delta.content = answer_chunk + chunk = MagicMock() + chunk.choices = [choice] + return chunk + else: + raise StopAsyncIteration + + async def mock_acreate(*args, **kwargs): + return AsyncChatCompletionIterator("head -c 100 /dev/urandom | hexdump -C") + + monkeypatch.setattr("openai.resources.chat.AsyncCompletions.create", mock_acreate) diff --git a/tests/providers/test_localai.py b/tests/providers/test_localai.py new file mode 100644 index 0000000..efedcc4 --- /dev/null +++ b/tests/providers/test_localai.py @@ -0,0 +1,28 @@ +import pytest + +from shelloracle.providers.localai import LocalAI + + +class TestOpenAI: + @pytest.fixture + def localai_config(self, set_config): + config = {'shelloracle': {'provider': 'LocalAI'}, 'provider': { + 'LocalAI': {'host': 'localhost', 'port': 8080, 'model': 'mistral-openorca'}}} + set_config(config) + + @pytest.fixture + def localai_instance(self, localai_config): + return LocalAI() + + def test_name(self): + assert LocalAI.name == "LocalAI" + + def test_model(self, localai_instance): + assert localai_instance.model == "mistral-openorca" + + @pytest.mark.asyncio + async def test_generate(self, mock_asyncopenai, localai_instance): + result = "" + async for response in localai_instance.generate(""): + result += response + assert result == "head -c 100 /dev/urandom | hexdump -C" diff --git a/tests/providers/test_ollama.py b/tests/providers/test_ollama.py index c573633..0571a5d 100644 --- a/tests/providers/test_ollama.py +++ b/tests/providers/test_ollama.py @@ -1,5 +1,43 @@ +import pytest +from pytest_httpx import IteratorStream + from shelloracle.providers.ollama import Ollama -def test_name(): - assert Ollama.name == "Ollama" +class TestOllama: + @pytest.fixture + def ollama_config(self, set_config): + config = {'shelloracle': {'provider': 'Ollama'}, + 'provider': {'Ollama': {'host': 'localhost', 'port': 11434, 'model': 'dolphin-mistral'}}} + set_config(config) + + @pytest.fixture + def ollama_instance(self, ollama_config): + return Ollama() + + def test_name(self): + assert Ollama.name == "Ollama" + + def test_host(self, ollama_instance): + assert ollama_instance.host == "localhost" + + def test_port(self, ollama_instance): + assert ollama_instance.port == 11434 + + def test_model(self, ollama_instance): + assert ollama_instance.model == "dolphin-mistral" + + def test_endpoint(self, ollama_instance): + assert ollama_instance.endpoint == "http://localhost:11434/api/generate" + + @pytest.mark.asyncio + async def test_generate(self, ollama_instance, httpx_mock): + responses = [ + b'{"response": "cat"}\n', b'{"response": " test"}\n', b'{"response": "."}\n', b'{"response": "py"}\n', + b'{"response": ""}\n' + ] + httpx_mock.add_response(stream=IteratorStream(responses)) + result = "" + async for response in ollama_instance.generate(""): + result += response + assert result == "cat test.py" diff --git a/tests/providers/test_openai.py b/tests/providers/test_openai.py index 277b032..e411289 100644 --- a/tests/providers/test_openai.py +++ b/tests/providers/test_openai.py @@ -1,5 +1,31 @@ +import pytest + from shelloracle.providers.openai import OpenAI -def test_name(): - assert OpenAI.name == "OpenAI" +class TestOpenAI: + @pytest.fixture + def openai_config(self, set_config): + config = {'shelloracle': {'provider': 'OpenAI'}, 'provider': { + 'OpenAI': {'api_key': 'sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx', 'model': 'gpt-3.5-turbo'}}} + set_config(config) + + @pytest.fixture + def openai_instance(self, openai_config): + return OpenAI() + + def test_name(self): + assert OpenAI.name == "OpenAI" + + def test_api_key(self, openai_instance): + assert openai_instance.api_key == "sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx" + + def test_model(self, openai_instance): + assert openai_instance.model == "gpt-3.5-turbo" + + @pytest.mark.asyncio + async def test_generate(self, mock_asyncopenai, openai_instance): + result = "" + async for response in openai_instance.generate(""): + result += response + assert result == "head -c 100 /dev/urandom | hexdump -C" diff --git a/tests/test_config.py b/tests/test_config.py new file mode 100644 index 0000000..0e6b45e --- /dev/null +++ b/tests/test_config.py @@ -0,0 +1,57 @@ +from __future__ import annotations + +import pytest + +from shelloracle.config import get_config, initialize_config + + +class TestConfiguration: + @pytest.fixture + def default_config(self, set_config): + config = {'shelloracle': {'provider': 'Ollama', 'spinner_style': 'earth'}, + 'provider': {'Ollama': {'host': 'localhost', 'port': 11434, 'model': 'dolphin-mistral'}}} + set_config(config) + return config + + def test_initialize_config(self, default_config): + with pytest.raises(RuntimeError): + initialize_config() + + def test_from_file(self, default_config): + assert get_config() == default_config + + def test_getitem(self, default_config): + for key in default_config: + assert default_config[key] == get_config()[key] + + def test_len(self, default_config): + assert len(default_config) == len(get_config()) + + def test_iter(self, default_config): + assert list(iter(default_config)) == list(iter(get_config())) + + def test_str(self, default_config): + assert str(get_config()) == f"Configuration({default_config})" + + def test_repr(self, default_config): + assert repr(default_config) == str(default_config) + + def test_provider(self, default_config): + assert get_config().provider == "Ollama" + + def test_spinner_style(self, default_config): + assert get_config().spinner_style == "earth" + + def test_no_spinner_style(self, caplog, set_config): + config_dict = {'shelloracle': {'provider': 'Ollama'}, + 'provider': {'Ollama': {'host': 'localhost', 'port': 11434, 'model': 'dolphin-mistral'}}} + set_config(config_dict) + assert get_config().spinner_style is None + assert "invalid spinner style" not in caplog.text + + def test_invalid_spinner_style(self, caplog, set_config): + config_dict = {'shelloracle': {'provider': 'Ollama', 'spinner_style': 'invalid'}, + 'provider': {'Ollama': {'host': 'localhost', 'port': 11434, 'model': 'dolphin-mistral'}}} + set_config(config_dict) + assert get_config().spinner_style is None + assert "invalid spinner style" in caplog.text diff --git a/tests/test_shelloracle.py b/tests/test_shelloracle.py index ca98c47..80e6ca2 100644 --- a/tests/test_shelloracle.py +++ b/tests/test_shelloracle.py @@ -2,7 +2,7 @@ import os import sys -from unittest.mock import MagicMock, call +from unittest.mock import call, MagicMock import pytest from yaspin.spinners import Spinners diff --git a/tox.ini b/tox.ini index 9c87a83..eb45929 100644 --- a/tox.ini +++ b/tox.ini @@ -8,6 +8,8 @@ description = run unit tests with pytest deps = pytest>=7 pytest-sugar + pytest-asyncio + pytest-httpx commands = pytest {posargs:tests}