Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 111 additions & 0 deletions compiler/src/dotty/tools/repl/DependencyResolver.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
package dotty.tools.repl

import scala.language.unsafeNulls

import java.io.File
import java.net.{URL, URLClassLoader}
import scala.jdk.CollectionConverters.*
import scala.util.control.NonFatal

import coursierapi.{Repository, Dependency, MavenRepository}
import com.virtuslab.using_directives.UsingDirectivesProcessor
import com.virtuslab.using_directives.custom.model.{Path, StringValue, Value}

/** Handles dependency resolution using Coursier for the REPL */
object DependencyResolver:

/** Parse a dependency string of the form `org::artifact:version` or `org:artifact:version`
* and return the (organization, artifact, version) triple if successful.
*
* Supports both Maven-style (single colon) and Scala-style (double colon) notation:
* - Maven: `com.lihaoyi:scalatags_3:0.13.1`
* - Scala: `com.lihaoyi::scalatags:0.13.1` (automatically appends _3)
*/
def parseDependency(dep: String): Option[(String, String, String)] =
dep match
case s"$org::$artifact:$version" => Some((org, s"${artifact}_3", version))
case s"$org:$artifact:$version" => Some((org, artifact, version))
case _ =>
System.err.println("Unable to parse dependency \"" + dep + "\"")
None

/** Extract all dependencies from using directives in source code */
def extractDependencies(sourceCode: String): List[String] =
try
val directives = new UsingDirectivesProcessor().extract(sourceCode.toCharArray)
val deps = scala.collection.mutable.Buffer[String]()

for
directive <- directives.asScala
(path, values) <- directive.getFlattenedMap.asScala
do
if path.getPath.asScala.toList == List("dep") then
values.asScala.foreach {
case strValue: StringValue => deps += strValue.get()
case value => System.err.println("Unrecognized directive value " + value)
}
else
System.err.println("Unrecognized directive " + path.getPath)

deps.toList
catch
case NonFatal(e) => Nil // If parsing fails, fall back to empty list

/** Resolve dependencies using Coursier Interface and return the classpath as a list of File objects */
def resolveDependencies(dependencies: List[(String, String, String)]): Either[String, List[File]] =
if dependencies.isEmpty then Right(Nil)
else
try
// Add Maven Central and Sonatype repositories
val repos = Array(
MavenRepository.of("https://repo1.maven.org/maven2"),
MavenRepository.of("https://oss.sonatype.org/content/repositories/releases")
)

// Create dependency objects
val deps = dependencies
.map { case (org, artifact, version) => Dependency.of(org, artifact, version) }
.toArray

val fetch = coursierapi.Fetch.create()
.withRepositories(repos*)
.withDependencies(deps*)

Right(fetch.fetch().asScala.toList)

catch
case NonFatal(e) =>
Left(s"Failed to resolve dependencies: ${e.getMessage}")

/** Add resolved dependencies to the compiler classpath and classloader.
* Returns the new classloader.
*
* This follows the same pattern as the `:jar` command.
*/
def addToCompilerClasspath(
files: List[File],
prevClassLoader: ClassLoader,
prevOutputDir: dotty.tools.io.AbstractFile
)(using ctx: dotty.tools.dotc.core.Contexts.Context): AbstractFileClassLoader =
import dotty.tools.dotc.classpath.ClassPathFactory
import dotty.tools.dotc.core.SymbolLoaders
import dotty.tools.dotc.core.Symbols.defn
import dotty.tools.io.*
import dotty.tools.runner.ScalaClassLoader.fromURLsParallelCapable

// Create a classloader with all the resolved JAR files
val urls = files.map(_.toURI.toURL).toArray
val depsClassLoader = new URLClassLoader(urls, prevClassLoader)

// Add each JAR to the compiler's classpath
for file <- files do
val jarFile = AbstractFile.getDirectory(file.getAbsolutePath)
if jarFile != null then
val jarClassPath = ClassPathFactory.newClassPath(jarFile)
ctx.platform.addToClassPath(jarClassPath)
SymbolLoaders.mergeNewEntries(defn.RootClass, ClassPath.RootPackage, jarClassPath, ctx.platform.classPath)

// Create new classloader with previous output dir and resolved dependencies
new AbstractFileClassLoader(prevOutputDir, depsClassLoader)

end DependencyResolver
31 changes: 29 additions & 2 deletions compiler/src/dotty/tools/repl/ReplDriver.scala
Original file line number Diff line number Diff line change
Expand Up @@ -301,8 +301,35 @@ class ReplDriver(settings: Array[String],

protected def interpret(res: ParseResult)(using state: State): State = {
res match {
case parsed: Parsed if parsed.trees.nonEmpty =>
compile(parsed, state)
case parsed: Parsed =>
// Check for magic comments specifying dependencies
val sourceCode = parsed.source.content().mkString
val depStrings = DependencyResolver.extractDependencies(sourceCode)

if depStrings.nonEmpty then
val deps = depStrings.flatMap(DependencyResolver.parseDependency)
if deps.nonEmpty then
DependencyResolver.resolveDependencies(deps) match
case Right(files) =>
if files.nonEmpty then
inContext(state.context):
// Update both compiler classpath and classloader
val prevOutputDir = ctx.settings.outputDir.value
val prevClassLoader = rendering.classLoader()
rendering.myClassLoader = DependencyResolver.addToCompilerClasspath(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just a nod to just a nod to #24119 which has sensitive state towards whatever the current classloader is.

files,
prevClassLoader,
prevOutputDir
)
out.println(s"Resolved ${deps.size} dependencies (${files.size} JARs)")
case Left(error) =>
out.println(s"Error resolving dependencies: $error")

// Only compile if there are actual trees to compile
if parsed.trees.nonEmpty then
compile(parsed, state)
else
state

case SyntaxErrors(_, errs, _) =>
displayErrors(errs)
Expand Down
4 changes: 4 additions & 0 deletions project/Build.scala
Original file line number Diff line number Diff line change
Expand Up @@ -717,6 +717,8 @@ object Build {
"org.jline" % "jline-terminal" % "3.29.0",
"org.jline" % "jline-terminal-jni" % "3.29.0", // needed for Windows
("io.get-coursier" %% "coursier" % "2.0.16" % Test).cross(CrossVersion.for3Use2_13),
"io.get-coursier" % "interface" % "1.0.19", // used by the REPL for dependency resolution
"org.virtuslab" % "using_directives" % "1.1.4", // used by the REPL for parsing magic comments
),

(Compile / sourceGenerators) += ShadedSourceGenerator.task.taskValue,
Expand Down Expand Up @@ -2137,6 +2139,8 @@ object Build {
"org.jline" % "jline-terminal" % "3.29.0",
"org.jline" % "jline-terminal-jni" % "3.29.0",
("io.get-coursier" %% "coursier" % "2.0.16" % Test).cross(CrossVersion.for3Use2_13),
"io.get-coursier" % "interface" % "1.0.19", // used by the REPL for dependency resolution
"org.virtuslab" % "using_directives" % "1.1.4", // used by the REPL for parsing magic comments
),
// NOTE: The only difference here is that we drop `-Werror` and semanticDB for now
Compile / scalacOptions := Seq("-deprecation", "-feature", "-unchecked", "-encoding", "UTF8", "-language:implicitConversions"),
Expand Down
Loading