Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: run DAP for test suites in Metals #6551

Merged
merged 1 commit into from
Jul 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -525,7 +525,11 @@ final class ImplementationProvider(
classOwner = classOwnerInfoOpt.map(_.symbol),
alternativeSymbols = alternativeSymbols.toList,
overriddenSymbols = info.overriddenSymbols.toList,
properties = if (info.isAbstract) List(PcSymbolProperty.ABSTRACT) else Nil,
properties =
if (info.isAbstract) List(PcSymbolProperty.ABSTRACT) else Nil,
recursiveParents = parents,
annotations = info.annotations.map(_.toString()).toList,
memberDefsAnnotations = Nil,
)
}
}
Expand Down
12 changes: 12 additions & 0 deletions metals/src/main/scala/scala/meta/internal/metals/Compilers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -941,6 +941,18 @@ class Compilers(
.getOrElse(Future(None))
}

def info(
id: BuildTargetIdentifier,
symbol: String,
): Future[Option[PcSymbolInformation]] = {
loadCompiler(id)
.map(
_.info(symbol).asScala
.map(_.asScala.map(PcSymbolInformation.from))
)
.getOrElse(Future(None))
}

private def definition(
params: TextDocumentPositionParams,
token: CancelToken,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -800,6 +800,9 @@ object MetalsEnrichments
value.replace('/', ':')
else value
}

def symbolToFullQualifiedName: String =
value.replaceAll("/|#", ".").stripSuffix(".")
}

implicit class XtensionTextDocumentSemanticdb(textDocument: s.TextDocument) {
Expand Down Expand Up @@ -1311,16 +1314,40 @@ object MetalsEnrichments
}

implicit class XtensionDebugSessionParams(params: b.DebugSessionParams) {
def asScalaMainClass(): Either[String, b.ScalaMainClass] =
def asScalaMainClass(): Either[String, b.ScalaMainClass] = {
lazy val className = "ScalaMainClass"
params.getDataKind() match {
case b.DebugSessionParamsDataKind.SCALA_MAIN_CLASS =>
decodeJson(params.getData(), classOf[b.ScalaMainClass])
.toRight(s"Cannot decode $params as `ScalaMainClass`.")
case _ =>
Left(
s"Cannot decode params as `ScalaMainClass` incorrect data kind: ${params.getDataKind()}."
.toRight(cannotDecode(className))
case _ => incorrectKind(className)
}
}

def asScalaTestSuites(): Either[String, b.ScalaTestSuites] = {
lazy val className = "ScalaTestSuites"
params.getDataKind() match {
case b.TestParamsDataKind.SCALA_TEST_SUITES_SELECTION =>
decodeJson(params.getData(), classOf[b.ScalaTestSuites])
.toRight(cannotDecode(className))
case b.TestParamsDataKind.SCALA_TEST_SUITES =>
(for (
tests <- decodeJson(params.getData(), classOf[util.List[String]])
)
yield {
val suites =
tests.map(new b.ScalaTestSuiteSelection(_, Nil.asJava))
new b.ScalaTestSuites(suites, Nil.asJava, Nil.asJava)
}).toRight(cannotDecode(className))
case _ => incorrectKind(className)
}
}
def cannotDecode(className: String): String =
s"Cannot decode $params as `$className`."
def incorrectKind(className: String): Left[String, Nothing] =
Left(
s"Cannot decode params as `$className` incorrect data kind: ${params.getDataKind()}."
)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,17 +108,15 @@ final class RunTestCodeLens(
buildServerCanDebug,
isJVM,
)
} else if (buildServerCanDebug || clientConfig.isRunProvider()) {
} else
codeLenses(
textDocument,
buildTargetId,
classes,
distance,
path,
buildServerCanDebug,
isJVM,
)
} else { Nil }

}

