From 29b4d9285867bfd7e5b92276ca25137e38d9fb2a Mon Sep 17 00:00:00 2001 From: min-guk Date: Sat, 21 Dec 2024 16:04:56 +0100 Subject: [PATCH] [SYSTEMDS-3790] Rework FedPlanner memo table, cost estimator, enumerator Closes #2147. --- .../hops/fedplanner/FederatedMemoTable.java | 288 ++++++++++++++++++ .../FederatedPlanCostEnumerator.java | 136 +++++++++ .../FederatedPlanCostEstimator.java | 116 +++++++ .../sysds/hops/fedplanner/MemoTable.java | 160 ---------- .../FederatedPlanCostEnumeratorTest.java | 87 ++++++ .../component/federated/MemoTableTest.java | 186 ----------- src/test/scripts/functions/federated/cost.dml | 25 ++ 7 files changed, 652 insertions(+), 346 deletions(-) create mode 100644 src/main/java/org/apache/sysds/hops/fedplanner/FederatedMemoTable.java create mode 100644 src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEnumerator.java create mode 100644 src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEstimator.java delete mode 100644 src/main/java/org/apache/sysds/hops/fedplanner/MemoTable.java create mode 100644 src/test/java/org/apache/sysds/test/component/federated/FederatedPlanCostEnumeratorTest.java delete mode 100644 src/test/java/org/apache/sysds/test/component/federated/MemoTableTest.java create mode 100644 src/test/scripts/functions/federated/cost.dml diff --git a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedMemoTable.java b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedMemoTable.java new file mode 100644 index 00000000000..16240f0281f --- /dev/null +++ b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedMemoTable.java @@ -0,0 +1,288 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.hops.fedplanner; + +import org.apache.sysds.hops.Hop; +import org.apache.sysds.hops.OptimizerUtils; +import org.apache.commons.lang3.tuple.Pair; +import org.apache.commons.lang3.tuple.ImmutablePair; +import org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput; +import java.util.Comparator; +import java.util.HashMap; +import java.util.List; +import java.util.ArrayList; +import java.util.Map; +import java.util.HashSet; +import java.util.Set; + +/** + * A Memoization Table for managing federated plans (FedPlan) based on combinations of Hops and fedOutTypes. + * This table stores and manages different execution plan variants for each Hop and fedOutType combination, + * facilitating the optimization of federated execution plans. + */ +public class FederatedMemoTable { + // Maps Hop ID and fedOutType pairs to their plan variants + private final Map, FedPlanVariants> hopMemoTable = new HashMap<>(); + + /** + * Adds a new federated plan to the memo table. + * Creates a new variant list if none exists for the given Hop and fedOutType. + * + * @param hop The Hop node + * @param fedOutType The federated output type + * @param planChilds List of child plan references + * @return The newly created FedPlan + */ + public FedPlan addFedPlan(Hop hop, FederatedOutput fedOutType, List> planChilds) { + long hopID = hop.getHopID(); + FedPlanVariants fedPlanVariantList; + + if (contains(hopID, fedOutType)) { + fedPlanVariantList = hopMemoTable.get(new ImmutablePair<>(hopID, fedOutType)); + } else { + fedPlanVariantList = new FedPlanVariants(hop, fedOutType); + hopMemoTable.put(new ImmutablePair<>(hopID, fedOutType), fedPlanVariantList); + } + + FedPlan newPlan = new FedPlan(planChilds, fedPlanVariantList); + fedPlanVariantList.addFedPlan(newPlan); + + return newPlan; + } + + /** + * Retrieves the minimum cost child plan considering the parent's output type. + * The cost is calculated using getParentViewCost to account for potential type mismatches. + * + * @param childHopID ? + * @param childFedOutType ? + * @return ? + */ + public FedPlan getMinCostChildFedPlan(long childHopID, FederatedOutput childFedOutType) { + FedPlanVariants fedPlanVariantList = hopMemoTable.get(new ImmutablePair<>(childHopID, childFedOutType)); + return fedPlanVariantList._fedPlanVariants.stream() + .min(Comparator.comparingDouble(FedPlan::getTotalCost)) + .orElse(null); + } + + public FedPlanVariants getFedPlanVariants(long hopID, FederatedOutput fedOutType) { + return hopMemoTable.get(new ImmutablePair<>(hopID, fedOutType)); + } + + /** + * Checks if the memo table contains an entry for a given Hop and fedOutType. + * + * @param hopID The Hop ID. + * @param fedOutType The associated fedOutType. + * @return True if the entry exists, false otherwise. + */ + public boolean contains(long hopID, FederatedOutput fedOutType) { + return hopMemoTable.containsKey(new ImmutablePair<>(hopID, fedOutType)); + } + + /** + * Prunes all entries in the memo table, retaining only the minimum-cost + * FedPlan for each entry. + */ + public void pruneMemoTable() { + for (Map.Entry, FedPlanVariants> entry : hopMemoTable.entrySet()) { + List fedPlanList = entry.getValue().getFedPlanVariants(); + if (fedPlanList.size() > 1) { + // Find the FedPlan with the minimum cost + FedPlan minCostPlan = fedPlanList.stream() + .min(Comparator.comparingDouble(FedPlan::getTotalCost)) + .orElse(null); + + // Retain only the minimum cost plan + fedPlanList.clear(); + fedPlanList.add(minCostPlan); + } + } + } + + /** + * Recursively prints a tree representation of the DAG starting from the given root FedPlan. + * Includes information about hopID, fedOutType, TotalCost, SelfCost, and NetCost for each node. + * + * @param rootFedPlan The starting point FedPlan to print + */ + public void printFedPlanTree(FedPlan rootFedPlan) { + Set visited = new HashSet<>(); + printFedPlanTreeRecursive(rootFedPlan, visited, 0, true); + } + + /** + * Helper method to recursively print the FedPlan tree. + * + * @param plan The current FedPlan to print + * @param visited Set to keep track of visited FedPlans (prevents cycles) + * @param depth The current depth level for indentation + * @param isLast Whether this node is the last child of its parent + */ + private void printFedPlanTreeRecursive(FedPlan plan, Set visited, int depth, boolean isLast) { + if (plan == null || visited.contains(plan)) { + return; + } + + visited.add(plan); + + Hop hop = plan.getHopRef(); + StringBuilder sb = new StringBuilder(); + + // Add FedPlan information + sb.append(String.format("(%d) ", plan.getHopRef().getHopID())) + .append(plan.getHopRef().getOpString()) + .append(" [") + .append(plan.getFedOutType()) + .append("]"); + + StringBuilder childs = new StringBuilder(); + childs.append(" ("); + boolean childAdded = false; + for( Hop input : hop.getInput()){ + childs.append(childAdded?",":""); + childs.append(input.getHopID()); + childAdded = true; + } + childs.append(")"); + if( childAdded ) + sb.append(childs.toString()); + + + sb.append(String.format(" {Total: %.1f, Self: %.1f, Net: %.1f}", + plan.getTotalCost(), + plan.getSelfCost(), + plan.getNetTransferCost())); + + // Add matrix characteristics + sb.append(" [") + .append(hop.getDim1()).append(", ") + .append(hop.getDim2()).append(", ") + .append(hop.getBlocksize()).append(", ") + .append(hop.getNnz()); + + if (hop.getUpdateType().isInPlace()) { + sb.append(", ").append(hop.getUpdateType().toString().toLowerCase()); + } + sb.append("]"); + + // Add memory estimates + sb.append(" [") + .append(OptimizerUtils.toMB(hop.getInputMemEstimate())).append(", ") + .append(OptimizerUtils.toMB(hop.getIntermediateMemEstimate())).append(", ") + .append(OptimizerUtils.toMB(hop.getOutputMemEstimate())).append(" -> ") + .append(OptimizerUtils.toMB(hop.getMemEstimate())).append("MB]"); + + // Add reblock and checkpoint requirements + if (hop.requiresReblock() && hop.requiresCheckpoint()) { + sb.append(" [rblk, chkpt]"); + } else if (hop.requiresReblock()) { + sb.append(" [rblk]"); + } else if (hop.requiresCheckpoint()) { + sb.append(" [chkpt]"); + } + + // Add execution type + if (hop.getExecType() != null) { + sb.append(", ").append(hop.getExecType()); + } + + System.out.println(sb); + + // Process child nodes + List> childRefs = plan.getChildFedPlans(); + for (int i = 0; i < childRefs.size(); i++) { + Pair childRef = childRefs.get(i); + FedPlanVariants childVariants = getFedPlanVariants(childRef.getLeft(), childRef.getRight()); + if (childVariants == null || childVariants.getFedPlanVariants().isEmpty()) + continue; + + boolean isLastChild = (i == childRefs.size() - 1); + for (FedPlan childPlan : childVariants.getFedPlanVariants()) { + printFedPlanTreeRecursive(childPlan, visited, depth + 1, isLastChild); + } + } + } + + /** + * Represents a collection of federated execution plan variants for a specific Hop. + * Contains cost information and references to the associated plans. + */ + public static class FedPlanVariants { + protected final Hop hopRef; // Reference to the associated Hop + protected double selfCost; // Current execution cost (compute + memory access) + protected double netTransferCost; // Network transfer cost + private final FederatedOutput fedOutType; // Output type (FOUT/LOUT) + protected List _fedPlanVariants; // List of plan variants + + public FedPlanVariants(Hop hopRef, FederatedOutput fedOutType) { + this.hopRef = hopRef; + this.fedOutType = fedOutType; + this.selfCost = 0; + this.netTransferCost = 0; + this._fedPlanVariants = new ArrayList<>(); + } + + public int size() {return _fedPlanVariants.size();} + public void addFedPlan(FedPlan fedPlan) {_fedPlanVariants.add(fedPlan);} + public List getFedPlanVariants() {return _fedPlanVariants;} + } + + /** + * Represents a single federated execution plan with its associated costs and dependencies. + * Contains: + * 1. selfCost: Cost of current hop (compute + input/output memory access) + * 2. totalCost: Cumulative cost including this plan and all child plans + * 3. netTransferCost: Network transfer cost for this plan to parent plan. + */ + public static class FedPlan { + private double totalCost; // Total cost including child plans + private final FedPlanVariants fedPlanVariants; // Reference to variant list + private final List> childFedPlans; // Child plan references + + public FedPlan(List> childFedPlans, FedPlanVariants fedPlanVariants) { + this.totalCost = 0; + this.childFedPlans = childFedPlans; + this.fedPlanVariants = fedPlanVariants; + } + + public void setTotalCost(double totalCost) {this.totalCost = totalCost;} + public void setSelfCost(double selfCost) {fedPlanVariants.selfCost = selfCost;} + public void setNetTransferCost(double netTransferCost) {fedPlanVariants.netTransferCost = netTransferCost;} + + public Hop getHopRef() {return fedPlanVariants.hopRef;} + public FederatedOutput getFedOutType() {return fedPlanVariants.fedOutType;} + public double getTotalCost() {return totalCost;} + public double getSelfCost() {return fedPlanVariants.selfCost;} + private double getNetTransferCost() {return fedPlanVariants.netTransferCost;} + public List> getChildFedPlans() {return childFedPlans;} + + /** + * Calculates the conditional network transfer cost based on output type compatibility. + * Returns 0 if output types match, otherwise returns the network transfer cost. + * @param parentFedOutType ? + * @return ? + */ + public double getCondNetTransferCost(FederatedOutput parentFedOutType) { + if (parentFedOutType == getFedOutType()) return 0; + return fedPlanVariants.netTransferCost; + } + } +} diff --git a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEnumerator.java b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEnumerator.java new file mode 100644 index 00000000000..73e8d5d6930 --- /dev/null +++ b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEnumerator.java @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.hops.fedplanner; +import java.util.ArrayList; +import java.util.List; +import java.util.Comparator; +import java.util.Objects; + +import org.apache.commons.lang3.tuple.Pair; +import org.apache.sysds.hops.Hop; +import org.apache.sysds.hops.fedplanner.FederatedMemoTable.FedPlan; +import org.apache.sysds.hops.fedplanner.FederatedMemoTable.FedPlanVariants; +import org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput; + +/** + * Enumerates and evaluates all possible federated execution plans for a given Hop DAG. + * Works with FederatedMemoTable to store plan variants and FederatedPlanCostEstimator + * to compute their costs. + */ +public class FederatedPlanCostEnumerator { + /** + * Entry point for federated plan enumeration. Creates a memo table and returns + * the minimum cost plan for the entire DAG. + * + * @param rootHop ? + * @param printTree ? + * @return ? + */ + public static FedPlan enumerateFederatedPlanCost(Hop rootHop, boolean printTree) { + // Create new memo table to store all plan variants + FederatedMemoTable memoTable = new FederatedMemoTable(); + + // Recursively enumerate all possible plans + enumerateFederatedPlanCost(rootHop, memoTable); + + // Return the minimum cost plan for the root node + FedPlan optimalPlan = getMinCostRootFedPlan(rootHop.getHopID(), memoTable); + memoTable.pruneMemoTable(); + if (printTree) memoTable.printFedPlanTree(optimalPlan); + + return optimalPlan; + } + + /** + * Recursively enumerates all possible federated execution plans for a Hop DAG. + * For each node: + * 1. First processes all input nodes recursively if not already processed + * 2. Generates all possible combinations of federation types (FOUT/LOUT) for inputs + * 3. Creates and evaluates both FOUT and LOUT variants for current node with each input combination + * + * The enumeration uses a bottom-up approach where: + * - Each input combination is represented by a binary number (i) + * - Bit j in i determines whether input j is FOUT (1) or LOUT (0) + * - Total number of combinations is 2^numInputs + * + * @param hop ? + * @param memoTable ? + */ + private static void enumerateFederatedPlanCost(Hop hop, FederatedMemoTable memoTable) { + int numInputs = hop.getInput().size(); + + // Process all input nodes first if not already in memo table + for (Hop inputHop : hop.getInput()) { + if (!memoTable.contains(inputHop.getHopID(), FederatedOutput.FOUT) + && !memoTable.contains(inputHop.getHopID(), FederatedOutput.LOUT)) { + enumerateFederatedPlanCost(inputHop, memoTable); + } + } + + // Generate all possible input combinations using binary representation + // i represents a specific combination of FOUT/LOUT for inputs + for (int i = 0; i < (1 << numInputs); i++) { + List> planChilds = new ArrayList<>(); + + // For each input, determine if it should be FOUT or LOUT based on bit j in i + for (int j = 0; j < numInputs; j++) { + Hop inputHop = hop.getInput().get(j); + // If bit j is set (1), use FOUT; otherwise use LOUT + FederatedOutput childType = ((i & (1 << j)) != 0) ? + FederatedOutput.FOUT : FederatedOutput.LOUT; + planChilds.add(Pair.of(inputHop.getHopID(), childType)); + } + + // Create and evaluate FOUT variant for current input combination + FedPlan fOutPlan = memoTable.addFedPlan(hop, FederatedOutput.FOUT, planChilds); + FederatedPlanCostEstimator.computeFederatedPlanCost(fOutPlan, memoTable); + + // Create and evaluate LOUT variant for current input combination + FedPlan lOutPlan = memoTable.addFedPlan(hop, FederatedOutput.LOUT, planChilds); + FederatedPlanCostEstimator.computeFederatedPlanCost(lOutPlan, memoTable); + } + } + + /** + * Returns the minimum cost plan for the root Hop, comparing both FOUT and LOUT variants. + * Used to select the final execution plan after enumeration. + * + * @param HopID ? + * @param memoTable ? + * @return ? + */ + private static FedPlan getMinCostRootFedPlan(long HopID, FederatedMemoTable memoTable) { + FedPlanVariants fOutFedPlanVariants = memoTable.getFedPlanVariants(HopID, FederatedOutput.FOUT); + FedPlanVariants lOutFedPlanVariants = memoTable.getFedPlanVariants(HopID, FederatedOutput.LOUT); + + FedPlan minFOutFedPlan = fOutFedPlanVariants._fedPlanVariants.stream() + .min(Comparator.comparingDouble(FedPlan::getTotalCost)) + .orElse(null); + FedPlan minlOutFedPlan = lOutFedPlanVariants._fedPlanVariants.stream() + .min(Comparator.comparingDouble(FedPlan::getTotalCost)) + .orElse(null); + + if (Objects.requireNonNull(minFOutFedPlan).getTotalCost() + < Objects.requireNonNull(minlOutFedPlan).getTotalCost()) { + return minFOutFedPlan; + } + return minlOutFedPlan; + } +} diff --git a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEstimator.java b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEstimator.java new file mode 100644 index 00000000000..a716c3321db --- /dev/null +++ b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEstimator.java @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.hops.fedplanner; +import org.apache.commons.lang3.tuple.Pair; +import org.apache.sysds.hops.Hop; +import org.apache.sysds.hops.cost.ComputeCost; +import org.apache.sysds.hops.fedplanner.FederatedMemoTable.FedPlan; +import org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput; + +/** + * Cost estimator for federated execution plans. + * Calculates computation, memory access, and network transfer costs for federated operations. + * Works in conjunction with FederatedMemoTable to evaluate different execution plan variants. + */ +public class FederatedPlanCostEstimator { + // Default value is used as a reasonable estimate since we only need + // to compare relative costs between different federated plans + // Memory bandwidth for local computations (25 GB/s) + private static final double DEFAULT_MBS_MEMORY_BANDWIDTH = 25000.0; + // Network bandwidth for data transfers between federated sites (1 Gbps) + private static final double DEFAULT_MBS_NETWORK_BANDWIDTH = 125.0; + + /** + * Computes total cost of federated plan by: + * 1. Computing current node cost (if not cached) + * 2. Adding minimum-cost child plans + * 3. Including network transfer costs when needed + * + * @param currentPlan Plan to compute cost for + * @param memoTable Table containing all plan variants + */ + public static void computeFederatedPlanCost(FedPlan currentPlan, FederatedMemoTable memoTable) { + double totalCost = 0; + Hop currentHop = currentPlan.getHopRef(); + + // Step 1: Calculate current node costs if not already computed + if (currentPlan.getSelfCost() == 0) { + // Compute cost for current node (computation + memory access) + totalCost = computeCurrentCost(currentHop); + currentPlan.setSelfCost(totalCost); + // Calculate potential network transfer cost if federation type changes + currentPlan.setNetTransferCost(computeHopNetworkAccessCost(currentHop.getOutputMemEstimate())); + } else { + totalCost = currentPlan.getSelfCost(); + } + + // Step 2: Process each child plan and add their costs + for (Pair planRefMeta : currentPlan.getChildFedPlans()) { + // Find minimum cost child plan considering federation type compatibility + // Note: This approach might lead to suboptimal or wrong solutions when a child has multiple parents + // because we're selecting child plans independently for each parent + FedPlan planRef = memoTable.getMinCostChildFedPlan(planRefMeta.getLeft(), planRefMeta.getRight()); + + // Add child plan cost (includes network transfer cost if federation types differ) + totalCost += planRef.getTotalCost() + planRef.getCondNetTransferCost(currentPlan.getFedOutType()); + } + + // Step 3: Set final cumulative cost including current node + currentPlan.setTotalCost(totalCost); + } + + /** + * Computes the cost for the current Hop node. + * + * @param currentHop The Hop node whose cost needs to be computed + * @return The total cost for the current node's operation + */ + private static double computeCurrentCost(Hop currentHop){ + double computeCost = ComputeCost.getHOPComputeCost(currentHop); + double inputAccessCost = computeHopMemoryAccessCost(currentHop.getInputMemEstimate()); + double ouputAccessCost = computeHopMemoryAccessCost(currentHop.getOutputMemEstimate()); + + // Compute total cost assuming: + // 1. Computation and input access can be overlapped (hence taking max) + // 2. Output access must wait for both to complete (hence adding) + return Math.max(computeCost, inputAccessCost) + ouputAccessCost; + } + + /** + * Calculates the memory access cost based on data size and memory bandwidth. + * + * @param memSize Size of data to be accessed (in bytes) + * @return Time cost for memory access (in seconds) + */ + private static double computeHopMemoryAccessCost(double memSize) { + return memSize / (1024*1024) / DEFAULT_MBS_MEMORY_BANDWIDTH; + } + + /** + * Calculates the network transfer cost based on data size and network bandwidth. + * Used when federation status changes between parent and child plans. + * + * @param memSize Size of data to be transferred (in bytes) + * @return Time cost for network transfer (in seconds) + */ + private static double computeHopNetworkAccessCost(double memSize) { + return memSize / (1024*1024) / DEFAULT_MBS_NETWORK_BANDWIDTH; + } +} diff --git a/src/main/java/org/apache/sysds/hops/fedplanner/MemoTable.java b/src/main/java/org/apache/sysds/hops/fedplanner/MemoTable.java deleted file mode 100644 index f11b17b9849..00000000000 --- a/src/main/java/org/apache/sysds/hops/fedplanner/MemoTable.java +++ /dev/null @@ -1,160 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.sysds.hops.fedplanner; - -import org.apache.sysds.hops.Hop; -import org.apache.sysds.hops.fedplanner.FTypes.FType; -import org.apache.commons.lang3.tuple.Pair; -import org.apache.commons.lang3.tuple.ImmutablePair; - -import java.util.Comparator; -import java.util.HashMap; -import java.util.List; -import java.util.ArrayList; -import java.util.Map; - -/** - * A Memoization Table for managing federated plans (`FedPlan`) based on - * combinations of Hops and FTypes. Each combination is mapped to a list - * of possible execution plans, allowing for pruning and optimization. - */ -public class MemoTable { - - // Maps combinations of Hop ID and FType to lists of FedPlans - private final Map, List> hopMemoTable = new HashMap<>(); - - /** - * Represents a federated execution plan with its cost and associated references. - */ - public static class FedPlan { - @SuppressWarnings("unused") - private final Hop hopRef; // The associated Hop object - private final double cost; // Cost of this federated plan - @SuppressWarnings("unused") - private final List> planRefs; // References to dependent plans - - public FedPlan(Hop hopRef, double cost, List> planRefs) { - this.hopRef = hopRef; - this.cost = cost; - this.planRefs = planRefs; - } - - public double getCost() { - return cost; - } - } - - /** - * Adds a single FedPlan to the memo table for a given Hop and FType. - * If the entry already exists, the new FedPlan is appended to the list. - * - * @param hop The Hop object. - * @param fType The associated FType. - * @param fedPlan The FedPlan to add. - */ - public void addFedPlan(Hop hop, FType fType, FedPlan fedPlan) { - if (contains(hop, fType)) { - List fedPlanList = get(hop, fType); - fedPlanList.add(fedPlan); - } else { - List fedPlanList = new ArrayList<>(); - fedPlanList.add(fedPlan); - hopMemoTable.put(new ImmutablePair<>(hop.getHopID(), fType), fedPlanList); - } - } - - /** - * Adds multiple FedPlans to the memo table for a given Hop and FType. - * If the entry already exists, the new FedPlans are appended to the list. - * - * @param hop The Hop object. - * @param fType The associated FType. - * @param fedPlanList The list of FedPlans to add. - */ - public void addFedPlanList(Hop hop, FType fType, List fedPlanList) { - if (contains(hop, fType)) { - List prevFedPlanList = get(hop, fType); - prevFedPlanList.addAll(fedPlanList); - } else { - hopMemoTable.put(new ImmutablePair<>(hop.getHopID(), fType), fedPlanList); - } - } - - /** - * Retrieves the list of FedPlans associated with a given Hop and FType. - * - * @param hop The Hop object. - * @param fType The associated FType. - * @return The list of FedPlans, or null if no entry exists. - */ - public List get(Hop hop, FType fType) { - return hopMemoTable.get(new ImmutablePair<>(hop.getHopID(), fType)); - } - - /** - * Checks if the memo table contains an entry for a given Hop and FType. - * - * @param hop The Hop object. - * @param fType The associated FType. - * @return True if the entry exists, false otherwise. - */ - public boolean contains(Hop hop, FType fType) { - return hopMemoTable.containsKey(new ImmutablePair<>(hop.getHopID(), fType)); - } - - /** - * Prunes the FedPlans associated with a specific Hop and FType, - * keeping only the plan with the minimum cost. - * - * @param hop The Hop object. - * @param fType The associated FType. - */ - public void prunePlan(Hop hop, FType fType) { - prunePlan(hopMemoTable.get(new ImmutablePair<>(hop.getHopID(), fType))); - } - - /** - * Prunes all entries in the memo table, retaining only the minimum-cost - * FedPlan for each entry. - */ - public void pruneAll() { - for (Map.Entry, List> entry : hopMemoTable.entrySet()) { - prunePlan(entry.getValue()); - } - } - - /** - * Prunes the given list of FedPlans to retain only the plan with the minimum cost. - * - * @param fedPlanList The list of FedPlans to prune. - */ - private void prunePlan(List fedPlanList) { - if (fedPlanList.size() > 1) { - // Find the FedPlan with the minimum cost - FedPlan minCostPlan = fedPlanList.stream() - .min(Comparator.comparingDouble(plan -> plan.cost)) - .orElse(null); - - // Retain only the minimum cost plan - fedPlanList.clear(); - fedPlanList.add(minCostPlan); - } - } -} diff --git a/src/test/java/org/apache/sysds/test/component/federated/FederatedPlanCostEnumeratorTest.java b/src/test/java/org/apache/sysds/test/component/federated/FederatedPlanCostEnumeratorTest.java new file mode 100644 index 00000000000..56de8cf3c4f --- /dev/null +++ b/src/test/java/org/apache/sysds/test/component/federated/FederatedPlanCostEnumeratorTest.java @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.component.federated; + +import java.io.IOException; +import java.util.HashMap; + +import org.apache.sysds.hops.Hop; +import org.junit.Assert; +import org.junit.Test; +import org.apache.sysds.api.DMLScript; +import org.apache.sysds.conf.ConfigurationManager; +import org.apache.sysds.conf.DMLConfig; +import org.apache.sysds.parser.DMLProgram; +import org.apache.sysds.parser.DMLTranslator; +import org.apache.sysds.parser.ParserFactory; +import org.apache.sysds.parser.ParserWrapper; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.hops.fedplanner.FederatedPlanCostEnumerator; + + +public class FederatedPlanCostEnumeratorTest extends AutomatedTestBase +{ + private static final String TEST_DIR = "functions/federated/"; + private static final String HOME = SCRIPT_DIR + TEST_DIR; + private static final String TEST_CLASS_DIR = TEST_DIR + FederatedPlanCostEnumeratorTest.class.getSimpleName() + "/"; + + @Override + public void setUp() {} + + @Test + public void testDependencyAnalysis1() { runTest("cost.dml"); } + + private void runTest( String scriptFilename ) { + int index = scriptFilename.lastIndexOf(".dml"); + String testName = scriptFilename.substring(0, index > 0 ? index : scriptFilename.length()); + TestConfiguration testConfig = new TestConfiguration(TEST_CLASS_DIR, testName, new String[] {}); + addTestConfiguration(testName, testConfig); + loadTestConfiguration(testConfig); + + try { + DMLConfig conf = new DMLConfig(getCurConfigFile().getPath()); + ConfigurationManager.setLocalConfig(conf); + + //read script + String dmlScriptString = DMLScript.readDMLScript(true, HOME + scriptFilename); + + //parsing and dependency analysis + ParserWrapper parser = ParserFactory.createParser(); + DMLProgram prog = parser.parse(DMLScript.DML_FILE_PATH_ANTLR_PARSER, dmlScriptString, new HashMap<>()); + DMLTranslator dmlt = new DMLTranslator(prog); + dmlt.liveVariableAnalysis(prog); + dmlt.validateParseTree(prog); + dmlt.constructHops(prog); + dmlt.rewriteHopsDAG(prog); + dmlt.constructLops(prog); + /* TODO) In the current DAG, Hop's _outputMemEstimate is not initialized + // This leads to incorrect fedplan generation, so test code needs to be modified + // If needed, modify costEstimator to handle cases where _outputMemEstimate is not initialized + */ + Hop hops = prog.getStatementBlocks().get(0).getHops().get(0); + FederatedPlanCostEnumerator.enumerateFederatedPlanCost(hops, true); + } + catch (IOException e) { + e.printStackTrace(); + Assert.fail(); + } + } +} diff --git a/src/test/java/org/apache/sysds/test/component/federated/MemoTableTest.java b/src/test/java/org/apache/sysds/test/component/federated/MemoTableTest.java deleted file mode 100644 index e3928c12630..00000000000 --- a/src/test/java/org/apache/sysds/test/component/federated/MemoTableTest.java +++ /dev/null @@ -1,186 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.sysds.test.component.federated; - -import org.apache.sysds.hops.Hop; -import org.apache.sysds.hops.fedplanner.FTypes; -import org.apache.sysds.hops.fedplanner.MemoTable; -import org.apache.sysds.hops.fedplanner.MemoTable.FedPlan; -import org.apache.commons.lang3.tuple.Pair; -import org.junit.Before; -import org.junit.Test; -import org.mockito.Mock; -import org.mockito.MockitoAnnotations; - -import java.util.ArrayList; -import java.util.List; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertNull; -import static org.junit.Assert.assertTrue; -import static org.mockito.Mockito.when; - -public class MemoTableTest { - - private MemoTable memoTable; - - @Mock - private Hop mockHop1; - - @Mock - private Hop mockHop2; - - private java.util.Random rand; - - @Before - public void setUp() { - MockitoAnnotations.openMocks(this); - memoTable = new MemoTable(); - - // Set up unique IDs for mock Hops - when(mockHop1.getHopID()).thenReturn(1L); - when(mockHop2.getHopID()).thenReturn(2L); - - // Initialize random generator with fixed seed for reproducible tests - rand = new java.util.Random(42); - } - - @Test - public void testAddAndGetSingleFedPlan() { - // Initialize test data - List> planRefs = new ArrayList<>(); - FedPlan fedPlan = new FedPlan(mockHop1, 10.0, planRefs); - - // Verify initial state - List result = memoTable.get(mockHop1, FTypes.FType.FULL); - assertNull("Initial FedPlan list should be null before adding any plans", result); - - // Add single FedPlan - memoTable.addFedPlan(mockHop1, FTypes.FType.FULL, fedPlan); - - // Verify after addition - result = memoTable.get(mockHop1, FTypes.FType.FULL); - assertNotNull("FedPlan list should exist after adding a plan", result); - assertEquals("FedPlan list should contain exactly one plan", 1, result.size()); - assertEquals("FedPlan cost should be exactly 10.0", 10.0, result.get(0).getCost(), 0.001); - } - - @Test - public void testAddMultipleDuplicatedFedPlans() { - // Initialize test data with duplicate costs - List> planRefs = new ArrayList<>(); - List fedPlans = new ArrayList<>(); - fedPlans.add(new FedPlan(mockHop1, 10.0, planRefs)); // Unique cost - fedPlans.add(new FedPlan(mockHop1, 20.0, planRefs)); // First duplicate - fedPlans.add(new FedPlan(mockHop1, 20.0, planRefs)); // Second duplicate - - // Add multiple plans including duplicates - memoTable.addFedPlanList(mockHop1, FTypes.FType.FULL, fedPlans); - - // Verify handling of duplicate plans - List result = memoTable.get(mockHop1, FTypes.FType.FULL); - assertNotNull("FedPlan list should exist after adding multiple plans", result); - assertEquals("FedPlan list should maintain all plans including duplicates", 3, result.size()); - } - - @Test - public void testContains() { - // Initialize test data - List> planRefs = new ArrayList<>(); - FedPlan fedPlan = new FedPlan(mockHop1, 10.0, planRefs); - - // Verify initial state - assertFalse("MemoTable should not contain any entries initially", - memoTable.contains(mockHop1, FTypes.FType.FULL)); - - // Add plan and verify presence - memoTable.addFedPlan(mockHop1, FTypes.FType.FULL, fedPlan); - - assertTrue("MemoTable should contain entry after adding FedPlan", - memoTable.contains(mockHop1, FTypes.FType.FULL)); - assertFalse("MemoTable should not contain entries for different Hop", - memoTable.contains(mockHop2, FTypes.FType.FULL)); - } - - @Test - public void testPrunePlanPruneAll() { - // Initialize base test data - List> planRefs = new ArrayList<>(); - // Create separate FedPlan lists for independent testing of each Hop - List fedPlans1 = new ArrayList<>(); // Plans for mockHop1 - List fedPlans2 = new ArrayList<>(); // Plans for mockHop2 - - // Generate random cost FedPlans for both Hops - double minCost = Double.MAX_VALUE; - int size = 100; - for(int i = 0; i < size; i++) { - double cost = rand.nextDouble() * 1000; // Random cost between 0 and 1000 - fedPlans1.add(new FedPlan(mockHop1, cost, planRefs)); - fedPlans2.add(new FedPlan(mockHop2, cost, planRefs)); - minCost = Math.min(minCost, cost); - } - - // Add FedPlan lists to MemoTable - memoTable.addFedPlanList(mockHop1, FTypes.FType.FULL, fedPlans1); - memoTable.addFedPlanList(mockHop2, FTypes.FType.FULL, fedPlans2); - - // Test selective pruning on mockHop1 - memoTable.prunePlan(mockHop1, FTypes.FType.FULL); - - // Get results for verification - List result1 = memoTable.get(mockHop1, FTypes.FType.FULL); - List result2 = memoTable.get(mockHop2, FTypes.FType.FULL); - - // Verify selective pruning results - assertNotNull("Pruned mockHop1 should maintain a FedPlan list", result1); - assertEquals("Pruned mockHop1 should contain exactly one minimum cost plan", 1, result1.size()); - assertEquals("Pruned mockHop1's plan should have the minimum cost", minCost, result1.get(0).getCost(), 0.001); - - // Verify unpruned Hop state - assertNotNull("Unpruned mockHop2 should maintain a FedPlan list", result2); - assertEquals("Unpruned mockHop2 should maintain all original plans", size, result2.size()); - - // Add additional plans to both Hops - for(int i = 0; i < size; i++) { - double cost = rand.nextDouble() * 1000; - memoTable.addFedPlan(mockHop1, FTypes.FType.FULL, new FedPlan(mockHop1, cost, planRefs)); - memoTable.addFedPlan(mockHop2, FTypes.FType.FULL, new FedPlan(mockHop2, cost, planRefs)); - minCost = Math.min(minCost, cost); - } - - // Test global pruning - memoTable.pruneAll(); - - // Verify global pruning results - assertNotNull("mockHop1 should maintain a FedPlan list after global pruning", result1); - assertEquals("mockHop1 should contain exactly one minimum cost plan after global pruning", - 1, result1.size()); - assertEquals("mockHop1's plan should have the global minimum cost", - minCost, result1.get(0).getCost(), 0.001); - - assertNotNull("mockHop2 should maintain a FedPlan list after global pruning", result2); - assertEquals("mockHop2 should contain exactly one minimum cost plan after global pruning", - 1, result2.size()); - assertEquals("mockHop2's plan should have the global minimum cost", - minCost, result2.get(0).getCost(), 0.001); - } -} diff --git a/src/test/scripts/functions/federated/cost.dml b/src/test/scripts/functions/federated/cost.dml new file mode 100644 index 00000000000..ec34d45bb65 --- /dev/null +++ b/src/test/scripts/functions/federated/cost.dml @@ -0,0 +1,25 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +a = matrix(7,10,10); +b = a + a^2; +c = sqrt(b); +print(sum(c)); \ No newline at end of file