Skip to content

Commit

Permalink
KAFKA-15045: (KIP-924 pt. 4) Generify rack graph solving utilities (a…
Browse files Browse the repository at this point in the history
…pache#15956)

The graph solving utilities are currently hardcoded to work with ClientState, but don't actually depend on anything in those state classes.

This change allows the MinTrafficGraphConstructor and BalanceSubtopologyGraphConstructor to be reused with KafkaStreamsStates instead.

Reviewers: Anna Sophie Blee-Goldman <[email protected]>, Almog Gavra <[email protected]>
  • Loading branch information
apourchet authored May 16, 2024
1 parent 056d232 commit fafa3c7
Show file tree
Hide file tree
Showing 6 changed files with 40 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
import org.apache.kafka.streams.processor.internals.TopologyMetadata.Subtopology;
import org.apache.kafka.streams.processor.internals.assignment.RackAwareTaskAssignor.CostFunction;

public class BalanceSubtopologyGraphConstructor implements RackAwareGraphConstructor {
public class BalanceSubtopologyGraphConstructor<T> implements RackAwareGraphConstructor<T> {

private final Map<Subtopology, Set<TaskId>> tasksForTopicGroup;

Expand Down Expand Up @@ -71,10 +71,10 @@ private static int getSecondStageClientNodeId(final List<TaskId> taskIdList, fin
public Graph<Integer> constructTaskGraph(
final List<UUID> clientList,
final List<TaskId> taskIdList,
final Map<UUID, ClientState> clientStates,
final Map<UUID, T> clientStates,
final Map<TaskId, UUID> taskClientMap,
final Map<UUID, Integer> originalAssignedTaskNumber,
final BiPredicate<ClientState, TaskId> hasAssignedTask,
final BiPredicate<T, TaskId> hasAssignedTask,
final CostFunction costFunction,
final int trafficCost,
final int nonOverlapCost,
Expand All @@ -86,7 +86,7 @@ public Graph<Integer> constructTaskGraph(
final Graph<Integer> graph = new Graph<>();

for (final TaskId taskId : taskIdList) {
for (final Entry<UUID, ClientState> clientState : clientStates.entrySet()) {
for (final Entry<UUID, T> clientState : clientStates.entrySet()) {
if (hasAssignedTask.test(clientState.getValue(), taskId)) {
originalAssignedTaskNumber.merge(clientState.getKey(), 1, Integer::sum);
}
Expand Down Expand Up @@ -122,12 +122,12 @@ public boolean assignTaskFromMinCostFlow(
final Graph<Integer> graph,
final List<UUID> clientList,
final List<TaskId> taskIdList,
final Map<UUID, ClientState> clientStates,
final Map<UUID, T> clientStates,
final Map<UUID, Integer> originalAssignedTaskNumber,
final Map<TaskId, UUID> taskClientMap,
final BiConsumer<ClientState, TaskId> assignTask,
final BiConsumer<ClientState, TaskId> unAssignTask,
final BiPredicate<ClientState, TaskId> hasAssignedTask
final BiConsumer<T, TaskId> assignTask,
final BiConsumer<T, TaskId> unAssignTask,
final BiPredicate<T, TaskId> hasAssignedTask
) {
final SortedMap<Subtopology, Set<TaskId>> sortedTasksForTopicGroup = new TreeMap<>(tasksForTopicGroup);
final Set<TaskId> taskIdSet = new HashSet<>(taskIdList);
Expand Down Expand Up @@ -170,10 +170,10 @@ private void constructEdges(
final Graph<Integer> graph,
final List<TaskId> taskIdList,
final List<UUID> clientList,
final Map<UUID, ClientState> clientStates,
final Map<UUID, T> clientStates,
final Map<TaskId, UUID> taskClientMap,
final Map<UUID, Integer> originalAssignedTaskNumber,
final BiPredicate<ClientState, TaskId> hasAssignedTask,
final BiPredicate<T, TaskId> hasAssignedTask,
final CostFunction costFunction,
final int trafficCost,
final int nonOverlapCost,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import org.apache.kafka.streams.processor.internals.TopologyMetadata.Subtopology;
import org.apache.kafka.streams.processor.internals.assignment.RackAwareTaskAssignor.CostFunction;

public class MinTrafficGraphConstructor implements RackAwareGraphConstructor {
public class MinTrafficGraphConstructor<T> implements RackAwareGraphConstructor<T> {

@Override
public int getSinkNodeID(
Expand All @@ -53,10 +53,10 @@ public int getClientIndex(final int clientNodeId, final List<TaskId> taskIdList,
public Graph<Integer> constructTaskGraph(
final List<UUID> clientList,
final List<TaskId> taskIdList,
final Map<UUID, ClientState> clientStates,
final Map<UUID, T> clientStates,
final Map<TaskId, UUID> taskClientMap,
final Map<UUID, Integer> originalAssignedTaskNumber,
final BiPredicate<ClientState, TaskId> hasAssignedTask,
final BiPredicate<T, TaskId> hasAssignedTask,
final CostFunction costFunction,
final int trafficCost,
final int nonOverlapCost,
Expand All @@ -66,7 +66,7 @@ public Graph<Integer> constructTaskGraph(
final Graph<Integer> graph = new Graph<>();

for (final TaskId taskId : taskIdList) {
for (final Entry<UUID, ClientState> clientState : clientStates.entrySet()) {
for (final Entry<UUID, T> clientState : clientStates.entrySet()) {
if (hasAssignedTask.test(clientState.getValue(), taskId)) {
originalAssignedTaskNumber.merge(clientState.getKey(), 1, Integer::sum);
}
Expand Down Expand Up @@ -122,12 +122,12 @@ public boolean assignTaskFromMinCostFlow(
final Graph<Integer> graph,
final List<UUID> clientList,
final List<TaskId> taskIdList,
final Map<UUID, ClientState> clientStates,
final Map<UUID, T> clientStates,
final Map<UUID, Integer> originalAssignedTaskNumber,
final Map<TaskId, UUID> taskClientMap,
final BiConsumer<ClientState, TaskId> assignTask,
final BiConsumer<ClientState, TaskId> unAssignTask,
final BiPredicate<ClientState, TaskId> hasAssignedTask
final BiConsumer<T, TaskId> assignTask,
final BiConsumer<T, TaskId> unAssignTask,
final BiPredicate<T, TaskId> hasAssignedTask
) {
int tasksAssigned = 0;
boolean taskMoved = false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
/**
* Construct graph for rack aware task assignor
*/
public interface RackAwareGraphConstructor {
public interface RackAwareGraphConstructor<T> {
int SOURCE_ID = -1;

int getSinkNodeID(final List<TaskId> taskIdList, final List<UUID> clientList, final Map<Subtopology, Set<TaskId>> tasksForTopicGroup);
Expand All @@ -45,10 +45,10 @@ public interface RackAwareGraphConstructor {
Graph<Integer> constructTaskGraph(
final List<UUID> clientList,
final List<TaskId> taskIdList,
final Map<UUID, ClientState> clientStates,
final Map<UUID, T> clientStates,
final Map<TaskId, UUID> taskClientMap,
final Map<UUID, Integer> originalAssignedTaskNumber,
final BiPredicate<ClientState, TaskId> hasAssignedTask,
final BiPredicate<T, TaskId> hasAssignedTask,
final CostFunction costFunction,
final int trafficCost,
final int nonOverlapCost,
Expand All @@ -59,24 +59,24 @@ boolean assignTaskFromMinCostFlow(
final Graph<Integer> graph,
final List<UUID> clientList,
final List<TaskId> taskIdList,
final Map<UUID, ClientState> clientStates,
final Map<UUID, T> clientStates,
final Map<UUID, Integer> originalAssignedTaskNumber,
final Map<TaskId, UUID> taskClientMap,
final BiConsumer<ClientState, TaskId> assignTask,
final BiConsumer<ClientState, TaskId> unAssignTask,
final BiPredicate<ClientState, TaskId> hasAssignedTask);
final BiConsumer<T, TaskId> assignTask,
final BiConsumer<T, TaskId> unAssignTask,
final BiPredicate<T, TaskId> hasAssignedTask);

default KeyValue<Boolean, Integer> assignTaskToClient(
final Graph<Integer> graph,
final TaskId taskId,
final int taskNodeId,
final int topicGroupIndex,
final Map<UUID, ClientState> clientStates,
final Map<UUID, T> clientStates,
final List<UUID> clientList,
final List<TaskId> taskIdList,
final Map<TaskId, UUID> taskClientMap,
final BiConsumer<ClientState, TaskId> assignTask,
final BiConsumer<ClientState, TaskId> unAssignTask
final BiConsumer<T, TaskId> assignTask,
final BiConsumer<T, TaskId> unAssignTask
) {
int tasksAssigned = 0;
boolean taskMoved = false;
Expand Down Expand Up @@ -104,9 +104,9 @@ default KeyValue<Boolean, Integer> assignTaskToClient(
default void validateAssignedTask(
final List<TaskId> taskIdList,
final int tasksAssigned,
final Map<UUID, ClientState> clientStates,
final Map<UUID, T> clientStates,
final Map<UUID, Integer> originalAssignedTaskNumber,
final BiPredicate<ClientState, TaskId> hasAssignedTask
final BiPredicate<T, TaskId> hasAssignedTask
) {
// Validate task assigned
if (tasksAssigned != taskIdList.size()) {
Expand All @@ -117,7 +117,7 @@ default void validateAssignedTask(
// Validate original assigned task number matches
final Map<UUID, Integer> assignedTaskNumber = new HashMap<>();
for (final TaskId taskId : taskIdList) {
for (final Entry<UUID, ClientState> clientState : clientStates.entrySet()) {
for (final Entry<UUID, T> clientState : clientStates.entrySet()) {
if (hasAssignedTask.test(clientState.getValue(), taskId)) {
assignedTaskNumber.merge(clientState.getKey(), 1, Integer::sum);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,12 @@

public class RackAwareGraphConstructorFactory {

static RackAwareGraphConstructor create(final AssignmentConfigs assignmentConfigs, final Map<Subtopology, Set<TaskId>> tasksForTopicGroup) {
static <T> RackAwareGraphConstructor<T> create(final AssignmentConfigs assignmentConfigs, final Map<Subtopology, Set<TaskId>> tasksForTopicGroup) {
switch (assignmentConfigs.rackAwareAssignmentStrategy) {
case StreamsConfig.RACK_AWARE_ASSIGNMENT_STRATEGY_MIN_TRAFFIC:
return new MinTrafficGraphConstructor();
return new MinTrafficGraphConstructor<T>();
case StreamsConfig.RACK_AWARE_ASSIGNMENT_STRATEGY_BALANCE_SUBTOPOLOGY:
return new BalanceSubtopologyGraphConstructor(tasksForTopicGroup);
return new BalanceSubtopologyGraphConstructor<T>(tasksForTopicGroup);
default:
throw new IllegalArgumentException("Rack aware assignment is disabled");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ private long tasksCost(final SortedSet<TaskId> tasks,
}
final List<UUID> clientList = new ArrayList<>(clientStates.keySet());
final List<TaskId> taskIdList = new ArrayList<>(tasks);
final Graph<Integer> graph = new MinTrafficGraphConstructor()
final Graph<Integer> graph = new MinTrafficGraphConstructor<ClientState>()
.constructTaskGraph(
clientList,
taskIdList,
Expand Down Expand Up @@ -373,7 +373,7 @@ public long optimizeActiveTasks(final SortedSet<TaskId> activeTasks,
final List<TaskId> taskIdList = new ArrayList<>(activeTasks);
final Map<TaskId, UUID> taskClientMap = new HashMap<>();
final Map<UUID, Integer> originalAssignedTaskNumber = new HashMap<>();
final RackAwareGraphConstructor graphConstructor = RackAwareGraphConstructorFactory.create(assignmentConfigs, tasksForTopicGroup);
final RackAwareGraphConstructor<ClientState> graphConstructor = RackAwareGraphConstructorFactory.create(assignmentConfigs, tasksForTopicGroup);
final Graph<Integer> graph = graphConstructor.constructTaskGraph(
clientList,
taskIdList,
Expand Down Expand Up @@ -419,7 +419,7 @@ public long optimizeStandbyTasks(final SortedMap<UUID, ClientState> clientStates

boolean taskMoved = true;
int round = 0;
final RackAwareGraphConstructor graphConstructor = new MinTrafficGraphConstructor();
final RackAwareGraphConstructor<ClientState> graphConstructor = new MinTrafficGraphConstructor<>();
while (taskMoved && round < STANDBY_OPTIMIZER_MAX_ITERATION) {
taskMoved = false;
round++;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ public class RackAwareGraphConstructorTest {
private final Map<UUID, Integer> originalAssignedTaskNumber = new HashMap<>();
private final Map<Subtopology, Set<TaskId>> tasksForTopicGroup = getTasksForTopicGroup(TP_SIZE,
PARTITION_SIZE);
private RackAwareGraphConstructor constructor;
private RackAwareGraphConstructor<ClientState> constructor;

@Parameter
public String constructorType;
Expand All @@ -86,9 +86,9 @@ public void setUp() {
randomAssignTasksToClient(taskIdList, clientStateMap);

if (constructorType.equals(MIN_COST)) {
constructor = new MinTrafficGraphConstructor();
constructor = new MinTrafficGraphConstructor<>();
} else if (constructorType.equals(BALANCE_SUBTOPOLOGY)) {
constructor = new BalanceSubtopologyGraphConstructor(tasksForTopicGroup);
constructor = new BalanceSubtopologyGraphConstructor<>(tasksForTopicGroup);
}
graph = constructor.constructTaskGraph(
clientList, taskIdList, clientStateMap, taskClientMap, originalAssignedTaskNumber, ClientState::hasAssignedTask, this::getCost, 10, 1, false, false);
Expand Down

0 comments on commit fafa3c7

Please sign in to comment.