Skip to content

Commit 0782024

Browse files
sarutakHyukjinKwon
authored andcommitted
[SPARK-37026][ML][BUILD] Ensure the element type of ResolvedRFormula.terms is scala.Seq for Scala 2.13
### What changes were proposed in this pull request? This PR fixes the issue that `scala.Seq[scala.collection.mutable.ArraySeq$ofRef]` will be passed to `ResolvedRFormula.terms` though it expects `scala.Seq[scala.Seq[String]]` with Scala 2.13. As of Scala 2.13, `scala.Seq` is `scala.collection.immutable.Seq`, so this issue happens. ### Why are the changes needed? Bug fix. Due to this issue, `ResolvedRFormula.toString` throws `ClassCastException`. ``` java.lang.ClassCastException: scala.collection.mutable.ArraySeq$ofRef cannot be cast to scala.collection.immutable.Seq at scala.collection.immutable.List.map(List.scala:246) at scala.collection.immutable.List.map(List.scala:79) at org.apache.spark.ml.feature.ResolvedRFormula.toString(RFormulaParser.scala:143) at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method) at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62) at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43) at java.lang.reflect.Method.invoke(Method.java:498) at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244) at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357) at py4j.Gateway.invoke(Gateway.java:282) at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132) at py4j.commands.CallCommand.execute(CallCommand.java:79) at py4j.ClientServerConnection.waitForCommands(ClientServerConnection.java:182) at py4j.ClientServerConnection.run(ClientServerConnection.java:106) at java.lang.Thread.run(Thread.java:748) ``` ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? New test is added and `build/sbt -Pscala-2.13 "testOnly org.apache.spark.ml.feature.RFormulaSuite"` passes. CIs should ensure that this change works with Scala 2.12 too. Closes apache#34301 from sarutak/fix-rformula-scala-2.13. Authored-by: Kousuke Saruta <[email protected]> Signed-off-by: Hyukjin Kwon <[email protected]>
1 parent f9cc7fb commit 0782024

File tree

2 files changed

+12
-1
lines changed

2 files changed

+12
-1
lines changed

mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -449,7 +449,7 @@ object RFormulaModel extends MLReadable[RFormulaModel] {
449449
val dataPath = new Path(path, "data").toString
450450
val data = sparkSession.read.parquet(dataPath).select("label", "terms", "hasIntercept").head()
451451
val label = data.getString(0)
452-
val terms = data.getSeq[Seq[String]](1)
452+
val terms = data.getSeq[scala.collection.Seq[String]](1).map(_.toSeq)
453453
val hasIntercept = data.getBoolean(2)
454454
val resolvedRFormula = ResolvedRFormula(label, terms, hasIntercept)
455455

mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -627,4 +627,15 @@ class RFormulaSuite extends MLTest with DefaultReadWriteTest {
627627
assert(get_output("keep").count() == 6)
628628
}
629629

630+
test("SPARK-37026: Ensure the element type of ResolvedRFormula.terms is " +
631+
"scala.Seq for Scala 2.13") {
632+
withTempPath { path =>
633+
val dataset = Seq((1, "foo", "zq"), (2, "bar", "zq"), (3, "bar", "zz")).toDF("id", "a", "b")
634+
val rFormula = new RFormula().setFormula("id ~ a:b")
635+
val model = rFormula.fit(dataset)
636+
model.save(path.getCanonicalPath)
637+
val newModel = RFormulaModel.load(path.getCanonicalPath)
638+
newModel.resolvedFormula.toString
639+
}
640+
}
630641
}

0 commit comments

Comments
 (0)