Skip to content

Commit

Permalink
Add ability to configure status spinner (#59)
Browse files Browse the repository at this point in the history
* Add ability to configure status spinner
  • Loading branch information
djcopley committed Apr 27, 2024
1 parent aef0344 commit 26f71a0
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 14 deletions.
13 changes: 13 additions & 0 deletions src/shelloracle/config.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
from __future__ import annotations

import logging
import os
import sys
from collections.abc import Mapping, Iterator
from pathlib import Path
from typing import Any
from yaspin.spinners import SPINNERS_DATA

if sys.version_info < (3, 11):
import tomli as tomllib
else:
import tomllib

logger = logging.getLogger(__name__)
shelloracle_home = Path.home() / ".shelloracle"


Expand Down Expand Up @@ -42,6 +45,16 @@ def __iter__(self) -> Iterator[Any]:
def provider(self) -> str:
return self["shelloracle"]["provider"]

@property
def spinner_style(self) -> str | None:
style = self["shelloracle"].get("spinner_style", None)
if not style:
return None
if style not in SPINNERS_DATA:
logger.warning("invalid spinner style: %s", style)
return None
return style


_config: Configuration | None = None

Expand Down
19 changes: 18 additions & 1 deletion src/shelloracle/shelloracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,22 @@
import logging
import os
import sys
from typing import TYPE_CHECKING

from prompt_toolkit import PromptSession, print_formatted_text
from prompt_toolkit.application import create_app_session_from_tty
from prompt_toolkit.formatted_text import FormattedText
from prompt_toolkit.history import FileHistory
from prompt_toolkit.patch_stdout import patch_stdout
from yaspin import yaspin
from yaspin.spinners import Spinners

from .config import get_config, shelloracle_home
from .providers import get_provider

if TYPE_CHECKING:
from yaspin.core import Yaspin

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -43,6 +48,18 @@ def get_query_from_pipe() -> str | None:
return lines[0].rstrip()


def spinner() -> Yaspin:
"""Get the correct spinner based on the user's configuration
:returns: yaspin object
"""
config = get_config()
if not config.spinner_style:
return yaspin()
style = getattr(Spinners, config.spinner_style)
return yaspin(style)


async def shelloracle() -> None:
"""ShellOracle program entrypoint
Expand All @@ -64,7 +81,7 @@ async def shelloracle() -> None:
logger.info("user prompt: %s", prompt)

shell_command = ""
with create_app_session_from_tty(), patch_stdout(raw=True), yaspin() as sp:
with create_app_session_from_tty(), patch_stdout(raw=True), spinner() as sp:
async for token in provider.generate(prompt):
# some models may erroneously return a newline, which causes issues with the status spinner
token = token.replace("\n", "")
Expand Down
53 changes: 40 additions & 13 deletions tests/test_shelloracle.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,53 @@
from __future__ import annotations

import os
import sys
from unittest.mock import MagicMock, call

import pytest
from yaspin.spinners import Spinners

from shelloracle.shelloracle import get_query_from_pipe
from shelloracle.shelloracle import get_query_from_pipe, spinner


def test_get_query_from_pipe(monkeypatch):
# Is a TTY
monkeypatch.setattr(os, "isatty", lambda _: True)
assert get_query_from_pipe() is None
@pytest.fixture
def mock_yaspin(monkeypatch):
mock = MagicMock()
monkeypatch.setattr("shelloracle.shelloracle.yaspin", mock)
return mock

# Not a TTY and no lines in the pipe
monkeypatch.setattr(os, "isatty", lambda _: False)
monkeypatch.setattr(sys.stdin, "readlines", lambda: [])
assert get_query_from_pipe() is None

# Not TTY and one line in the pipe
monkeypatch.setattr(sys.stdin, "readlines", lambda: ["what is up"])
assert get_query_from_pipe() == "what is up"
@pytest.fixture
def mock_config(monkeypatch):
config = MagicMock()
monkeypatch.setattr("shelloracle.config._config", config)
return config


@pytest.mark.parametrize("spinner_style,expected", [(None, call()), ("earth", call(Spinners.earth))])
def test_spinner(spinner_style, expected, mock_config, mock_yaspin):
mock_config.spinner_style = spinner_style
spinner()
assert mock_yaspin.call_args == expected


# Not a TTY and multiple lines in the pipe
def test_spinner_fail(mock_yaspin, mock_config):
mock_config.spinner_style = "not a spinner style"
with pytest.raises(AttributeError):
spinner()


@pytest.mark.parametrize("isatty,readlines,expected", [
(True, None, None), (False, [], None), (False, ["what is up"], "what is up")
])
def test_get_query_from_pipe(isatty, readlines, expected, monkeypatch):
monkeypatch.setattr(os, "isatty", lambda _: isatty)
monkeypatch.setattr(sys.stdin, "readlines", lambda: readlines)
assert get_query_from_pipe() == expected


def test_get_query_from_pipe_fail(monkeypatch):
monkeypatch.setattr(os, "isatty", lambda _: False)
monkeypatch.setattr(sys.stdin, "readlines", lambda: ["what is up", "what is down"])
with pytest.raises(ValueError):
get_query_from_pipe()

0 comments on commit 26f71a0

Please sign in to comment.