From 33f697ab8c91c9ddbbd9410a7c55a06c72f114c5 Mon Sep 17 00:00:00 2001 From: Daniel Copley Date: Mon, 13 May 2024 13:54:51 -0700 Subject: [PATCH] Test updates --- tests/conftest.py | 5 ++--- tests/providers/conftest.py | 14 ++++++-------- tests/providers/test_localai.py | 2 +- tests/providers/test_openai.py | 2 +- 4 files changed, 10 insertions(+), 13 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 4f2985b..9c8b3d2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -14,14 +14,13 @@ def tmp_shelloracle_home(monkeypatch, tmp_path): def set_config(monkeypatch, tmp_shelloracle_home): config_path = tmp_shelloracle_home / "config.toml" - def setter(config: dict) -> Configuration: + def _set_config(config: dict) -> Configuration: with config_path.open("w") as f: tomlkit.dump(config, f) configuration = Configuration(config_path) monkeypatch.setattr("shelloracle.config._config", configuration) - return configuration - yield setter + yield _set_config config_path.unlink() diff --git a/tests/providers/conftest.py b/tests/providers/conftest.py index d97d9ca..9686a2e 100644 --- a/tests/providers/conftest.py +++ b/tests/providers/conftest.py @@ -1,7 +1,5 @@ from unittest.mock import MagicMock -import openai -import openai.resources import pytest @@ -31,15 +29,15 @@ async def __anext__(self): if self.answer_index < len(self.answer_deltas): answer_chunk = self.answer_deltas[self.answer_index] self.answer_index += 1 + choice = MagicMock() + choice.delta.content = answer_chunk chunk = MagicMock() - chunk.delta.content = answer_chunk - answer = MagicMock() - answer.choices = [chunk] - return answer + chunk.choices = [choice] + return chunk else: raise StopAsyncIteration async def mock_acreate(*args, **kwargs): - return AsyncChatCompletionIterator("cat test.py") + return AsyncChatCompletionIterator("head -c 100 /dev/urandom | hexdump -C") - monkeypatch.setattr(openai.resources.chat.AsyncCompletions, "create", mock_acreate) + monkeypatch.setattr("openai.resources.chat.AsyncCompletions.create", mock_acreate) diff --git a/tests/providers/test_localai.py b/tests/providers/test_localai.py index e25ab3b..efedcc4 100644 --- a/tests/providers/test_localai.py +++ b/tests/providers/test_localai.py @@ -25,4 +25,4 @@ async def test_generate(self, mock_asyncopenai, localai_instance): result = "" async for response in localai_instance.generate(""): result += response - assert result == "cat test.py" + assert result == "head -c 100 /dev/urandom | hexdump -C" diff --git a/tests/providers/test_openai.py b/tests/providers/test_openai.py index fa459c6..e411289 100644 --- a/tests/providers/test_openai.py +++ b/tests/providers/test_openai.py @@ -28,4 +28,4 @@ async def test_generate(self, mock_asyncopenai, openai_instance): result = "" async for response in openai_instance.generate(""): result += response - assert result == "cat test.py" + assert result == "head -c 100 /dev/urandom | hexdump -C"