From 3628b053d26953626d45e8b041e4549ea45087c2 Mon Sep 17 00:00:00 2001 From: Daniel Date: Sat, 6 Apr 2024 10:08:16 -0400 Subject: [PATCH] Add ability to configure status spinner (#59) * Add ability to configure status spinner --- src/shelloracle/config.py | 13 +++++++++ src/shelloracle/shelloracle.py | 19 +++++++++++- tests/test_shelloracle.py | 53 +++++++++++++++++++++++++--------- 3 files changed, 71 insertions(+), 14 deletions(-) diff --git a/src/shelloracle/config.py b/src/shelloracle/config.py index 2462128..2240b0e 100644 --- a/src/shelloracle/config.py +++ b/src/shelloracle/config.py @@ -1,16 +1,19 @@ from __future__ import annotations +import logging import os import sys from collections.abc import Mapping, Iterator from pathlib import Path from typing import Any +from yaspin.spinners import SPINNERS_DATA if sys.version_info < (3, 11): import tomli as tomllib else: import tomllib +logger = logging.getLogger(__name__) shelloracle_home = Path.home() / ".shelloracle" @@ -42,6 +45,16 @@ def __iter__(self) -> Iterator[Any]: def provider(self) -> str: return self["shelloracle"]["provider"] + @property + def spinner_style(self) -> str | None: + style = self["shelloracle"].get("spinner_style", None) + if not style: + return None + if style not in SPINNERS_DATA: + logger.warning("invalid spinner style: %s", style) + return None + return style + _config: Configuration | None = None diff --git a/src/shelloracle/shelloracle.py b/src/shelloracle/shelloracle.py index 7ce4968..b35fb37 100644 --- a/src/shelloracle/shelloracle.py +++ b/src/shelloracle/shelloracle.py @@ -5,6 +5,7 @@ import os import sys from pathlib import Path +from typing import TYPE_CHECKING from prompt_toolkit import PromptSession, print_formatted_text from prompt_toolkit.application import create_app_session_from_tty @@ -12,10 +13,14 @@ from prompt_toolkit.history import FileHistory from prompt_toolkit.patch_stdout import patch_stdout from yaspin import yaspin +from yaspin.spinners import Spinners from .config import get_config from .providers import get_provider +if TYPE_CHECKING: + from yaspin.core import Yaspin + logger = logging.getLogger(__name__) @@ -44,6 +49,18 @@ def get_query_from_pipe() -> str | None: return lines[0].rstrip() +def spinner() -> Yaspin: + """Get the correct spinner based on the user's configuration + + :returns: yaspin object + """ + config = get_config() + if not config.spinner_style: + return yaspin() + style = getattr(Spinners, config.spinner_style) + return yaspin(style) + + async def shelloracle() -> None: """ShellOracle program entrypoint @@ -65,7 +82,7 @@ async def shelloracle() -> None: logger.info("user prompt: %s", prompt) shell_command = "" - with create_app_session_from_tty(), patch_stdout(raw=True), yaspin() as sp: + with create_app_session_from_tty(), patch_stdout(raw=True), spinner() as sp: async for token in provider.generate(prompt): # some models may erroneously return a newline, which causes issues with the status spinner token = token.replace("\n", "") diff --git a/tests/test_shelloracle.py b/tests/test_shelloracle.py index b40a45f..ca98c47 100644 --- a/tests/test_shelloracle.py +++ b/tests/test_shelloracle.py @@ -1,26 +1,53 @@ +from __future__ import annotations + import os import sys +from unittest.mock import MagicMock, call import pytest +from yaspin.spinners import Spinners -from shelloracle.shelloracle import get_query_from_pipe +from shelloracle.shelloracle import get_query_from_pipe, spinner -def test_get_query_from_pipe(monkeypatch): - # Is a TTY - monkeypatch.setattr(os, "isatty", lambda _: True) - assert get_query_from_pipe() is None +@pytest.fixture +def mock_yaspin(monkeypatch): + mock = MagicMock() + monkeypatch.setattr("shelloracle.shelloracle.yaspin", mock) + return mock - # Not a TTY and no lines in the pipe - monkeypatch.setattr(os, "isatty", lambda _: False) - monkeypatch.setattr(sys.stdin, "readlines", lambda: []) - assert get_query_from_pipe() is None - # Not TTY and one line in the pipe - monkeypatch.setattr(sys.stdin, "readlines", lambda: ["what is up"]) - assert get_query_from_pipe() == "what is up" +@pytest.fixture +def mock_config(monkeypatch): + config = MagicMock() + monkeypatch.setattr("shelloracle.config._config", config) + return config + + +@pytest.mark.parametrize("spinner_style,expected", [(None, call()), ("earth", call(Spinners.earth))]) +def test_spinner(spinner_style, expected, mock_config, mock_yaspin): + mock_config.spinner_style = spinner_style + spinner() + assert mock_yaspin.call_args == expected + - # Not a TTY and multiple lines in the pipe +def test_spinner_fail(mock_yaspin, mock_config): + mock_config.spinner_style = "not a spinner style" + with pytest.raises(AttributeError): + spinner() + + +@pytest.mark.parametrize("isatty,readlines,expected", [ + (True, None, None), (False, [], None), (False, ["what is up"], "what is up") +]) +def test_get_query_from_pipe(isatty, readlines, expected, monkeypatch): + monkeypatch.setattr(os, "isatty", lambda _: isatty) + monkeypatch.setattr(sys.stdin, "readlines", lambda: readlines) + assert get_query_from_pipe() == expected + + +def test_get_query_from_pipe_fail(monkeypatch): + monkeypatch.setattr(os, "isatty", lambda _: False) monkeypatch.setattr(sys.stdin, "readlines", lambda: ["what is up", "what is down"]) with pytest.raises(ValueError): get_query_from_pipe()