Expand Down Expand Up @@ -198,7 +196,6 @@ final class RunTestCodeLens(
classes: BuildTargetClasses.Classes,
distance: TokenEditDistance,
path: AbsolutePath,
buildServerCanDebug: Boolean,
isJVM: Boolean,
): Seq[l.CodeLens] = {
for {
Expand All @@ -212,7 +209,7 @@ final class RunTestCodeLens(
.getOrElse(Nil)
lazy val tests =
// Currently tests can only be run via DAP
if (clientConfig.isDebuggingProvider() && buildServerCanDebug)
if (clientConfig.isDebuggingProvider())
testClasses(target, classes, symbol, isJVM)
else Nil
val fromAnnot = DebugProvider
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import scala.meta.internal.metals.BatchedFunction
import scala.meta.internal.metals.BuildTargets
import scala.meta.internal.metals.MetalsEnrichments._
import scala.meta.internal.metals.debug.BuildTargetClasses.Classes
import scala.meta.internal.metals.debug.BuildTargetClasses.TestSymbolInfo
import scala.meta.internal.semanticdb.Scala.Descriptor
import scala.meta.internal.semanticdb.Scala.Symbols

Expand All @@ -16,9 +17,9 @@ import ch.epfl.scala.{bsp4j => b}
/**
* In-memory index of main class symbols grouped by their enclosing build target
*/
final class BuildTargetClasses(
buildTargets: BuildTargets
)(implicit val ec: ExecutionContext) {
final class BuildTargetClasses(buildTargets: BuildTargets)(implicit
val ec: ExecutionContext
) {
private val index = TrieMap.empty[b.BuildTargetIdentifier, Classes]
private val jvmRunEnvironments
: TrieMap[b.BuildTargetIdentifier, b.JvmEnvironmentItem] =
Expand Down Expand Up @@ -56,6 +57,19 @@ final class BuildTargetClasses(
.map(_.fullyQualifiedName)
)

def getTestClasses(
name: String,
id: b.BuildTargetIdentifier,
): List[(String, TestSymbolInfo)] = {
index.get(id).toList.flatMap {
_.testClasses
.filter { case (_, info) =>
info.fullyQualifiedName == name
}
.toList
}
}

private def findClassesBy[A](
f: Classes => Option[A]
): List[(A, b.BuildTargetIdentifier)] = {
Expand All @@ -77,25 +91,22 @@ final class BuildTargetClasses(
Future.successful(())
case (Some(connection), targets0) =>
val targetsList = targets0.asJava
targetsList.forEach(invalidate)
val classes = targets0.map(t => (t, new Classes)).toMap

val updateMainClasses = connection
.mainClasses(new b.ScalaMainClassesParams(targetsList))
.map(cacheMainClasses(classes, _))

// Currently tests are only run using DAP
val updateTestClasses =
if (connection.isDebuggingProvider || connection.isSbt)
connection
.testClasses(new b.ScalaTestClassesParams(targetsList))
.map(cacheTestClasses(classes, _))
else Future.unit
connection
.testClasses(new b.ScalaTestClassesParams(targetsList))
.map(cacheTestClasses(classes, _))

for {
_ <- updateMainClasses
_ <- updateTestClasses
} yield {
targetsList.forEach(invalidate)
classes.foreach { case (id, classes) =>
index.put(id, classes)
}
Expand Down Expand Up @@ -214,7 +225,10 @@ final class BuildTargetClasses(
}
}

sealed abstract class TestFramework(val canResolveChildren: Boolean)
sealed abstract class TestFramework(val canResolveChildren: Boolean) {
def names: List[String]
}

object TestFramework {
def apply(framework: Option[String]): TestFramework = framework
.map {
Expand All @@ -226,11 +240,30 @@ object TestFramework {
}
.getOrElse(Unknown)
}
case object JUnit4 extends TestFramework(true)
case object MUnit extends TestFramework(true)
case object Scalatest extends TestFramework(true)
case object WeaverCatsEffect extends TestFramework(true)
case object Unknown extends TestFramework(false)

case object JUnit4 extends TestFramework(true) {
def names: List[String] = List("com.novocode.junit.JUnitFramework")
}

case object MUnit extends TestFramework(true) {
def names: List[String] = List("munit.Framework")
}

case object Scalatest extends TestFramework(true) {
def names: List[String] =
List(
"org.scalatest.tools.Framework",
"org.scalatest.tools.ScalaTestFramework",
)
}

case object WeaverCatsEffect extends TestFramework(true) {
def names: List[String] = List("weaver.BaseCatsSuite")
}

case object Unknown extends TestFramework(false) {
def names: List[String] = Nil
}

object BuildTargetClasses {
type Symbol = String
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,11 @@ import scala.meta.internal.metals.config.RunType
import scala.meta.internal.metals.config.RunType._
import scala.meta.internal.metals.debug.server.DebugLogger
import scala.meta.internal.metals.debug.server.DebugeeParamsCreator
import scala.meta.internal.metals.debug.server.Discovered
import scala.meta.internal.metals.debug.server.MainClassDebugAdapter
import scala.meta.internal.metals.debug.server.MetalsDebugToolsResolver
import scala.meta.internal.metals.debug.server.MetalsDebuggee
import scala.meta.internal.metals.debug.server.TestSuiteDebugAdapter
import scala.meta.internal.metals.testProvider.TestSuitesProvider
import scala.meta.internal.mtags.DefinitionAlternatives.GlobalSymbol
import scala.meta.internal.mtags.OnDemandSymbolIndex
Expand Down Expand Up @@ -328,38 +330,60 @@ class DebugProvider(
buildServer: BuildServerConnection,
params: DebugSessionParams,
cancelPromise: Promise[Unit],
) =
)(implicit ec: ExecutionContext) =
if (buildServer.isDebuggingProvider || buildServer.isSbt) {
buildServer.startDebugSession(params, cancelPromise)
} else {
def getDebugee: Either[String, MetalsDebuggee] =
def getDebugee: Either[String, Future[MetalsDebuggee]] = {
def buildTarget = params
.getTargets()
.asScala
.headOption
.toRight(s"Missing build target in debug params.")
params.getDataKind() match {
case b.DebugSessionParamsDataKind.SCALA_MAIN_CLASS =>
for {
id <- params
.getTargets()
.asScala
.headOption
.toRight(s"Missing build target in debug params.")
id <- buildTarget
projectInfo <- debugConfigCreator.create(id)
scalaMainClass <- params.asScalaMainClass()
} yield new MainClassDebugAdapter(
workspace,
scalaMainClass,
projectInfo,
userConfig().javaHome,
} yield Future.successful(
new MainClassDebugAdapter(
workspace,
scalaMainClass,
projectInfo,
userConfig().javaHome,
)
)
case (b.TestParamsDataKind.SCALA_TEST_SUITES_SELECTION |
b.TestParamsDataKind.SCALA_TEST_SUITES) =>
for {
id <- buildTarget
project <- debugConfigCreator.create(id)
testSuites <- params.asScalaTestSuites()
} yield {
for {
discovered <- discoverTests(id, testSuites)
} yield new TestSuiteDebugAdapter(
workspace,
testSuites,
project,
userConfig().javaHome,
discovered,
)
}
case kind =>
Left(s"Starting debug session for $kind in not supported.")
}
}

for {
_ <- compilations.compileTargets(params.getTargets().asScala.toSeq)
} yield {
val debuggee = getDebugee match {
debuggee <- getDebugee match {
case Right(debuggee) => debuggee
case Left(errorMessage) => throw new RuntimeException(errorMessage)
case Left(errorMessage) =>
Future.failed(new RuntimeException(errorMessage))
}
} yield {
val dapLogger = new DebugLogger()
val resolver = new MetalsDebugToolsResolver()
val handler =
Expand All @@ -373,6 +397,37 @@ class DebugProvider(
}
}

private def discoverTests(
id: BuildTargetIdentifier,
testClasses: b.ScalaTestSuites,
): Future[Map[TestFramework, List[Discovered]]] = {
val symbolInfosList =
for {
selection <- testClasses.getSuites().asScala.toList
(sym, info) <- buildTargetClasses.getTestClasses(
selection.getClassName(),
id,
)
} yield compilers.info(id, sym).map(_.map(pcInfo => (info, pcInfo)))
tgodzik marked this conversation as resolved.
Show resolved Hide resolved

Future.sequence(symbolInfosList).map {
_.flatten.groupBy(_._1.framework).map { case (framework, testSuites) =>
(
framework,
testSuites.map { case (testInfo, pcInfo) =>
new Discovered(
pcInfo.symbol,
testInfo.fullyQualifiedName,
pcInfo.recursiveParents.map(_.symbolToFullQualifiedName).toSet,
(pcInfo.annotations ++ pcInfo.memberDefsAnnotations).toSet,
isModule = false,
)
},
)
}
}
}

/**
* Given a BuildTargetIdentifier either get the displayName of that build
* target or default to the full URI to display to the user.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package scala.meta.internal.metals.debug.server

import java.net.InetSocketAddress
import java.util.concurrent.atomic.AtomicBoolean

import ch.epfl.scala.debugadapter.DebuggeeListener

class Logger(listener: DebuggeeListener) {
private val initialized = new AtomicBoolean(false)
private final val JDINotificationPrefix =
"Listening for transport dt_socket at address: "

def logError(errorMessage: String): Unit = {
listener.err(errorMessage)
scribe.error(errorMessage)
}

def logOutput(msg: String): Unit = {
if (msg.startsWith(JDINotificationPrefix)) {
if (initialized.compareAndSet(false, true)) {
val port = Integer.parseInt(msg.drop(JDINotificationPrefix.length))
val address = new InetSocketAddress("127.0.0.1", port)
listener.onListening(address)
}
} else listener.out(msg)
}
}
Loading