From 093989eca9b495e7e2d7c6c9a584d0b66a71f34e Mon Sep 17 00:00:00 2001 From: Katarzyna Marek Date: Fri, 28 Jun 2024 12:00:35 +0200 Subject: [PATCH] feat: run DAP for test suites in Metals --- .../ImplementationProvider.scala | 6 +- .../meta/internal/metals/Compilers.scala | 12 ++ .../internal/metals/MetalsEnrichments.scala | 37 +++- .../metals/codelenses/RunTestCodeLens.scala | 7 +- .../metals/debug/BuildTargetClasses.scala | 65 ++++-- .../internal/metals/debug/DebugProvider.scala | 85 ++++++-- .../internal/metals/debug/server/Logger.scala | 27 +++ .../debug/server/MainClassDebugAdapter.scala | 105 +--------- .../internal/metals/debug/server/Run.scala | 84 ++++++++ .../debug/server/TestSuiteDebugAdapter.scala | 159 ++++++++++++++ .../debug/server/testing/FilteredLoader.scala | 28 +++ .../server/testing/LoggingEventHandler.scala | 186 +++++++++++++++++ .../testing/SerializableFingerprints.scala | 21 ++ .../debug/server/testing/TestInternals.scala | 192 +++++++++++++++++ .../debug/server/testing/TestServer.scala | 197 ++++++++++++++++++ .../testProvider/TestSuitesProvider.scala | 15 +- .../scala/meta/pc/PcSymbolInformation.java | 12 ++ .../internal/pc/PcSymbolInformation.scala | 20 +- .../internal/pc/WorkspaceSymbolSearch.scala | 20 +- .../pc/SymbolInformationProvider.scala | 32 ++- .../tests/mill/MillDebugDiscoverySuite.scala | 164 +++++++++++++++ .../tests/mill/MillServerCodeLensSuite.scala | 1 + .../main/scala/tests/BuildServerLayout.scala | 42 +++- .../tests/debug/BaseBreakpointDapSuite.scala | 1 + .../test/scala/tests/DebugProtocolSuite.scala | 2 + 25 files changed, 1353 insertions(+), 167 deletions(-) create mode 100644 metals/src/main/scala/scala/meta/internal/metals/debug/server/Logger.scala create mode 100644 metals/src/main/scala/scala/meta/internal/metals/debug/server/Run.scala create mode 100644 metals/src/main/scala/scala/meta/internal/metals/debug/server/TestSuiteDebugAdapter.scala create mode 100644 metals/src/main/scala/scala/meta/internal/metals/debug/server/testing/FilteredLoader.scala create mode 100644 metals/src/main/scala/scala/meta/internal/metals/debug/server/testing/LoggingEventHandler.scala create mode 100644 metals/src/main/scala/scala/meta/internal/metals/debug/server/testing/SerializableFingerprints.scala create mode 100644 metals/src/main/scala/scala/meta/internal/metals/debug/server/testing/TestInternals.scala create mode 100644 metals/src/main/scala/scala/meta/internal/metals/debug/server/testing/TestServer.scala create mode 100644 tests/slow/src/test/scala/tests/mill/MillDebugDiscoverySuite.scala diff --git a/metals/src/main/scala/scala/meta/internal/implementation/ImplementationProvider.scala b/metals/src/main/scala/scala/meta/internal/implementation/ImplementationProvider.scala index 6cf59eab16c..1a7d06f8ac6 100644 --- a/metals/src/main/scala/scala/meta/internal/implementation/ImplementationProvider.scala +++ b/metals/src/main/scala/scala/meta/internal/implementation/ImplementationProvider.scala @@ -525,7 +525,11 @@ final class ImplementationProvider( classOwner = classOwnerInfoOpt.map(_.symbol), alternativeSymbols = alternativeSymbols.toList, overriddenSymbols = info.overriddenSymbols.toList, - properties = if (info.isAbstract) List(PcSymbolProperty.ABSTRACT) else Nil, + properties = + if (info.isAbstract) List(PcSymbolProperty.ABSTRACT) else Nil, + recursiveParents = parents, + annotations = info.annotations.map(_.toString()).toList, + memberDefsAnnotations = Nil, ) } } diff --git a/metals/src/main/scala/scala/meta/internal/metals/Compilers.scala b/metals/src/main/scala/scala/meta/internal/metals/Compilers.scala index 538be4464e7..c1786c98d36 100644 --- a/metals/src/main/scala/scala/meta/internal/metals/Compilers.scala +++ b/metals/src/main/scala/scala/meta/internal/metals/Compilers.scala @@ -941,6 +941,18 @@ class Compilers( .getOrElse(Future(None)) } + def info( + id: BuildTargetIdentifier, + symbol: String, + ): Future[Option[PcSymbolInformation]] = { + loadCompiler(id) + .map( + _.info(symbol).asScala + .map(_.asScala.map(PcSymbolInformation.from)) + ) + .getOrElse(Future(None)) + } + private def definition( params: TextDocumentPositionParams, token: CancelToken, diff --git a/metals/src/main/scala/scala/meta/internal/metals/MetalsEnrichments.scala b/metals/src/main/scala/scala/meta/internal/metals/MetalsEnrichments.scala index 4062c6d3dc9..ccc0c45f4ca 100644 --- a/metals/src/main/scala/scala/meta/internal/metals/MetalsEnrichments.scala +++ b/metals/src/main/scala/scala/meta/internal/metals/MetalsEnrichments.scala @@ -800,6 +800,9 @@ object MetalsEnrichments value.replace('/', ':') else value } + + def symbolToFullQualifiedName: String = + value.replaceAll("/|#", ".").stripSuffix(".") } implicit class XtensionTextDocumentSemanticdb(textDocument: s.TextDocument) { @@ -1311,16 +1314,40 @@ object MetalsEnrichments } implicit class XtensionDebugSessionParams(params: b.DebugSessionParams) { - def asScalaMainClass(): Either[String, b.ScalaMainClass] = + def asScalaMainClass(): Either[String, b.ScalaMainClass] = { + lazy val className = "ScalaMainClass" params.getDataKind() match { case b.DebugSessionParamsDataKind.SCALA_MAIN_CLASS => decodeJson(params.getData(), classOf[b.ScalaMainClass]) - .toRight(s"Cannot decode $params as `ScalaMainClass`.") - case _ => - Left( - s"Cannot decode params as `ScalaMainClass` incorrect data kind: ${params.getDataKind()}." + .toRight(cannotDecode(className)) + case _ => incorrectKind(className) + } + } + + def asScalaTestSuites(): Either[String, b.ScalaTestSuites] = { + lazy val className = "ScalaTestSuites" + params.getDataKind() match { + case b.TestParamsDataKind.SCALA_TEST_SUITES_SELECTION => + decodeJson(params.getData(), classOf[b.ScalaTestSuites]) + .toRight(cannotDecode(className)) + case b.TestParamsDataKind.SCALA_TEST_SUITES => + (for ( + tests <- decodeJson(params.getData(), classOf[util.List[String]]) ) + yield { + val suites = + tests.map(new b.ScalaTestSuiteSelection(_, Nil.asJava)) + new b.ScalaTestSuites(suites, Nil.asJava, Nil.asJava) + }).toRight(cannotDecode(className)) + case _ => incorrectKind(className) } + } + def cannotDecode(className: String): String = + s"Cannot decode $params as `$className`." + def incorrectKind(className: String): Left[String, Nothing] = + Left( + s"Cannot decode params as `$className` incorrect data kind: ${params.getDataKind()}." + ) } /** diff --git a/metals/src/main/scala/scala/meta/internal/metals/codelenses/RunTestCodeLens.scala b/metals/src/main/scala/scala/meta/internal/metals/codelenses/RunTestCodeLens.scala index f754325a0a2..076989d25bc 100644 --- a/metals/src/main/scala/scala/meta/internal/metals/codelenses/RunTestCodeLens.scala +++ b/metals/src/main/scala/scala/meta/internal/metals/codelenses/RunTestCodeLens.scala @@ -108,17 +108,15 @@ final class RunTestCodeLens( buildServerCanDebug, isJVM, ) - } else if (buildServerCanDebug || clientConfig.isRunProvider()) { + } else codeLenses( textDocument, buildTargetId, classes, distance, path, - buildServerCanDebug, isJVM, ) - } else { Nil } } @@ -198,7 +196,6 @@ final class RunTestCodeLens( classes: BuildTargetClasses.Classes, distance: TokenEditDistance, path: AbsolutePath, - buildServerCanDebug: Boolean, isJVM: Boolean, ): Seq[l.CodeLens] = { for { @@ -212,7 +209,7 @@ final class RunTestCodeLens( .getOrElse(Nil) lazy val tests = // Currently tests can only be run via DAP - if (clientConfig.isDebuggingProvider() && buildServerCanDebug) + if (clientConfig.isDebuggingProvider()) testClasses(target, classes, symbol, isJVM) else Nil val fromAnnot = DebugProvider diff --git a/metals/src/main/scala/scala/meta/internal/metals/debug/BuildTargetClasses.scala b/metals/src/main/scala/scala/meta/internal/metals/debug/BuildTargetClasses.scala index 0f659c9c101..a7f4886a288 100644 --- a/metals/src/main/scala/scala/meta/internal/metals/debug/BuildTargetClasses.scala +++ b/metals/src/main/scala/scala/meta/internal/metals/debug/BuildTargetClasses.scala @@ -8,6 +8,7 @@ import scala.meta.internal.metals.BatchedFunction import scala.meta.internal.metals.BuildTargets import scala.meta.internal.metals.MetalsEnrichments._ import scala.meta.internal.metals.debug.BuildTargetClasses.Classes +import scala.meta.internal.metals.debug.BuildTargetClasses.TestSymbolInfo import scala.meta.internal.semanticdb.Scala.Descriptor import scala.meta.internal.semanticdb.Scala.Symbols @@ -16,9 +17,9 @@ import ch.epfl.scala.{bsp4j => b} /** * In-memory index of main class symbols grouped by their enclosing build target */ -final class BuildTargetClasses( - buildTargets: BuildTargets -)(implicit val ec: ExecutionContext) { +final class BuildTargetClasses(buildTargets: BuildTargets)(implicit + val ec: ExecutionContext +) { private val index = TrieMap.empty[b.BuildTargetIdentifier, Classes] private val jvmRunEnvironments : TrieMap[b.BuildTargetIdentifier, b.JvmEnvironmentItem] = @@ -56,6 +57,19 @@ final class BuildTargetClasses( .map(_.fullyQualifiedName) ) + def getTestClasses( + name: String, + id: b.BuildTargetIdentifier, + ): List[(String, TestSymbolInfo)] = { + index.get(id).toList.flatMap { + _.testClasses + .filter { case (_, info) => + info.fullyQualifiedName == name + } + .toList + } + } + private def findClassesBy[A]( f: Classes => Option[A] ): List[(A, b.BuildTargetIdentifier)] = { @@ -77,25 +91,22 @@ final class BuildTargetClasses( Future.successful(()) case (Some(connection), targets0) => val targetsList = targets0.asJava - targetsList.forEach(invalidate) val classes = targets0.map(t => (t, new Classes)).toMap val updateMainClasses = connection .mainClasses(new b.ScalaMainClassesParams(targetsList)) .map(cacheMainClasses(classes, _)) - // Currently tests are only run using DAP val updateTestClasses = - if (connection.isDebuggingProvider || connection.isSbt) - connection - .testClasses(new b.ScalaTestClassesParams(targetsList)) - .map(cacheTestClasses(classes, _)) - else Future.unit + connection + .testClasses(new b.ScalaTestClassesParams(targetsList)) + .map(cacheTestClasses(classes, _)) for { _ <- updateMainClasses _ <- updateTestClasses } yield { + targetsList.forEach(invalidate) classes.foreach { case (id, classes) => index.put(id, classes) } @@ -214,7 +225,10 @@ final class BuildTargetClasses( } } -sealed abstract class TestFramework(val canResolveChildren: Boolean) +sealed abstract class TestFramework(val canResolveChildren: Boolean) { + def names: List[String] +} + object TestFramework { def apply(framework: Option[String]): TestFramework = framework .map { @@ -226,11 +240,30 @@ object TestFramework { } .getOrElse(Unknown) } -case object JUnit4 extends TestFramework(true) -case object MUnit extends TestFramework(true) -case object Scalatest extends TestFramework(true) -case object WeaverCatsEffect extends TestFramework(true) -case object Unknown extends TestFramework(false) + +case object JUnit4 extends TestFramework(true) { + def names: List[String] = List("com.novocode.junit.JUnitFramework") +} + +case object MUnit extends TestFramework(true) { + def names: List[String] = List("munit.Framework") +} + +case object Scalatest extends TestFramework(true) { + def names: List[String] = + List( + "org.scalatest.tools.Framework", + "org.scalatest.tools.ScalaTestFramework", + ) +} + +case object WeaverCatsEffect extends TestFramework(true) { + def names: List[String] = List("weaver.BaseCatsSuite") +} + +case object Unknown extends TestFramework(false) { + def names: List[String] = Nil +} object BuildTargetClasses { type Symbol = String diff --git a/metals/src/main/scala/scala/meta/internal/metals/debug/DebugProvider.scala b/metals/src/main/scala/scala/meta/internal/metals/debug/DebugProvider.scala index 060a35c7d50..0097a732495 100644 --- a/metals/src/main/scala/scala/meta/internal/metals/debug/DebugProvider.scala +++ b/metals/src/main/scala/scala/meta/internal/metals/debug/DebugProvider.scala @@ -54,9 +54,11 @@ import scala.meta.internal.metals.config.RunType import scala.meta.internal.metals.config.RunType._ import scala.meta.internal.metals.debug.server.DebugLogger import scala.meta.internal.metals.debug.server.DebugeeParamsCreator +import scala.meta.internal.metals.debug.server.Discovered import scala.meta.internal.metals.debug.server.MainClassDebugAdapter import scala.meta.internal.metals.debug.server.MetalsDebugToolsResolver import scala.meta.internal.metals.debug.server.MetalsDebuggee +import scala.meta.internal.metals.debug.server.TestSuiteDebugAdapter import scala.meta.internal.metals.testProvider.TestSuitesProvider import scala.meta.internal.mtags.DefinitionAlternatives.GlobalSymbol import scala.meta.internal.mtags.OnDemandSymbolIndex @@ -328,38 +330,60 @@ class DebugProvider( buildServer: BuildServerConnection, params: DebugSessionParams, cancelPromise: Promise[Unit], - ) = + )(implicit ec: ExecutionContext) = if (buildServer.isDebuggingProvider || buildServer.isSbt) { buildServer.startDebugSession(params, cancelPromise) } else { - def getDebugee: Either[String, MetalsDebuggee] = + def getDebugee: Either[String, Future[MetalsDebuggee]] = { + def buildTarget = params + .getTargets() + .asScala + .headOption + .toRight(s"Missing build target in debug params.") params.getDataKind() match { case b.DebugSessionParamsDataKind.SCALA_MAIN_CLASS => for { - id <- params - .getTargets() - .asScala - .headOption - .toRight(s"Missing build target in debug params.") + id <- buildTarget projectInfo <- debugConfigCreator.create(id) scalaMainClass <- params.asScalaMainClass() - } yield new MainClassDebugAdapter( - workspace, - scalaMainClass, - projectInfo, - userConfig().javaHome, + } yield Future.successful( + new MainClassDebugAdapter( + workspace, + scalaMainClass, + projectInfo, + userConfig().javaHome, + ) ) + case (b.TestParamsDataKind.SCALA_TEST_SUITES_SELECTION | + b.TestParamsDataKind.SCALA_TEST_SUITES) => + for { + id <- buildTarget + project <- debugConfigCreator.create(id) + testSuites <- params.asScalaTestSuites() + } yield { + for { + discovered <- discoverTests(id, testSuites) + } yield new TestSuiteDebugAdapter( + workspace, + testSuites, + project, + userConfig().javaHome, + discovered, + ) + } case kind => Left(s"Starting debug session for $kind in not supported.") } + } for { _ <- compilations.compileTargets(params.getTargets().asScala.toSeq) - } yield { - val debuggee = getDebugee match { + debuggee <- getDebugee match { case Right(debuggee) => debuggee - case Left(errorMessage) => throw new RuntimeException(errorMessage) + case Left(errorMessage) => + Future.failed(new RuntimeException(errorMessage)) } + } yield { val dapLogger = new DebugLogger() val resolver = new MetalsDebugToolsResolver() val handler = @@ -373,6 +397,37 @@ class DebugProvider( } } + private def discoverTests( + id: BuildTargetIdentifier, + testClasses: b.ScalaTestSuites, + ): Future[Map[TestFramework, List[Discovered]]] = { + val symbolInfosList = + for { + selection <- testClasses.getSuites().asScala.toList + (sym, info) <- buildTargetClasses.getTestClasses( + selection.getClassName(), + id, + ) + } yield compilers.info(id, sym).map(_.map(pcInfo => (info, pcInfo))) + + Future.sequence(symbolInfosList).map { + _.flatten.groupBy(_._1.framework).map { case (framework, testSuites) => + ( + framework, + testSuites.map { case (testInfo, pcInfo) => + new Discovered( + pcInfo.symbol, + testInfo.fullyQualifiedName, + pcInfo.recursiveParents.map(_.symbolToFullQualifiedName).toSet, + (pcInfo.annotations ++ pcInfo.memberDefsAnnotations).toSet, + isModule = false, + ) + }, + ) + } + } + } + /** * Given a BuildTargetIdentifier either get the displayName of that build * target or default to the full URI to display to the user. diff --git a/metals/src/main/scala/scala/meta/internal/metals/debug/server/Logger.scala b/metals/src/main/scala/scala/meta/internal/metals/debug/server/Logger.scala new file mode 100644 index 00000000000..4a100a350a9 --- /dev/null +++ b/metals/src/main/scala/scala/meta/internal/metals/debug/server/Logger.scala @@ -0,0 +1,27 @@ +package scala.meta.internal.metals.debug.server + +import java.net.InetSocketAddress +import java.util.concurrent.atomic.AtomicBoolean + +import ch.epfl.scala.debugadapter.DebuggeeListener + +class Logger(listener: DebuggeeListener) { + private val initialized = new AtomicBoolean(false) + private final val JDINotificationPrefix = + "Listening for transport dt_socket at address: " + + def logError(errorMessage: String): Unit = { + listener.err(errorMessage) + scribe.error(errorMessage) + } + + def logOutput(msg: String): Unit = { + if (msg.startsWith(JDINotificationPrefix)) { + if (initialized.compareAndSet(false, true)) { + val port = Integer.parseInt(msg.drop(JDINotificationPrefix.length)) + val address = new InetSocketAddress("127.0.0.1", port) + listener.onListening(address) + } + } else listener.out(msg) + } +} diff --git a/metals/src/main/scala/scala/meta/internal/metals/debug/server/MainClassDebugAdapter.scala b/metals/src/main/scala/scala/meta/internal/metals/debug/server/MainClassDebugAdapter.scala index e3df1b4b406..ae0d0631f53 100644 --- a/metals/src/main/scala/scala/meta/internal/metals/debug/server/MainClassDebugAdapter.scala +++ b/metals/src/main/scala/scala/meta/internal/metals/debug/server/MainClassDebugAdapter.scala @@ -1,16 +1,9 @@ package scala.meta.internal.metals.debug.server -import java.io.File -import java.net.InetSocketAddress -import java.util.concurrent.atomic.AtomicBoolean - import scala.concurrent.ExecutionContext -import scala.meta.internal.metals.JavaBinary import scala.meta.internal.metals.JdkSources -import scala.meta.internal.metals.ManifestJar import scala.meta.internal.metals.MetalsEnrichments._ -import scala.meta.internal.process.SystemProcess import scala.meta.io.AbsolutePath import ch.epfl.scala.bsp4j.ScalaMainClass @@ -29,17 +22,12 @@ class MainClassDebugAdapter( )(implicit ec: ExecutionContext) extends MetalsDebuggee() { - private val initialized = new AtomicBoolean(false) - override def modules: Seq[Module] = project.modules override def libraries: Seq[Library] = project.libraries override def unmanagedEntries: Seq[UnmanagedEntry] = project.unmanagedEntries - private final val JDINotificationPrefix = - "Listening for transport dt_socket at address: " - protected def scalaVersionOpt: Option[String] = project.scalaVersion val javaRuntime: Option[JavaRuntime] = @@ -50,86 +38,15 @@ class MainClassDebugAdapter( def name: String = s"${getClass.getSimpleName}(${project.name}, ${mainClass.getClassName()})" - def run(listener: DebuggeeListener): CancelableFuture[Unit] = { - val jvmOptions = - mainClass.getJvmOptions.asScala.toList :+ enableDebugInterface - val fullClasspathStr = - classPath.map(_.toString()).mkString(File.pathSeparator) - val java = JavaBinary(userJavaHome).toString() - val classpathOption = "-cp" :: fullClasspathStr :: Nil - val appOptions = - mainClass.getClassName :: mainClass.getArguments().asScala.toList - val cmd = java :: jvmOptions ::: classpathOption ::: appOptions - val cmdLength = cmd.foldLeft(0)(_ + _.length) - val envOptions = - mainClass - .getEnvironmentVariables() - .asScala - .flatMap { line => - val eqIdx = line.indexOf("=") - if (eqIdx > 0 && eqIdx != line.length - 1) { - val key = line.substring(0, eqIdx) - val value = line.substring(eqIdx + 1) - Some(key -> value) - } else None - } - .toMap - - def logError(errorMessage: String) = { - listener.err(errorMessage) - scribe.error(errorMessage) - } - - def logOutput(msg: String) = { - if (msg.startsWith(JDINotificationPrefix)) { - if (initialized.compareAndSet(false, true)) { - val port = Integer.parseInt(msg.drop(JDINotificationPrefix.length)) - val address = new InetSocketAddress("127.0.0.1", port) - listener.onListening(address) - } - } else { - listener.out(msg) - } - } - - // Note that we current only shorten the classpath portion and not other options - // Thus we do not yet *guarantee* that the command will not exceed OS limits - val process = - if (cmdLength <= SystemProcess.processCmdCharLimit) { - SystemProcess.run( - cmd, - root, - redirectErrorOutput = false, - envOptions, - processErr = Some(logError), - processOut = Some(logOutput), - ) - } else { - ManifestJar.withTempManifestJar(classPath) { manifestJar => - val shortClasspathOption = "-cp" :: manifestJar.syntax :: Nil - val shortCmd = - java :: jvmOptions ::: shortClasspathOption ::: appOptions - SystemProcess.run( - shortCmd, - root, - redirectErrorOutput = false, - envOptions, - processErr = Some(logError), - processOut = Some(logOutput), - ) - } - } - - new CancelableFuture[Unit] { - def future = process.complete.map { code => - if (code != 0) - throw new Exception(s"debuggee failed with error code $code") - } - def cancel(): Unit = process.cancel - } - } - - private def enableDebugInterface: String = { - s"-agentlib:jdwp=transport=dt_socket,server=y,suspend=y,quiet=n" - } + def run(listener: DebuggeeListener): CancelableFuture[Unit] = + Run.runMain( + root = root, + classPath = classPath, + userJavaHome = userJavaHome, + className = mainClass.getClassName, + args = mainClass.getArguments().asScala.toList, + jvmOptions = mainClass.getJvmOptions.asScala.toList, + evnVariables = mainClass.getEnvironmentVariables().asScala.toList, + logger = new Logger(listener), + ) } diff --git a/metals/src/main/scala/scala/meta/internal/metals/debug/server/Run.scala b/metals/src/main/scala/scala/meta/internal/metals/debug/server/Run.scala new file mode 100644 index 00000000000..3c58773e938 --- /dev/null +++ b/metals/src/main/scala/scala/meta/internal/metals/debug/server/Run.scala @@ -0,0 +1,84 @@ +package scala.meta.internal.metals.debug.server + +import java.io.File +import java.nio.file.Path + +import scala.concurrent.ExecutionContext + +import scala.meta.internal.metals.JavaBinary +import scala.meta.internal.metals.ManifestJar +import scala.meta.internal.process.SystemProcess +import scala.meta.io.AbsolutePath + +import ch.epfl.scala.debugadapter.CancelableFuture + +object Run { + + def runMain( + root: AbsolutePath, + classPath: Seq[Path], + userJavaHome: Option[String], + className: String, + args: List[String], + jvmOptions: List[String], + evnVariables: List[String], + logger: Logger, + )(implicit ec: ExecutionContext): CancelableFuture[Unit] = { + val fullClasspathStr = + classPath.map(_.toString()).mkString(File.pathSeparator) + val java = JavaBinary(userJavaHome).toString() + val classpathOption = "-cp" :: fullClasspathStr :: Nil + val cmd = + java :: (jvmOptions :+ enableDebugInterface) ::: classpathOption ::: (className :: args) + val cmdLength = cmd.foldLeft(0)(_ + _.length) + val envOptions = + evnVariables.flatMap { line => + val eqIdx = line.indexOf("=") + if (eqIdx > 0 && eqIdx != line.length - 1) { + val key = line.substring(0, eqIdx) + val value = line.substring(eqIdx + 1) + Some(key -> value) + } else None + }.toMap + + // Note that we current only shorten the classpath portion and not other options + // Thus we do not yet *guarantee* that the command will not exceed OS limits + val process = + if (cmdLength <= SystemProcess.processCmdCharLimit) { + SystemProcess.run( + cmd, + root, + redirectErrorOutput = false, + envOptions, + processErr = Some(logger.logError), + processOut = Some(logger.logOutput), + ) + } else { + ManifestJar.withTempManifestJar(classPath) { manifestJar => + val shortClasspathOption = "-cp" :: manifestJar.syntax :: Nil + val shortCmd = + java :: jvmOptions ::: shortClasspathOption ::: (className :: args) // appOptions + SystemProcess.run( + shortCmd, + root, + redirectErrorOutput = false, + envOptions, + processErr = Some(logger.logError), + processOut = Some(logger.logOutput), + ) + } + } + + new CancelableFuture[Unit] { + def future = process.complete.map { code => + if (code != 0) + throw new Exception(s"debuggee failed with error code $code") + } + def cancel(): Unit = process.cancel + } + } + + private def enableDebugInterface: String = { + s"-agentlib:jdwp=transport=dt_socket,server=y,suspend=y,quiet=n" + } +} diff --git a/metals/src/main/scala/scala/meta/internal/metals/debug/server/TestSuiteDebugAdapter.scala b/metals/src/main/scala/scala/meta/internal/metals/debug/server/TestSuiteDebugAdapter.scala new file mode 100644 index 00000000000..b19745186a0 --- /dev/null +++ b/metals/src/main/scala/scala/meta/internal/metals/debug/server/TestSuiteDebugAdapter.scala @@ -0,0 +1,159 @@ +package scala.meta.internal.metals.debug.server + +import java.net.URLClassLoader + +import scala.collection.mutable +import scala.concurrent.ExecutionContext + +import scala.meta.internal.metals.JdkSources +import scala.meta.internal.metals.MetalsEnrichments._ +import scala.meta.internal.metals.debug.TestFramework +import scala.meta.internal.metals.debug.server.testing.FingerprintInfo +import scala.meta.internal.metals.debug.server.testing.LoggingEventHandler +import scala.meta.internal.metals.debug.server.testing.TestInternals +import scala.meta.internal.metals.debug.server.testing.TestServer +import scala.meta.io.AbsolutePath + +import ch.epfl.scala.bsp4j.ScalaTestSuites +import ch.epfl.scala.debugadapter.CancelableFuture +import ch.epfl.scala.debugadapter.DebuggeeListener +import ch.epfl.scala.debugadapter.JavaRuntime +import ch.epfl.scala.debugadapter.Library +import ch.epfl.scala.debugadapter.Module +import ch.epfl.scala.debugadapter.UnmanagedEntry +import sbt.testing.Framework +import sbt.testing.SuiteSelector +import sbt.testing.TaskDef +import sbt.testing.TestSelector + +class TestSuiteDebugAdapter( + root: AbsolutePath, + testClasses: ScalaTestSuites, + project: DebugeeProject, + userJavaHome: Option[String], + discoveredTests: Map[TestFramework, List[Discovered]], +)(implicit ec: ExecutionContext) + extends MetalsDebuggee() { + + override def name: String = { + val selectedTests = testClasses + .getSuites() + .asScala + .map { suite => + val tests = suite.getTests.asScala.mkString("(", ",", ")") + s"${suite.getClassName()}$tests" + } + .mkString("[", ", ", "]") + s"${getClass.getSimpleName}($selectedTests)" + } + + override def modules: Seq[Module] = project.modules + override def libraries: Seq[Library] = project.libraries + override def unmanagedEntries: Seq[UnmanagedEntry] = project.unmanagedEntries + override protected def scalaVersionOpt: Option[String] = project.scalaVersion + + override val javaRuntime: Option[JavaRuntime] = + JdkSources + .defaultJavaHome(userJavaHome) + .flatMap(path => JavaRuntime(path.toNIO)) + .headOption + + def newClassLoader(parent: Option[ClassLoader]): ClassLoader = { + val classpathEntries = classPath.map(_.toUri.toURL).toArray + new URLClassLoader(classpathEntries, parent.orNull) + } + + def suites(frameworks: Seq[Framework]): Map[Framework, List[TaskDef]] = { + val testFilter = TestInternals.parseFilters( + testClasses.getSuites.asScala.map(_.getClassName()).toList + ) + val (subclassPrints, annotatedPrints) = + TestInternals.getFingerprints(frameworks) + val tasks = mutable.ArrayBuffer.empty[(Framework, TaskDef)] + val seen = mutable.Set.empty[String] + discoveredTests + .flatMap { case (testFramework, discovered) => + discovered.map((testFramework, _)) + } + .foreach { case (_, discovered) => + TestInternals + .matchingFingerprints(subclassPrints, annotatedPrints, discovered) + .foreach { case FingerprintInfo(_, _, framework, fingerprint) => + if (seen.add(discovered.className)) { + tasks += (framework -> new TaskDef( + discovered.className, + fingerprint, + false, + Array(new SuiteSelector), + )) + } + } + } + val selectedTests = testClasses + .getSuites() + .asScala + .map(entry => (entry.getClassName(), entry.getTests().asScala.toList)) + .toMap + tasks.toSeq + .filter { case (_, taskDef) => + val fqn = taskDef.fullyQualifiedName() + testFilter(fqn) + } + .groupBy(_._1) + .mapValues(_.map { case (_, taskDef) => + selectedTests.get(taskDef.fullyQualifiedName()).getOrElse(Nil) match { + case Nil => taskDef + case selectedTests => + new TaskDef( + taskDef.fullyQualifiedName(), + taskDef.fingerprint(), + false, + selectedTests.map(test => new TestSelector(test)).toList.toArray, + ) + } + }.toList) + .toMap + } + + override def run(listener: DebuggeeListener): CancelableFuture[Unit] = { + val loader = newClassLoader(Some(TestInternals.filteredLoader)) + val frameworks = discoveredTests.flatMap { case (framework, _) => + TestInternals.loadFramework(loader, framework.names) + } + val handler = new LoggingEventHandler(listener) + val jvmOptions = testClasses.getJvmOptions.asScala.toList + val envOptions = testClasses.getEnvironmentVariables().asScala.toList + + scribe.debug("Starting forked test execution...") + val resolvedSuites = suites(frameworks.toSeq) + + val server = + new TestServer(handler, loader, resolvedSuites) + val forkMain = classOf[sbt.ForkMain].getCanonicalName + val arguments = List(server.port.toString) + val testAgentJars = + TestInternals.testAgentFiles.filter(_.toString.endsWith(".jar")) + scribe.debug("Test agent JARs: " + testAgentJars.mkString(", ")) + + server.listenToTests + Run.runMain( + root, + classPath ++ testAgentJars, + userJavaHome, + forkMain, + arguments, + jvmOptions, + envOptions, + new Logger(listener), + ) + } + +} + +final case class Discovered( + symbol: String, + className: String, + baseClasses: Set[String], + annotations: Set[String], + isModule: Boolean, +) diff --git a/metals/src/main/scala/scala/meta/internal/metals/debug/server/testing/FilteredLoader.scala b/metals/src/main/scala/scala/meta/internal/metals/debug/server/testing/FilteredLoader.scala new file mode 100644 index 00000000000..dcb051bb673 --- /dev/null +++ b/metals/src/main/scala/scala/meta/internal/metals/debug/server/testing/FilteredLoader.scala @@ -0,0 +1,28 @@ +package scala.meta.internal.metals.debug.server.testing + +/** + * Delegates class loading to `parent` for all classes included by `filter`. An attempt to load classes excluded by `filter` + * results in a `ClassNotFoundException`. + */ +final class FilteredLoader(parent: ClassLoader, filter: IncludeClassFilter) + extends ClassLoader(parent) { + require( + parent != null + ) // included because a null parent is legitimate in Java + + @throws(classOf[ClassNotFoundException]) + override final def loadClass( + className: String, + resolve: Boolean, + ): Class[_] = { + if (filter.include(className)) + super.loadClass(className, resolve) + else + throw new ClassNotFoundException(className) + } +} + +class IncludeClassFilter(packages: Set[String]) { + def include(className: String): Boolean = + packages.exists(className.startsWith) +} diff --git a/metals/src/main/scala/scala/meta/internal/metals/debug/server/testing/LoggingEventHandler.scala b/metals/src/main/scala/scala/meta/internal/metals/debug/server/testing/LoggingEventHandler.scala new file mode 100644 index 00000000000..b3533c08262 --- /dev/null +++ b/metals/src/main/scala/scala/meta/internal/metals/debug/server/testing/LoggingEventHandler.scala @@ -0,0 +1,186 @@ +package scala.meta.internal.metals.debug.server.testing + +import java.util.Locale +import java.util.concurrent.TimeUnit + +import scala.collection.mutable + +import ch.epfl.scala.debugadapter.DebuggeeListener +import ch.epfl.scala.debugadapter.testing.TestSuiteEvent +import ch.epfl.scala.debugadapter.testing.TestSuiteEventHandler +import ch.epfl.scala.debugadapter.testing.TestUtils +import sbt.testing.Event +import sbt.testing.Status + +class LoggingEventHandler(listener: DebuggeeListener) + extends TestSuiteEventHandler { + type SuiteName = String + type TestName = String + type FailureMessage = String + + private val failedStatuses = + Set(Status.Error, Status.Canceled, Status.Failure) + + protected var suitesDuration = 0L + protected var suitesPassed = 0 + protected var suitesAborted = 0 + protected val testsFailedBySuite + : mutable.SortedMap[SuiteName, Map[TestName, FailureMessage]] = + mutable.SortedMap.empty[SuiteName, Map[TestName, FailureMessage]] + protected var suitesTotal = 0 + + protected def formatMetrics(metrics: List[(Int, String)]): String = { + val relevant = metrics.iterator.filter(_._1 > 0) + relevant.map { case (value, metric) => value + " " + metric }.mkString(", ") + } + + override def handle(event: TestSuiteEvent): Unit = event match { + case TestSuiteEvent.Error(message) => + listener.err(message) + error(message) + case TestSuiteEvent.Warn(message) => scribe.warn(message) + case TestSuiteEvent.Info(message) => info(message) + case TestSuiteEvent.Debug(message) => scribe.debug(message) + case TestSuiteEvent.Trace(throwable) => + error("Test suite aborted") + scribe.trace(throwable) + suitesAborted += 1 + suitesTotal += 1 + + case results @ TestSuiteEvent.Results(testSuite, events) => + val summary = TestSuiteEventHandler.summarizeResults(results) + listener.testResult(summary) + val testsTotal = events.length + + info( + s"Execution took ${TimeFormat.readableMillis(results.duration)}" + ) + val regularMetrics = List( + testsTotal -> "tests", + results.passed -> "passed", + results.pending -> "pending", + results.ignored -> "ignored", + results.skipped -> "skipped", + ) + + // If test metrics + val failureCount = results.failed + results.canceled + results.errors + val failureMetrics = + List( + results.failed -> "failed", + results.canceled -> "canceled", + results.errors -> "errors", + ) + val testMetrics = formatMetrics(regularMetrics ++ failureMetrics) + if (!testMetrics.isEmpty) info(testMetrics) + + if (failureCount > 0) { + val currentFailedTests = extractErrors(events) + val previousFailedTests = + testsFailedBySuite.getOrElse(testSuite, Map.empty) + testsFailedBySuite += testSuite -> (previousFailedTests ++ currentFailedTests) + } else if (testsTotal <= 0) info("No test suite was run") + else { + suitesPassed += 1 + info(s"All tests in $testSuite passed") + } + + suitesTotal += 1 + suitesDuration += results.duration + + case TestSuiteEvent.Done => () + } + + private def extractErrors(events: List[Event]) = + events + .filter(e => failedStatuses.contains(e.status())) + .map { event => + val selectorOpt = TestUtils.printSelector(event.selector) + if (selectorOpt.isEmpty) { + scribe.debug( + s"Unexpected test selector ${event.selector} won't be pretty printed!" + ) + } + val key = selectorOpt.getOrElse("") + val value = TestUtils.printThrowable(event.throwable()).getOrElse("") + key -> value + } + .toMap + + def report(): Unit = { + // TODO: Shall we think of a better way to format this delimiter based on screen length? + info("===============================================") + info(s"Total duration: ${TimeFormat.readableMillis(suitesDuration)}") + + if (suitesTotal == 0) { + info(s"No test suites were run.") + } else if (suitesPassed == suitesTotal) { + info(s"All $suitesPassed test suites passed.") + } else { + val metrics = List( + suitesPassed -> "passed", + testsFailedBySuite.size -> "failed", + suitesAborted -> "aborted", + ) + + info(formatMetrics(metrics)) + if (testsFailedBySuite.nonEmpty) { + info("") + info("Failed:") + testsFailedBySuite.foreach { case (suiteName, failedTests) => + info(s"- $suiteName:") + val summary = failedTests.map { case (suiteName, failureMsg) => + TestSuiteEventHandler.formatError( + suiteName, + failureMsg, + indentSize = 2, + ) + } + summary.foreach(s => info(s)) + } + } + } + + info("===============================================") + } + + def info(message: String): Unit = { + listener.out(message) + scribe.info(message) + } + + def error(message: String): Unit = { + listener.err(message) + scribe.error(message) + } +} + +object TimeFormat { + def readableMillis(nanos: Long): String = { + import java.text.DecimalFormat + import java.text.DecimalFormatSymbols + val seconds = TimeUnit.MILLISECONDS.toSeconds(nanos) + if (seconds > 9) readableSeconds(seconds) + else { + val ms = TimeUnit.MILLISECONDS.toMillis(nanos) + if (ms < 100) { + s"${ms}ms" + } else { + val partialSeconds = ms.toDouble / 1000 + new DecimalFormat("#.##s", new DecimalFormatSymbols(Locale.US)) + .format(partialSeconds) + } + } + } + + def readableSeconds(n: Long): String = { + val minutes = n / 60 + val seconds = n % 60 + if (minutes > 0) { + if (seconds == 0) s"${minutes}m" + else s"${minutes}m${seconds}s" + } else { + s"${seconds}s" + } + } +} diff --git a/metals/src/main/scala/scala/meta/internal/metals/debug/server/testing/SerializableFingerprints.scala b/metals/src/main/scala/scala/meta/internal/metals/debug/server/testing/SerializableFingerprints.scala new file mode 100644 index 00000000000..6c763cc6081 --- /dev/null +++ b/metals/src/main/scala/scala/meta/internal/metals/debug/server/testing/SerializableFingerprints.scala @@ -0,0 +1,21 @@ +package sbt + +import java.io.Serializable + +import sbt.testing.AnnotatedFingerprint +import sbt.testing.Fingerprint +import sbt.testing.SubclassFingerprint + +/** + * This object is there only to instantiate the subclasses of Fingerprint that are + * expected by the remote test runner (they're package private in sbt) + */ +object SerializableFingerprints { + // Copied from ForkTests.scala in sbt + def forkFingerprint(f: Fingerprint): Fingerprint with Serializable = + f match { + case s: SubclassFingerprint => new ForkMain.SubclassFingerscan(s) + case a: AnnotatedFingerprint => new ForkMain.AnnotatedFingerscan(a) + case _ => sys.error("Unknown fingerprint type: " + f.getClass) + } +} diff --git a/metals/src/main/scala/scala/meta/internal/metals/debug/server/testing/TestInternals.scala b/metals/src/main/scala/scala/meta/internal/metals/debug/server/testing/TestInternals.scala new file mode 100644 index 00000000000..f4fa8d23633 --- /dev/null +++ b/metals/src/main/scala/scala/meta/internal/metals/debug/server/testing/TestInternals.scala @@ -0,0 +1,192 @@ +package scala.meta.internal.metals.debug.server.testing + +import java.nio.file.Path +import java.util.regex.Pattern + +import scala.collection.mutable +import scala.util.control.NonFatal + +import scala.meta.internal.metals.Embedded +import scala.meta.internal.metals.debug.server.Discovered + +import coursierapi.Dependency +import org.scalatools.testing.{Framework => OldFramework} +import sbt.testing.AnnotatedFingerprint +import sbt.testing.Fingerprint +import sbt.testing.Framework +import sbt.testing.Runner +import sbt.testing.SubclassFingerprint + +final case class FingerprintInfo[+Print <: Fingerprint]( + name: String, + isModule: Boolean, + framework: Framework, + fingerprint: Print, +) + +object TestInternals { + private final val sbtOrg = "org.scala-sbt" + private final val testAgentId = "test-agent" + private final val testAgentVersion = "1.8.0" + + lazy val testAgentFiles: List[Path] = { + val dependency = Dependency.of(sbtOrg, testAgentId, testAgentVersion) + try { + Embedded.downloadDependency(dependency, None) + } catch { + case NonFatal(e) => + scribe.warn(e) + Nil + } + } + + /** + * Parses `filters` to produce a filtering function for the tests. + * Only the tests accepted by this filter will be run. + * + * `*` is interpreter as wildcard. Each filter can start with `-`, in which case it means + * that it is an exclusion filter. + * + * @param filters A list of strings, representing inclusion or exclusion patterns + * @return A function that determines whether a test should be run given its FQCN. + */ + def parseFilters(filters: List[String]): String => Boolean = { + val (exclusionFilters, inclusionFilters) = + filters.map(_.trim).partition(_.startsWith("-")) + val inc = inclusionFilters.map(toPattern) + val exc = exclusionFilters.map(f => toPattern(f.tail)) + + (inc, exc) match { + case (Nil, Nil) => + (_ => true) + case (inc, Nil) => + (s => inc.exists(_.matcher(s).matches)) + case (Nil, exc) => + (s => !exc.exists(_.matcher(s).matches)) + case (inc, exc) => + ( + s => + inc.exists(_.matcher(s).matches) && !exc.exists( + _.matcher(s).matches + ) + ) + } + } + + lazy val filteredLoader: FilteredLoader = { + val filter = new IncludeClassFilter( + Set( + "jdk.", "java.", "javax.", "sun.", "sbt.testing.", + "org.scalatools.testing.", "org.xml.sax.", + ) + ) + new FilteredLoader(getClass.getClassLoader, filter) + } + + def loadFramework(l: ClassLoader, fqns: List[String]): Option[Framework] = { + fqns match { + case head :: tail => loadFramework(l, head).orElse(loadFramework(l, tail)) + case Nil => None + } + } + + def getFingerprints( + frameworks: Seq[Framework] + ): ( + List[FingerprintInfo[SubclassFingerprint]], + List[FingerprintInfo[AnnotatedFingerprint]], + ) = { + // The tests need to be run with the first matching framework, so we use a LinkedHashSet + // to keep the ordering of `frameworks`. + val subclasses = + mutable.LinkedHashSet.empty[FingerprintInfo[SubclassFingerprint]] + val annotated = + mutable.LinkedHashSet.empty[FingerprintInfo[AnnotatedFingerprint]] + for { + framework <- frameworks + fingerprint <- framework.fingerprints() + } fingerprint match { + case sub: SubclassFingerprint => + subclasses += FingerprintInfo( + sub.superclassName, + sub.isModule, + framework, + sub, + ) + case ann: AnnotatedFingerprint => + annotated += FingerprintInfo( + ann.annotationName, + ann.isModule, + framework, + ann, + ) + } + (subclasses.toList, annotated.toList) + } + + // Slightly adapted from sbt/sbt + def matchingFingerprints( + subclassPrints: List[FingerprintInfo[SubclassFingerprint]], + annotatedPrints: List[FingerprintInfo[AnnotatedFingerprint]], + d: Discovered, + ): List[FingerprintInfo[Fingerprint]] = { + defined(subclassPrints, d.baseClasses, d.isModule) ++ + defined(annotatedPrints, d.annotations, d.isModule) + } + + def getRunner( + framework: Framework, + testClassLoader: ClassLoader, + ): Runner = framework.runner(Array.empty, Array.empty, testClassLoader) + + // Slightly adapted from sbt/sbt + private def defined[T <: Fingerprint]( + in: List[FingerprintInfo[T]], + names: Set[String], + IsModule: Boolean, + ): List[FingerprintInfo[T]] = { + in.collect { + case info @ FingerprintInfo(name, IsModule, _, _) if names(name) => info + } + } + + private def loadFramework( + loader: ClassLoader, + fqn: String, + ): Option[Framework] = { + try { + Class + .forName(fqn, true, loader) + .getDeclaredConstructor() + .newInstance() match { + case framework: Framework => Some(framework) + case _: OldFramework => + scribe.warn(s"Old frameworks are not supported: $fqn"); None + } + } catch { + case _: ClassNotFoundException => None + case NonFatal(t) => + scribe.error(s"Initialisation of test framework $fqn failed", t) + None + } + } + + /** + * Converts the input string to a compiled `Pattern`. + * + * The string is split at `*` (representing wildcards). + * + * @param filter The input filter + * @return The compiled pattern matching the input filter. + */ + private def toPattern(filter: String): Pattern = { + val parts = filter + .split("\\*", -1) + .map { // Don't discard trailing empty string, if any. + case "" => "" + case str => Pattern.quote(str) + } + Pattern.compile(parts.mkString(".*")) + } + +} diff --git a/metals/src/main/scala/scala/meta/internal/metals/debug/server/testing/TestServer.scala b/metals/src/main/scala/scala/meta/internal/metals/debug/server/testing/TestServer.scala new file mode 100644 index 00000000000..eb291bb3779 --- /dev/null +++ b/metals/src/main/scala/scala/meta/internal/metals/debug/server/testing/TestServer.scala @@ -0,0 +1,197 @@ +package scala.meta.internal.metals.debug.server.testing + +import java.io.ObjectInputStream +import java.io.ObjectOutputStream +import java.net.ServerSocket + +import scala.concurrent.ExecutionContext +import scala.concurrent.Future +import scala.concurrent.Promise +import scala.util.Failure +import scala.util.Try +import scala.util.control.NonFatal + +import ch.epfl.scala.debugadapter.testing.TestSuiteEvent +import sbt.ForkConfiguration +import sbt.ForkTags +import sbt.testing.Event +import sbt.testing.Framework +import sbt.testing.TaskDef + +/** + * Implements the protocol that the forked remote JVM talks with the host process. + * + * This protocol is not formal and has been implemented after sbt's `ForkTests`. + */ +final class TestServer( + eventHandler: LoggingEventHandler, + classLoader: ClassLoader, + discoveredTests: Map[Framework, List[TaskDef]], +)(implicit ec: ExecutionContext) { + + private val server = new ServerSocket(0) + private val (runners, tasks) = { + def getRunner(framework: Framework) = { + val frameworkClass = framework.getClass.getName + frameworkClass -> TestInternals.getRunner(framework, classLoader) + } + // Return frameworks and tasks in order to ensure a deterministic test execution + val sorted = discoveredTests.toList.sortBy(_._1.name()) + ( + sorted.map(_._1).map(getRunner), + sorted.flatMap(_._2.sortBy(_.fullyQualifiedName())), + ) + } + + case class TestOrchestrator(startServer: Future[Unit], reporter: Future[Unit]) + val port = server.getLocalPort + def listenToTests: TestOrchestrator = { + def forkFingerprint(td: TaskDef): TaskDef = { + val newFingerprint = + sbt.SerializableFingerprints.forkFingerprint(td.fingerprint) + new TaskDef( + td.fullyQualifiedName, + newFingerprint, + td.explicitlySpecified, + td.selectors, + ) + } + + @annotation.tailrec + def receiveLogs(is: ObjectInputStream, os: ObjectOutputStream): Unit = { + is.readObject() match { + case Array(ForkTags.`Error`, s: String) => + eventHandler.handle(TestSuiteEvent.Error(s)) + receiveLogs(is, os) + case Array(ForkTags.`Warn`, s: String) => + eventHandler.handle(TestSuiteEvent.Warn(s)) + receiveLogs(is, os) + case Array(ForkTags.`Info`, s: String) => + eventHandler.handle(TestSuiteEvent.Info(s)) + receiveLogs(is, os) + case Array(ForkTags.`Debug`, s: String) => + eventHandler.handle(TestSuiteEvent.Debug(s)) + receiveLogs(is, os) + case t: Throwable => + eventHandler.handle(TestSuiteEvent.Trace(t)) + receiveLogs(is, os) + case Array(testSuite: String, events: Array[Event]) => + eventHandler.handle(TestSuiteEvent.Results(testSuite, events.toList)) + receiveLogs(is, os) + case ForkTags.`Done` => + eventHandler.handle(TestSuiteEvent.Done) + os.writeObject(ForkTags.Done) + os.flush() + } + } + + def talk( + is: ObjectInputStream, + os: ObjectOutputStream, + config: ForkConfiguration, + ): Unit = { + try { + os.writeObject(config) + val taskDefs = tasks.map(forkFingerprint) + os.writeObject(taskDefs.toArray) + os.writeInt(runners.size) + taskDefs.foreach { taskDef => + taskDef.fingerprint() + } + + val taskDefsDescription = taskDefs.map { taskDef => + val selectors = + taskDef.selectors().toList.map(_.toString()).mkString("(", ",", ")") + s"${taskDef.fullyQualifiedName()}$selectors" + } + scribe.debug(s"Sent task defs to test server: $taskDefsDescription") + + runners.foreach { case (frameworkClass, runner) => + scribe.debug( + s"Sending runner to test server: ${frameworkClass} ${runner.args.toList}" + ) + os.writeObject(Array(frameworkClass)) + os.writeObject(runner.args) + os.writeObject(runner.remoteArgs) + } + + os.flush() + receiveLogs(is, os) + } catch { + case NonFatal(t) => + scribe.error(s"Failed to initialize communication: ${t.getMessage}") + scribe.trace(t) + } + } + + val serverStarted = Promise[Unit]() + val clientConnection = Future { + scribe.debug(s"Firing up test server at $port. Waiting for client...") + serverStarted.trySuccess(()) + server.accept() + } + + val testListeningTask = clientConnection.flatMap { socket => + scribe.debug("Test server established connection with remote JVM.") + val os = new ObjectOutputStream(socket.getOutputStream) + os.flush() + val is = new ObjectInputStream(socket.getInputStream) + val config = new ForkConfiguration( + /* ansiCodesSupported = */ false, /* parallel = */ false, + ) + + @volatile var alreadyClosed: Boolean = false + def cleanSocketResources() = Future { + if (!alreadyClosed) { + for { + _ <- Try(is.close()) + _ <- Try(os.close()) + _ <- Try(socket.close()) + } yield { + alreadyClosed = false + } + () + } + } + + val talkFuture = Future(talk(is, os, config)) + talkFuture.onComplete(_ => cleanSocketResources()) + talkFuture + } + + def closeServer(t: Option[Throwable], fromCancel: Boolean) = Future { + t.foreach { + case NonFatal(e) => + scribe.error( + s"Unexpected error during remote test execution: '${e.getMessage}'." + ) + scribe.trace(e) + case _ => + } + + runners.foreach(_._2.done()) + + server.close() + // Do both just in case the logger streams have been closed by nailgun + if (fromCancel) { + // opts.ngout.println("The test execution was successfully cancelled.") + scribe.debug("Test server has been successfully cancelled.") + } else { + // opts.ngout.println("The test execution was successfully closed.") + scribe.debug("Test server has been successfully closed.") + } + } + + val listener = { + testListeningTask.onComplete { + case Failure(exception) => closeServer(Some(exception), true) + case _ => closeServer(None, false) + } + testListeningTask + } + + TestOrchestrator(serverStarted.future, listener) + } +} + +case class TestArgument(args: List[String], frameworkNames: List[String]) diff --git a/metals/src/main/scala/scala/meta/internal/metals/testProvider/TestSuitesProvider.scala b/metals/src/main/scala/scala/meta/internal/metals/testProvider/TestSuitesProvider.scala index 27dcb82dde7..85338186304 100644 --- a/metals/src/main/scala/scala/meta/internal/metals/testProvider/TestSuitesProvider.scala +++ b/metals/src/main/scala/scala/meta/internal/metals/testProvider/TestSuitesProvider.scala @@ -553,12 +553,15 @@ final class TestSuitesProvider( def getFramework( target: BuildTarget, selection: ScalaTestSuiteSelection, - ): TestFramework = { - val framework = - for { - testEntry <- index.get(target, FullyQualifiedName(selection.className)) - } yield testEntry.suiteDetails.framework - framework.getOrElse(Unknown) + ): TestFramework = getFromCache(target, selection.className) + .map(_.suiteDetails.framework) + .getOrElse(Unknown) + + def getFromCache( + target: BuildTarget, + className: String, + ): Option[TestEntry] = { + index.get(target, FullyQualifiedName(className)) } } diff --git a/mtags-interfaces/src/main/java/scala/meta/pc/PcSymbolInformation.java b/mtags-interfaces/src/main/java/scala/meta/pc/PcSymbolInformation.java index 43e57233420..a43b82b6007 100644 --- a/mtags-interfaces/src/main/java/scala/meta/pc/PcSymbolInformation.java +++ b/mtags-interfaces/src/main/java/scala/meta/pc/PcSymbolInformation.java @@ -1,6 +1,7 @@ package scala.meta.pc; import java.util.List; +import java.util.Collections; public interface PcSymbolInformation { String symbol(); @@ -12,4 +13,15 @@ public interface PcSymbolInformation { // overloaded methods List alternativeSymbols(); List properties(); + default List recursiveParents() { + return Collections.emptyList(); + } + + default List annotations() { + return Collections.emptyList(); + } + + default List memberDefsAnnotations() { + return Collections.emptyList(); + } } diff --git a/mtags-shared/src/main/scala/scala/meta/internal/pc/PcSymbolInformation.scala b/mtags-shared/src/main/scala/scala/meta/internal/pc/PcSymbolInformation.scala index aa35cbabc47..6e8853d237a 100644 --- a/mtags-shared/src/main/scala/scala/meta/internal/pc/PcSymbolInformation.scala +++ b/mtags-shared/src/main/scala/scala/meta/internal/pc/PcSymbolInformation.scala @@ -15,7 +15,10 @@ case class PcSymbolInformation( classOwner: Option[String], overriddenSymbols: List[String], alternativeSymbols: List[String], - properties: List[PcSymbolProperty] + properties: List[PcSymbolProperty], + recursiveParents: List[String], + annotations: List[String], + memberDefsAnnotations: List[String] ) { def asJava: PcSymbolInformationJava = PcSymbolInformationJava( @@ -26,7 +29,10 @@ case class PcSymbolInformation( classOwner.getOrElse(""), overriddenSymbols.asJava, alternativeSymbols.asJava, - properties.asJava + properties.asJava, + recursiveParents.asJava, + annotations.asJava, + memberDefsAnnotations.asJava ) } @@ -38,7 +44,10 @@ case class PcSymbolInformationJava( classOwner: String, overriddenSymbols: ju.List[String], alternativeSymbols: ju.List[String], - properties: ju.List[PcSymbolProperty] + properties: ju.List[PcSymbolProperty], + override val recursiveParents: ju.List[String], + override val annotations: ju.List[String], + override val memberDefsAnnotations: ju.List[String] ) extends IPcSymbolInformation object PcSymbolInformation { @@ -52,6 +61,9 @@ object PcSymbolInformation { else None, info.overriddenSymbols().asScala.toList, info.alternativeSymbols().asScala.toList, - info.properties().asScala.toList + info.properties().asScala.toList, + info.recursiveParents().asScala.toList, + info.annotations().asScala.toList, + info.memberDefsAnnotations().asScala.toList ) } diff --git a/mtags/src/main/scala-2/scala/meta/internal/pc/WorkspaceSymbolSearch.scala b/mtags/src/main/scala-2/scala/meta/internal/pc/WorkspaceSymbolSearch.scala index 94405317f67..baab1b815dc 100644 --- a/mtags/src/main/scala-2/scala/meta/internal/pc/WorkspaceSymbolSearch.scala +++ b/mtags/src/main/scala-2/scala/meta/internal/pc/WorkspaceSymbolSearch.scala @@ -3,6 +3,7 @@ package scala.meta.internal.pc import java.nio.file.Path import scala.annotation.tailrec +import scala.collection.mutable import scala.reflect.NameTransformer import scala.util.control.NonFatal @@ -56,6 +57,20 @@ trait WorkspaceSymbolSearch { compiler: MetalsGlobal => searchedSymbol match { case compilerSymbol :: _ => + val allParents = { + val visited = mutable.Set[Symbol]() + def collect(sym: Symbol): Unit = { + visited += sym + sym.parentSymbols.foreach { + case parent if !visited(parent) => collect(parent) + case _ => + } + } + collect(compilerSymbol) + visited.toList.map(semanticdbSymbol) + } + val defnAnn = + compilerSymbol.info.members.filter(_.isMethod).flatMap(_.annotations) Some( PcSymbolInformation( symbol = symbol, @@ -72,7 +87,10 @@ trait WorkspaceSymbolSearch { compiler: MetalsGlobal => compilerSymbol.isAbstractClass || compilerSymbol.isAbstractType ) List(PcSymbolProperty.ABSTRACT) - else Nil + else Nil, + allParents, + compilerSymbol.annotations.map(_.toString()).distinct, + defnAnn.map(_.toString()).toList.distinct ) ) case _ => None diff --git a/mtags/src/main/scala-3/scala/meta/internal/pc/SymbolInformationProvider.scala b/mtags/src/main/scala-3/scala/meta/internal/pc/SymbolInformationProvider.scala index 489569878ef..2d4e6a8b355 100644 --- a/mtags/src/main/scala-3/scala/meta/internal/pc/SymbolInformationProvider.scala +++ b/mtags/src/main/scala-3/scala/meta/internal/pc/SymbolInformationProvider.scala @@ -1,15 +1,15 @@ package scala.meta.internal.pc +import scala.collection.mutable import scala.util.control.NonFatal +import scala.meta.internal.mtags.MtagsEnrichments.allSymbols import scala.meta.internal.mtags.MtagsEnrichments.metalsDealias import scala.meta.internal.mtags.MtagsEnrichments.stripBackticks import scala.meta.pc.PcSymbolKind import scala.meta.pc.PcSymbolProperty import dotty.tools.dotc.core.Contexts.Context -import dotty.tools.dotc.core.Denotations.Denotation -import dotty.tools.dotc.core.Denotations.MultiDenotation import dotty.tools.dotc.core.Flags import dotty.tools.dotc.core.Names import dotty.tools.dotc.core.Names.* @@ -34,11 +34,28 @@ class SymbolInformationProvider(using Context): if classSym.isClass then classSym.asClass.parentSyms.map(SemanticdbSymbols.symbolName) else Nil + + val allParents = { + val visited = mutable.Set[Symbol]() + def collect(sym: Symbol): Unit = { + visited += sym + if sym.isClass + then sym.asClass.parentSyms.foreach { + case parent if !visited(parent) => + collect(parent) + case _ => + } + } + collect(classSym) + visited.toList.map(SemanticdbSymbols.symbolName) + } + val dealisedSymbol = if sym.isAliasType then sym.info.metalsDealias.typeSymbol else sym val classOwner = sym.ownersIterator.drop(1).find(s => s.isClass || s.is(Flags.Module)) val overridden = sym.denot.allOverriddenSymbols.toList + val memberDefAnnots = sym.info.membersBasedOnFlags(Flags.Method, Flags.EmptyFlags).flatMap(_.allSymbols).flatMap(_.denot.annotations) val pcSymbolInformation = PcSymbolInformation( @@ -53,6 +70,9 @@ class SymbolInformationProvider(using Context): properties = if sym.is(Flags.Abstract) then List(PcSymbolProperty.ABSTRACT) else Nil, + allParents, + sym.denot.annotations.map(_.symbol.showFullName), + memberDefAnnots.map(_.symbol.showFullName).toList ) Some(pcSymbolInformation) @@ -88,12 +108,6 @@ object SymbolProvider: pkg.replace("/", ".").stripSuffix(".") private def toSymbols(info: SymbolInfo.SymbolParts)(using Context): List[Symbol] = - def collectSymbols(denotation: Denotation): List[Symbol] = - denotation match - case MultiDenotation(denot1, denot2) => - collectSymbols(denot1) ++ collectSymbols(denot2) - case denot => List(denot.symbol) - def loop( owners: List[Symbol], parts: List[(String, Boolean)], @@ -106,7 +120,7 @@ object SymbolProvider: val next = if isClass then owner.info.member(typeName(name)) else owner.info.member(termName(name)) - collectSymbols(next).filter(_.exists) + next.allSymbols.filter(_.exists) } if foundSymbols.nonEmpty then loop(foundSymbols, tl) else Nil diff --git a/tests/slow/src/test/scala/tests/mill/MillDebugDiscoverySuite.scala b/tests/slow/src/test/scala/tests/mill/MillDebugDiscoverySuite.scala new file mode 100644 index 00000000000..bc71be1c59c --- /dev/null +++ b/tests/slow/src/test/scala/tests/mill/MillDebugDiscoverySuite.scala @@ -0,0 +1,164 @@ +package tests.mill + +import java.util.concurrent.TimeUnit + +import scala.meta.internal.metals.BuildInfo.scala3 +import scala.meta.internal.metals.DebugDiscoveryParams +import scala.meta.internal.metals.JsonParser._ +import scala.meta.internal.metals.MetalsEnrichments._ +import scala.meta.internal.metals.ScalaTestSuiteSelection +import scala.meta.internal.metals.ScalaTestSuites +import scala.meta.internal.metals.debug.JUnit4 +import scala.meta.internal.metals.debug.Scalatest + +import ch.epfl.scala.bsp4j.TestParamsDataKind +import tests.BaseDapSuite +import tests.MillBuildLayout +import tests.MillServerInitializer + +class MillDebugDiscoverySuite + extends BaseDapSuite( + "mill-debug-discovery", + MillServerInitializer, + MillBuildLayout, + ) { + + private val fooPath = "a/test/src/Foo.scala" + private val barPath = "a/test/src/Bar.scala" + + for (scala <- List(scalaVersion, scala3)) { + + test(s"testTarget-$scala") { + cleanWorkspace() + for { + _ <- initialize( + MillBuildLayout( + s""" + |/${fooPath} + |package a + |class Foo extends org.scalatest.funsuite.AnyFunSuite { + | test("foo") {} + |} + |/${barPath} + |package a + |class Bar extends org.scalatest.funsuite.AnyFunSuite { + | test("bart") {} + |} + |""".stripMargin, + scala, + Some(Scalatest), + ) + ) + _ <- server.didOpen(barPath) + _ <- server.didSave(barPath)(identity) + _ <- server.waitFor(TimeUnit.SECONDS.toMillis(10)) + debugger <- server.startDebuggingUnresolved( + new DebugDiscoveryParams( + server.toPath(barPath).toURI.toString, + "testTarget", + ).toJson + ) + _ <- debugger.initialize + _ <- debugger.launch + _ <- debugger.configurationDone + _ <- debugger.shutdown + output <- debugger.allOutput + } yield assert(output.contains("All tests in a.Bar passed")) + } + + test(s"junit-$scala") { + cleanWorkspace() + for { + _ <- initialize( + MillBuildLayout( + s"""|/${fooPath} + |package a + |import org.junit.Test + |import org.junit.Assert._ + | + |class Foo { + | @Test + | def testOneIsPositive = { + | assertTrue(1 > 0) + | } + | + | @Test + | def testMinusOneIsNegative = { + | assertTrue(-1 < 0) + | } + |} + |""".stripMargin, + scala, + Some(JUnit4), + ) + ) + _ <- server.didOpen(fooPath) + _ <- server.didSave(fooPath)(identity) + _ <- server.waitFor(TimeUnit.SECONDS.toMillis(10)) + debugger <- server.startDebugging( + "a.test", + TestParamsDataKind.SCALA_TEST_SUITES_SELECTION, + ScalaTestSuites( + List( + ScalaTestSuiteSelection("a.Foo", Nil.asJava) + ).asJava, + Nil.asJava, + Nil.asJava, + ), + ) + _ <- debugger.initialize + _ <- debugger.launch + _ <- debugger.configurationDone + _ <- debugger.shutdown + output <- debugger.allOutput + } yield assert(output.contains("All tests in a.Foo passed")) + } + } + + test(s"test-selection") { + cleanWorkspace() + for { + _ <- initialize( + MillBuildLayout( + s""" + |/${fooPath} + |package a + |class Foo extends org.scalatest.funsuite.AnyFunSuite { + | test("foo") {} + | test("bar") {} + |} + |""".stripMargin, + scalaVersion, + Some(Scalatest), + ) + ) + _ <- server.didOpen(fooPath) + _ <- server.didSave(fooPath)(identity) + _ <- server.waitFor(TimeUnit.SECONDS.toMillis(10)) + debugger <- server.startDebugging( + "a.test", + TestParamsDataKind.SCALA_TEST_SUITES_SELECTION, + ScalaTestSuites( + List( + ScalaTestSuiteSelection("a.Foo", List("foo").asJava) + ).asJava, + Nil.asJava, + Nil.asJava, + ), + ) + _ <- debugger.initialize + _ <- debugger.launch + _ <- debugger.configurationDone + _ <- debugger.shutdown + output <- debugger.allOutput + } yield assertNoDiff( + output.replaceFirst("[0-9]+ms", "xxx"), + """|Foo: + |- foo + |Execution took xxx + |1 tests, 1 passed + |All tests in a.Foo passed + |""".stripMargin, + ) + } +} diff --git a/tests/slow/src/test/scala/tests/mill/MillServerCodeLensSuite.scala b/tests/slow/src/test/scala/tests/mill/MillServerCodeLensSuite.scala index 92dbd61cf88..e465db16afc 100644 --- a/tests/slow/src/test/scala/tests/mill/MillServerCodeLensSuite.scala +++ b/tests/slow/src/test/scala/tests/mill/MillServerCodeLensSuite.scala @@ -57,6 +57,7 @@ class MillServerCodeLensSuite _ <- assertCodeLenses( "a/test/src/Foo.scala", """|// no test lense as debug is not supported + |<><> |class Foo extends munit.FunSuite {} |""".stripMargin, ) diff --git a/tests/unit/src/main/scala/tests/BuildServerLayout.scala b/tests/unit/src/main/scala/tests/BuildServerLayout.scala index cdf2a1c011b..290d17a5165 100644 --- a/tests/unit/src/main/scala/tests/BuildServerLayout.scala +++ b/tests/unit/src/main/scala/tests/BuildServerLayout.scala @@ -1,5 +1,9 @@ package tests +import scala.meta.internal.metals.debug.JUnit4 +import scala.meta.internal.metals.debug.MUnit +import scala.meta.internal.metals.debug.Scalatest +import scala.meta.internal.metals.debug.TestFramework import scala.meta.internal.metals.{BuildInfo => V} trait BuildToolLayout { @@ -51,22 +55,34 @@ object SbtBuildLayout extends BuildToolLayout { object MillBuildLayout extends BuildToolLayout { override def apply(sourceLayout: String, scalaVersion: String): String = - apply(sourceLayout, scalaVersion, includeMunit = false) + apply(sourceLayout, scalaVersion, None) def apply( sourceLayout: String, scalaVersion: String, - includeMunit: Boolean, + testDep: Option[TestFramework], ): String = { + val optDepModule = + testDep.map { + case Scalatest => ("ScalaTest", "org.scalatest::scalatest:3.2.16") + case MUnit => ("Munit", "org.scalameta::munit::0.7.29") + case JUnit4 => ("Junit4", "com.github.sbt:junit-interface:0.13.2") + case testFramework => + throw new RuntimeException( + s"No implementation for layout for $testFramework" + ) + } val munitModule = - if (includeMunit) - """|object test extends ScalaTests with TestModule.Munit { - | def ivyDeps = Agg( - | ivy"org.scalameta::munit::0.7.29" - | ) - | } - |""".stripMargin - else "" + optDepModule match { + case Some((module, dep)) => + s"""|object test extends ScalaTests with TestModule.$module { + | def ivyDeps = Agg( + | ivy"$dep" + | ) + | } + |""".stripMargin + case _ => "" + } s"""|/build.sc |import mill._, scalalib._ @@ -88,7 +104,11 @@ object MillBuildLayout extends BuildToolLayout { s"""|/.mill-version |$millVersion |${apply(sourceLayout, scalaVersion)} - |${apply(sourceLayout, scalaVersion, includeMunit)} + |${apply( + sourceLayout, + scalaVersion, + if (includeMunit) Some(MUnit) else None, + )} |""".stripMargin } diff --git a/tests/unit/src/main/scala/tests/debug/BaseBreakpointDapSuite.scala b/tests/unit/src/main/scala/tests/debug/BaseBreakpointDapSuite.scala index 6f5c835c174..2421783617e 100644 --- a/tests/unit/src/main/scala/tests/debug/BaseBreakpointDapSuite.scala +++ b/tests/unit/src/main/scala/tests/debug/BaseBreakpointDapSuite.scala @@ -677,6 +677,7 @@ abstract class BaseBreakpointDapSuite( ) test("remove-breakpoints") { + cleanWorkspace() val debugLayout = DebugWorkspaceLayout( """|/a/src/main/scala/a/Main.scala |package a diff --git a/tests/unit/src/test/scala/tests/DebugProtocolSuite.scala b/tests/unit/src/test/scala/tests/DebugProtocolSuite.scala index 93e89657104..a9fe2fd6e0a 100644 --- a/tests/unit/src/test/scala/tests/DebugProtocolSuite.scala +++ b/tests/unit/src/test/scala/tests/DebugProtocolSuite.scala @@ -443,6 +443,8 @@ class DebugProtocolSuite |} |""".stripMargin ) + _ <- server.server.indexingPromise.future + _ <- server.didOpen("a/src/main/scala/a/Foo.scala") debugger <- server.startDebuggingUnresolved( new DebugUnresolvedTestClassParams( "a.Foo"