Skip to content

Commit

Permalink
KAFKA-15045: (KIP-924 pt. 15) Implement #defaultStandbyTaskAssignment…
Browse files Browse the repository at this point in the history
… and finish rack-aware standby optimization (apache#16129)

This fills in the implementation details of the standby task assignment utility functions within TaskAssignmentUtils.

Reviewers: Anna Sophie Blee-Goldman <[email protected]>
  • Loading branch information
apourchet authored May 30, 2024
1 parent 7c1bb15 commit 370e5ea
Show file tree
Hide file tree
Showing 9 changed files with 543 additions and 84 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@
*/
package org.apache.kafka.streams.processor.assignment;

import static java.util.Collections.unmodifiableMap;

import java.time.Instant;
import java.util.HashSet;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.apache.kafka.streams.processor.TaskId;

Expand All @@ -31,7 +33,7 @@
public class KafkaStreamsAssignment {

private final ProcessId processId;
private final Map<TaskId, AssignedTask> assignment;
private final Map<TaskId, AssignedTask> tasks;
private final Optional<Instant> followupRebalanceDeadline;

/**
Expand All @@ -45,7 +47,8 @@ public class KafkaStreamsAssignment {
* @return a new KafkaStreamsAssignment object with the given processId and assignment
*/
public static KafkaStreamsAssignment of(final ProcessId processId, final Set<AssignedTask> assignment) {
return new KafkaStreamsAssignment(processId, assignment, Optional.empty());
final Map<TaskId, AssignedTask> tasks = assignment.stream().collect(Collectors.toMap(AssignedTask::id, Function.identity()));
return new KafkaStreamsAssignment(processId, tasks, Optional.empty());
}

/**
Expand All @@ -62,14 +65,14 @@ public static KafkaStreamsAssignment of(final ProcessId processId, final Set<Ass
* @return a new KafkaStreamsAssignment object with the same processId and assignment but with the given rebalanceDeadline
*/
public KafkaStreamsAssignment withFollowupRebalance(final Instant rebalanceDeadline) {
return new KafkaStreamsAssignment(this.processId(), this.assignment(), Optional.of(rebalanceDeadline));
return new KafkaStreamsAssignment(this.processId(), this.tasks(), Optional.of(rebalanceDeadline));
}

private KafkaStreamsAssignment(final ProcessId processId,
final Set<AssignedTask> assignment,
final Map<TaskId, AssignedTask> tasks,
final Optional<Instant> followupRebalanceDeadline) {
this.processId = processId;
this.assignment = assignment.stream().collect(Collectors.toMap(AssignedTask::id, t -> t));
this.tasks = tasks;
this.followupRebalanceDeadline = followupRebalanceDeadline;
}

Expand All @@ -83,24 +86,18 @@ public ProcessId processId() {

/**
*
* @return a set of assigned tasks that are part of this {@code KafkaStreamsAssignment}
* @return a read-only set of assigned tasks that are part of this {@code KafkaStreamsAssignment}
*/
public Set<AssignedTask> assignment() {
// TODO change assignment to return a map so we aren't forced to copy this into a Set
return new HashSet<>(assignment.values());
}

// TODO: merge this with #assignment by having it return a Map<TaskId, AssignedTask>
public Set<TaskId> assignedTaskIds() {
return assignment.keySet();
public Map<TaskId, AssignedTask> tasks() {
return unmodifiableMap(tasks);
}

public void assignTask(final AssignedTask newTask) {
assignment.put(newTask.id(), newTask);
tasks.put(newTask.id(), newTask);
}

public void removeTask(final AssignedTask removedTask) {
assignment.remove(removedTask.id());
tasks.remove(removedTask.id());
}

/**
Expand Down Expand Up @@ -140,5 +137,25 @@ public TaskId id() {
public Type type() {
return taskType;
}

@Override
public int hashCode() {
final int prime = 31;
int result = prime + this.id.hashCode();
result = prime * result + this.type().hashCode();
return result;
}

@Override
public boolean equals(final Object obj) {
if (this == obj)
return true;
if (obj == null)
return false;
if (getClass() != obj.getClass())
return false;
final AssignedTask other = (AssignedTask) obj;
return this.id.equals(other.id()) && this.taskType == other.taskType;
}
}
}

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ public void processOptimizedAssignments(final Map<ProcessId, KafkaStreamsAssignm

for (final Map.Entry<ProcessId, KafkaStreamsAssignment> entry : optimizedAssignments.entrySet()) {
final ProcessId processId = entry.getKey();
final Set<AssignedTask> assignedTasks = optimizedAssignments.get(processId).assignment();
final Set<AssignedTask> assignedTasks = new HashSet<>(optimizedAssignments.get(processId).tasks().values());
newAssignments.put(processId, assignedTasks);

for (final AssignedTask task : assignedTasks) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import java.time.Instant;
import java.util.Optional;
import java.util.function.Function;
import org.apache.kafka.clients.admin.Admin;
import org.apache.kafka.clients.admin.ListOffsetsResult;
import org.apache.kafka.clients.admin.ListOffsetsResult.ListOffsetsResultInfo;
Expand Down Expand Up @@ -555,18 +556,21 @@ private ApplicationState buildApplicationState(final TopologyMetadata topologyMe

RackUtils.annotateTopicPartitionsWithRackInfo(cluster, internalTopicManager, allTopicPartitions);

final Set<TaskInfo> logicalTasks = logicalTaskIds.stream().map(taskId -> {
final Set<String> stateStoreNames = topologyMetadata
.stateStoreNameToSourceTopicsForTopology(taskId.topologyName())
.keySet();
final Set<TaskTopicPartition> topicPartitions = topicPartitionsForTask.get(taskId);
return new DefaultTaskInfo(
taskId,
!stateStoreNames.isEmpty(),
stateStoreNames,
topicPartitions
);
}).collect(Collectors.toSet());
final Map<TaskId, TaskInfo> logicalTasks = logicalTaskIds.stream().collect(Collectors.toMap(
Function.identity(),
taskId -> {
final Set<String> stateStoreNames = topologyMetadata
.stateStoreNameToSourceTopicsForTopology(taskId.topologyName())
.keySet();
final Set<TaskTopicPartition> topicPartitions = topicPartitionsForTask.get(taskId);
return new DefaultTaskInfo(
taskId,
!stateStoreNames.isEmpty(),
stateStoreNames,
topicPartitions
);
}
));

return new DefaultApplicationState(
assignmentConfigs.toPublicAssignmentConfigs(),
Expand Down Expand Up @@ -728,12 +732,12 @@ private void checkAllPartitions(final Set<String> allSourceTopics,
* populate the stateful tasks that have been assigned to the clients
*/
private UserTaskAssignmentListener assignTasksToClients(final Cluster fullMetadata,
final Set<String> allSourceTopics,
final Map<Subtopology, TopicsInfo> topicGroups,
final Map<UUID, ClientMetadata> clientMetadataMap,
final Map<TaskId, Set<TopicPartition>> partitionsForTask,
final Map<UUID, Map<String, Optional<String>>> racksForProcessConsumer,
final Set<TaskId> statefulTasks) {
final Set<String> allSourceTopics,
final Map<Subtopology, TopicsInfo> topicGroups,
final Map<UUID, ClientMetadata> clientMetadataMap,
final Map<TaskId, Set<TopicPartition>> partitionsForTask,
final Map<UUID, Map<String, Optional<String>>> racksForProcessConsumer,
final Set<TaskId> statefulTasks) {
if (!statefulTasks.isEmpty()) {
throw new TaskAssignmentException("The stateful tasks should not be populated before assigning tasks to clients");
}
Expand Down Expand Up @@ -775,7 +779,7 @@ private UserTaskAssignmentListener assignTasksToClients(final Cluster fullMetada

final Optional<org.apache.kafka.streams.processor.assignment.TaskAssignor> userTaskAssignor =
userTaskAssignorSupplier.get();
UserTaskAssignmentListener userTaskAssignmentListener = (GroupAssignment assignment, GroupSubscription subscription) -> { };
final UserTaskAssignmentListener userTaskAssignmentListener;
if (userTaskAssignor.isPresent()) {
final ApplicationState applicationState = buildApplicationState(
taskManager.topologyMetadata(),
Expand All @@ -785,12 +789,11 @@ private UserTaskAssignmentListener assignTasksToClients(final Cluster fullMetada
);
final org.apache.kafka.streams.processor.assignment.TaskAssignor assignor = userTaskAssignor.get();
final TaskAssignment taskAssignment = assignor.assign(applicationState);
processStreamsPartitionAssignment(clientMetadataMap, taskAssignment);
final AssignmentError assignmentError = validateTaskAssignment(applicationState, taskAssignment);
userTaskAssignmentListener = (GroupAssignment assignment, GroupSubscription subscription) -> {
assignor.onAssignmentComputed(assignment, subscription, assignmentError);
};
processStreamsPartitionAssignment(clientMetadataMap, taskAssignment);
userTaskAssignmentListener = (assignment, subscription) -> assignor.onAssignmentComputed(assignment, subscription, assignmentError);
} else {
userTaskAssignmentListener = (assignment, subscription) -> { };
final TaskAssignor taskAssignor = createTaskAssignor(lagComputationSuccessful);
final RackAwareTaskAssignor rackAwareTaskAssignor = new RackAwareTaskAssignor(
fullMetadata,
Expand Down Expand Up @@ -1564,7 +1567,7 @@ private AssignmentError validateTaskAssignment(final ApplicationState applicatio
final Map<TaskId, ProcessId> standbyTasksInOutput = new HashMap<>();
for (final KafkaStreamsAssignment assignment : assignments) {
final Set<TaskId> tasksForAssignment = new HashSet<>();
for (final KafkaStreamsAssignment.AssignedTask task : assignment.assignment()) {
for (final KafkaStreamsAssignment.AssignedTask task : assignment.tasks().values()) {
if (activeTasksInOutput.containsKey(task.id()) && task.type() == KafkaStreamsAssignment.AssignedTask.Type.ACTIVE) {
log.error("Assignment is invalid: active task {} was assigned to multiple KafkaStreams clients: {} and {}",
task.id(), assignment.processId().id(), activeTasksInOutput.get(task.id()).id());
Expand Down Expand Up @@ -1614,7 +1617,7 @@ private AssignmentError validateTaskAssignment(final ApplicationState applicatio

final Set<TaskId> taskIdsInInput = applicationState.allTasks().keySet();
for (final KafkaStreamsAssignment assignment : assignments) {
for (final KafkaStreamsAssignment.AssignedTask task : assignment.assignment()) {
for (final KafkaStreamsAssignment.AssignedTask task : assignment.tasks().values()) {
if (!taskIdsInInput.contains(task.id())) {
log.error("Assignment is invalid: task {} assigned to KafkaStreams client {} was unknown",
task.id(), assignment.processId().id());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,8 @@ public ClientState(final UUID processId, final Map<String, String> clientTags) {
}

ClientState(final UUID processId, final int capacity, final Map<String, String> clientTags) {
previousStandbyTasks.taskIds(new TreeSet<>());
previousActiveTasks.taskIds(new TreeSet<>());
previousStandbyTasks.setTaskIds(new TreeSet<>());
previousActiveTasks.setTaskIds(new TreeSet<>());
taskOffsetSums = new TreeMap<>();
taskLagTotals = new TreeMap<>();
this.capacity = capacity;
Expand All @@ -110,8 +110,8 @@ public ClientState(final Set<TaskId> previousActiveTasks,
final Map<String, String> clientTags,
final int capacity,
final UUID processId) {
this.previousStandbyTasks.taskIds(unmodifiableSet(new TreeSet<>(previousStandbyTasks)));
this.previousActiveTasks.taskIds(unmodifiableSet(new TreeSet<>(previousActiveTasks)));
this.previousStandbyTasks.setTaskIds(unmodifiableSet(new TreeSet<>(previousStandbyTasks)));
this.previousActiveTasks.setTaskIds(unmodifiableSet(new TreeSet<>(previousActiveTasks)));
taskOffsetSums = emptyMap();
this.taskLagTotals = unmodifiableMap(taskLagTotals);
this.capacity = capacity;
Expand Down Expand Up @@ -489,14 +489,14 @@ public SortedMap<String, Set<TaskId>> taskIdsByPreviousConsumer() {
}

public void setAssignedTasks(final KafkaStreamsAssignment assignment) {
final Set<TaskId> activeTasks = assignment.assignment().stream()
final Set<TaskId> activeTasks = assignment.tasks().values().stream()
.filter(task -> task.type() == ACTIVE).map(KafkaStreamsAssignment.AssignedTask::id)
.collect(Collectors.toSet());
final Set<TaskId> standbyTasks = assignment.assignment().stream()
final Set<TaskId> standbyTasks = assignment.tasks().values().stream()
.filter(task -> task.type() == STANDBY).map(KafkaStreamsAssignment.AssignedTask::id)
.collect(Collectors.toSet());
assignedActiveTasks.taskIds(activeTasks);
assignedStandbyTasks.taskIds(standbyTasks);
assignedActiveTasks.setTaskIds(activeTasks);
assignedStandbyTasks.setTaskIds(standbyTasks);
}

public String currentAssignment() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class ClientStateTask {
this.consumerToTaskIds = consumerToTaskIds;
}

void taskIds(final Set<TaskId> clientToTaskIds) {
void setTaskIds(final Set<TaskId> clientToTaskIds) {
taskIds = clientToTaskIds;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,22 +30,22 @@
/**
* Wraps a priority queue of clients and returns the next valid candidate(s) based on the current task assignment
*/
class ConstrainedPrioritySet {
public class ConstrainedPrioritySet {

private final PriorityQueue<UUID> clientsByTaskLoad;
private final BiFunction<UUID, TaskId, Boolean> constraint;
private final Set<UUID> uniqueClients = new HashSet<>();

ConstrainedPrioritySet(final BiFunction<UUID, TaskId, Boolean> constraint,
final Function<UUID, Double> weight) {
public ConstrainedPrioritySet(final BiFunction<UUID, TaskId, Boolean> constraint,
final Function<UUID, Double> weight) {
this.constraint = constraint;
clientsByTaskLoad = new PriorityQueue<>(Comparator.comparing(weight).thenComparing(clientId -> clientId));
}

/**
* @return the next least loaded client that satisfies the given criteria, or null if none do
*/
UUID poll(final TaskId task, final Function<UUID, Boolean> extraConstraint) {
public UUID poll(final TaskId task, final Function<UUID, Boolean> extraConstraint) {
final Set<UUID> invalidPolledClients = new HashSet<>();
while (!clientsByTaskLoad.isEmpty()) {
final UUID candidateClient = pollNextClient();
Expand All @@ -66,17 +66,17 @@ UUID poll(final TaskId task, final Function<UUID, Boolean> extraConstraint) {
/**
* @return the next least loaded client that satisfies the given criteria, or null if none do
*/
UUID poll(final TaskId task) {
public UUID poll(final TaskId task) {
return poll(task, client -> true);
}

void offerAll(final Collection<UUID> clients) {
public void offerAll(final Collection<UUID> clients) {
for (final UUID client : clients) {
offer(client);
}
}

void offer(final UUID client) {
public void offer(final UUID client) {
if (uniqueClients.contains(client)) {
clientsByTaskLoad.remove(client);
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,7 @@
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.UUID;
import java.util.stream.Collectors;
import org.apache.kafka.streams.processor.assignment.TaskInfo;
import org.apache.kafka.streams.processor.internals.StreamsPartitionAssignor.ClientMetadata;
import org.apache.kafka.streams.processor.TaskId;
Expand All @@ -42,10 +40,10 @@ public class DefaultApplicationState implements ApplicationState {
private final Map<Boolean, Map<ProcessId, KafkaStreamsState>> cachedKafkaStreamStates;

public DefaultApplicationState(final AssignmentConfigs assignmentConfigs,
final Set<TaskInfo> tasks,
final Map<TaskId, TaskInfo> tasks,
final Map<UUID, ClientMetadata> clientStates) {
this.assignmentConfigs = assignmentConfigs;
this.tasks = unmodifiableMap(tasks.stream().collect(Collectors.toMap(TaskInfo::id, task -> task)));
this.tasks = unmodifiableMap(tasks);
this.clientStates = clientStates;
this.cachedKafkaStreamStates = new HashMap<>();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ int getCost(final TaskId taskId,

// This is number is picked based on testing. Usually the optimization for standby assignment
// stops after 3 rounds
private static final int STANDBY_OPTIMIZER_MAX_ITERATION = 4;
public static final int STANDBY_OPTIMIZER_MAX_ITERATION = 4;

private final Cluster fullMetadata;
private final Map<TaskId, Set<TopicPartition>> partitionsForTask;
Expand Down

0 comments on commit 370e5ea

Please sign in to comment.