@@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate
2020import org .apache .spark .SparkFunSuite
2121import org .apache .spark .sql .catalyst .InternalRow
2222import org .apache .spark .sql .catalyst .expressions .{BoundReference , SpecificInternalRow , VectorAvg , VectorSum }
23- import org .apache .spark .sql .catalyst .util .ArrayData
23+ import org .apache .spark .sql .catalyst .util .{ ArrayData , GenericArrayData }
2424import org .apache .spark .sql .types .{ArrayType , FloatType }
2525
2626class VectorAggSuite extends SparkFunSuite {
@@ -55,6 +55,15 @@ class VectorAggSuite extends SparkFunSuite {
5555 row
5656 }
5757
58+ // Helper to create input row with array containing null elements
59+ def createInputRowWithNullElement (values : Array [java.lang.Float ]): InternalRow = {
60+ val arrayData = new GenericArrayData (values.map {
61+ case null => null
62+ case v => v.floatValue().asInstanceOf [AnyRef ]
63+ })
64+ InternalRow (arrayData)
65+ }
66+
5867 // Helper to extract result as float array
5968 def evalAsFloatArray (agg : VectorSum , buffer : InternalRow ): Array [Float ] = {
6069 val result = agg.eval(buffer)
@@ -358,4 +367,131 @@ class VectorAggSuite extends SparkFunSuite {
358367 // Average of 1 to 100 = 50.5
359368 assertFloatArrayEquals(result, Array (50.5f , 50.5f ), tolerance = 1e-3f )
360369 }
370+
371+ test(" VectorSum - mathematical correctness: element-wise sum" ) {
372+ val (agg, buffer, _) = createVectorSum()
373+ agg.update(buffer, createInputRow(Array (1.0f , 2.0f , 3.0f )))
374+ agg.update(buffer, createInputRow(Array (10.0f , 20.0f , 30.0f )))
375+ val result = evalAsFloatArray(agg, buffer)
376+ // [1, 2, 3] + [10, 20, 30] = [11, 22, 33]
377+ assert(result === Array (11.0f , 22.0f , 33.0f ))
378+ }
379+
380+ test(" VectorAvg - mathematical correctness: element-wise average" ) {
381+ val (agg, buffer, _) = createVectorAvg()
382+ agg.update(buffer, createInputRow(Array (0.0f , 0.0f )))
383+ agg.update(buffer, createInputRow(Array (10.0f , 20.0f )))
384+ val result = evalAsFloatArray(agg, buffer)
385+ // avg([0, 0], [10, 20]) = [5, 10]
386+ assert(result === Array (5.0f , 10.0f ))
387+ }
388+
389+ test(" VectorAvg - mathematical correctness: negative values" ) {
390+ val (agg, buffer, _) = createVectorAvg()
391+ agg.update(buffer, createInputRow(Array (- 5.0f , 10.0f )))
392+ agg.update(buffer, createInputRow(Array (5.0f , - 10.0f )))
393+ val result = evalAsFloatArray(agg, buffer)
394+ // avg([-5, 10], [5, -10]) = [0, 0]
395+ assert(result === Array (0.0f , 0.0f ))
396+ }
397+
398+ test(" VectorSum - vectors with null elements are skipped" ) {
399+ val (agg, buffer, _) = createVectorSum()
400+ agg.update(buffer, createInputRow(Array (1.0f , 2.0f )))
401+ agg.update(buffer, createInputRowWithNullElement(Array (null , 10.0f )))
402+ agg.update(buffer, createInputRow(Array (3.0f , 4.0f )))
403+ val result = evalAsFloatArray(agg, buffer)
404+ // Vector with null element is skipped, so [1, 2] + [3, 4] = [4, 6]
405+ assert(result === Array (4.0f , 6.0f ))
406+ }
407+
408+ test(" VectorAvg - vectors with null elements are skipped" ) {
409+ val (agg, buffer, _) = createVectorAvg()
410+ agg.update(buffer, createInputRow(Array (1.0f , 2.0f )))
411+ agg.update(buffer, createInputRowWithNullElement(Array (null , 10.0f )))
412+ agg.update(buffer, createInputRow(Array (3.0f , 4.0f )))
413+ val result = evalAsFloatArray(agg, buffer)
414+ // Vector with null element is skipped, so avg([1, 2], [3, 4]) = [2, 3]
415+ assert(result === Array (2.0f , 3.0f ))
416+ }
417+
418+ test(" VectorSum - only vectors with null elements returns null" ) {
419+ val (agg, buffer, _) = createVectorSum()
420+ agg.update(buffer, createInputRowWithNullElement(Array (1.0f , null )))
421+ agg.update(buffer, createInputRowWithNullElement(Array (null , 2.0f )))
422+ assert(agg.eval(buffer) === null )
423+ }
424+
425+ test(" VectorAvg - only vectors with null elements returns null" ) {
426+ val (agg, buffer, _) = createVectorAvg()
427+ agg.update(buffer, createInputRowWithNullElement(Array (1.0f , null )))
428+ agg.update(buffer, createInputRowWithNullElement(Array (null , 2.0f )))
429+ assert(agg.eval(buffer) === null )
430+ }
431+
432+ test(" VectorSum - mix of null vectors and vectors with null elements" ) {
433+ val (agg, buffer, _) = createVectorSum()
434+ agg.update(buffer, createNullInputRow())
435+ agg.update(buffer, createInputRowWithNullElement(Array (1.0f , null )))
436+ agg.update(buffer, createInputRow(Array (1.0f , 2.0f )))
437+ agg.update(buffer, createInputRow(Array (3.0f , 4.0f )))
438+ val result = evalAsFloatArray(agg, buffer)
439+ // Only valid vectors are summed: [1, 2] + [3, 4] = [4, 6]
440+ assert(result === Array (4.0f , 6.0f ))
441+ }
442+
443+ test(" VectorSum - large vectors (16 elements)" ) {
444+ val (agg, buffer, _) = createVectorSum()
445+ val vec1 = Array (1.0f , 2.0f , 3.0f , 4.0f , 5.0f , 6.0f , 7.0f , 8.0f ,
446+ 9.0f , 10.0f , 11.0f , 12.0f , 13.0f , 14.0f , 15.0f , 16.0f )
447+ val vec2 = Array (16.0f , 15.0f , 14.0f , 13.0f , 12.0f , 11.0f , 10.0f , 9.0f ,
448+ 8.0f , 7.0f , 6.0f , 5.0f , 4.0f , 3.0f , 2.0f , 1.0f )
449+ agg.update(buffer, createInputRow(vec1))
450+ agg.update(buffer, createInputRow(vec2))
451+ val result = evalAsFloatArray(agg, buffer)
452+ // Each element should sum to 17
453+ assert(result === Array .fill(16 )(17.0f ))
454+ }
455+
456+ test(" VectorAvg - large vectors (16 elements)" ) {
457+ val (agg, buffer, _) = createVectorAvg()
458+ val vec1 = Array (1.0f , 2.0f , 3.0f , 4.0f , 5.0f , 6.0f , 7.0f , 8.0f ,
459+ 9.0f , 10.0f , 11.0f , 12.0f , 13.0f , 14.0f , 15.0f , 16.0f )
460+ val vec2 = Array (16.0f , 15.0f , 14.0f , 13.0f , 12.0f , 11.0f , 10.0f , 9.0f ,
461+ 8.0f , 7.0f , 6.0f , 5.0f , 4.0f , 3.0f , 2.0f , 1.0f )
462+ agg.update(buffer, createInputRow(vec1))
463+ agg.update(buffer, createInputRow(vec2))
464+ val result = evalAsFloatArray(agg, buffer)
465+ // Each element average should be 8.5
466+ assert(result === Array .fill(16 )(8.5f ))
467+ }
468+
469+ test(" VectorSum - large vector with null element is skipped" ) {
470+ val (agg, buffer, _) = createVectorSum()
471+ val vec1 = Array [java.lang.Float ](1.0f , 2.0f , 3.0f , 4.0f , 5.0f , null , 7.0f , 8.0f ,
472+ 9.0f , 10.0f , 11.0f , 12.0f , 13.0f , 14.0f , 15.0f , 16.0f )
473+ val vec2 = Array (16.0f , 15.0f , 14.0f , 13.0f , 12.0f , 11.0f , 10.0f , 9.0f ,
474+ 8.0f , 7.0f , 6.0f , 5.0f , 4.0f , 3.0f , 2.0f , 1.0f )
475+ agg.update(buffer, createInputRowWithNullElement(vec1))
476+ agg.update(buffer, createInputRow(vec2))
477+ val result = evalAsFloatArray(agg, buffer)
478+ // First vector is skipped due to null element, result is just vec2
479+ assert(result === vec2)
480+ }
481+
482+ test(" VectorSum - single element vectors" ) {
483+ val (agg, buffer, _) = createVectorSum()
484+ agg.update(buffer, createInputRow(Array (5.0f )))
485+ agg.update(buffer, createInputRow(Array (3.0f )))
486+ val result = evalAsFloatArray(agg, buffer)
487+ assert(result === Array (8.0f ))
488+ }
489+
490+ test(" VectorAvg - single element vectors" ) {
491+ val (agg, buffer, _) = createVectorAvg()
492+ agg.update(buffer, createInputRow(Array (5.0f )))
493+ agg.update(buffer, createInputRow(Array (3.0f )))
494+ val result = evalAsFloatArray(agg, buffer)
495+ assert(result === Array (4.0f ))
496+ }
361497}
0 commit comments