Skip to content

Commit 6f38545

Browse files
committed
address comments
1 parent 636c8ec commit 6f38545

File tree

5 files changed

+214
-1167
lines changed

5 files changed

+214
-1167
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/vectorExpressions.scala

Lines changed: 7 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
2525
import org.apache.spark.sql.catalyst.expressions.aggregate.ImperativeAggregate
2626
import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke
2727
import org.apache.spark.sql.catalyst.trees.UnaryLike
28+
import org.apache.spark.sql.catalyst.types.DataTypeUtils
2829
import org.apache.spark.sql.catalyst.util.ArrayData
2930
import org.apache.spark.sql.errors.{QueryErrorsBase, QueryExecutionErrors}
3031
import org.apache.spark.sql.types.{
@@ -35,7 +36,6 @@ import org.apache.spark.sql.types.{
3536
IntegerType,
3637
LongType,
3738
StringType,
38-
StructField,
3939
StructType
4040
}
4141

@@ -411,21 +411,12 @@ case class VectorAvg(
411411
private lazy val countAttr =
412412
AttributeReference("count", LongType, nullable = false)()
413413

414-
override def aggBufferSchema: StructType = StructType(
415-
Seq(
416-
StructField(
417-
"avg",
418-
BinaryType,
419-
nullable = true
420-
),
421-
StructField("dim", IntegerType, nullable = true),
422-
StructField("count", LongType, nullable = false)
423-
)
424-
)
425-
426414
override def aggBufferAttributes: Seq[AttributeReference] =
427415
Seq(avgAttr, dimAttr, countAttr)
428416

417+
override def aggBufferSchema: StructType =
418+
DataTypeUtils.fromAttributes(aggBufferAttributes)
419+
429420
override lazy val inputAggBufferAttributes: Seq[AttributeReference] =
430421
aggBufferAttributes.map(_.newInstance())
431422

@@ -657,16 +648,12 @@ case class VectorSum(
657648
nullable = true
658649
)()
659650

660-
override def aggBufferSchema: StructType = StructType(
661-
Seq(
662-
StructField("sum", BinaryType, nullable = true),
663-
StructField("dim", IntegerType, nullable = true)
664-
)
665-
)
666-
667651
override def aggBufferAttributes: Seq[AttributeReference] =
668652
Seq(sumAttr, dimAttr)
669653

654+
override def aggBufferSchema: StructType =
655+
DataTypeUtils.fromAttributes(aggBufferAttributes)
656+
670657
override lazy val inputAggBufferAttributes: Seq[AttributeReference] =
671658
aggBufferAttributes.map(_.newInstance())
672659

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/VectorAggSuite.scala

Lines changed: 137 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate
2020
import org.apache.spark.SparkFunSuite
2121
import org.apache.spark.sql.catalyst.InternalRow
2222
import org.apache.spark.sql.catalyst.expressions.{BoundReference, SpecificInternalRow, VectorAvg, VectorSum}
23-
import org.apache.spark.sql.catalyst.util.ArrayData
23+
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData}
2424
import org.apache.spark.sql.types.{ArrayType, FloatType}
2525

2626
class VectorAggSuite extends SparkFunSuite {
@@ -55,6 +55,15 @@ class VectorAggSuite extends SparkFunSuite {
5555
row
5656
}
5757

58+
// Helper to create input row with array containing null elements
59+
def createInputRowWithNullElement(values: Array[java.lang.Float]): InternalRow = {
60+
val arrayData = new GenericArrayData(values.map {
61+
case null => null
62+
case v => v.floatValue().asInstanceOf[AnyRef]
63+
})
64+
InternalRow(arrayData)
65+
}
66+
5867
// Helper to extract result as float array
5968
def evalAsFloatArray(agg: VectorSum, buffer: InternalRow): Array[Float] = {
6069
val result = agg.eval(buffer)
@@ -358,4 +367,131 @@ class VectorAggSuite extends SparkFunSuite {
358367
// Average of 1 to 100 = 50.5
359368
assertFloatArrayEquals(result, Array(50.5f, 50.5f), tolerance = 1e-3f)
360369
}
370+
371+
test("VectorSum - mathematical correctness: element-wise sum") {
372+
val (agg, buffer, _) = createVectorSum()
373+
agg.update(buffer, createInputRow(Array(1.0f, 2.0f, 3.0f)))
374+
agg.update(buffer, createInputRow(Array(10.0f, 20.0f, 30.0f)))
375+
val result = evalAsFloatArray(agg, buffer)
376+
// [1, 2, 3] + [10, 20, 30] = [11, 22, 33]
377+
assert(result === Array(11.0f, 22.0f, 33.0f))
378+
}
379+
380+
test("VectorAvg - mathematical correctness: element-wise average") {
381+
val (agg, buffer, _) = createVectorAvg()
382+
agg.update(buffer, createInputRow(Array(0.0f, 0.0f)))
383+
agg.update(buffer, createInputRow(Array(10.0f, 20.0f)))
384+
val result = evalAsFloatArray(agg, buffer)
385+
// avg([0, 0], [10, 20]) = [5, 10]
386+
assert(result === Array(5.0f, 10.0f))
387+
}
388+
389+
test("VectorAvg - mathematical correctness: negative values") {
390+
val (agg, buffer, _) = createVectorAvg()
391+
agg.update(buffer, createInputRow(Array(-5.0f, 10.0f)))
392+
agg.update(buffer, createInputRow(Array(5.0f, -10.0f)))
393+
val result = evalAsFloatArray(agg, buffer)
394+
// avg([-5, 10], [5, -10]) = [0, 0]
395+
assert(result === Array(0.0f, 0.0f))
396+
}
397+
398+
test("VectorSum - vectors with null elements are skipped") {
399+
val (agg, buffer, _) = createVectorSum()
400+
agg.update(buffer, createInputRow(Array(1.0f, 2.0f)))
401+
agg.update(buffer, createInputRowWithNullElement(Array(null, 10.0f)))
402+
agg.update(buffer, createInputRow(Array(3.0f, 4.0f)))
403+
val result = evalAsFloatArray(agg, buffer)
404+
// Vector with null element is skipped, so [1, 2] + [3, 4] = [4, 6]
405+
assert(result === Array(4.0f, 6.0f))
406+
}
407+
408+
test("VectorAvg - vectors with null elements are skipped") {
409+
val (agg, buffer, _) = createVectorAvg()
410+
agg.update(buffer, createInputRow(Array(1.0f, 2.0f)))
411+
agg.update(buffer, createInputRowWithNullElement(Array(null, 10.0f)))
412+
agg.update(buffer, createInputRow(Array(3.0f, 4.0f)))
413+
val result = evalAsFloatArray(agg, buffer)
414+
// Vector with null element is skipped, so avg([1, 2], [3, 4]) = [2, 3]
415+
assert(result === Array(2.0f, 3.0f))
416+
}
417+
418+
test("VectorSum - only vectors with null elements returns null") {
419+
val (agg, buffer, _) = createVectorSum()
420+
agg.update(buffer, createInputRowWithNullElement(Array(1.0f, null)))
421+
agg.update(buffer, createInputRowWithNullElement(Array(null, 2.0f)))
422+
assert(agg.eval(buffer) === null)
423+
}
424+
425+
test("VectorAvg - only vectors with null elements returns null") {
426+
val (agg, buffer, _) = createVectorAvg()
427+
agg.update(buffer, createInputRowWithNullElement(Array(1.0f, null)))
428+
agg.update(buffer, createInputRowWithNullElement(Array(null, 2.0f)))
429+
assert(agg.eval(buffer) === null)
430+
}
431+
432+
test("VectorSum - mix of null vectors and vectors with null elements") {
433+
val (agg, buffer, _) = createVectorSum()
434+
agg.update(buffer, createNullInputRow())
435+
agg.update(buffer, createInputRowWithNullElement(Array(1.0f, null)))
436+
agg.update(buffer, createInputRow(Array(1.0f, 2.0f)))
437+
agg.update(buffer, createInputRow(Array(3.0f, 4.0f)))
438+
val result = evalAsFloatArray(agg, buffer)
439+
// Only valid vectors are summed: [1, 2] + [3, 4] = [4, 6]
440+
assert(result === Array(4.0f, 6.0f))
441+
}
442+
443+
test("VectorSum - large vectors (16 elements)") {
444+
val (agg, buffer, _) = createVectorSum()
445+
val vec1 = Array(1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f,
446+
9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f)
447+
val vec2 = Array(16.0f, 15.0f, 14.0f, 13.0f, 12.0f, 11.0f, 10.0f, 9.0f,
448+
8.0f, 7.0f, 6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f)
449+
agg.update(buffer, createInputRow(vec1))
450+
agg.update(buffer, createInputRow(vec2))
451+
val result = evalAsFloatArray(agg, buffer)
452+
// Each element should sum to 17
453+
assert(result === Array.fill(16)(17.0f))
454+
}
455+
456+
test("VectorAvg - large vectors (16 elements)") {
457+
val (agg, buffer, _) = createVectorAvg()
458+
val vec1 = Array(1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f,
459+
9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f)
460+
val vec2 = Array(16.0f, 15.0f, 14.0f, 13.0f, 12.0f, 11.0f, 10.0f, 9.0f,
461+
8.0f, 7.0f, 6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f)
462+
agg.update(buffer, createInputRow(vec1))
463+
agg.update(buffer, createInputRow(vec2))
464+
val result = evalAsFloatArray(agg, buffer)
465+
// Each element average should be 8.5
466+
assert(result === Array.fill(16)(8.5f))
467+
}
468+
469+
test("VectorSum - large vector with null element is skipped") {
470+
val (agg, buffer, _) = createVectorSum()
471+
val vec1 = Array[java.lang.Float](1.0f, 2.0f, 3.0f, 4.0f, 5.0f, null, 7.0f, 8.0f,
472+
9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f)
473+
val vec2 = Array(16.0f, 15.0f, 14.0f, 13.0f, 12.0f, 11.0f, 10.0f, 9.0f,
474+
8.0f, 7.0f, 6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f)
475+
agg.update(buffer, createInputRowWithNullElement(vec1))
476+
agg.update(buffer, createInputRow(vec2))
477+
val result = evalAsFloatArray(agg, buffer)
478+
// First vector is skipped due to null element, result is just vec2
479+
assert(result === vec2)
480+
}
481+
482+
test("VectorSum - single element vectors") {
483+
val (agg, buffer, _) = createVectorSum()
484+
agg.update(buffer, createInputRow(Array(5.0f)))
485+
agg.update(buffer, createInputRow(Array(3.0f)))
486+
val result = evalAsFloatArray(agg, buffer)
487+
assert(result === Array(8.0f))
488+
}
489+
490+
test("VectorAvg - single element vectors") {
491+
val (agg, buffer, _) = createVectorAvg()
492+
agg.update(buffer, createInputRow(Array(5.0f)))
493+
agg.update(buffer, createInputRow(Array(3.0f)))
494+
val result = evalAsFloatArray(agg, buffer)
495+
assert(result === Array(4.0f))
496+
}
361497
}

0 commit comments

Comments
 (0)