diff --git a/src/shelloracle/__main__.py b/src/shelloracle/__main__.py index c16bb03..7545544 100644 --- a/src/shelloracle/__main__.py +++ b/src/shelloracle/__main__.py @@ -2,21 +2,30 @@ import logging from importlib.metadata import version -from shelloracle.config import initialize_config from . import shelloracle +from .config import initialize_config from .settings import Settings +from .tty_log_handler import TtyLogHandler + +logger = logging.getLogger(__name__) def configure_logging(): root_logger = logging.getLogger() root_logger.setLevel(logging.DEBUG) - formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(name)s - %(message)s") - handler = logging.FileHandler(Settings.shelloracle_home / "shelloracle.log") - handler.setLevel(logging.DEBUG) - handler.setFormatter(formatter) + file_formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(name)s - %(message)s") + file_handler = logging.FileHandler(Settings.shelloracle_home / "shelloracle.log") + file_handler.setLevel(logging.DEBUG) + file_handler.setFormatter(file_formatter) + + tty_formatter = logging.Formatter("%(message)s") + tty_handler = TtyLogHandler() + tty_handler.setLevel(logging.WARNING) + tty_handler.setFormatter(tty_formatter) - root_logger.addHandler(handler) + root_logger.addHandler(file_handler) + root_logger.addHandler(tty_handler) def configure(): @@ -37,13 +46,18 @@ def parse_args() -> argparse.Namespace: def main() -> None: + args = parse_args() configure_logging() - args = parse_args() if action := getattr(args, "action", None): action() exit(0) - initialize_config() + + try: + initialize_config() + except FileNotFoundError: + logger.warning("ShellOracle configuration not found. Run `shor configure` to initialize.") + exit(1) shelloracle.cli() diff --git a/src/shelloracle/shelloracle.py b/src/shelloracle/shelloracle.py index e831855..139bc81 100644 --- a/src/shelloracle/shelloracle.py +++ b/src/shelloracle/shelloracle.py @@ -7,9 +7,8 @@ from pathlib import Path from typing import TYPE_CHECKING -from prompt_toolkit import PromptSession, print_formatted_text +from prompt_toolkit import PromptSession from prompt_toolkit.application import create_app_session_from_tty -from prompt_toolkit.formatted_text import FormattedText from prompt_toolkit.history import FileHistory from prompt_toolkit.patch_stdout import patch_stdout from yaspin import yaspin @@ -100,6 +99,4 @@ def cli() -> None: except (KeyboardInterrupt, asyncio.exceptions.CancelledError): return except Exception as err: - logger.exception("An unhandled exception occurred") - with create_app_session_from_tty(): - print_formatted_text(FormattedText([("ansired", f"\n{err}")])) + logger.error("An error occurred: %s", err) diff --git a/src/shelloracle/tty_log_handler.py b/src/shelloracle/tty_log_handler.py new file mode 100644 index 0000000..d005c75 --- /dev/null +++ b/src/shelloracle/tty_log_handler.py @@ -0,0 +1,19 @@ +import logging + +from prompt_toolkit import print_formatted_text +from prompt_toolkit.application import create_app_session_from_tty +from prompt_toolkit.formatted_text import FormattedText + + +class TtyLogHandler(logging.Handler): + def emit(self, record): + if record.levelno >= logging.ERROR: + color = "ansired" + elif record.levelno == logging.WARNING: + color = "ansiyellow" + else: + color = "ansywhite" + log_entry = self.format(record) + formatted_log_entry = FormattedText([(color, f"\n{log_entry}")]) + with create_app_session_from_tty(): + print_formatted_text(formatted_log_entry)