diff --git a/changes/3085.fix.md b/changes/3085.fix.md new file mode 100644 index 0000000000..3acb33be3e --- /dev/null +++ b/changes/3085.fix.md @@ -0,0 +1 @@ +Fix session `status_info` not being updated correctly when batch executions fail, ensuring failed batch execution states are properly reflected in the sessions table diff --git a/src/ai/backend/agent/agent.py b/src/ai/backend/agent/agent.py index ba9a3d8f1d..fdb6a89277 100644 --- a/src/ai/backend/agent/agent.py +++ b/src/ai/backend/agent/agent.py @@ -1650,7 +1650,7 @@ async def execute_batch( if result["exitCode"] == 0: await self.produce_event( SessionSuccessEvent( - session_id, KernelLifecycleEventReason.TASK_DONE, 0 + session_id, KernelLifecycleEventReason.TASK_FINISHED, 0 ), ) else: diff --git a/src/ai/backend/common/events.py b/src/ai/backend/common/events.py index 3815fb2fbe..4d4f25b767 100644 --- a/src/ai/backend/common/events.py +++ b/src/ai/backend/common/events.py @@ -226,7 +226,6 @@ class KernelLifecycleEventReason(str, enum.Enum): RESTART_TIMEOUT = "restart-timeout" RESUMING_AGENT_OPERATION = "resuming-agent-operation" SELF_TERMINATED = "self-terminated" - TASK_DONE = "task-done" TASK_FAILED = "task-failed" TASK_TIMEOUT = "task-timeout" TASK_CANCELLED = "task-cancelled" diff --git a/src/ai/backend/manager/registry.py b/src/ai/backend/manager/registry.py index b4dc16d767..a25c744876 100644 --- a/src/ai/backend/manager/registry.py +++ b/src/ai/backend/manager/registry.py @@ -3647,10 +3647,11 @@ async def handle_batch_result( """ Update the database according to the batch-job completion results """ - if isinstance(event, SessionSuccessEvent): - await SessionRow.set_session_result(context.db, event.session_id, True, event.exit_code) - elif isinstance(event, SessionFailureEvent): - await SessionRow.set_session_result(context.db, event.session_id, False, event.exit_code) + match event: + case SessionSuccessEvent(session_id=session_id, reason=reason, exit_code=exit_code): + await SessionRow.set_session_result(context.db, session_id, True, exit_code) + case SessionFailureEvent(session_id=session_id, reason=reason, exit_code=exit_code): + await SessionRow.set_session_result(context.db, session_id, False, exit_code) async with context.db.begin_session() as db_sess: try: session = await SessionRow.get_session( @@ -3660,7 +3661,7 @@ async def handle_batch_result( return await context.destroy_session( session, - reason=KernelLifecycleEventReason.TASK_FINISHED, + reason=reason, ) await invoke_session_callback(context, source, event)