Skip to content

Commit 568f920

Browse files
committed
[SPARK-52051][ML][CONNECT] Enable model summary when memory control is enabled
### What changes were proposed in this pull request? Enable model summary in SparkConnect when memory control is enabled. ### Why are the changes needed? Motivation: model summary is necessary in many use-cases. although it hasn't support offloading, we can still enable it. User can use the summary object within the offloading timeout. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Unit tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes apache#50843 from WeichenXu123/spark-connect-enable-summary. Lead-authored-by: Weichen Xu <[email protected]> Co-authored-by: WeichenXu <[email protected]> Signed-off-by: Weichen Xu <[email protected]>
1 parent 207e296 commit 568f920

File tree

18 files changed

+116
-232
lines changed

18 files changed

+116
-232
lines changed

common/utils/src/main/resources/error/error-conditions.json

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -826,15 +826,17 @@
826826
},
827827
"CACHE_INVALID" : {
828828
"message" : [
829-
"Cannot retrieve <objectName> from the ML cache. It is probably because the entry has been evicted."
829+
"Cannot retrieve Summary object <objectName> from the ML cache.",
830+
"The Summary object is evicted if it hasn't been used a specified period of time.",
831+
"You can configure the timeout by setting Spark cluster configure 'spark.connect.session.connectML.mlCache.memoryControl.offloadingTimeout'."
830832
]
831833
},
832834
"ML_CACHE_SIZE_OVERFLOW_EXCEPTION" : {
833835
"message" : [
834836
"The model cache size in current session is about to exceed",
835837
"<mlCacheMaxSize> bytes.",
836838
"Please delete existing cached model by executing 'del model' in python client before fitting new model or loading new model,",
837-
"or increase Spark config 'spark.connect.session.connectML.mlCache.memoryControl.maxSize'."
839+
"or increase Spark config 'spark.connect.session.connectML.mlCache.memoryControl.maxStorageSize'."
838840
]
839841
},
840842
"MODEL_SIZE_OVERFLOW_EXCEPTION" : {

mllib/src/main/scala/org/apache/spark/ml/classification/FMClassifier.scala

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -224,18 +224,15 @@ class FMClassifier @Since("3.0.0") (
224224
val model = copyValues(new FMClassificationModel(uid, intercept, linear, factors))
225225
val weightColName = if (!isDefined(weightCol)) "weightCol" else $(weightCol)
226226

227-
if (SummaryUtils.enableTrainingSummary) {
228-
val (summaryModel, probabilityColName, predictionColName) = model.findSummaryModel()
229-
val summary = new FMClassificationTrainingSummaryImpl(
230-
summaryModel.transform(dataset),
231-
probabilityColName,
232-
predictionColName,
233-
$(labelCol),
234-
weightColName,
235-
objectiveHistory)
236-
model.setSummary(Some(summary))
237-
}
238-
model
227+
val (summaryModel, probabilityColName, predictionColName) = model.findSummaryModel()
228+
val summary = new FMClassificationTrainingSummaryImpl(
229+
summaryModel.transform(dataset),
230+
probabilityColName,
231+
predictionColName,
232+
$(labelCol),
233+
weightColName,
234+
objectiveHistory)
235+
model.setSummary(Some(summary))
239236
}
240237

241238
@Since("3.0.0")

mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -277,18 +277,15 @@ class LinearSVC @Since("2.2.0") (
277277
val model = copyValues(new LinearSVCModel(uid, coefficients, intercept))
278278
val weightColName = if (!isDefined(weightCol)) "weightCol" else $(weightCol)
279279

280-
if (SummaryUtils.enableTrainingSummary) {
281-
val (summaryModel, rawPredictionColName, predictionColName) = model.findSummaryModel()
282-
val summary = new LinearSVCTrainingSummaryImpl(
283-
summaryModel.transform(dataset),
284-
rawPredictionColName,
285-
predictionColName,
286-
$(labelCol),
287-
weightColName,
288-
objectiveHistory)
289-
model.setSummary(Some(summary))
290-
}
291-
model
280+
val (summaryModel, rawPredictionColName, predictionColName) = model.findSummaryModel()
281+
val summary = new LinearSVCTrainingSummaryImpl(
282+
summaryModel.transform(dataset),
283+
rawPredictionColName,
284+
predictionColName,
285+
$(labelCol),
286+
weightColName,
287+
objectiveHistory)
288+
model.setSummary(Some(summary))
292289
}
293290

294291
private def trainImpl(

mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala

Lines changed: 20 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -711,30 +711,27 @@ class LogisticRegression @Since("1.2.0") (
711711
numClasses, checkMultinomial(numClasses)))
712712
val weightColName = if (!isDefined(weightCol)) "weightCol" else $(weightCol)
713713

714-
if (SummaryUtils.enableTrainingSummary) {
715-
val (summaryModel, probabilityColName, predictionColName) = model.findSummaryModel()
716-
val logRegSummary = if (numClasses <= 2) {
717-
new BinaryLogisticRegressionTrainingSummaryImpl(
718-
summaryModel.transform(dataset),
719-
probabilityColName,
720-
predictionColName,
721-
$(labelCol),
722-
$(featuresCol),
723-
weightColName,
724-
objectiveHistory)
725-
} else {
726-
new LogisticRegressionTrainingSummaryImpl(
727-
summaryModel.transform(dataset),
728-
probabilityColName,
729-
predictionColName,
730-
$(labelCol),
731-
$(featuresCol),
732-
weightColName,
733-
objectiveHistory)
734-
}
735-
model.setSummary(Some(logRegSummary))
714+
val (summaryModel, probabilityColName, predictionColName) = model.findSummaryModel()
715+
val logRegSummary = if (numClasses <= 2) {
716+
new BinaryLogisticRegressionTrainingSummaryImpl(
717+
summaryModel.transform(dataset),
718+
probabilityColName,
719+
predictionColName,
720+
$(labelCol),
721+
$(featuresCol),
722+
weightColName,
723+
objectiveHistory)
724+
} else {
725+
new LogisticRegressionTrainingSummaryImpl(
726+
summaryModel.transform(dataset),
727+
probabilityColName,
728+
predictionColName,
729+
$(labelCol),
730+
$(featuresCol),
731+
weightColName,
732+
objectiveHistory)
736733
}
737-
model
734+
model.setSummary(Some(logRegSummary))
738735
}
739736

740737
private def createBounds(

mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -249,17 +249,14 @@ class MultilayerPerceptronClassifier @Since("1.5.0") (
249249
objectiveHistory: Array[Double]): MultilayerPerceptronClassificationModel = {
250250
val model = copyValues(new MultilayerPerceptronClassificationModel(uid, weights))
251251

252-
if (SummaryUtils.enableTrainingSummary) {
253-
val (summaryModel, _, predictionColName) = model.findSummaryModel()
254-
val summary = new MultilayerPerceptronClassificationTrainingSummaryImpl(
255-
summaryModel.transform(dataset),
256-
predictionColName,
257-
$(labelCol),
258-
"",
259-
objectiveHistory)
260-
model.setSummary(Some(summary))
261-
}
262-
model
252+
val (summaryModel, _, predictionColName) = model.findSummaryModel()
253+
val summary = new MultilayerPerceptronClassificationTrainingSummaryImpl(
254+
summaryModel.transform(dataset),
255+
predictionColName,
256+
$(labelCol),
257+
"",
258+
objectiveHistory)
259+
model.setSummary(Some(summary))
263260
}
264261
}
265262

mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala

Lines changed: 16 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -185,26 +185,23 @@ class RandomForestClassifier @Since("1.4.0") (
185185
val weightColName = if (!isDefined(weightCol)) "weightCol" else $(weightCol)
186186

187187
val (summaryModel, probabilityColName, predictionColName) = model.findSummaryModel()
188-
if (SummaryUtils.enableTrainingSummary) {
189-
val rfSummary = if (numClasses <= 2) {
190-
new BinaryRandomForestClassificationTrainingSummaryImpl(
191-
summaryModel.transform(dataset),
192-
probabilityColName,
193-
predictionColName,
194-
$(labelCol),
195-
weightColName,
196-
Array(0.0))
197-
} else {
198-
new RandomForestClassificationTrainingSummaryImpl(
199-
summaryModel.transform(dataset),
200-
predictionColName,
201-
$(labelCol),
202-
weightColName,
203-
Array(0.0))
204-
}
205-
model.setSummary(Some(rfSummary))
188+
val rfSummary = if (numClasses <= 2) {
189+
new BinaryRandomForestClassificationTrainingSummaryImpl(
190+
summaryModel.transform(dataset),
191+
probabilityColName,
192+
predictionColName,
193+
$(labelCol),
194+
weightColName,
195+
Array(0.0))
196+
} else {
197+
new RandomForestClassificationTrainingSummaryImpl(
198+
summaryModel.transform(dataset),
199+
predictionColName,
200+
$(labelCol),
201+
weightColName,
202+
Array(0.0))
206203
}
207-
model
204+
model.setSummary(Some(rfSummary))
208205
}
209206

210207
@Since("1.4.1")

mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -303,19 +303,16 @@ class BisectingKMeans @Since("2.0.0") (
303303
val parentModel = bkm.runWithWeight(instances, handlePersistence, Some(instr))
304304
val model = copyValues(new BisectingKMeansModel(uid, parentModel).setParent(this))
305305

306-
if (SummaryUtils.enableTrainingSummary) {
307-
val summary = new BisectingKMeansSummary(
308-
model.transform(dataset),
309-
$(predictionCol),
310-
$(featuresCol),
311-
$(k),
312-
$(maxIter),
313-
parentModel.trainingCost)
314-
instr.logNamedValue("clusterSizes", summary.clusterSizes)
315-
instr.logNumFeatures(model.clusterCenters.head.size)
316-
model.setSummary(Some(summary))
317-
}
318-
model
306+
val summary = new BisectingKMeansSummary(
307+
model.transform(dataset),
308+
$(predictionCol),
309+
$(featuresCol),
310+
$(k),
311+
$(maxIter),
312+
parentModel.trainingCost)
313+
instr.logNamedValue("clusterSizes", summary.clusterSizes)
314+
instr.logNumFeatures(model.clusterCenters.head.size)
315+
model.setSummary(Some(summary))
319316
}
320317

321318
@Since("2.0.0")

mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -430,14 +430,11 @@ class GaussianMixture @Since("2.0.0") (
430430

431431
val model = copyValues(new GaussianMixtureModel(uid, weights, gaussianDists))
432432
.setParent(this)
433-
if (SummaryUtils.enableTrainingSummary) {
434-
val summary = new GaussianMixtureSummary(model.transform(dataset),
435-
$(predictionCol), $(probabilityCol), $(featuresCol), $(k), logLikelihood, iteration)
436-
instr.logNamedValue("logLikelihood", logLikelihood)
437-
instr.logNamedValue("clusterSizes", summary.clusterSizes)
438-
model.setSummary(Some(summary))
439-
}
440-
model
433+
val summary = new GaussianMixtureSummary(model.transform(dataset),
434+
$(predictionCol), $(probabilityCol), $(featuresCol), $(k), logLikelihood, iteration)
435+
instr.logNamedValue("logLikelihood", logLikelihood)
436+
instr.logNamedValue("clusterSizes", summary.clusterSizes)
437+
model.setSummary(Some(summary))
441438
}
442439

443440
private def trainImpl(

mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -391,18 +391,16 @@ class KMeans @Since("1.5.0") (
391391
}
392392

393393
val model = copyValues(new KMeansModel(uid, oldModel).setParent(this))
394-
if (SummaryUtils.enableTrainingSummary) {
395-
val summary = new KMeansSummary(
396-
model.transform(dataset),
397-
$(predictionCol),
398-
$(featuresCol),
399-
$(k),
400-
oldModel.numIter,
401-
oldModel.trainingCost)
402-
403-
model.setSummary(Some(summary))
404-
instr.logNamedValue("clusterSizes", summary.clusterSizes)
405-
}
394+
val summary = new KMeansSummary(
395+
model.transform(dataset),
396+
$(predictionCol),
397+
$(featuresCol),
398+
$(k),
399+
oldModel.numIter,
400+
oldModel.trainingCost)
401+
402+
model.setSummary(Some(summary))
403+
instr.logNamedValue("clusterSizes", summary.clusterSizes)
406404
model
407405
}
408406

mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -418,12 +418,9 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
418418
val model = copyValues(
419419
new GeneralizedLinearRegressionModel(uid, wlsModel.coefficients, wlsModel.intercept)
420420
.setParent(this))
421-
if (SummaryUtils.enableTrainingSummary) {
422-
val trainingSummary = new GeneralizedLinearRegressionTrainingSummary(dataset, model,
423-
wlsModel.diagInvAtWA.toArray, 1, getSolver)
424-
model.setSummary(Some(trainingSummary))
425-
}
426-
model
421+
val trainingSummary = new GeneralizedLinearRegressionTrainingSummary(dataset, model,
422+
wlsModel.diagInvAtWA.toArray, 1, getSolver)
423+
model.setSummary(Some(trainingSummary))
427424
} else {
428425
val instances = validated.rdd.map {
429426
case Row(label: Double, weight: Double, offset: Double, features: Vector) =>
@@ -438,12 +435,9 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
438435
val model = copyValues(
439436
new GeneralizedLinearRegressionModel(uid, irlsModel.coefficients, irlsModel.intercept)
440437
.setParent(this))
441-
if (SummaryUtils.enableTrainingSummary) {
442-
val trainingSummary = new GeneralizedLinearRegressionTrainingSummary(dataset, model,
443-
irlsModel.diagInvAtWA.toArray, irlsModel.numIterations, getSolver)
444-
model.setSummary(Some(trainingSummary))
445-
}
446-
model
438+
val trainingSummary = new GeneralizedLinearRegressionTrainingSummary(dataset, model,
439+
irlsModel.diagInvAtWA.toArray, irlsModel.numIterations, getSolver)
440+
model.setSummary(Some(trainingSummary))
447441
}
448442

449443
model

0 commit comments

Comments
 (0)