Skip to content

Commit

Permalink
Add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
djcopley committed May 13, 2024
1 parent 24c7d80 commit fc30481
Show file tree
Hide file tree
Showing 7 changed files with 90 additions and 23 deletions.
9 changes: 9 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,15 @@ classifiers = [
"License :: OSI Approved :: GNU General Public License v3 (GPLv3)",
]

[project.optional-dependencies]
tests = [
"tox",
"pytest",
"pytest-sugar",
"pytest-asyncio",
"pytest-httpx"
]

[project.scripts]
shor = "shelloracle.__main__:main"

Expand Down
4 changes: 2 additions & 2 deletions src/shelloracle/providers/localai.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from collections.abc import AsyncIterator

from openai import APIError
from openai import AsyncOpenAI as OpenAIClient
from openai import AsyncOpenAI

from . import Provider, ProviderError, Setting, system_prompt

Expand All @@ -19,7 +19,7 @@ def endpoint(self) -> str:

def __init__(self):
# Use a placeholder API key so the client will work
self.client = OpenAIClient(api_key="sk-xxx", base_url=self.endpoint)
self.client = AsyncOpenAI(api_key="sk-xxx", base_url=self.endpoint)

async def generate(self, prompt: str) -> AsyncIterator[str]:
try:
Expand Down
5 changes: 2 additions & 3 deletions src/shelloracle/providers/openai.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from collections.abc import AsyncIterator

from openai import APIError
from openai import AsyncOpenAI as OpenAIClient
from openai import AsyncOpenAI, APIError

from . import Provider, ProviderError, Setting, system_prompt

Expand All @@ -15,7 +14,7 @@ class OpenAI(Provider):
def __init__(self):
if not self.api_key:
raise ProviderError("No API key provided")
self.client = OpenAIClient(api_key=self.api_key)
self.client = AsyncOpenAI(api_key=self.api_key)

async def generate(self, prompt: str) -> AsyncIterator[str]:
try:
Expand Down
52 changes: 40 additions & 12 deletions tests/providers/conftest.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,45 @@
import textwrap
from unittest.mock import MagicMock

import openai
import openai.resources
import pytest


def split_with_delimiter(string, delim):
result = []
last_split = 0
for index, character in enumerate(string):
if character == delim:
result.append(string[last_split:index + 1])
last_split = index + 1
if last_split != len(string):
result.append(string[last_split:])
return result


@pytest.fixture
def ollama_config(set_config):
config = textwrap.dedent("""\
[shelloracle]
provider = "Ollama"
[provider.Ollama]
host = "localhost"
port = 11434
model = "dolphin-mistral"
""")
set_config(config)
def mock_asyncopenai(monkeypatch):
class AsyncChatCompletionIterator:
def __init__(self, answer: str):
self.answer_index = 0
self.answer_deltas = split_with_delimiter(answer, " ")

def __aiter__(self):
return self

async def __anext__(self):
if self.answer_index < len(self.answer_deltas):
answer_chunk = self.answer_deltas[self.answer_index]
self.answer_index += 1
chunk = MagicMock()
chunk.delta.content = answer_chunk
answer = MagicMock()
answer.choices = [chunk]
return answer
else:
raise StopAsyncIteration

async def mock_acreate(*args, **kwargs):
return AsyncChatCompletionIterator("cat test.py")

monkeypatch.setattr(openai.resources.chat.AsyncCompletions, "create", mock_acreate)
26 changes: 22 additions & 4 deletions tests/providers/test_localai.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,28 @@
import pytest

from shelloracle.providers.localai import LocalAI

class TestLocalAI:

class TestOpenAI:
@pytest.fixture
def localai_config(self, set_config):
config = {'shelloracle': {'provider': 'LocalAI'},
'provider': {'LocalAI': {'host': 'localhost', 'port': 8080, 'model': 'gpt-3.5-turbo'}}}
config = {'shelloracle': {'provider': 'LocalAI'}, 'provider': {
'LocalAI': {'host': 'localhost', 'port': 8080, 'model': 'mistral-openorca'}}}
set_config(config)

@pytest.fixture
def localai_instance(self, localai_config):
return LocalAI()

def test_name(self):
assert LocalAI.name == "LocalAI"

def test_model(self, localai_instance):
assert localai_instance.model == "mistral-openorca"

return set_config(config)
@pytest.mark.asyncio
async def test_generate(self, mock_asyncopenai, localai_instance):
result = ""
async for response in localai_instance.generate(""):
result += response
assert result == "cat test.py"
2 changes: 1 addition & 1 deletion tests/providers/test_ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ class TestOllama:
def ollama_config(self, set_config):
config = {'shelloracle': {'provider': 'Ollama'},
'provider': {'Ollama': {'host': 'localhost', 'port': 11434, 'model': 'dolphin-mistral'}}}
return set_config(config)
set_config(config)

@pytest.fixture
def ollama_instance(self, ollama_config):
Expand Down
15 changes: 14 additions & 1 deletion tests/providers/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,24 @@ class TestOpenAI:
def openai_config(self, set_config):
config = {'shelloracle': {'provider': 'OpenAI'}, 'provider': {
'OpenAI': {'api_key': 'sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx', 'model': 'gpt-3.5-turbo'}}}
return set_config(config)
set_config(config)

@pytest.fixture
def openai_instance(self, openai_config):
return OpenAI()

def test_name(self):
assert OpenAI.name == "OpenAI"

def test_api_key(self, openai_instance):
assert openai_instance.api_key == "sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"

def test_model(self, openai_instance):
assert openai_instance.model == "gpt-3.5-turbo"

@pytest.mark.asyncio
async def test_generate(self, mock_asyncopenai, openai_instance):
result = ""
async for response in openai_instance.generate(""):
result += response
assert result == "cat test.py"

0 comments on commit fc30481

Please sign in to comment.