Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-25299] Yet another attempt to integrate API with scheduler #559

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.api.shuffle;

import org.apache.spark.annotation.Experimental;
import org.apache.spark.api.java.Optional;

/**
* :: Experimental ::
Expand All @@ -31,4 +32,6 @@ public interface ShuffleDataIO {

ShuffleDriverComponents driver();
ShuffleExecutorComponents executor();
Optional<ShuffleLocationComponents> shuffleLocations();

}
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,5 @@ public interface ShuffleDriverComponents {
void cleanupApplication() throws IOException;

void removeShuffleData(int shuffleId, boolean blocking) throws IOException;

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.api.shuffle;

import org.apache.spark.annotation.Experimental;

/**
* :: Experimental ::
* An interface for interaction with shuffle locations.
*
* @since 3.0.0
*/
@Experimental
public interface ShuffleLocationComponents {

/**
* Returns whether the MapShuffleLocations now has missing data based on the
* removal of the lost shuffle location.
*/
boolean shouldRemoveMapOutputOnLostBlock(
ShuffleLocation lostLocation,
MapShuffleLocations mapOutputLocations);
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@
package org.apache.spark.shuffle.sort.io;

import org.apache.spark.SparkConf;
import org.apache.spark.api.shuffle.ShuffleDriverComponents;
import org.apache.spark.api.shuffle.ShuffleExecutorComponents;
import org.apache.spark.api.shuffle.ShuffleDataIO;
import org.apache.spark.api.java.Optional;
import org.apache.spark.api.shuffle.*;
import org.apache.spark.internal.config.package$;
import org.apache.spark.shuffle.sort.DefaultMapShuffleLocations;
import org.apache.spark.shuffle.sort.lifecycle.DefaultShuffleDriverComponents;

public class DefaultShuffleDataIO implements ShuffleDataIO {
Expand All @@ -40,4 +41,9 @@ public ShuffleExecutorComponents executor() {
public ShuffleDriverComponents driver() {
return new DefaultShuffleDriverComponents();
}

@Override
public Optional<ShuffleLocationComponents> shuffleLocations() {
return Optional.of(new DefaultShuffleLocationComponents(sparkConf));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.shuffle.sort.io;

import org.apache.spark.SparkConf;
import org.apache.spark.api.shuffle.MapShuffleLocations;
import org.apache.spark.api.shuffle.ShuffleLocation;
import org.apache.spark.api.shuffle.ShuffleLocationComponents;
import org.apache.spark.internal.config.package$;
import org.apache.spark.shuffle.sort.DefaultMapShuffleLocations;

public class DefaultShuffleLocationComponents implements ShuffleLocationComponents {

private final boolean externalShuffleServiceEnabled;
private final boolean unRegisterOutputHostOnFetchFailure;

public DefaultShuffleLocationComponents(SparkConf sparkConf) {
externalShuffleServiceEnabled = (boolean) sparkConf.get(package$.MODULE$.SHUFFLE_SERVICE_ENABLED());
unRegisterOutputHostOnFetchFailure = (boolean)
sparkConf.get(package$.MODULE$.UNREGISTER_OUTPUT_ON_HOST_ON_FETCH_FAILURE());
}

@Override
public boolean shouldRemoveMapOutputOnLostBlock(
ShuffleLocation lostLocation,
MapShuffleLocations mapOutputLocations) {
DefaultMapShuffleLocations mapStatusLoc = (DefaultMapShuffleLocations) mapOutputLocations;
DefaultMapShuffleLocations lostLoc = (DefaultMapShuffleLocations) lostLocation;
if (externalShuffleServiceEnabled && unRegisterOutputHostOnFetchFailure) {
return mapStatusLoc.getBlockManagerId().host().equals(lostLoc.getBlockManagerId().host());
} else {
return mapStatusLoc.getBlockManagerId().executorId().equals(lostLoc.getBlockManagerId().executorId());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,13 @@
package org.apache.spark.shuffle.sort.lifecycle;

import com.google.common.collect.ImmutableMap;
import org.apache.spark.SparkConf;
import org.apache.spark.SparkEnv;
import org.apache.spark.api.shuffle.MapShuffleLocations;
import org.apache.spark.api.shuffle.ShuffleDriverComponents;
import org.apache.spark.api.shuffle.ShuffleLocation;
import org.apache.spark.internal.config.package$;
import org.apache.spark.shuffle.sort.DefaultMapShuffleLocations;
import org.apache.spark.storage.BlockManagerMaster;

import java.io.IOException;
Expand Down
46 changes: 37 additions & 9 deletions core/src/main/scala/org/apache/spark/MapOutputTracker.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import scala.concurrent.duration.Duration
import scala.reflect.ClassTag
import scala.util.control.NonFatal

import org.apache.spark.api.shuffle.{MapShuffleLocations, ShuffleLocation}
import org.apache.spark.api.shuffle.{MapShuffleLocations, ShuffleDriverComponents, ShuffleLocation, ShuffleLocationComponents}
import org.apache.spark.broadcast.{Broadcast, BroadcastManager}
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config._
Expand Down Expand Up @@ -98,12 +98,12 @@ private class ShuffleStatus(numPartitions: Int) {
}

/**
* Remove the map output which was served by the specified block manager.
* This is a no-op if there is no registered map output or if the registered output is from a
* different block manager.
* Remove the map output which contains the specific shuffle location for the given reduce Id.
*/
def removeMapOutput(mapId: Int, bmAddress: BlockManagerId): Unit = synchronized {
if (mapStatuses(mapId) != null && mapStatuses(mapId).location == bmAddress) {
def removeMapOutput(mapId: Int, reduceId: Int, shuffleLoc: ShuffleLocation)
: Unit = synchronized {
if (mapStatuses(mapId) != null && mapStatuses(mapId).mapShuffleLocations != null &&
mapStatuses(mapId).mapShuffleLocations.getLocationForBlock(reduceId) == shuffleLoc) {
_numAvailableOutputs -= 1
mapStatuses(mapId) = null
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I understand what's going on with removeMapAtLocation, but I'm confused by this part. When there is a fetch failure from host X for map M from reduce R, the dagscheduler removes shuffle outputs two different ways (no idea why) -- (1) it uses the path you're changing below in removeMapAtLocation to remove all map output on host X. I see how you're changing that part below to move logic into the ShuffleLocationComponent. (2) here, it is removing the Map M from host X. before this change, it removes the entire map output of Map M. You're changing it to take in the reduce R as a parameter, but still removing all map output of Map M (mapStatus(mapId) = null), which doesn't seem consistent.

Or is there some special logic in the extra condition mapStatuses(mapId).mapShuffleLocations.getLocationForBlock(reduceId) == shuffleLoc? I guess I don't understand when it would be true or false -- seems like it should always be true. Can one ShuffleLocation really represent multiple locations? I'd still expect them to be equal here, as the mapOutputTracker would have stored multiple locations, and the fetch failure would also send back multiple locations, right?

(above discussion ignores the host / executor distinction for external shuffle service in current implementation just to keep things "simple")

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh the first point about this function taking a reducer parameter R, it's using R to validate that the ShuffleLocation does indeed exist in the MapStatus for mapper M. I was doing this to mimic the previous behavior where you're checking that the bmAddress exists in the MapStatus before removing it, although I wasn't exactly sure why one would need to check (could it be that we might have received an obsolete FetchFailed?). However, you're right that I can replace the mapStatuses(mapId).mapShuffleLocations.getLocationForBlock(reduceId) == shuffleLoc logic with something that calls into the logic inside ShuffleLocationComponent, which would be more consistent.

On the note about ShuffleLocation representing multiple locations: yup that's right, I tried to code up what that would look like in DAGSchedulerFileServerSuite. Fetch failure would then send back a single ShuffleLocation that encodes multiple host/port combos.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can an obsolete FetchFailed be a problem, I would leave the check

Does order in the multiple host/port combos matter in this equality check?

invalidateSerializedMapOutputStatusCache()
Expand Down Expand Up @@ -141,6 +141,18 @@ private class ShuffleStatus(numPartitions: Int) {
}
}

def removeOutputsByShuffleLocation(
shuffleLoc: ShuffleLocation,
f: (ShuffleLocation, MapShuffleLocations) => Boolean) : Unit = synchronized {
for (mapId <- 0 until mapStatuses.length) {
if (mapStatuses(mapId) != null && f(shuffleLoc, mapStatuses(mapId).mapShuffleLocations)) {
_numAvailableOutputs -= 1
mapStatuses(mapId) = null
invalidateSerializedMapOutputStatusCache()
}
}
}

/**
* Number of partitions that have shuffle outputs.
*/
Expand Down Expand Up @@ -319,6 +331,7 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging
private[spark] class MapOutputTrackerMaster(
conf: SparkConf,
broadcastManager: BroadcastManager,
shuffleLocationComponents: Option[ShuffleLocationComponents],
isLocal: Boolean)
extends MapOutputTracker(conf) {

Expand Down Expand Up @@ -423,17 +436,32 @@ private[spark] class MapOutputTrackerMaster(
shuffleStatuses(shuffleId).addMapOutput(mapId, status)
}

/** Unregister map output information of the given shuffle, mapper and block manager */
def unregisterMapOutput(shuffleId: Int, mapId: Int, bmAddress: BlockManagerId) {
/** Unregister map output information of the given shuffle, mapper, reducer and location */
def unregisterMapOutput(
shuffleId: Int,
mapId: Int,
reduceId: Int,
shuffleLoc: ShuffleLocation): Unit = {
shuffleStatuses.get(shuffleId) match {
case Some(shuffleStatus) =>
shuffleStatus.removeMapOutput(mapId, bmAddress)
shuffleStatus.removeMapOutput(mapId, reduceId, shuffleLoc)
incrementEpoch()
case None =>
throw new SparkException("unregisterMapOutput called for nonexistent shuffle ID")
}
}

def removeMapAtLocation(shuffleLoc: ShuffleLocation): Unit = {
shuffleStatuses.valuesIterator.foreach { mapStatuses =>
if (shuffleLocationComponents.isDefined) {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can just do shuffleLocationComponents.forEach

mapStatuses.removeOutputsByShuffleLocation(
shuffleLoc,
shuffleLocationComponents.get.shouldRemoveMapOutputOnLostBlock)
}
}
incrementEpoch()
}

/** Unregister all map output information of the given shuffle. */
def unregisterAllMapOutput(shuffleId: Int) {
shuffleStatuses.get(shuffleId) match {
Expand Down
35 changes: 22 additions & 13 deletions core/src/main/scala/org/apache/spark/SparkContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFor
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.api.conda.CondaEnvironment
import org.apache.spark.api.conda.CondaEnvironment.CondaSetupInstructions
import org.apache.spark.api.shuffle.{ShuffleDataIO, ShuffleDriverComponents}
import org.apache.spark.api.shuffle.{ShuffleDataIO, ShuffleDriverComponents, ShuffleLocationComponents}
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.deploy.{CondaRunner, LocalSparkCluster, SparkHadoopUtil}
import org.apache.spark.input.{FixedLengthBinaryInputFormat, PortableDataStream, StreamInputFormat, WholeTextFileInputFormat}
Expand Down Expand Up @@ -216,7 +216,7 @@ class SparkContext(config: SparkConf) extends SafeLogging {
private var _shutdownHookRef: AnyRef = _
private var _statusStore: AppStatusStore = _
private var _heartbeater: Heartbeater = _
private var _shuffleDriverComponents: ShuffleDriverComponents = _
private var _shuffleDataIo: ShuffleDataIO = _

/* ------------------------------------------------------------------------------------- *
| Accessors and public fields. These provide access to the internal state of the |
Expand Down Expand Up @@ -257,8 +257,10 @@ class SparkContext(config: SparkConf) extends SafeLogging {
private[spark] def createSparkEnv(
conf: SparkConf,
isLocal: Boolean,
listenerBus: LiveListenerBus): SparkEnv = {
SparkEnv.createDriverEnv(conf, isLocal, listenerBus, SparkContext.numDriverCores(master, conf))
listenerBus: LiveListenerBus,
shuffleDataIO: ShuffleDataIO): SparkEnv = {
SparkEnv.createDriverEnv(conf, isLocal, listenerBus, SparkContext.numDriverCores(master, conf),
shuffleDataIO)
}

private[spark] def env: SparkEnv = _env
Expand Down Expand Up @@ -308,6 +310,10 @@ class SparkContext(config: SparkConf) extends SafeLogging {
_dagScheduler = ds
}

private[spark] def shuffleLocationComponents: Some[ShuffleLocationComponents] = {
Some(_shuffleDataIo.shuffleLocations().orNull())
}

/**
* A unique identifier for the Spark application.
* Its format depends on the scheduler implementation.
Expand Down Expand Up @@ -429,8 +435,14 @@ class SparkContext(config: SparkConf) extends SafeLogging {
_statusStore = AppStatusStore.createLiveStore(conf, appStatusSource)
listenerBus.addToStatusQueue(_statusStore.listener.get)


val configuredPluginClasses = conf.get(SHUFFLE_IO_PLUGIN_CLASS)
val maybeIO = Utils.loadExtensions(
classOf[ShuffleDataIO], Seq(configuredPluginClasses), conf)
require(maybeIO.size == 1, s"Failed to load plugins of type $configuredPluginClasses")

// Create the Spark execution environment (cache, map output tracker, etc)
_env = createSparkEnv(_conf, isLocal, listenerBus)
_env = createSparkEnv(_conf, isLocal, listenerBus, maybeIO.head)
SparkEnv.set(_env)

// If running the REPL, register the repl's output dir with the file server.
Expand Down Expand Up @@ -493,12 +505,9 @@ class SparkContext(config: SparkConf) extends SafeLogging {
executorEnvs ++= _conf.getExecutorEnv
executorEnvs("SPARK_USER") = sparkUser

val configuredPluginClasses = conf.get(SHUFFLE_IO_PLUGIN_CLASS)
val maybeIO = Utils.loadExtensions(
classOf[ShuffleDataIO], Seq(configuredPluginClasses), conf)
require(maybeIO.size == 1, s"Failed to load plugins of type $configuredPluginClasses")
_shuffleDriverComponents = maybeIO.head.driver()
_shuffleDriverComponents.initializeApplication().asScala.foreach {
_shuffleDataIo = maybeIO.head
maybeIO.head.driver()
.initializeApplication().asScala.foreach {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indentation seems weird

case (k, v) => _conf.set(ShuffleDataIO.SHUFFLE_SPARK_CONF_PREFIX + k, v) }

// We need to register "HeartbeatReceiver" before "createTaskScheduler" because Executor will
Expand Down Expand Up @@ -570,7 +579,7 @@ class SparkContext(config: SparkConf) extends SafeLogging {

_cleaner =
if (_conf.get(CLEANER_REFERENCE_TRACKING)) {
Some(new ContextCleaner(this, _shuffleDriverComponents))
Some(new ContextCleaner(this, _shuffleDataIo.driver()))
} else {
None
}
Expand Down Expand Up @@ -1960,7 +1969,7 @@ class SparkContext(config: SparkConf) extends SafeLogging {
}
_heartbeater = null
}
_shuffleDriverComponents.cleanupApplication()
_shuffleDataIo.driver().cleanupApplication()
if (env != null && _heartbeatReceiver != null) {
Utils.tryLogNonFatalError {
env.rpcEnv.stop(_heartbeatReceiver)
Expand Down
8 changes: 7 additions & 1 deletion core/src/main/scala/org/apache/spark/SparkEnv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import org.apache.hadoop.conf.Configuration
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.api.conda.CondaEnvironment.CondaSetupInstructions
import org.apache.spark.api.python.PythonWorkerFactory
import org.apache.spark.api.shuffle.ShuffleDataIO
import org.apache.spark.broadcast.BroadcastManager
import org.apache.spark.internal.{config, Logging}
import org.apache.spark.internal.config._
Expand Down Expand Up @@ -200,6 +201,7 @@ object SparkEnv extends Logging {
isLocal: Boolean,
listenerBus: LiveListenerBus,
numCores: Int,
shuffleDataIO: ShuffleDataIO,
mockOutputCommitCoordinator: Option[OutputCommitCoordinator] = None): SparkEnv = {
assert(conf.contains(DRIVER_HOST_ADDRESS),
s"${DRIVER_HOST_ADDRESS.key} is not set on the driver!")
Expand All @@ -221,6 +223,7 @@ object SparkEnv extends Logging {
isLocal,
numCores,
ioEncryptionKey,
shuffleDataIO = Some(shuffleDataIO),
listenerBus = listenerBus,
mockOutputCommitCoordinator = mockOutputCommitCoordinator
)
Expand Down Expand Up @@ -254,6 +257,7 @@ object SparkEnv extends Logging {
/**
* Helper method to create a SparkEnv for a driver or an executor.
*/
// scalastyle:off
private def create(
conf: SparkConf,
executorId: String,
Expand All @@ -263,6 +267,7 @@ object SparkEnv extends Logging {
isLocal: Boolean,
numUsableCores: Int,
ioEncryptionKey: Option[Array[Byte]],
shuffleDataIO: Option[ShuffleDataIO] = Option.empty,
listenerBus: LiveListenerBus = null,
mockOutputCommitCoordinator: Option[OutputCommitCoordinator] = None): SparkEnv = {

Expand Down Expand Up @@ -341,7 +346,8 @@ object SparkEnv extends Logging {
val broadcastManager = new BroadcastManager(isDriver, conf, securityManager)

val mapOutputTracker = if (isDriver) {
new MapOutputTrackerMaster(conf, broadcastManager, isLocal)
new MapOutputTrackerMaster(
conf, broadcastManager, Some(shuffleDataIO.get.shuffleLocations().orNull()), isLocal)
} else {
new MapOutputTrackerWorker(conf)
}
Expand Down
Loading