Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support refreshing CondaEnvironment auth in specfile mode #669

Merged
merged 26 commits into from
Apr 7, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions core/src/main/scala/org/apache/spark/SparkContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1879,6 +1879,10 @@ class SparkContext(config: SparkConf) extends SafeLogging {
condaEnvironmentOrFail().setChannels(urls)
}

def setPackageUrlsUserInfo(userInfo: Option[String]): Unit = {
condaEnvironmentOrFail().setPackageUrlsUserInfo(userInfo)
}

private[spark] def buildCondaInstructions(): Option[CondaSetupInstructions] = {
condaEnvironment().map(_.buildSetupInstructions)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ final class CondaEnvironment(
bootstrapMode: CondaBootstrapMode,
bootstrapPackages: Seq[String],
bootstrapPackageUrls: Seq[String],
private var bootstrapPackageUrlsUserInfo: Option[String],
bootstrapChannels: Seq[String],
extraArgs: Seq[String] = Nil,
envVars: Map[String, String] = Map.empty) extends Logging {
Expand Down Expand Up @@ -74,6 +75,10 @@ final class CondaEnvironment(
channels ++= urls.iterator.map(AuthenticatedChannel.apply)
}

def setPackageUrlsUserInfo(userInfo: Option[String]): Unit = {
bootstrapPackageUrlsUserInfo = userInfo
}

def getTransitivePackageUrls(): List[String] = {
manager.listPackagesExplicit(condaEnvDir.toAbsolutePath.toString)
}
Expand Down Expand Up @@ -115,6 +120,7 @@ final class CondaEnvironment(
bootstrapMode,
packages.toList,
bootstrapPackageUrls,
bootstrapPackageUrlsUserInfo,
channels.toList,
extraArgs,
envVars)
Expand Down Expand Up @@ -177,6 +183,7 @@ object CondaEnvironment {
mode: CondaBootstrapMode,
packages: Seq[String],
packageUrls: Seq[String],
packageUrlsUserInfo: Option[String],
unauthenticatedChannels: Seq[UnauthenticatedChannel],
extraArgs: Seq[String],
envVars: Map[String, String])
Expand All @@ -200,11 +207,13 @@ object CondaEnvironment {
mode: CondaBootstrapMode,
packages: Seq[String],
packageUrls: Seq[String],
packageUrlsUserInfo: Option[String],
channels: Seq[AuthenticatedChannel],
extraArgs: Seq[String],
envVars: Map[String, String]): CondaSetupInstructions = {
val ChannelsWithCreds(unauthed, userInfos) = unauthenticateChannels(channels)
CondaSetupInstructions(mode, packages, packageUrls, unauthed, extraArgs, envVars)(userInfos)
CondaSetupInstructions(
mode, packages, packageUrls, packageUrlsUserInfo, unauthed, extraArgs, envVars)(userInfos)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@
*/
package org.apache.spark.api.conda

import java.net.URI
import java.nio.file.Files
import java.nio.file.Path
import java.nio.file.Paths
import java.nio.file.attribute.PosixFilePermission
import java.util.regex.Pattern
import javax.ws.rs.core.UriBuilder

import scala.collection.JavaConverters._
import scala.collection.mutable
Expand Down Expand Up @@ -82,14 +84,16 @@ final class CondaEnvironmentManager(condaBinaryPath: String,
condaMode: CondaBootstrapMode,
condaPackages: Seq[String],
condaPackageUrls: Seq[String],
condaPackageUrlsUserInfo: Option[String],
condaChannelUrls: Seq[String],
condaExtraArgs: Seq[String] = Nil,
condaEnvVars: Map[String, String] = Map.empty): CondaEnvironment = {
condaMode match {
case CondaBootstrapMode.Solve =>
create(baseDir, condaPackages, condaChannelUrls, condaExtraArgs, condaEnvVars)
case CondaBootstrapMode.File =>
createWithFile(baseDir, condaPackageUrls, condaExtraArgs, condaEnvVars)
createWithFile(
baseDir, condaPackageUrls, condaPackageUrlsUserInfo, condaExtraArgs, condaEnvVars)
}
}

Expand Down Expand Up @@ -131,16 +135,21 @@ final class CondaEnvironmentManager(condaBinaryPath: String,
CondaBootstrapMode.Solve,
condaPackages,
Nil,
None,
condaChannelUrls,
condaExtraArgs)
}

def createWithFile(
baseDir: String,
condaPackageUrls: Seq[String],
condaPackageUrlsUserInfo: Option[String],
condaExtraArgs: Seq[String] = Nil,
condaEnvVars: Map[String, String] = Map.empty): CondaEnvironment = {
require(condaPackageUrls.nonEmpty, "Expected at least one conda package url.")
require(condaPackageUrls.find(packageUrl => new URI(packageUrl).getUserInfo != null).isEmpty,
"Cannot pass condaPackageUrls with inlined auth; pass UserInfo " +
"via spark.conda.bootstrapPackageUrlsUserInfo.")
val name = "conda-env"

// must link in /tmp to reduce path length in case baseDir is very long...
Expand All @@ -152,9 +161,18 @@ final class CondaEnvironmentManager(condaBinaryPath: String,

val verbosityFlags = 0.until(verbosity).map(_ => "-v").toList

// Authenticate URLs if we have a UserInfo argument
val finalCondaPackageUrls = if (condaPackageUrlsUserInfo.isDefined) {
condaPackageUrls.map { packageUrl =>
UriBuilder.fromUri(packageUrl).userInfo(condaPackageUrlsUserInfo.get).build().toString
}
} else {
condaPackageUrls
}

// Create spec file with URLs
val specFilePath = linkedBaseDir.resolve("spec-file")
Files.write(specFilePath, ("@EXPLICIT" +: condaPackageUrls).asJava)
Files.write(specFilePath, ("@EXPLICIT" +: finalCondaPackageUrls).asJava)

// Attempt to create environment
runCondaProcess(
Expand All @@ -175,6 +193,7 @@ final class CondaEnvironmentManager(condaBinaryPath: String,
CondaBootstrapMode.File,
Nil,
condaPackageUrls,
condaPackageUrlsUserInfo,
Nil,
condaExtraArgs)
}
Expand Down Expand Up @@ -312,7 +331,6 @@ object CondaEnvironmentManager extends Logging {
private[this] def createCondaEnvironment(
instructions: CondaSetupInstructions): CondaEnvironment = {
val condaPackages = instructions.packages
val condaPackageUrls = instructions.packageUrls
val env = SparkEnv.get
val condaEnvManager = CondaEnvironmentManager.fromConf(env.conf)
val envDir = {
Expand All @@ -326,7 +344,8 @@ object CondaEnvironmentManager extends Logging {
envDir,
instructions.mode,
condaPackages,
condaPackageUrls,
instructions.packageUrls,
instructions.packageUrlsUserInfo,
instructions.channels,
instructions.extraArgs)
}
Expand Down
6 changes: 4 additions & 2 deletions core/src/main/scala/org/apache/spark/deploy/CondaRunner.scala
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ object CondaRunner {
if (CondaEnvironmentManager.isConfigured(sparkConf)) {
val condaBootstrapMode = CondaBootstrapMode.fromString(sparkConf.get(CONDA_BOOTSTRAP_MODE))
val condaBootstrapDeps = sparkConf.get(CONDA_BOOTSTRAP_PACKAGES)
val condaBootstrapDepUrls = sparkConf.get(CONDA_BOOTSTRAP_PACKAGE_URLS)
val condaBootstrapPackageUrls = sparkConf.get(CONDA_BOOTSTRAP_PACKAGE_URLS)
val condaBootstrapPackageUrlsUserInfo = sparkConf.get(CONDA_BOOTSTRAP_PACKAGE_URLS_USER_INFO)
val condaChannelUrls = sparkConf.get(CONDA_CHANNEL_URLS)
val condaExtraArgs = sparkConf.get(CONDA_EXTRA_ARGUMENTS)
val condaEnvVariables = extractEnvVariables(sparkConf)
Expand All @@ -63,7 +64,8 @@ object CondaRunner {
condaBaseDir,
condaBootstrapMode,
condaBootstrapDeps,
condaBootstrapDepUrls,
condaBootstrapPackageUrls,
condaBootstrapPackageUrlsUserInfo,
condaChannelUrls,
condaExtraArgs,
condaEnvVariables)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,13 @@ package object config {
.toSequence
.createWithDefault(Nil)

private[spark] val CONDA_BOOTSTRAP_PACKAGE_URLS_USER_INFO =
ConfigBuilder("spark.conda.bootstrapPackageUrlsUserInfo")
.doc("Basic auth information (in 'user:pw' format) to be added to package urls that will " +
"be added to the conda environment. Only relevant when main class is CondaRunner.")
.stringConf
.createOptional

private[spark] val CONDA_CHANNEL_URLS = ConfigBuilder("spark.conda.channelUrls")
.doc("The URLs the Conda channels to use when resolving the conda packages. "
+ "Only relevant when main class is CondaRunner.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ package org.apache.spark.api.conda

import java.nio.file.Files

import org.apache.spark.SparkConf
import org.apache.spark.internal.config._
import org.apache.spark.util.TempDirectory

class CondaEnvironmentManagerTest extends org.apache.spark.SparkFunSuite with TempDirectory {
Expand All @@ -43,4 +45,30 @@ class CondaEnvironmentManagerTest extends org.apache.spark.SparkFunSuite with Te
"[http://us_r:<password>@yy.bar:222"
assert(CondaEnvironmentManager.redactCredentials(original) == redacted)
}

test("CondaEnvironmentManager.failOnAuthenticatedPackageUrls") {
val packageUrl =
"https://myuser:[email protected]/whatever/else/linux-64/package-0.0.1-py_0.tar.bz2"
val userInfo = "anotheruser:theirpassword"

val binaryPath = tempDir.toPath.resolve("dummy-conda.bin")
val condaEnvDir = tempDir.toPath.resolve("test-conda-env")
Files.createFile(binaryPath)
Files.createDirectory(condaEnvDir)

val conf = new SparkConf()
conf.set(CONDA_BINARY_PATH, binaryPath.toString)
conf.set(CONDA_BOOTSTRAP_MODE, "File")
conf.set(CONDA_BOOTSTRAP_PACKAGE_URLS, Seq(packageUrl))
conf.set(CONDA_BOOTSTRAP_PACKAGE_URLS_USER_INFO, userInfo)

val thrown = intercept[IllegalArgumentException] {
CondaEnvironmentManager.fromConf(conf)
.createWithFile(condaEnvDir.toString, Seq(packageUrl), Some(userInfo))
}

assert(("requirement failed: Cannot pass condaPackageUrls with inlined auth; pass UserInfo " +
"via spark.conda.bootstrapPackageUrlsUserInfo.")
.equals(thrown.getMessage))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,7 @@ class YarnClusterSuite extends BaseYarnClusterSuite {
extraEnv = extraEnvVars,
extraConf = extraConf,
outFile = outFile,
timeoutDuration = 4.minutes) // give it a bit longer
timeoutDuration = 5.minutes) // give it a bit longer
checkResult(finalState, result, outFile = outFile)
}

Expand Down