Skip to content

Commit

Permalink
feat: run DAP for test suites in Metals
Browse files Browse the repository at this point in the history
  • Loading branch information
kasiaMarek committed Jun 28, 2024
1 parent f116e13 commit aefe215
Show file tree
Hide file tree
Showing 25 changed files with 1,353 additions and 167 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -525,7 +525,11 @@ final class ImplementationProvider(
classOwner = classOwnerInfoOpt.map(_.symbol),
alternativeSymbols = alternativeSymbols.toList,
overriddenSymbols = info.overriddenSymbols.toList,
properties = if (info.isAbstract) List(PcSymbolProperty.ABSTRACT) else Nil,
properties =
if (info.isAbstract) List(PcSymbolProperty.ABSTRACT) else Nil,
recursiveParents = parents,
annotations = info.annotations.map(_.toString()).toList,
memberDefsAnnotations = Nil,
)
}
}
Expand Down
12 changes: 12 additions & 0 deletions metals/src/main/scala/scala/meta/internal/metals/Compilers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -941,6 +941,18 @@ class Compilers(
.getOrElse(Future(None))
}

def info(
id: BuildTargetIdentifier,
symbol: String,
): Future[Option[PcSymbolInformation]] = {
loadCompiler(id)
.map(
_.info(symbol).asScala
.map(_.asScala.map(PcSymbolInformation.from))
)
.getOrElse(Future(None))
}

private def definition(
params: TextDocumentPositionParams,
token: CancelToken,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -800,6 +800,9 @@ object MetalsEnrichments
value.replace('/', ':')
else value
}

def symbolToFullQualifiedName: String =
value.replaceAll("/|#", ".").stripSuffix(".")
}

implicit class XtensionTextDocumentSemanticdb(textDocument: s.TextDocument) {
Expand Down Expand Up @@ -1311,16 +1314,40 @@ object MetalsEnrichments
}

implicit class XtensionDebugSessionParams(params: b.DebugSessionParams) {
def asScalaMainClass(): Either[String, b.ScalaMainClass] =
def asScalaMainClass(): Either[String, b.ScalaMainClass] = {
lazy val className = "ScalaMainClass"
params.getDataKind() match {
case b.DebugSessionParamsDataKind.SCALA_MAIN_CLASS =>
decodeJson(params.getData(), classOf[b.ScalaMainClass])
.toRight(s"Cannot decode $params as `ScalaMainClass`.")
case _ =>
Left(
s"Cannot decode params as `ScalaMainClass` incorrect data kind: ${params.getDataKind()}."
.toRight(cannotDecode(className))
case _ => incorrectKind(className)
}
}

def asScalaTestSuites(): Either[String, b.ScalaTestSuites] = {
lazy val className = "ScalaTestSuites"
params.getDataKind() match {
case b.TestParamsDataKind.SCALA_TEST_SUITES_SELECTION =>
decodeJson(params.getData(), classOf[b.ScalaTestSuites])
.toRight(cannotDecode(className))
case b.TestParamsDataKind.SCALA_TEST_SUITES =>
(for (
tests <- decodeJson(params.getData(), classOf[util.List[String]])
)
yield {
val suites =
tests.map(new b.ScalaTestSuiteSelection(_, Nil.asJava))
new b.ScalaTestSuites(suites, Nil.asJava, Nil.asJava)
}).toRight(cannotDecode(className))
case _ => incorrectKind(className)
}
}
def cannotDecode(className: String): String =
s"Cannot decode $params as `$className`."
def incorrectKind(className: String): Left[String, Nothing] =
Left(
s"Cannot decode params as `$className` incorrect data kind: ${params.getDataKind()}."
)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,17 +108,15 @@ final class RunTestCodeLens(
buildServerCanDebug,
isJVM,
)
} else if (buildServerCanDebug || clientConfig.isRunProvider()) {
} else
codeLenses(
textDocument,
buildTargetId,
classes,
distance,
path,
buildServerCanDebug,
isJVM,
)
} else { Nil }

}

