From 815a51ad67c82da89c7c2d72626958733d28e50e Mon Sep 17 00:00:00 2001 From: Renjie Liu Date: Wed, 22 Mar 2023 17:52:39 +0800 Subject: [PATCH] fix(batch): Cancel task should not propagate error. (#8675) --- dashboard/proto/gen/task_service.ts | 550 ++++++++++++++++++ proto/task_service.proto | 9 +- src/batch/src/rpc/service/task_service.rs | 16 +- src/batch/src/task/task_execution.rs | 24 +- src/batch/src/task/task_manager.rs | 10 +- .../src/scheduler/distributed/stage.rs | 8 +- src/rpc_client/src/compute_client.rs | 6 +- 7 files changed, 596 insertions(+), 27 deletions(-) create mode 100644 dashboard/proto/gen/task_service.ts diff --git a/dashboard/proto/gen/task_service.ts b/dashboard/proto/gen/task_service.ts new file mode 100644 index 0000000000000..183626f2d0518 --- /dev/null +++ b/dashboard/proto/gen/task_service.ts @@ -0,0 +1,550 @@ +/* eslint-disable */ +import { PlanFragment, TaskId as TaskId1, TaskOutputId } from "./batch_plan"; +import { BatchQueryEpoch, Status } from "./common"; +import { DataChunk } from "./data"; +import { StreamMessage } from "./stream_plan"; + +export const protobufPackage = "task_service"; + +/** Task is a running instance of Stage. */ +export interface TaskId { + queryId: string; + stageId: number; + taskId: number; +} + +export interface TaskInfoResponse { + taskId: TaskId1 | undefined; + taskStatus: TaskInfoResponse_TaskStatus; + /** Optional error message for failed task. */ + errorMessage: string; +} + +export const TaskInfoResponse_TaskStatus = { + /** UNSPECIFIED - Note: Requirement of proto3: first enum must be 0. */ + UNSPECIFIED: "UNSPECIFIED", + PENDING: "PENDING", + RUNNING: "RUNNING", + FINISHED: "FINISHED", + FAILED: "FAILED", + ABORTED: "ABORTED", + CANCELLED: "CANCELLED", + UNRECOGNIZED: "UNRECOGNIZED", +} as const; + +export type TaskInfoResponse_TaskStatus = typeof TaskInfoResponse_TaskStatus[keyof typeof TaskInfoResponse_TaskStatus]; + +export function taskInfoResponse_TaskStatusFromJSON(object: any): TaskInfoResponse_TaskStatus { + switch (object) { + case 0: + case "UNSPECIFIED": + return TaskInfoResponse_TaskStatus.UNSPECIFIED; + case 2: + case "PENDING": + return TaskInfoResponse_TaskStatus.PENDING; + case 3: + case "RUNNING": + return TaskInfoResponse_TaskStatus.RUNNING; + case 6: + case "FINISHED": + return TaskInfoResponse_TaskStatus.FINISHED; + case 7: + case "FAILED": + return TaskInfoResponse_TaskStatus.FAILED; + case 8: + case "ABORTED": + return TaskInfoResponse_TaskStatus.ABORTED; + case 9: + case "CANCELLED": + return TaskInfoResponse_TaskStatus.CANCELLED; + case -1: + case "UNRECOGNIZED": + default: + return TaskInfoResponse_TaskStatus.UNRECOGNIZED; + } +} + +export function taskInfoResponse_TaskStatusToJSON(object: TaskInfoResponse_TaskStatus): string { + switch (object) { + case TaskInfoResponse_TaskStatus.UNSPECIFIED: + return "UNSPECIFIED"; + case TaskInfoResponse_TaskStatus.PENDING: + return "PENDING"; + case TaskInfoResponse_TaskStatus.RUNNING: + return "RUNNING"; + case TaskInfoResponse_TaskStatus.FINISHED: + return "FINISHED"; + case TaskInfoResponse_TaskStatus.FAILED: + return "FAILED"; + case TaskInfoResponse_TaskStatus.ABORTED: + return "ABORTED"; + case TaskInfoResponse_TaskStatus.CANCELLED: + return "CANCELLED"; + case TaskInfoResponse_TaskStatus.UNRECOGNIZED: + default: + return "UNRECOGNIZED"; + } +} + +export interface CreateTaskRequest { + taskId: TaskId1 | undefined; + plan: PlanFragment | undefined; + epoch: BatchQueryEpoch | undefined; +} + +export interface CancelTaskRequest { + taskId: TaskId1 | undefined; +} + +export interface CancelTaskResponse { + status: Status | undefined; +} + +export interface GetTaskInfoRequest { + taskId: TaskId1 | undefined; +} + +export interface GetDataResponse { + recordBatch: DataChunk | undefined; +} + +export interface ExecuteRequest { + taskId: TaskId1 | undefined; + plan: PlanFragment | undefined; + epoch: BatchQueryEpoch | undefined; +} + +export interface GetDataRequest { + taskOutputId: TaskOutputId | undefined; +} + +export interface GetStreamRequest { + value?: { $case: "get"; get: GetStreamRequest_Get } | { + $case: "addPermits"; + addPermits: GetStreamRequest_AddPermits; + }; +} + +/** The first message, which tells the upstream which channel this exchange stream is for. */ +export interface GetStreamRequest_Get { + upActorId: number; + downActorId: number; + upFragmentId: number; + downFragmentId: number; +} + +/** The following messages, which adds the permits back to the upstream to achieve back-pressure. */ +export interface GetStreamRequest_AddPermits { + permits: number; +} + +export interface GetStreamResponse { + message: + | StreamMessage + | undefined; + /** The number of permits acquired for this message, which should be sent back to the upstream with `AddPermits`. */ + permits: number; +} + +function createBaseTaskId(): TaskId { + return { queryId: "", stageId: 0, taskId: 0 }; +} + +export const TaskId = { + fromJSON(object: any): TaskId { + return { + queryId: isSet(object.queryId) ? String(object.queryId) : "", + stageId: isSet(object.stageId) ? Number(object.stageId) : 0, + taskId: isSet(object.taskId) ? Number(object.taskId) : 0, + }; + }, + + toJSON(message: TaskId): unknown { + const obj: any = {}; + message.queryId !== undefined && (obj.queryId = message.queryId); + message.stageId !== undefined && (obj.stageId = Math.round(message.stageId)); + message.taskId !== undefined && (obj.taskId = Math.round(message.taskId)); + return obj; + }, + + fromPartial, I>>(object: I): TaskId { + const message = createBaseTaskId(); + message.queryId = object.queryId ?? ""; + message.stageId = object.stageId ?? 0; + message.taskId = object.taskId ?? 0; + return message; + }, +}; + +function createBaseTaskInfoResponse(): TaskInfoResponse { + return { taskId: undefined, taskStatus: TaskInfoResponse_TaskStatus.UNSPECIFIED, errorMessage: "" }; +} + +export const TaskInfoResponse = { + fromJSON(object: any): TaskInfoResponse { + return { + taskId: isSet(object.taskId) ? TaskId1.fromJSON(object.taskId) : undefined, + taskStatus: isSet(object.taskStatus) + ? taskInfoResponse_TaskStatusFromJSON(object.taskStatus) + : TaskInfoResponse_TaskStatus.UNSPECIFIED, + errorMessage: isSet(object.errorMessage) ? String(object.errorMessage) : "", + }; + }, + + toJSON(message: TaskInfoResponse): unknown { + const obj: any = {}; + message.taskId !== undefined && (obj.taskId = message.taskId ? TaskId1.toJSON(message.taskId) : undefined); + message.taskStatus !== undefined && (obj.taskStatus = taskInfoResponse_TaskStatusToJSON(message.taskStatus)); + message.errorMessage !== undefined && (obj.errorMessage = message.errorMessage); + return obj; + }, + + fromPartial, I>>(object: I): TaskInfoResponse { + const message = createBaseTaskInfoResponse(); + message.taskId = (object.taskId !== undefined && object.taskId !== null) + ? TaskId1.fromPartial(object.taskId) + : undefined; + message.taskStatus = object.taskStatus ?? TaskInfoResponse_TaskStatus.UNSPECIFIED; + message.errorMessage = object.errorMessage ?? ""; + return message; + }, +}; + +function createBaseCreateTaskRequest(): CreateTaskRequest { + return { taskId: undefined, plan: undefined, epoch: undefined }; +} + +export const CreateTaskRequest = { + fromJSON(object: any): CreateTaskRequest { + return { + taskId: isSet(object.taskId) ? TaskId1.fromJSON(object.taskId) : undefined, + plan: isSet(object.plan) ? PlanFragment.fromJSON(object.plan) : undefined, + epoch: isSet(object.epoch) ? BatchQueryEpoch.fromJSON(object.epoch) : undefined, + }; + }, + + toJSON(message: CreateTaskRequest): unknown { + const obj: any = {}; + message.taskId !== undefined && (obj.taskId = message.taskId ? TaskId1.toJSON(message.taskId) : undefined); + message.plan !== undefined && (obj.plan = message.plan ? PlanFragment.toJSON(message.plan) : undefined); + message.epoch !== undefined && (obj.epoch = message.epoch ? BatchQueryEpoch.toJSON(message.epoch) : undefined); + return obj; + }, + + fromPartial, I>>(object: I): CreateTaskRequest { + const message = createBaseCreateTaskRequest(); + message.taskId = (object.taskId !== undefined && object.taskId !== null) + ? TaskId1.fromPartial(object.taskId) + : undefined; + message.plan = (object.plan !== undefined && object.plan !== null) + ? PlanFragment.fromPartial(object.plan) + : undefined; + message.epoch = (object.epoch !== undefined && object.epoch !== null) + ? BatchQueryEpoch.fromPartial(object.epoch) + : undefined; + return message; + }, +}; + +function createBaseCancelTaskRequest(): CancelTaskRequest { + return { taskId: undefined }; +} + +export const CancelTaskRequest = { + fromJSON(object: any): CancelTaskRequest { + return { taskId: isSet(object.taskId) ? TaskId1.fromJSON(object.taskId) : undefined }; + }, + + toJSON(message: CancelTaskRequest): unknown { + const obj: any = {}; + message.taskId !== undefined && (obj.taskId = message.taskId ? TaskId1.toJSON(message.taskId) : undefined); + return obj; + }, + + fromPartial, I>>(object: I): CancelTaskRequest { + const message = createBaseCancelTaskRequest(); + message.taskId = (object.taskId !== undefined && object.taskId !== null) + ? TaskId1.fromPartial(object.taskId) + : undefined; + return message; + }, +}; + +function createBaseCancelTaskResponse(): CancelTaskResponse { + return { status: undefined }; +} + +export const CancelTaskResponse = { + fromJSON(object: any): CancelTaskResponse { + return { status: isSet(object.status) ? Status.fromJSON(object.status) : undefined }; + }, + + toJSON(message: CancelTaskResponse): unknown { + const obj: any = {}; + message.status !== undefined && (obj.status = message.status ? Status.toJSON(message.status) : undefined); + return obj; + }, + + fromPartial, I>>(object: I): CancelTaskResponse { + const message = createBaseCancelTaskResponse(); + message.status = (object.status !== undefined && object.status !== null) + ? Status.fromPartial(object.status) + : undefined; + return message; + }, +}; + +function createBaseGetTaskInfoRequest(): GetTaskInfoRequest { + return { taskId: undefined }; +} + +export const GetTaskInfoRequest = { + fromJSON(object: any): GetTaskInfoRequest { + return { taskId: isSet(object.taskId) ? TaskId1.fromJSON(object.taskId) : undefined }; + }, + + toJSON(message: GetTaskInfoRequest): unknown { + const obj: any = {}; + message.taskId !== undefined && (obj.taskId = message.taskId ? TaskId1.toJSON(message.taskId) : undefined); + return obj; + }, + + fromPartial, I>>(object: I): GetTaskInfoRequest { + const message = createBaseGetTaskInfoRequest(); + message.taskId = (object.taskId !== undefined && object.taskId !== null) + ? TaskId1.fromPartial(object.taskId) + : undefined; + return message; + }, +}; + +function createBaseGetDataResponse(): GetDataResponse { + return { recordBatch: undefined }; +} + +export const GetDataResponse = { + fromJSON(object: any): GetDataResponse { + return { recordBatch: isSet(object.recordBatch) ? DataChunk.fromJSON(object.recordBatch) : undefined }; + }, + + toJSON(message: GetDataResponse): unknown { + const obj: any = {}; + message.recordBatch !== undefined && + (obj.recordBatch = message.recordBatch ? DataChunk.toJSON(message.recordBatch) : undefined); + return obj; + }, + + fromPartial, I>>(object: I): GetDataResponse { + const message = createBaseGetDataResponse(); + message.recordBatch = (object.recordBatch !== undefined && object.recordBatch !== null) + ? DataChunk.fromPartial(object.recordBatch) + : undefined; + return message; + }, +}; + +function createBaseExecuteRequest(): ExecuteRequest { + return { taskId: undefined, plan: undefined, epoch: undefined }; +} + +export const ExecuteRequest = { + fromJSON(object: any): ExecuteRequest { + return { + taskId: isSet(object.taskId) ? TaskId1.fromJSON(object.taskId) : undefined, + plan: isSet(object.plan) ? PlanFragment.fromJSON(object.plan) : undefined, + epoch: isSet(object.epoch) ? BatchQueryEpoch.fromJSON(object.epoch) : undefined, + }; + }, + + toJSON(message: ExecuteRequest): unknown { + const obj: any = {}; + message.taskId !== undefined && (obj.taskId = message.taskId ? TaskId1.toJSON(message.taskId) : undefined); + message.plan !== undefined && (obj.plan = message.plan ? PlanFragment.toJSON(message.plan) : undefined); + message.epoch !== undefined && (obj.epoch = message.epoch ? BatchQueryEpoch.toJSON(message.epoch) : undefined); + return obj; + }, + + fromPartial, I>>(object: I): ExecuteRequest { + const message = createBaseExecuteRequest(); + message.taskId = (object.taskId !== undefined && object.taskId !== null) + ? TaskId1.fromPartial(object.taskId) + : undefined; + message.plan = (object.plan !== undefined && object.plan !== null) + ? PlanFragment.fromPartial(object.plan) + : undefined; + message.epoch = (object.epoch !== undefined && object.epoch !== null) + ? BatchQueryEpoch.fromPartial(object.epoch) + : undefined; + return message; + }, +}; + +function createBaseGetDataRequest(): GetDataRequest { + return { taskOutputId: undefined }; +} + +export const GetDataRequest = { + fromJSON(object: any): GetDataRequest { + return { taskOutputId: isSet(object.taskOutputId) ? TaskOutputId.fromJSON(object.taskOutputId) : undefined }; + }, + + toJSON(message: GetDataRequest): unknown { + const obj: any = {}; + message.taskOutputId !== undefined && + (obj.taskOutputId = message.taskOutputId ? TaskOutputId.toJSON(message.taskOutputId) : undefined); + return obj; + }, + + fromPartial, I>>(object: I): GetDataRequest { + const message = createBaseGetDataRequest(); + message.taskOutputId = (object.taskOutputId !== undefined && object.taskOutputId !== null) + ? TaskOutputId.fromPartial(object.taskOutputId) + : undefined; + return message; + }, +}; + +function createBaseGetStreamRequest(): GetStreamRequest { + return { value: undefined }; +} + +export const GetStreamRequest = { + fromJSON(object: any): GetStreamRequest { + return { + value: isSet(object.get) + ? { $case: "get", get: GetStreamRequest_Get.fromJSON(object.get) } + : isSet(object.addPermits) + ? { $case: "addPermits", addPermits: GetStreamRequest_AddPermits.fromJSON(object.addPermits) } + : undefined, + }; + }, + + toJSON(message: GetStreamRequest): unknown { + const obj: any = {}; + message.value?.$case === "get" && + (obj.get = message.value?.get ? GetStreamRequest_Get.toJSON(message.value?.get) : undefined); + message.value?.$case === "addPermits" && (obj.addPermits = message.value?.addPermits + ? GetStreamRequest_AddPermits.toJSON(message.value?.addPermits) + : undefined); + return obj; + }, + + fromPartial, I>>(object: I): GetStreamRequest { + const message = createBaseGetStreamRequest(); + if (object.value?.$case === "get" && object.value?.get !== undefined && object.value?.get !== null) { + message.value = { $case: "get", get: GetStreamRequest_Get.fromPartial(object.value.get) }; + } + if ( + object.value?.$case === "addPermits" && + object.value?.addPermits !== undefined && + object.value?.addPermits !== null + ) { + message.value = { + $case: "addPermits", + addPermits: GetStreamRequest_AddPermits.fromPartial(object.value.addPermits), + }; + } + return message; + }, +}; + +function createBaseGetStreamRequest_Get(): GetStreamRequest_Get { + return { upActorId: 0, downActorId: 0, upFragmentId: 0, downFragmentId: 0 }; +} + +export const GetStreamRequest_Get = { + fromJSON(object: any): GetStreamRequest_Get { + return { + upActorId: isSet(object.upActorId) ? Number(object.upActorId) : 0, + downActorId: isSet(object.downActorId) ? Number(object.downActorId) : 0, + upFragmentId: isSet(object.upFragmentId) ? Number(object.upFragmentId) : 0, + downFragmentId: isSet(object.downFragmentId) ? Number(object.downFragmentId) : 0, + }; + }, + + toJSON(message: GetStreamRequest_Get): unknown { + const obj: any = {}; + message.upActorId !== undefined && (obj.upActorId = Math.round(message.upActorId)); + message.downActorId !== undefined && (obj.downActorId = Math.round(message.downActorId)); + message.upFragmentId !== undefined && (obj.upFragmentId = Math.round(message.upFragmentId)); + message.downFragmentId !== undefined && (obj.downFragmentId = Math.round(message.downFragmentId)); + return obj; + }, + + fromPartial, I>>(object: I): GetStreamRequest_Get { + const message = createBaseGetStreamRequest_Get(); + message.upActorId = object.upActorId ?? 0; + message.downActorId = object.downActorId ?? 0; + message.upFragmentId = object.upFragmentId ?? 0; + message.downFragmentId = object.downFragmentId ?? 0; + return message; + }, +}; + +function createBaseGetStreamRequest_AddPermits(): GetStreamRequest_AddPermits { + return { permits: 0 }; +} + +export const GetStreamRequest_AddPermits = { + fromJSON(object: any): GetStreamRequest_AddPermits { + return { permits: isSet(object.permits) ? Number(object.permits) : 0 }; + }, + + toJSON(message: GetStreamRequest_AddPermits): unknown { + const obj: any = {}; + message.permits !== undefined && (obj.permits = Math.round(message.permits)); + return obj; + }, + + fromPartial, I>>(object: I): GetStreamRequest_AddPermits { + const message = createBaseGetStreamRequest_AddPermits(); + message.permits = object.permits ?? 0; + return message; + }, +}; + +function createBaseGetStreamResponse(): GetStreamResponse { + return { message: undefined, permits: 0 }; +} + +export const GetStreamResponse = { + fromJSON(object: any): GetStreamResponse { + return { + message: isSet(object.message) ? StreamMessage.fromJSON(object.message) : undefined, + permits: isSet(object.permits) ? Number(object.permits) : 0, + }; + }, + + toJSON(message: GetStreamResponse): unknown { + const obj: any = {}; + message.message !== undefined && + (obj.message = message.message ? StreamMessage.toJSON(message.message) : undefined); + message.permits !== undefined && (obj.permits = Math.round(message.permits)); + return obj; + }, + + fromPartial, I>>(object: I): GetStreamResponse { + const message = createBaseGetStreamResponse(); + message.message = (object.message !== undefined && object.message !== null) + ? StreamMessage.fromPartial(object.message) + : undefined; + message.permits = object.permits ?? 0; + return message; + }, +}; + +type Builtin = Date | Function | Uint8Array | string | number | boolean | undefined; + +export type DeepPartial = T extends Builtin ? T + : T extends Array ? Array> : T extends ReadonlyArray ? ReadonlyArray> + : T extends { $case: string } ? { [K in keyof Omit]?: DeepPartial } & { $case: T["$case"] } + : T extends {} ? { [K in keyof T]?: DeepPartial } + : Partial; + +type KeysOfUnion = T extends T ? keyof T : never; +export type Exact = P extends Builtin ? P + : P & { [K in keyof P]: Exact } & { [K in Exclude>]: never }; + +function isSet(value: any): boolean { + return value !== null && value !== undefined; +} diff --git a/proto/task_service.proto b/proto/task_service.proto index 0be05132472fa..f6d061cb9a707 100644 --- a/proto/task_service.proto +++ b/proto/task_service.proto @@ -26,6 +26,7 @@ message TaskInfoResponse { FINISHED = 6; FAILED = 7; ABORTED = 8; + CANCELLED = 9; } batch_plan.TaskId task_id = 1; TaskStatus task_status = 2; @@ -39,11 +40,11 @@ message CreateTaskRequest { common.BatchQueryEpoch epoch = 3; } -message AbortTaskRequest { +message CancelTaskRequest { batch_plan.TaskId task_id = 1; } -message AbortTaskResponse { +message CancelTaskResponse { common.Status status = 1; } @@ -63,8 +64,8 @@ message ExecuteRequest { service TaskService { rpc CreateTask(CreateTaskRequest) returns (stream TaskInfoResponse); - // Abort an already-died (self execution-failure, previous aborted, completed) task will still succeed. - rpc AbortTask(AbortTaskRequest) returns (AbortTaskResponse); + // Cancel an already-died (self execution-failure, previous aborted, completed) task will still succeed. + rpc CancelTask(CancelTaskRequest) returns (CancelTaskResponse); rpc Execute(ExecuteRequest) returns (stream GetDataResponse); } diff --git a/src/batch/src/rpc/service/task_service.rs b/src/batch/src/rpc/service/task_service.rs index b6b710127a8f0..a343685bde8c3 100644 --- a/src/batch/src/rpc/service/task_service.rs +++ b/src/batch/src/rpc/service/task_service.rs @@ -18,7 +18,7 @@ use std::sync::Arc; use risingwave_pb::batch_plan::TaskOutputId; use risingwave_pb::task_service::task_service_server::TaskService; use risingwave_pb::task_service::{ - AbortTaskRequest, AbortTaskResponse, CreateTaskRequest, ExecuteRequest, GetDataResponse, + CancelTaskRequest, CancelTaskResponse, CreateTaskRequest, ExecuteRequest, GetDataResponse, TaskInfoResponse, }; use tokio_stream::wrappers::ReceiverStream; @@ -95,17 +95,15 @@ impl TaskService for BatchServiceImpl { } #[cfg_attr(coverage, no_coverage)] - async fn abort_task( + async fn cancel_task( &self, - req: Request, - ) -> Result, Status> { + req: Request, + ) -> Result, Status> { let req = req.into_inner(); tracing::trace!("Aborting task: {:?}", req.get_task_id().unwrap()); - self.mgr.abort_task( - req.get_task_id().expect("no task id found"), - "abort task request".to_string(), - ); - Ok(Response::new(AbortTaskResponse { status: None })) + self.mgr + .cancel_task(req.get_task_id().expect("no task id found")); + Ok(Response::new(CancelTaskResponse { status: None })) } #[cfg_attr(coverage, no_coverage)] diff --git a/src/batch/src/task/task_execution.rs b/src/batch/src/task/task_execution.rs index 62502ad471e03..bcb3758ee95bb 100644 --- a/src/batch/src/task/task_execution.rs +++ b/src/batch/src/task/task_execution.rs @@ -446,6 +446,7 @@ impl BatchTaskExecution { ctx2, ); self.runtime.spawn(alloc_stat_wrap_fut); + Ok(()) } @@ -490,8 +491,18 @@ impl BatchTaskExecution { // We prioritize abort signal over normal data chunks. biased; err_reason = &mut shutdown_rx => { - state = TaskStatus::Aborted; - error = Some(Aborted(err_reason.unwrap_or("".to_string()))); + match err_reason { + Ok(reason_str) => { + state = TaskStatus::Aborted; + error = Some(Aborted(reason_str)); + } + Err(_) => { + // We use early close shutdown channel to cancel task. + // Cancelling a task is different from aborting a task + // in that it's not an error and should not be reported to user. + state = TaskStatus::Cancelled; + } + } break; } res = data_chunk_stream.next() => { @@ -558,7 +569,7 @@ impl BatchTaskExecution { } } - pub fn abort_task(&self, err_msg: String) { + pub fn abort(&self, err_msg: String) { if let Some(sender) = self.shutdown_tx.lock().take() { // No need to set state to be Aborted here cuz it will be set by shutdown receiver. // Stop task execution. @@ -570,6 +581,13 @@ impl BatchTaskExecution { }; } + pub fn cancel(&self) { + if let Some(sender) = self.shutdown_tx.lock().take() { + // Drop sender directly to mark cancel without error. + drop(sender); + }; + } + pub fn get_task_output(&self, output_id: &PbTaskOutputId) -> Result { let task_id = TaskId::from(output_id.get_task_id()?); let receiver = self.receivers.lock()[output_id.get_output_id() as usize] diff --git a/src/batch/src/task/task_manager.rs b/src/batch/src/task/task_manager.rs index 7f50b43cd581f..16fdf6ff4f3de 100644 --- a/src/batch/src/task/task_manager.rs +++ b/src/batch/src/task/task_manager.rs @@ -145,12 +145,14 @@ impl BatchManager { .get_task_output(output_id) } - pub fn abort_task(&self, sid: &PbTaskId, msg: String) { + pub fn cancel_task(&self, sid: &PbTaskId) { let sid = TaskId::from(sid); match self.tasks.lock().remove(&sid) { Some(task) => { tracing::trace!("Removed task: {:?}", task.get_task_id()); - task.abort_task(msg); + // Use `cancel` rather than `abort` here since this is not an error which should be + // propagated to upstream. + task.cancel(); self.metrics.task_num.dec() } None => { @@ -232,7 +234,7 @@ impl BatchManager { let t = guard.get(&id).unwrap(); // FIXME: `Abort` will not report error but truncated results to user. We should // consider throw error. - t.abort_task(reason); + t.abort(reason); } } @@ -386,7 +388,7 @@ mod tests { ) .await .unwrap(); - manager.abort_task(&task_id, "".to_string()); + manager.cancel_task(&task_id); let task_id = TaskId::from(&task_id); assert!(!manager.tasks.lock().contains_key(&task_id)); } diff --git a/src/frontend/src/scheduler/distributed/stage.rs b/src/frontend/src/scheduler/distributed/stage.rs index d9773fdb566fe..5f2dcc49ade31 100644 --- a/src/frontend/src/scheduler/distributed/stage.rs +++ b/src/frontend/src/scheduler/distributed/stage.rs @@ -41,7 +41,7 @@ use risingwave_pb::batch_plan::{ PlanNode as PlanNodePb, PlanNode, TaskId as TaskIdPb, TaskOutputId, }; use risingwave_pb::common::{BatchQueryEpoch, HostAddress, WorkerNode}; -use risingwave_pb::task_service::{AbortTaskRequest, TaskInfoResponse}; +use risingwave_pb::task_service::{CancelTaskRequest, TaskInfoResponse}; use risingwave_rpc_client::ComputeClientPoolRef; use tokio::spawn; use tokio::sync::mpsc::{Receiver, Sender}; @@ -536,7 +536,7 @@ impl StageRunner { self.stage.id, self.tasks.len() ); - self.abort_all_scheduled_tasks().await?; + self.cancel_all_scheducancled_tasks().await?; tracing::trace!( "Stage runner [{:?}-{:?}] existed. ", @@ -761,7 +761,7 @@ impl StageRunner { /// Abort all registered tasks. Note that here we do not care which part of tasks has already /// failed or completed, cuz the abort task will not fail if the task has already die. /// See PR (#4560). - async fn abort_all_scheduled_tasks(&self) -> SchedulerResult<()> { + async fn cancel_all_scheducancled_tasks(&self) -> SchedulerResult<()> { // Set state to failed. // { // let mut state = self.state.write().await; @@ -789,7 +789,7 @@ impl StageRunner { let task_id = *task; spawn(async move { if let Err(e) = client - .abort(AbortTaskRequest { + .cancel(CancelTaskRequest { task_id: Some(risingwave_pb::batch_plan::TaskId { query_id: query_id.clone(), stage_id, diff --git a/src/rpc_client/src/compute_client.rs b/src/rpc_client/src/compute_client.rs index d7cff5eb9bc93..a0ad224106290 100644 --- a/src/rpc_client/src/compute_client.rs +++ b/src/rpc_client/src/compute_client.rs @@ -30,7 +30,7 @@ use risingwave_pb::monitor_service::{ use risingwave_pb::task_service::exchange_service_client::ExchangeServiceClient; use risingwave_pb::task_service::task_service_client::TaskServiceClient; use risingwave_pb::task_service::{ - AbortTaskRequest, AbortTaskResponse, CreateTaskRequest, ExecuteRequest, GetDataRequest, + CancelTaskRequest, CancelTaskResponse, CreateTaskRequest, ExecuteRequest, GetDataRequest, GetDataResponse, GetStreamRequest, GetStreamResponse, TaskInfoResponse, }; use tokio::sync::mpsc; @@ -157,11 +157,11 @@ impl ComputeClient { Ok(self.task_client.to_owned().execute(req).await?.into_inner()) } - pub async fn abort(&self, req: AbortTaskRequest) -> Result { + pub async fn cancel(&self, req: CancelTaskRequest) -> Result { Ok(self .task_client .to_owned() - .abort_task(req) + .cancel_task(req) .await? .into_inner()) }