diff --git a/src/inspect_ai/_cli/trace.py b/src/inspect_ai/_cli/trace.py index 0a14612de..adf77ce16 100644 --- a/src/inspect_ai/_cli/trace.py +++ b/src/inspect_ai/_cli/trace.py @@ -1,4 +1,5 @@ import os +import shlex import shutil import subprocess import time @@ -57,15 +58,21 @@ def read_command(trace_file: str) -> None: @trace_command.command("anomalies") @click.argument("trace-file", type=str, required=False, default=TRACE_FILE_NAME) -def anomolies_command(trace_file: str) -> None: +@click.option( + "--all", + is_flag=True, + default=False, + help="Show all anomolies including errors and timeouts (by default only still running and cancelled actions are shown).", +) +def anomolies_command(trace_file: str, all: bool) -> None: """Look for anomalies in a trace file (never completed or cancelled actions).""" trace_file_path = resolve_trace_file_path(trace_file) traces = read_trace_file(trace_file_path) # Track started actions running_actions: dict[str, ActionTraceRecord] = {} - error_actions: dict[str, ActionTraceRecord] = {} canceled_actions: dict[str, ActionTraceRecord] = {} + error_actions: dict[str, ActionTraceRecord] = {} def action_started(trace: ActionTraceRecord) -> None: running_actions[trace.trace_id] = trace @@ -79,7 +86,8 @@ def action_completed(trace: ActionTraceRecord) -> ActionTraceRecord: raise RuntimeError(f"Expected {trace.trace_id} in action dictionary.") def action_failed(trace: ActionTraceRecord) -> None: - error_actions[start_trace.trace_id] = trace + if all: + error_actions[start_trace.trace_id] = trace def action_canceled(trace: ActionTraceRecord) -> None: canceled_actions[start_trace.trace_id] = trace @@ -111,6 +119,12 @@ def action_canceled(trace: ActionTraceRecord) -> None: case _: print(f"Unknown event type: {trace.event}") + # do we have any traces? + if len(running_actions) + len(canceled_actions) + len(error_actions) == 0: + print(f"TRACE: {shlex.quote(trace_file_path.as_posix())}\n") + print("No anomalies found in trace log.") + return + with open(os.devnull, "w") as f: # generate output console = Console(record=True, file=f) @@ -118,6 +132,8 @@ def action_canceled(trace: ActionTraceRecord) -> None: def print_fn(o: RenderableType) -> None: console.print(o, highlight=False) + print_fn(f"[bold]TRACE: {shlex.quote(trace_file_path.as_posix())}[bold]\n") + _print_bucket(print_fn, "Running Actions", running_actions) _print_bucket(print_fn, "Canceled Actions", canceled_actions) _print_bucket(print_fn, "Error Actions", error_actions) @@ -125,9 +141,8 @@ def print_fn(o: RenderableType) -> None: # display with 'less' if possible less = shutil.which("less") if less: - subprocess.run( - [less, "-R"], input=console.export_text(styles=True).encode() - ) + ansi_output = console.export_text(styles=True).encode() + subprocess.run([less, "-R"], input=ansi_output) else: print(console.export_text(styles=False))