Expand Down Expand Up @@ -198,7 +196,6 @@ final class RunTestCodeLens(
classes: BuildTargetClasses.Classes,
distance: TokenEditDistance,
path: AbsolutePath,
buildServerCanDebug: Boolean,
isJVM: Boolean,
): Seq[l.CodeLens] = {
for {
Expand All @@ -212,7 +209,7 @@ final class RunTestCodeLens(
.getOrElse(Nil)
lazy val tests =
// Currently tests can only be run via DAP
if (clientConfig.isDebuggingProvider() && buildServerCanDebug)
if (clientConfig.isDebuggingProvider())
testClasses(target, classes, symbol, isJVM)
else Nil
val fromAnnot = DebugProvider
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import scala.meta.internal.metals.BatchedFunction
import scala.meta.internal.metals.BuildTargets
import scala.meta.internal.metals.MetalsEnrichments._
import scala.meta.internal.metals.debug.BuildTargetClasses.Classes
import scala.meta.internal.metals.debug.BuildTargetClasses.TestSymbolInfo
import scala.meta.internal.semanticdb.Scala.Descriptor
import scala.meta.internal.semanticdb.Scala.Symbols

Expand All @@ -16,9 +17,9 @@ import ch.epfl.scala.{bsp4j => b}
/**
* In-memory index of main class symbols grouped by their enclosing build target
*/
final class BuildTargetClasses(
buildTargets: BuildTargets
)(implicit val ec: ExecutionContext) {
final class BuildTargetClasses(buildTargets: BuildTargets)(implicit
val ec: ExecutionContext
) {
private val index = TrieMap.empty[b.BuildTargetIdentifier, Classes]
private val jvmRunEnvironments
: TrieMap[b.BuildTargetIdentifier, b.JvmEnvironmentItem] =
Expand Down Expand Up @@ -56,6 +57,19 @@ final class BuildTargetClasses(
.map(_.fullyQualifiedName)
)

def getTestClasses(
name: String,
id: b.BuildTargetIdentifier,
): List[(String, TestSymbolInfo)] = {
index.get(id).toList.flatMap {
_.testClasses
.filter { case (_, info) =>
info.fullyQualifiedName == name
}
.toList
}
}

private def findClassesBy[A](
f: Classes => Option[A]
): List[(A, b.BuildTargetIdentifier)] = {
Expand All @@ -77,25 +91,22 @@ final class BuildTargetClasses(
Future.successful(())
case (Some(connection), targets0) =>
val targetsList = targets0.asJava
targetsList.forEach(invalidate)
val classes = targets0.map(t => (t, new Classes)).toMap

val updateMainClasses = connection
.mainClasses(new b.ScalaMainClassesParams(targetsList))
.map(cacheMainClasses(classes, _))

// Currently tests are only run using DAP
val updateTestClasses =
if (connection.isDebuggingProvider || connection.isSbt)
connection
.testClasses(new b.ScalaTestClassesParams(targetsList))
.map(cacheTestClasses(classes, _))
else Future.unit
connection
.testClasses(new b.ScalaTestClassesParams(targetsList))
.map(cacheTestClasses(classes, _))

for {
_ <- updateMainClasses
_ <- updateTestClasses
} yield {
targetsList.forEach(invalidate)
classes.foreach { case (id, classes) =>
index.put(id, classes)
}
Expand Down Expand Up @@ -214,7 +225,10 @@ final class BuildTargetClasses(
}
}

sealed abstract class TestFramework(val canResolveChildren: Boolean)
sealed abstract class TestFramework(val canResolveChildren: Boolean) {
def names: List[String]
}

object TestFramework {
def apply(framework: Option[String]): TestFramework = framework
.map {
Expand All @@ -226,11 +240,30 @@ object TestFramework {
}
.getOrElse(Unknown)
}
case object JUnit4 extends TestFramework(true)
case object MUnit extends TestFramework(true)
case object Scalatest extends TestFramework(true)
case object WeaverCatsEffect extends TestFramework(true)
case object Unknown extends TestFramework(false)

case object JUnit4 extends TestFramework(true) {
def names: List[String] = List("com.novocode.junit.JUnitFramework")
}

case object MUnit extends TestFramework(true) {
def names: List[String] = List("munit.Framework")
}

case object Scalatest extends TestFramework(true) {
def names: List[String] =
List(
"org.scalatest.tools.Framework",
"org.scalatest.tools.ScalaTestFramework",
)
}

case object WeaverCatsEffect extends TestFramework(true) {
def names: List[String] = List("weaver.BaseCatsSuite")
}

case object Unknown extends TestFramework(false) {
def names: List[String] = Nil
}

object BuildTargetClasses {
type Symbol = String
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,11 @@ import scala.meta.internal.metals.config.RunType
import scala.meta.internal.metals.config.RunType._
import scala.meta.internal.metals.debug.server.DebugLogger
import scala.meta.internal.metals.debug.server.DebugeeParamsCreator
import scala.meta.internal.metals.debug.server.Discovered
import scala.meta.internal.metals.debug.server.MainClassDebugAdapter
import scala.meta.internal.metals.debug.server.MetalsDebugToolsResolver
import scala.meta.internal.metals.debug.server.MetalsDebuggee
import scala.meta.internal.metals.debug.server.TestSuiteDebugAdapter
import scala.meta.internal.metals.testProvider.TestSuitesProvider
import scala.meta.internal.mtags.DefinitionAlternatives.GlobalSymbol
import scala.meta.internal.mtags.OnDemandSymbolIndex
Expand Down Expand Up @@ -328,38 +330,60 @@ class DebugProvider(
buildServer: BuildServerConnection,
params: DebugSessionParams,
cancelPromise: Promise[Unit],
) =
)(implicit ec: ExecutionContext) =
if (buildServer.isDebuggingProvider || buildServer.isSbt) {
buildServer.startDebugSession(params, cancelPromise)
} else {
def getDebugee: Either[String, MetalsDebuggee] =
def getDebugee: Either[String, Future[MetalsDebuggee]] = {
def buildTarget = params
.getTargets()
.asScala
.headOption
.toRight(s"Missing build target in debug params.")
params.getDataKind() match {
case b.DebugSessionParamsDataKind.SCALA_MAIN_CLASS =>
for {
id <- params
.getTargets()
.asScala
.headOption
.toRight(s"Missing build target in debug params.")
id <- buildTarget
projectInfo <- debugConfigCreator.create(id)
scalaMainClass <- params.asScalaMainClass()
} yield new MainClassDebugAdapter(
workspace,
scalaMainClass,
projectInfo,
userConfig().javaHome,
} yield Future.successful(
new MainClassDebugAdapter(
workspace,
scalaMainClass,
projectInfo,
userConfig().javaHome,
)
)
case (b.TestParamsDataKind.SCALA_TEST_SUITES_SELECTION |
b.TestParamsDataKind.SCALA_TEST_SUITES) =>
for {
id <- buildTarget
project <- debugConfigCreator.create(id)
testSuites <- params.asScalaTestSuites()
} yield {
for {
discovered <- discoverTests(id, testSuites)
} yield new TestSuiteDebugAdapter(
workspace,
testSuites,
project,
userConfig().javaHome,
discovered,
)
}
case kind =>
Left(s"Starting debug session for $kind in not supported.")
}
}

for {
_ <- compilations.compileTargets(params.getTargets().asScala.toSeq)
} yield {
val debuggee = getDebugee match {
debuggee <- getDebugee match {
case Right(debuggee) => debuggee
case Left(errorMessage) => throw new RuntimeException(errorMessage)
case Left(errorMessage) =>
Future.failed(new RuntimeException(errorMessage))
}
} yield {
val dapLogger = new DebugLogger()
val resolver = new MetalsDebugToolsResolver()
val handler =
Expand All @@ -373,6 +397,37 @@ class DebugProvider(
}
}

private def discoverTests(
id: BuildTargetIdentifier,
testClasses: b.ScalaTestSuites,
): Future[Map[TestFramework, List[Discovered]]] = {
val symbolInfosList =
for {
selection <- testClasses.getSuites().asScala.toList
(sym, info) <- buildTargetClasses.getTestClasses(
selection.getClassName(),
id,
)
} yield compilers.info(id, sym).map(_.map(pcInfo => (info, pcInfo)))

Future.sequence(symbolInfosList).map {
_.flatten.groupBy(_._1.framework).map { case (framework, testSuites) =>
(
framework,
testSuites.map { case (testInfo, pcInfo) =>
new Discovered(
pcInfo.symbol,
testInfo.fullyQualifiedName,
pcInfo.recursiveParents.map(_.symbolToFullQualifiedName).toSet,
(pcInfo.annotations ++ pcInfo.memberDefsAnnotations).toSet,
isModule = false,
)
},
)
}
}
}

/**
* Given a BuildTargetIdentifier either get the displayName of that build
* target or default to the full URI to display to the user.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package scala.meta.internal.metals.debug.server

import java.net.InetSocketAddress
import java.util.concurrent.atomic.AtomicBoolean

import ch.epfl.scala.debugadapter.DebuggeeListener

class Logger(listener: DebuggeeListener) {
private val initialized = new AtomicBoolean(false)
private final val JDINotificationPrefix =
"Listening for transport dt_socket at address: "

def logError(errorMessage: String): Unit = {
listener.err(errorMessage)
scribe.error(errorMessage)
}

def logOutput(msg: String): Unit = {
if (msg.startsWith(JDINotificationPrefix)) {
if (initialized.compareAndSet(false, true)) {
val port = Integer.parseInt(msg.drop(JDINotificationPrefix.length))
val address = new InetSocketAddress("127.0.0.1", port)
listener.onListening(address)
}
} else listener.out(msg)
}
}
Loading

0 comments on commit aefe215

Please sign in to comment.