Skip to content

Commit

Permalink
KAFKA-15045: (KIP-924 pt. 12) Wiring in new assignment configs and lo…
Browse files Browse the repository at this point in the history
…gic (apache#16074)

This PR creates the new public config of KIP-924 in StreamsConfig and uses it to instantiate user-created TaskAssignors. If such a TaskAssignor is found and successfully created we then use that assignor to perform the task assignment, otherwise we revert back to the pre KIP-924 world with the internal task assignors.

Reviewers: Anna Sophie Blee-Goldman <[email protected]>, Almog Gavra <[email protected]>
  • Loading branch information
apourchet authored May 29, 2024
1 parent 56ee139 commit 8d243df
Show file tree
Hide file tree
Showing 5 changed files with 113 additions and 47 deletions.
10 changes: 10 additions & 0 deletions streams/src/main/java/org/apache/kafka/streams/StreamsConfig.java
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
import org.apache.kafka.streams.internals.UpgradeFromValues;
import org.apache.kafka.streams.processor.FailOnInvalidTimestamp;
import org.apache.kafka.streams.processor.TimestampExtractor;
import org.apache.kafka.streams.processor.assignment.TaskAssignor;
import org.apache.kafka.streams.processor.internals.DefaultKafkaClientSupplier;
import org.apache.kafka.streams.processor.internals.StreamsPartitionAssignor;
import org.apache.kafka.streams.processor.internals.assignment.RackAwareTaskAssignor;
Expand Down Expand Up @@ -820,6 +821,10 @@ public class StreamsConfig extends AbstractConfig {
+ "optimization algorithm favors minimizing cross rack traffic or minimize the movement of tasks in existing assignment. If set a larger value <code>" + RackAwareTaskAssignor.class.getName() + "</code> will "
+ "optimize to maintain the existing assignment. The default value is null which means it will use default non_overlap cost values in different assignors.";

@SuppressWarnings("WeakerAccess")
public static final String TASK_ASSIGNOR_CLASS_CONFIG = "task.assignor.class";
private static final String TASK_ASSIGNOR_CLASS_DOC = "A task assignor class or class name implementing the <code>" +
TaskAssignor.class.getName() + "</code> interface. Defaults to the <code>HighAvailabilityTaskAssignor</code> class.";

/**
* {@code topology.optimization}
Expand Down Expand Up @@ -980,6 +985,11 @@ public class StreamsConfig extends AbstractConfig {
null,
Importance.MEDIUM,
RACK_AWARE_ASSIGNMENT_TRAFFIC_COST_DOC)
.define(TASK_ASSIGNOR_CLASS_CONFIG,
Type.STRING,
null,
Importance.MEDIUM,
TASK_ASSIGNOR_CLASS_DOC)
.define(REPLICATION_FACTOR_CONFIG,
Type.INT,
-1,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
*/
package org.apache.kafka.streams.processor.internals;

import java.time.Instant;
import java.util.Optional;
import org.apache.kafka.clients.admin.Admin;
import org.apache.kafka.clients.admin.ListOffsetsResult;
Expand Down Expand Up @@ -214,6 +215,7 @@ public String toString() {
private RebalanceProtocol rebalanceProtocol;
private AssignmentListener assignmentListener;

private Supplier<Optional<org.apache.kafka.streams.processor.assignment.TaskAssignor>> userTaskAssignorSupplier;
private Supplier<TaskAssignor> taskAssignorSupplier;
private byte uniqueField;
private Map<String, String> clientTags;
Expand Down Expand Up @@ -248,6 +250,7 @@ public void configure(final Map<String, ?> configs) {
internalTopicManager = assignorConfiguration.internalTopicManager();
copartitionedTopicsEnforcer = assignorConfiguration.copartitionedTopicsEnforcer();
rebalanceProtocol = assignorConfiguration.rebalanceProtocol();
userTaskAssignorSupplier = assignorConfiguration::userTaskAssignor;
taskAssignorSupplier = assignorConfiguration::taskAssignor;
assignmentListener = assignorConfiguration.assignmentListener();
uniqueField = 0;
Expand Down Expand Up @@ -400,9 +403,6 @@ public GroupAssignment assign(final Cluster metadata, final GroupSubscription gr
}

try {
final boolean versionProbing =
checkMetadataVersions(minReceivedMetadataVersion, minSupportedMetadataVersion, futureMetadataVersion);

log.debug("Constructed client metadata {} from the member subscriptions.", clientMetadataMap);

// ---------------- Step One ---------------- //
Expand Down Expand Up @@ -440,7 +440,9 @@ public GroupAssignment assign(final Cluster metadata, final GroupSubscription gr

final Set<TaskId> statefulTasks = new HashSet<>();

final boolean probingRebalanceNeeded = assignTasksToClients(fullMetadata, allSourceTopics, topicGroups,
final boolean versionProbing =
checkMetadataVersions(minReceivedMetadataVersion, minSupportedMetadataVersion, futureMetadataVersion);
assignTasksToClients(fullMetadata, allSourceTopics, topicGroups,
clientMetadataMap, partitionsForTask, racksForProcessConsumer, statefulTasks);

// ---------------- Step Three ---------------- //
Expand All @@ -465,8 +467,7 @@ public GroupAssignment assign(final Cluster metadata, final GroupSubscription gr
allOwnedPartitions,
minReceivedMetadataVersion,
minSupportedMetadataVersion,
versionProbing,
probingRebalanceNeeded
versionProbing
);

return new GroupAssignment(assignment);
Expand Down Expand Up @@ -570,6 +571,9 @@ private static void processStreamsPartitionAssignment(final Map<UUID, ClientMeta
final ProcessId processId = kafkaStreamsAssignment.processId();
final ClientMetadata clientMetadata = clientMetadataMap.get(processId.id());
clientMetadata.state.setAssignedTasks(kafkaStreamsAssignment);
if (kafkaStreamsAssignment.followupRebalanceDeadline().isPresent()) {
clientMetadata.state.setFollowupRebalanceDeadline(kafkaStreamsAssignment.followupRebalanceDeadline().get());
}
});
}

Expand Down Expand Up @@ -712,15 +716,14 @@ private void checkAllPartitions(final Set<String> allSourceTopics,
/**
* Assigns a set of tasks to each client (Streams instance) using the configured task assignor, and also
* populate the stateful tasks that have been assigned to the clients
* @return true if a probing rebalance should be triggered
*/
private boolean assignTasksToClients(final Cluster fullMetadata,
final Set<String> allSourceTopics,
final Map<Subtopology, TopicsInfo> topicGroups,
final Map<UUID, ClientMetadata> clientMetadataMap,
final Map<TaskId, Set<TopicPartition>> partitionsForTask,
final Map<UUID, Map<String, Optional<String>>> racksForProcessConsumer,
final Set<TaskId> statefulTasks) {
private void assignTasksToClients(final Cluster fullMetadata,
final Set<String> allSourceTopics,
final Map<Subtopology, TopicsInfo> topicGroups,
final Map<UUID, ClientMetadata> clientMetadataMap,
final Map<TaskId, Set<TopicPartition>> partitionsForTask,
final Map<UUID, Map<String, Optional<String>>> racksForProcessConsumer,
final Set<TaskId> statefulTasks) {
if (!statefulTasks.isEmpty()) {
throw new TaskAssignmentException("The stateful tasks should not be populated before assigning tasks to clients");
}
Expand Down Expand Up @@ -760,23 +763,45 @@ private boolean assignTasksToClients(final Cluster fullMetadata,
log.debug("Assigning tasks and {} standby replicas to client nodes {}",
numStandbyReplicas(), clientStates);

final TaskAssignor taskAssignor = createTaskAssignor(lagComputationSuccessful);

final RackAwareTaskAssignor rackAwareTaskAssignor = new RackAwareTaskAssignor(
fullMetadata,
partitionsForTask,
changelogTopics.changelogPartionsForTask(),
tasksForTopicGroup,
racksForProcessConsumer,
internalTopicManager,
assignmentConfigs,
time
);
final boolean probingRebalanceNeeded = taskAssignor.assign(clientStates,
allTasks,
statefulTasks,
rackAwareTaskAssignor,
assignmentConfigs);
final Optional<org.apache.kafka.streams.processor.assignment.TaskAssignor> userTaskAssignor =
userTaskAssignorSupplier.get();
if (userTaskAssignor.isPresent()) {
final ApplicationState applicationState = buildApplicationState(
taskManager.topologyMetadata(),
clientMetadataMap,
topicGroups,
fullMetadata
);
final TaskAssignment taskAssignment = userTaskAssignor.get().assign(applicationState);
processStreamsPartitionAssignment(clientMetadataMap, taskAssignment);
} else {
final TaskAssignor taskAssignor = createTaskAssignor(lagComputationSuccessful);
final RackAwareTaskAssignor rackAwareTaskAssignor = new RackAwareTaskAssignor(
fullMetadata,
partitionsForTask,
changelogTopics.changelogPartionsForTask(),
tasksForTopicGroup,
racksForProcessConsumer,
internalTopicManager,
assignmentConfigs,
time
);
final boolean probingRebalanceNeeded = taskAssignor.assign(clientStates,
allTasks,
statefulTasks,
rackAwareTaskAssignor,
assignmentConfigs);
if (probingRebalanceNeeded) {
// Arbitrarily choose the leader's client to be responsible for triggering the probing rebalance,
// note once we pick the first consumer within the process to trigger probing rebalance, other consumer
// would not set to trigger any more.
final ClientMetadata rebalanceClientMetadata = clientMetadataMap.get(taskManager.processId());
if (rebalanceClientMetadata != null) {
final Instant rebalanceDeadline = Instant.ofEpochMilli(time.milliseconds() + probingRebalanceIntervalMs());
rebalanceClientMetadata.state.setFollowupRebalanceDeadline(rebalanceDeadline);
}
}
}

// Break this up into multiple logs to make sure the summary info gets through, which helps avoid
// info loss for example due to long line truncation with large apps
Expand All @@ -789,8 +814,6 @@ private boolean assignTasksToClients(final Cluster fullMetadata,
.sorted(comparingByKey())
.map(entry -> entry.getKey() + "=" + entry.getValue().currentAssignment())
.collect(Collectors.joining(Utils.NL)));

return probingRebalanceNeeded;
}

private TaskAssignor createTaskAssignor(final boolean lagComputationSuccessful) {
Expand Down Expand Up @@ -948,9 +971,8 @@ private Map<String, Assignment> computeNewAssignment(final Set<TaskId> statefulT
final Set<TopicPartition> allOwnedPartitions,
final int minUserMetadataVersion,
final int minSupportedMetadataVersion,
final boolean versionProbing,
final boolean shouldTriggerProbingRebalance) {
boolean rebalanceRequired = shouldTriggerProbingRebalance || versionProbing;
final boolean versionProbing) {
boolean rebalanceRequired = versionProbing;
final Map<String, Assignment> assignment = new HashMap<>();

// within the client, distribute tasks to its owned consumers
Expand Down Expand Up @@ -992,10 +1014,7 @@ private Map<String, Assignment> computeNewAssignment(final Set<TaskId> statefulT
activeTaskAssignment.get(threadEntry.getKey()).addAll(threadEntry.getValue());
}

// Arbitrarily choose the leader's client to be responsible for triggering the probing rebalance,
// note once we pick the first consumer within the process to trigger probing rebalance, other consumer
// would not set to trigger any more.
final boolean encodeNextProbingRebalanceTime = shouldTriggerProbingRebalance && clientId.equals(taskManager.processId());
final boolean isNextProbingRebalanceEncoded = clientMetadata.state.followupRebalanceDeadline().isPresent();

final boolean tasksRevoked = addClientAssignments(
statefulTasks,
Expand All @@ -1008,11 +1027,10 @@ private Map<String, Assignment> computeNewAssignment(final Set<TaskId> statefulT
activeTaskAssignment,
standbyTaskAssignment,
minUserMetadataVersion,
minSupportedMetadataVersion,
encodeNextProbingRebalanceTime
minSupportedMetadataVersion
);

if (tasksRevoked || encodeNextProbingRebalanceTime) {
if (tasksRevoked || isNextProbingRebalanceEncoded) {
rebalanceRequired = true;
log.debug("Requested client {} to schedule a followup rebalance", clientId);
}
Expand Down Expand Up @@ -1056,12 +1074,12 @@ private boolean addClientAssignments(final Set<TaskId> statefulTasks,
final Map<String, List<TaskId>> activeTaskAssignments,
final Map<String, List<TaskId>> standbyTaskAssignments,
final int minUserMetadataVersion,
final int minSupportedMetadataVersion,
final boolean probingRebalanceNeeded) {
final int minSupportedMetadataVersion) {
boolean followupRebalanceRequiredForRevokedTasks = false;

// We only want to encode a scheduled probing rebalance for a single member in this client
boolean shouldEncodeProbingRebalance = probingRebalanceNeeded;
final Optional<Instant> followupRebalanceDeadline = clientMetadata.state.followupRebalanceDeadline();
boolean shouldEncodeProbingRebalance = followupRebalanceDeadline.isPresent();

// Loop through the consumers and build their assignment
for (final String consumer : clientMetadata.consumers) {
Expand Down Expand Up @@ -1108,7 +1126,7 @@ private boolean addClientAssignments(final Set<TaskId> statefulTasks,
// Don't bother to schedule a probing rebalance if an immediate one is already scheduled
shouldEncodeProbingRebalance = false;
} else if (shouldEncodeProbingRebalance) {
final long nextRebalanceTimeMs = time.milliseconds() + probingRebalanceIntervalMs();
final long nextRebalanceTimeMs = followupRebalanceDeadline.get().toEpochMilli();
log.info("Requesting followup rebalance be scheduled by {} for {} to probe for caught-up replica tasks.",
consumer, Utils.toLogDateTimeFormat(nextRebalanceTimeMs));
info.setNextRebalanceTime(nextRebalanceTimeMs);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
*/
package org.apache.kafka.streams.processor.internals.assignment;

import java.util.Optional;
import org.apache.kafka.clients.CommonClientConfigs;
import org.apache.kafka.clients.consumer.ConsumerPartitionAssignor.RebalanceProtocol;
import org.apache.kafka.common.KafkaException;
Expand Down Expand Up @@ -253,6 +254,24 @@ public TaskAssignor taskAssignor() {
}
}

public Optional<org.apache.kafka.streams.processor.assignment.TaskAssignor> userTaskAssignor() {
final String userTaskAssignorClassname = streamsConfig.getString(StreamsConfig.TASK_ASSIGNOR_CLASS_CONFIG);
if (userTaskAssignorClassname == null) {
return Optional.empty();
}
try {
final org.apache.kafka.streams.processor.assignment.TaskAssignor assignor = Utils.newInstance(userTaskAssignorClassname,
org.apache.kafka.streams.processor.assignment.TaskAssignor.class);
log.info("Instantiated {} as the task assignor.", userTaskAssignorClassname);
return Optional.of(assignor);
} catch (final ClassNotFoundException e) {
throw new IllegalArgumentException(
"Expected an instantiable class name for " + StreamsConfig.TASK_ASSIGNOR_CLASS_CONFIG + " but got " + userTaskAssignorClassname,
e
);
}
}

public AssignmentListener assignmentListener() {
final Object o = internalConfigs.get(InternalConfig.ASSIGNMENT_LISTENER);
if (o == null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
*/
package org.apache.kafka.streams.processor.internals.assignment;

import java.time.Instant;
import java.util.Optional;
import java.util.SortedMap;
import org.apache.kafka.common.TopicPartition;
import org.apache.kafka.streams.processor.TaskId;
Expand Down Expand Up @@ -62,6 +64,8 @@ public class ClientState {
private final ClientStateTask previousStandbyTasks = new ClientStateTask(null, null);
private final ClientStateTask revokingActiveTasks = new ClientStateTask(null, new TreeMap<>());
private final UUID processId;

private Optional<Instant> followupRebalanceDeadline = Optional.empty();
private int capacity;

public ClientState() {
Expand Down Expand Up @@ -143,6 +147,14 @@ boolean reachedCapacity() {
return assignedTaskCount() >= capacity;
}

public Optional<Instant> followupRebalanceDeadline() {
return followupRebalanceDeadline;
}

public void setFollowupRebalanceDeadline(final Instant followupRebalanceDeadline) {
this.followupRebalanceDeadline = Optional.of(followupRebalanceDeadline);
}

public Set<TaskId> activeTasks() {
return unmodifiableSet(assignedActiveTasks.taskIds());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
import static org.apache.kafka.streams.StreamsConfig.RACK_AWARE_ASSIGNMENT_NON_OVERLAP_COST_CONFIG;
import static org.apache.kafka.streams.StreamsConfig.RACK_AWARE_ASSIGNMENT_TRAFFIC_COST_CONFIG;
import static org.apache.kafka.streams.StreamsConfig.STATE_DIR_CONFIG;
import static org.apache.kafka.streams.StreamsConfig.TASK_ASSIGNOR_CLASS_CONFIG;
import static org.apache.kafka.streams.StreamsConfig.TOPOLOGY_OPTIMIZATION_CONFIG;
import static org.apache.kafka.streams.StreamsConfig.adminClientPrefix;
import static org.apache.kafka.streams.StreamsConfig.consumerPrefix;
Expand Down Expand Up @@ -1457,6 +1458,12 @@ public void shouldReturnRackAwareAssignmentNonOverlapCost() {
assertEquals(Integer.valueOf(10), new StreamsConfig(props).getInt(RACK_AWARE_ASSIGNMENT_NON_OVERLAP_COST_CONFIG));
}

@Test
public void shouldReturnTaskAssignorClass() {
props.put(StreamsConfig.TASK_ASSIGNOR_CLASS_CONFIG, "StickyTaskAssignor");
assertEquals("StickyTaskAssignor", new StreamsConfig(props).getString(TASK_ASSIGNOR_CLASS_CONFIG));
}

@Test
public void shouldReturnDefaultClientSupplier() {
final KafkaClientSupplier supplier = streamsConfig.getKafkaClientSupplier();
Expand Down

0 comments on commit 8d243df

Please sign in to comment.