Skip to content

Commit 2e7fef6

Browse files
committed
.
1 parent f526951 commit 2e7fef6

File tree

1 file changed

+112
-0
lines changed

1 file changed

+112
-0
lines changed
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
import sbt._
2+
import sbt.Keys._
3+
4+
/**
5+
* ShadedSourceGenerator - A build plugin for creating shaded versions of external dependencies
6+
*
7+
* This generator downloads source JARs for specified dependencies (currently pprint, fansi, and sourcecode),
8+
* extracts them, and applies patches to:
9+
* 1. Add the dotty.shaded package prefix
10+
* 2. Rewrite imports to use _root_ to avoid conflicts
11+
* 3. Apply Scala 3 compatibility fixes (mostly due to enforcing null safety in scala/scala3)
12+
*
13+
* The shaded sources are placed in the managed source directory and included in compilation.
14+
* This allows the Scala 3 compiler to bundle these utilities without external dependencies.
15+
*/
16+
object ShadedSourceGenerator {
17+
18+
val task = Def.task {
19+
val s = streams.value
20+
val cacheDir = s.cacheDirectory
21+
val dest = (Compile / sourceManaged).value / "downloaded"
22+
val lm = dependencyResolution.value
23+
24+
val dependencies = Seq(
25+
("com.lihaoyi", "pprint_3", "0.9.3"),
26+
("com.lihaoyi", "fansi_3", "0.5.1"),
27+
("com.lihaoyi", "sourcecode_3", "0.4.4"),
28+
)
29+
30+
// Create a marker file that tracks the dependencies for cache invalidation
31+
val markerFile = cacheDir / "shaded-sources-marker"
32+
val markerContent = dependencies.map { case (org, name, version) => s"$org:$name:$version:sources" }.mkString("\n")
33+
if (!markerFile.exists || IO.read(markerFile) != markerContent) {
34+
IO.write(markerFile, markerContent)
35+
}
36+
37+
FileFunction.cached(cacheDir / "fetchShadedSources",
38+
FilesInfo.lastModified, FilesInfo.exists) { _ =>
39+
s.log.info(s"Downloading and processing shaded sources to $dest...")
40+
41+
if (dest.exists) {
42+
IO.delete(dest)
43+
}
44+
IO.createDirectory(dest)
45+
46+
for((org, name, version) <- dependencies) {
47+
import sbt.librarymanagement._
48+
49+
val moduleId = ModuleID(org, name, version).sources()
50+
val retrieveDir = cacheDir / "retrieved" / s"$org-$name-$version-sources"
51+
52+
s.log.info(s"Retrieving $org:$name:$version:sources...")
53+
val retrieved = lm.retrieve(moduleId, scalaModuleInfo = None, retrieveDir, s.log)
54+
val jarFiles = retrieved.fold(
55+
w => throw w.resolveException,
56+
files => files.filter(_.getName.contains("-sources.jar"))
57+
)
58+
59+
jarFiles.foreach { jarFile =>
60+
s.log.info(s"Extracting ${jarFile.getName}...")
61+
IO.unzip(jarFile, dest)
62+
}
63+
}
64+
65+
val scalaFiles = (dest ** "*.scala").get
66+
67+
// Define patches as a map from search text to replacement text
68+
val patches = Map(
69+
"import scala" -> "import _root_.scala",
70+
" scala.collection." -> " _root_.scala.collection.",
71+
"_root_.pprint" -> "_root_.dotty.shaded.pprint",
72+
"_root_.fansi" -> "_root_.dotty.shaded.fansi",
73+
"def apply(c: Char): Trie[T]" -> "def apply(c: Char): Trie[T] | Null",
74+
"var head: Iterator[T] = null" -> "var head: Iterator[T] | Null = null",
75+
"if (head != null && head.hasNext) true" -> "if (head != null && head.nn.hasNext) true",
76+
"head.next()" -> "head.nn.next()",
77+
"abstract class Walker" -> "@scala.annotation.nowarn abstract class Walker",
78+
"object TPrintLowPri" -> "@scala.annotation.nowarn object TPrintLowPri",
79+
"x.toString match{" -> "scala.runtime.ScalaRunTime.stringOf(x) match{"
80+
)
81+
82+
val patchUsageCounter = scala.collection.mutable.Map(patches.keys.map(_ -> 0).toSeq: _*)
83+
84+
scalaFiles.foreach { file =>
85+
val text = IO.read(file)
86+
if (!file.getName.equals("CollectionName.scala")) {
87+
var processedText = "package dotty.shaded\n" + text
88+
89+
// Apply patches and count usage
90+
patches.foreach { case (search, replacement) =>
91+
if (processedText.contains(search)) {
92+
processedText = processedText.replace(search, replacement)
93+
patchUsageCounter(search) += 1
94+
}
95+
}
96+
97+
IO.write(file, processedText)
98+
}
99+
}
100+
101+
// Assert that all patches were applied at least once
102+
val unappliedPatches = patchUsageCounter.filter(_._2 == 0).keys
103+
if (unappliedPatches.nonEmpty) {
104+
throw new RuntimeException(s"Patches were not applied: ${unappliedPatches.mkString(", ")}")
105+
}
106+
107+
scalaFiles.toSet
108+
} (Set(markerFile)).toSeq
109+
110+
}
111+
112+
}

0 commit comments

Comments
 (0)