Skip to content
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package scalafix.internal.patch

import scala.annotation.tailrec

import scala.meta._
import scala.meta.internal.trees._

Expand All @@ -12,6 +14,11 @@ import scalafix.syntax._
import scalafix.v0._

object ReplaceSymbolOps {
private case class ImportInfo(
globalImports: Seq[Import],
globalImportedSymbols: Map[String, Symbol]
)

private object Select {
def unapply(arg: Ref): Option[(Ref, Name)] = arg match {
case Term.Select(a: Ref, b) => Some(a -> b)
Expand All @@ -20,10 +27,47 @@ object ReplaceSymbolOps {
}
}

private def extractImports(stats: Seq[Stat]): Seq[Import] = {
stats.collect { case i: Import => i }
}

private def extractImportInfo(
tree: Tree
)(implicit index: SemanticdbIndex): ImportInfo = {
@tailrec
def getGlobalImports(ast: Tree): Seq[Import] = ast match {
case Pkg(_, Seq(pkg: Pkg)) => getGlobalImports(pkg)
case Source(Seq(pkg: Pkg)) => getGlobalImports(pkg)
case Pkg(_, stats) => extractImports(stats)
case Source(stats) => extractImports(stats)
case _ => Nil
}

val globalImports = getGlobalImports(tree)

// pre-compute global imported symbols for O(1) collision detection
// since ctx.addGlobalImport adds imports at global scope
val globalImportedSymbols = globalImports.flatMap { importStat =>
importStat.importers.flatMap { importer =>
importer.importees.collect {
case Importee.Name(name) =>
name.value -> name.symbol.getOrElse(Symbol.None)
case Importee.Rename(_, rename) =>
rename.value -> rename.symbol.getOrElse(Symbol.None)
}
}
}.toMap

ImportInfo(globalImports, globalImportedSymbols)
}

def naiveMoveSymbolPatch(
moveSymbols: Seq[ReplaceSymbol]
)(implicit ctx: RuleCtx, index: SemanticdbIndex): Patch = {
if (moveSymbols.isEmpty) return Patch.empty

val importInfo = extractImportInfo(ctx.tree)(index)

val moves: Map[String, Symbol.Global] =
moveSymbols.iterator.flatMap {
case ReplaceSymbol(
Expand Down Expand Up @@ -126,10 +170,15 @@ object ReplaceSymbolOps {
if sig.name != parent.value =>
Patch.empty // do nothing because it was a renamed symbol
case Some(_) =>
val causesCollision =
importInfo.globalImportedSymbols.contains(to.signature.name)
val addImport =
if (n.isDefinition) Patch.empty
if (n.isDefinition || causesCollision) Patch.empty
else ctx.addGlobalImport(to)
addImport + ctx.replaceTree(n, to.signature.name)
if (causesCollision)
addImport + ctx.replaceTree(n, to.owner.syntax + to.signature.name)
Copy link

Copilot AI Aug 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

String concatenation for building qualified names could be error-prone. Consider using a more robust method like to.owner.syntax + "." + to.signature.name or a dedicated method to ensure proper formatting.

Suggested change
addImport + ctx.replaceTree(n, to.owner.syntax + to.signature.name)
addImport + ctx.replaceTree(n, to.owner.syntax + "." + to.signature.name)

Copilot uses AI. Check for mistakes.

else
addImport + ctx.replaceTree(n, to.signature.name)
Copy link

Copilot AI Aug 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The collision handling logic creates duplicated code paths. Consider extracting the replacement logic into a helper function to reduce duplication between the collision and non-collision cases.

Suggested change
addImport + ctx.replaceTree(n, to.signature.name)
patchForCollision(n, to, addImport, causesCollision)

Copilot uses AI. Check for mistakes.

case _ =>
Patch.empty
}
Expand Down
4 changes: 4 additions & 0 deletions scalafix-tests/input/src/main/scala/test/ReplaceSymbol.scala
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ patches.replaceSymbols = [
to = "com.geirsson.mutable.CoolBuffer" }
{ from = "scala.collection.mutable.HashMap"
to = "com.geirsson.mutable.unsafe.CoolMap" }
{ from = "scala.collection.immutable.TreeMap"
to = "com.geirsson.immutable.SortedMap" }
{ from = "scala.math.sqrt"
to = "com.geirsson.fastmath.sqrt" }
// normalized symbol renames all overloaded methods
Expand All @@ -29,6 +31,7 @@ patches.replaceSymbols = [
*/
package fix

import scala.collection.immutable.{ SortedMap, TreeMap }
import scala.collection.mutable.HashMap
import scala.collection.mutable.ListBuffer
import scala.collection.mutable
Expand All @@ -42,6 +45,7 @@ object ReplaceSymbol {
"blah".substring(1)
"blah".substring(1, 2)
val u: mutable.HashMap[Int, Int] = HashMap.empty[Int, Int]
val v: SortedMap[Int, Int] = TreeMap.empty[Int, Int]
val x: ListBuffer[Int] = ListBuffer.empty[Int]
val y: mutable.ListBuffer[Int] = mutable.ListBuffer.empty[Int]
val z: scala.collection.mutable.ListBuffer[Int] =
Expand Down
11 changes: 11 additions & 0 deletions scalafix-tests/output/src/main/scala/com/geirsson/immutable.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package com.geirsson

import scala.collection.immutable.TreeMap

object immutable {
type SortedMap[A, B] = TreeMap[A, B]

object SortedMap {
def empty[A : Ordering, B]: SortedMap[A, B] = TreeMap.empty[A, B]
}
}
2 changes: 2 additions & 0 deletions scalafix-tests/output/src/main/scala/test/ReplaceSymbol.scala
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package fix

import scala.collection.immutable.SortedMap
import com.geirsson.Future
import com.geirsson.{ fastmath, mutable }
import com.geirsson.mutable.{ CoolBuffer, unsafe }
Expand All @@ -13,6 +14,7 @@ object ReplaceSymbol {
"blah".substringFrom(1)
"blah".substringBetween(1, 2)
val u: unsafe.CoolMap[Int, Int] = CoolMap.empty[Int, Int]
val v: SortedMap[Int, Int] = com.geirsson.immutable.SortedMap.empty[Int, Int]
val x: CoolBuffer[Int] = CoolBuffer.empty[Int]
val y: mutable.CoolBuffer[Int] = mutable.CoolBuffer.empty[Int]
val z: com.geirsson.mutable.CoolBuffer[Int] =
Expand Down