Skip to content

Commit

Permalink
Add foundational tests (#60)
Browse files Browse the repository at this point in the history
  • Loading branch information
djcopley authored May 13, 2024
1 parent 3f56e60 commit e9f4e12
Show file tree
Hide file tree
Showing 11 changed files with 238 additions and 10 deletions.
9 changes: 9 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,15 @@ classifiers = [
"License :: OSI Approved :: GNU General Public License v3 (GPLv3)",
]

[project.optional-dependencies]
tests = [
"tox",
"pytest",
"pytest-sugar",
"pytest-asyncio",
"pytest-httpx"
]

[project.scripts]
shor = "shelloracle.__main__:main"

Expand Down
4 changes: 2 additions & 2 deletions src/shelloracle/providers/localai.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from collections.abc import AsyncIterator

from openai import APIError
from openai import AsyncOpenAI as OpenAIClient
from openai import AsyncOpenAI

from . import Provider, ProviderError, Setting, system_prompt

Expand All @@ -19,7 +19,7 @@ def endpoint(self) -> str:

def __init__(self):
# Use a placeholder API key so the client will work
self.client = OpenAIClient(api_key="sk-xxx", base_url=self.endpoint)
self.client = AsyncOpenAI(api_key="sk-xxx", base_url=self.endpoint)

async def generate(self, prompt: str) -> AsyncIterator[str]:
try:
Expand Down
5 changes: 2 additions & 3 deletions src/shelloracle/providers/openai.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from collections.abc import AsyncIterator

from openai import APIError
from openai import AsyncOpenAI as OpenAIClient
from openai import AsyncOpenAI, APIError

from . import Provider, ProviderError, Setting, system_prompt

Expand All @@ -15,7 +14,7 @@ class OpenAI(Provider):
def __init__(self):
if not self.api_key:
raise ProviderError("No API key provided")
self.client = OpenAIClient(api_key=self.api_key)
self.client = AsyncOpenAI(api_key=self.api_key)

async def generate(self, prompt: str) -> AsyncIterator[str]:
try:
Expand Down
26 changes: 26 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import pytest
import tomlkit

from shelloracle.config import Configuration


@pytest.fixture(autouse=True)
def tmp_shelloracle_home(monkeypatch, tmp_path):
monkeypatch.setattr("shelloracle.settings.Settings.shelloracle_home", tmp_path)
return tmp_path


@pytest.fixture
def set_config(monkeypatch, tmp_shelloracle_home):
config_path = tmp_shelloracle_home / "config.toml"

def _set_config(config: dict) -> Configuration:
with config_path.open("w") as f:
tomlkit.dump(config, f)
configuration = Configuration(config_path)
monkeypatch.setattr("shelloracle.config._config", configuration)

yield _set_config

config_path.unlink()

43 changes: 43 additions & 0 deletions tests/providers/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from unittest.mock import MagicMock

import pytest


def split_with_delimiter(string, delim):
result = []
last_split = 0
for index, character in enumerate(string):
if character == delim:
result.append(string[last_split:index + 1])
last_split = index + 1
if last_split != len(string):
result.append(string[last_split:])
return result


@pytest.fixture
def mock_asyncopenai(monkeypatch):
class AsyncChatCompletionIterator:
def __init__(self, answer: str):
self.answer_index = 0
self.answer_deltas = split_with_delimiter(answer, " ")

def __aiter__(self):
return self

async def __anext__(self):
if self.answer_index < len(self.answer_deltas):
answer_chunk = self.answer_deltas[self.answer_index]
self.answer_index += 1
choice = MagicMock()
choice.delta.content = answer_chunk
chunk = MagicMock()
chunk.choices = [choice]
return chunk
else:
raise StopAsyncIteration

async def mock_acreate(*args, **kwargs):
return AsyncChatCompletionIterator("head -c 100 /dev/urandom | hexdump -C")

monkeypatch.setattr("openai.resources.chat.AsyncCompletions.create", mock_acreate)
28 changes: 28 additions & 0 deletions tests/providers/test_localai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import pytest

from shelloracle.providers.localai import LocalAI


class TestOpenAI:
@pytest.fixture
def localai_config(self, set_config):
config = {'shelloracle': {'provider': 'LocalAI'}, 'provider': {
'LocalAI': {'host': 'localhost', 'port': 8080, 'model': 'mistral-openorca'}}}
set_config(config)

@pytest.fixture
def localai_instance(self, localai_config):
return LocalAI()

def test_name(self):
assert LocalAI.name == "LocalAI"

def test_model(self, localai_instance):
assert localai_instance.model == "mistral-openorca"

@pytest.mark.asyncio
async def test_generate(self, mock_asyncopenai, localai_instance):
result = ""
async for response in localai_instance.generate(""):
result += response
assert result == "head -c 100 /dev/urandom | hexdump -C"
42 changes: 40 additions & 2 deletions tests/providers/test_ollama.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,43 @@
import pytest
from pytest_httpx import IteratorStream

from shelloracle.providers.ollama import Ollama


def test_name():
assert Ollama.name == "Ollama"
class TestOllama:
@pytest.fixture
def ollama_config(self, set_config):
config = {'shelloracle': {'provider': 'Ollama'},
'provider': {'Ollama': {'host': 'localhost', 'port': 11434, 'model': 'dolphin-mistral'}}}
set_config(config)

@pytest.fixture
def ollama_instance(self, ollama_config):
return Ollama()

def test_name(self):
assert Ollama.name == "Ollama"

def test_host(self, ollama_instance):
assert ollama_instance.host == "localhost"

def test_port(self, ollama_instance):
assert ollama_instance.port == 11434

def test_model(self, ollama_instance):
assert ollama_instance.model == "dolphin-mistral"

def test_endpoint(self, ollama_instance):
assert ollama_instance.endpoint == "http://localhost:11434/api/generate"

@pytest.mark.asyncio
async def test_generate(self, ollama_instance, httpx_mock):
responses = [
b'{"response": "cat"}\n', b'{"response": " test"}\n', b'{"response": "."}\n', b'{"response": "py"}\n',
b'{"response": ""}\n'
]
httpx_mock.add_response(stream=IteratorStream(responses))
result = ""
async for response in ollama_instance.generate(""):
result += response
assert result == "cat test.py"
30 changes: 28 additions & 2 deletions tests/providers/test_openai.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,31 @@
import pytest

from shelloracle.providers.openai import OpenAI


def test_name():
assert OpenAI.name == "OpenAI"
class TestOpenAI:
@pytest.fixture
def openai_config(self, set_config):
config = {'shelloracle': {'provider': 'OpenAI'}, 'provider': {
'OpenAI': {'api_key': 'sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx', 'model': 'gpt-3.5-turbo'}}}
set_config(config)

@pytest.fixture
def openai_instance(self, openai_config):
return OpenAI()

def test_name(self):
assert OpenAI.name == "OpenAI"

def test_api_key(self, openai_instance):
assert openai_instance.api_key == "sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"

def test_model(self, openai_instance):
assert openai_instance.model == "gpt-3.5-turbo"

@pytest.mark.asyncio
async def test_generate(self, mock_asyncopenai, openai_instance):
result = ""
async for response in openai_instance.generate(""):
result += response
assert result == "head -c 100 /dev/urandom | hexdump -C"
57 changes: 57 additions & 0 deletions tests/test_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from __future__ import annotations

import pytest

from shelloracle.config import get_config, initialize_config


class TestConfiguration:
@pytest.fixture
def default_config(self, set_config):
config = {'shelloracle': {'provider': 'Ollama', 'spinner_style': 'earth'},
'provider': {'Ollama': {'host': 'localhost', 'port': 11434, 'model': 'dolphin-mistral'}}}
set_config(config)
return config

def test_initialize_config(self, default_config):
with pytest.raises(RuntimeError):
initialize_config()

def test_from_file(self, default_config):
assert get_config() == default_config

def test_getitem(self, default_config):
for key in default_config:
assert default_config[key] == get_config()[key]

def test_len(self, default_config):
assert len(default_config) == len(get_config())

def test_iter(self, default_config):
assert list(iter(default_config)) == list(iter(get_config()))

def test_str(self, default_config):
assert str(get_config()) == f"Configuration({default_config})"

def test_repr(self, default_config):
assert repr(default_config) == str(default_config)

def test_provider(self, default_config):
assert get_config().provider == "Ollama"

def test_spinner_style(self, default_config):
assert get_config().spinner_style == "earth"

def test_no_spinner_style(self, caplog, set_config):
config_dict = {'shelloracle': {'provider': 'Ollama'},
'provider': {'Ollama': {'host': 'localhost', 'port': 11434, 'model': 'dolphin-mistral'}}}
set_config(config_dict)
assert get_config().spinner_style is None
assert "invalid spinner style" not in caplog.text

def test_invalid_spinner_style(self, caplog, set_config):
config_dict = {'shelloracle': {'provider': 'Ollama', 'spinner_style': 'invalid'},
'provider': {'Ollama': {'host': 'localhost', 'port': 11434, 'model': 'dolphin-mistral'}}}
set_config(config_dict)
assert get_config().spinner_style is None
assert "invalid spinner style" in caplog.text
2 changes: 1 addition & 1 deletion tests/test_shelloracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import os
import sys
from unittest.mock import MagicMock, call
from unittest.mock import call, MagicMock

import pytest
from yaspin.spinners import Spinners
Expand Down
2 changes: 2 additions & 0 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ description = run unit tests with pytest
deps =
pytest>=7
pytest-sugar
pytest-asyncio
pytest-httpx
commands =
pytest {posargs:tests}

Expand Down

0 comments on commit e9f4e12

Please sign in to comment.