diff --git a/scalafix-core/src/main/scala/scalafix/internal/patch/ReplaceSymbolOps.scala b/scalafix-core/src/main/scala/scalafix/internal/patch/ReplaceSymbolOps.scala index 052dbaf50..19ab9ea99 100644 --- a/scalafix-core/src/main/scala/scalafix/internal/patch/ReplaceSymbolOps.scala +++ b/scalafix-core/src/main/scala/scalafix/internal/patch/ReplaceSymbolOps.scala @@ -1,5 +1,7 @@ package scalafix.internal.patch +import scala.annotation.tailrec + import scala.meta._ import scala.meta.internal.trees._ @@ -12,6 +14,11 @@ import scalafix.syntax._ import scalafix.v0._ object ReplaceSymbolOps { + private case class ImportInfo( + globalImports: Seq[Import], + globalImportedSymbols: Map[String, Symbol] + ) + private object Select { def unapply(arg: Ref): Option[(Ref, Name)] = arg match { case Term.Select(a: Ref, b) => Some(a -> b) @@ -20,10 +27,47 @@ object ReplaceSymbolOps { } } + private def extractImports(stats: Seq[Stat]): Seq[Import] = { + stats.collect { case i: Import => i } + } + + private def extractImportInfo( + tree: Tree + )(implicit index: SemanticdbIndex): ImportInfo = { + @tailrec + def getGlobalImports(ast: Tree): Seq[Import] = ast match { + case Pkg(_, Seq(pkg: Pkg)) => getGlobalImports(pkg) + case Source(Seq(pkg: Pkg)) => getGlobalImports(pkg) + case Pkg(_, stats) => extractImports(stats) + case Source(stats) => extractImports(stats) + case _ => Nil + } + + val globalImports = getGlobalImports(tree) + + // pre-compute global imported symbols for O(1) collision detection + // since ctx.addGlobalImport adds imports at global scope + val globalImportedSymbols = globalImports.flatMap { importStat => + importStat.importers.flatMap { importer => + importer.importees.collect { + case Importee.Name(name) => + name.value -> name.symbol.getOrElse(Symbol.None) + case Importee.Rename(_, rename) => + rename.value -> rename.symbol.getOrElse(Symbol.None) + } + } + }.toMap + + ImportInfo(globalImports, globalImportedSymbols) + } + def naiveMoveSymbolPatch( moveSymbols: Seq[ReplaceSymbol] )(implicit ctx: RuleCtx, index: SemanticdbIndex): Patch = { if (moveSymbols.isEmpty) return Patch.empty + + val importInfo = extractImportInfo(ctx.tree)(index) + val moves: Map[String, Symbol.Global] = moveSymbols.iterator.flatMap { case ReplaceSymbol( @@ -126,10 +170,15 @@ object ReplaceSymbolOps { if sig.name != parent.value => Patch.empty // do nothing because it was a renamed symbol case Some(_) => + val causesCollision = + importInfo.globalImportedSymbols.contains(to.signature.name) val addImport = - if (n.isDefinition) Patch.empty + if (n.isDefinition || causesCollision) Patch.empty else ctx.addGlobalImport(to) - addImport + ctx.replaceTree(n, to.signature.name) + if (causesCollision) + addImport + ctx.replaceTree(n, to.owner.syntax + to.signature.name) + else + addImport + ctx.replaceTree(n, to.signature.name) case _ => Patch.empty } diff --git a/scalafix-tests/input/src/main/scala/test/ReplaceSymbol.scala b/scalafix-tests/input/src/main/scala/test/ReplaceSymbol.scala index 5281f1dca..a74f7322b 100644 --- a/scalafix-tests/input/src/main/scala/test/ReplaceSymbol.scala +++ b/scalafix-tests/input/src/main/scala/test/ReplaceSymbol.scala @@ -12,6 +12,8 @@ patches.replaceSymbols = [ to = "com.geirsson.mutable.CoolBuffer" } { from = "scala.collection.mutable.HashMap" to = "com.geirsson.mutable.unsafe.CoolMap" } + { from = "scala.collection.immutable.TreeMap" + to = "com.geirsson.immutable.SortedMap" } { from = "scala.math.sqrt" to = "com.geirsson.fastmath.sqrt" } // normalized symbol renames all overloaded methods @@ -29,6 +31,7 @@ patches.replaceSymbols = [ */ package fix +import scala.collection.immutable.{ SortedMap, TreeMap } import scala.collection.mutable.HashMap import scala.collection.mutable.ListBuffer import scala.collection.mutable @@ -42,6 +45,7 @@ object ReplaceSymbol { "blah".substring(1) "blah".substring(1, 2) val u: mutable.HashMap[Int, Int] = HashMap.empty[Int, Int] + val v: SortedMap[Int, Int] = TreeMap.empty[Int, Int] val x: ListBuffer[Int] = ListBuffer.empty[Int] val y: mutable.ListBuffer[Int] = mutable.ListBuffer.empty[Int] val z: scala.collection.mutable.ListBuffer[Int] = diff --git a/scalafix-tests/output/src/main/scala/com/geirsson/immutable.scala b/scalafix-tests/output/src/main/scala/com/geirsson/immutable.scala new file mode 100644 index 000000000..0e1ad286f --- /dev/null +++ b/scalafix-tests/output/src/main/scala/com/geirsson/immutable.scala @@ -0,0 +1,11 @@ +package com.geirsson + +import scala.collection.immutable.TreeMap + +object immutable { + type SortedMap[A, B] = TreeMap[A, B] + + object SortedMap { + def empty[A : Ordering, B]: SortedMap[A, B] = TreeMap.empty[A, B] + } +} diff --git a/scalafix-tests/output/src/main/scala/test/ReplaceSymbol.scala b/scalafix-tests/output/src/main/scala/test/ReplaceSymbol.scala index 55fc14b58..24f745d68 100644 --- a/scalafix-tests/output/src/main/scala/test/ReplaceSymbol.scala +++ b/scalafix-tests/output/src/main/scala/test/ReplaceSymbol.scala @@ -1,5 +1,6 @@ package fix +import scala.collection.immutable.SortedMap import com.geirsson.Future import com.geirsson.{ fastmath, mutable } import com.geirsson.mutable.{ CoolBuffer, unsafe } @@ -13,6 +14,7 @@ object ReplaceSymbol { "blah".substringFrom(1) "blah".substringBetween(1, 2) val u: unsafe.CoolMap[Int, Int] = CoolMap.empty[Int, Int] + val v: SortedMap[Int, Int] = com.geirsson.immutable.SortedMap.empty[Int, Int] val x: CoolBuffer[Int] = CoolBuffer.empty[Int] val y: mutable.CoolBuffer[Int] = mutable.CoolBuffer.empty[Int] val z: com.geirsson.mutable.CoolBuffer[Int] =