Skip to content

Commit 47a84c5

Browse files
committed
optimizations
1 parent c94ce2c commit 47a84c5

File tree

1 file changed

+134
-83
lines changed

1 file changed

+134
-83
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/vectorExpressions.scala

Lines changed: 134 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -360,6 +360,10 @@ case class VectorNormalize(vector: Expression, degree: Expression)
360360
group = "vector_funcs"
361361
)
362362
// scalastyle:on line.size.limit
363+
// Note: This implementation uses single-precision floating-point arithmetic (Float).
364+
// Precision loss is expected for very large aggregates due to:
365+
// 1. Accumulated rounding errors in incremental average updates
366+
// 2. Loss of significance when dividing by large counts
363367
case class VectorAvg(
364368
child: Expression,
365369
mutableAggBufferOffset: Int = 0,
@@ -441,6 +445,8 @@ case class VectorAvg(
441445
buffer.setLong(mutableAggBufferOffset + countIndex, 0L)
442446
}
443447

448+
private lazy val inputContainsNull = child.dataType.asInstanceOf[ArrayType].containsNull
449+
444450
override def update(buffer: InternalRow, input: InternalRow): Unit = {
445451
val inputValue = child.eval(input)
446452
if (inputValue == null) {
@@ -451,30 +457,41 @@ case class VectorAvg(
451457
val inputLen = inputArray.numElements()
452458

453459
// Check for NULL elements in input vector - skip if any NULL element found
454-
for (i <- 0 until inputLen) {
455-
if (inputArray.isNullAt(i)) {
456-
return
460+
// Only check if the array type can contain nulls
461+
if (inputContainsNull) {
462+
var i = 0
463+
while (i < inputLen) {
464+
if (inputArray.isNullAt(i)) {
465+
return
466+
}
467+
i += 1
457468
}
458469
}
459470

460-
val currentCount = buffer.getLong(mutableAggBufferOffset + countIndex)
471+
val avgOffset = mutableAggBufferOffset + avgIndex
472+
val dimOffset = mutableAggBufferOffset + dimIndex
473+
val countOffset = mutableAggBufferOffset + countIndex
474+
475+
val currentCount = buffer.getLong(countOffset)
461476

462477
if (currentCount == 0L) {
463478
// First valid vector - just copy it as the initial average
464479
val byteBuffer =
465480
ByteBuffer.allocate(inputLen * 4).order(ByteOrder.LITTLE_ENDIAN)
466-
for (i <- 0 until inputLen) {
481+
var i = 0
482+
while (i < inputLen) {
467483
byteBuffer.putFloat(inputArray.getFloat(i))
484+
i += 1
468485
}
469-
buffer.update(mutableAggBufferOffset + avgIndex, byteBuffer.array())
470-
buffer.setInt(mutableAggBufferOffset + dimIndex, inputLen)
471-
buffer.setLong(mutableAggBufferOffset + countIndex, 1L)
486+
buffer.update(avgOffset, byteBuffer.array())
487+
buffer.setInt(dimOffset, inputLen)
488+
buffer.setLong(countOffset, 1L)
472489
} else {
473-
val currentDim = buffer.getInt(mutableAggBufferOffset + dimIndex)
490+
val currentDim = buffer.getInt(dimOffset)
474491

475492
// Empty array case - if current is empty and input is empty, keep empty
476493
if (currentDim == 0 && inputLen == 0) {
477-
buffer.setLong(mutableAggBufferOffset + countIndex, currentCount + 1L)
494+
buffer.setLong(countOffset, currentCount + 1L)
478495
return
479496
}
480497

@@ -489,45 +506,52 @@ case class VectorAvg(
489506

490507
// Update running average: new_avg = old_avg + (new_value - old_avg) / (count + 1)
491508
val newCount = currentCount + 1L
492-
val currentAvgBytes = buffer.getBinary(mutableAggBufferOffset + avgIndex)
493-
val currentAvgBuffer =
509+
val invCount = 1.0f / newCount
510+
val currentAvgBytes = buffer.getBinary(avgOffset)
511+
// reuse the buffer without reallocation
512+
val avgBuffer =
494513
ByteBuffer.wrap(currentAvgBytes).order(ByteOrder.LITTLE_ENDIAN)
495-
val newAvgBuffer =
496-
ByteBuffer.allocate(currentDim * 4).order(ByteOrder.LITTLE_ENDIAN)
497-
for (i <- 0 until currentDim) {
498-
val oldAvg = currentAvgBuffer.getFloat()
514+
var i = 0
515+
var idx = 0
516+
while (i < currentDim) {
517+
val oldAvg = avgBuffer.getFloat(idx)
499518
val newVal = inputArray.getFloat(i)
500-
newAvgBuffer.putFloat(oldAvg + ((newVal - oldAvg) / newCount.toFloat))
519+
avgBuffer.putFloat(idx, oldAvg + (newVal - oldAvg) * invCount)
520+
i += 1
521+
idx += 4 // 4 bytes per float
501522
}
502-
buffer.update(mutableAggBufferOffset + avgIndex, newAvgBuffer.array())
503-
buffer.setLong(mutableAggBufferOffset + countIndex, newCount)
523+
buffer.setLong(countOffset, newCount)
504524
}
505525
}
506526

507527
override def merge(buffer: InternalRow, inputBuffer: InternalRow): Unit = {
508-
val inputCount = inputBuffer.getLong(inputAggBufferOffset + countIndex)
528+
val avgOffset = mutableAggBufferOffset + avgIndex
529+
val dimOffset = mutableAggBufferOffset + dimIndex
530+
val countOffset = mutableAggBufferOffset + countIndex
531+
val inputAvgOffset = inputAggBufferOffset + avgIndex
532+
val inputDimOffset = inputAggBufferOffset + dimIndex
533+
val inputCountOffset = inputAggBufferOffset + countIndex
534+
535+
val inputCount = inputBuffer.getLong(inputCountOffset)
509536
if (inputCount == 0L) {
510537
return
511538
}
512539

513-
val inputAvgBytes = inputBuffer.getBinary(inputAggBufferOffset + avgIndex)
514-
val inputDim = inputBuffer.getInt(inputAggBufferOffset + dimIndex)
515-
val currentCount = buffer.getLong(mutableAggBufferOffset + countIndex)
540+
val inputAvgBytes = inputBuffer.getBinary(inputAvgOffset)
541+
val inputDim = inputBuffer.getInt(inputDimOffset)
542+
val currentCount = buffer.getLong(countOffset)
516543

517544
if (currentCount == 0L) {
518545
// Copy input buffer to current buffer
519-
buffer.update(mutableAggBufferOffset + avgIndex, inputAvgBytes.clone())
520-
buffer.setInt(mutableAggBufferOffset + dimIndex, inputDim)
521-
buffer.setLong(mutableAggBufferOffset + countIndex, inputCount)
546+
buffer.update(avgOffset, inputAvgBytes.clone())
547+
buffer.setInt(dimOffset, inputDim)
548+
buffer.setLong(countOffset, inputCount)
522549
} else {
523-
val currentDim = buffer.getInt(mutableAggBufferOffset + dimIndex)
550+
val currentDim = buffer.getInt(dimOffset)
524551

525552
// Empty array case
526553
if (currentDim == 0 && inputDim == 0) {
527-
buffer.setLong(
528-
mutableAggBufferOffset + countIndex,
529-
currentCount + inputCount
530-
)
554+
buffer.setLong(countOffset, currentCount + inputCount)
531555
return
532556
}
533557

@@ -544,38 +568,41 @@ case class VectorAvg(
544568
// combined_avg = (left_avg * left_count) / (left_count + right_count) +
545569
// (right_avg * right_count) / (left_count + right_count)
546570
val newCount = currentCount + inputCount
547-
val currentAvgBytes = buffer.getBinary(mutableAggBufferOffset + avgIndex)
548-
val currentAvgBuffer =
571+
val leftWeight = currentCount.toFloat / newCount
572+
val rightWeight = inputCount.toFloat / newCount
573+
val currentAvgBytes = buffer.getBinary(avgOffset)
574+
// reuse the buffer without reallocation
575+
val avgBuffer =
549576
ByteBuffer.wrap(currentAvgBytes).order(ByteOrder.LITTLE_ENDIAN)
550577
val inputAvgBuffer =
551578
ByteBuffer.wrap(inputAvgBytes).order(ByteOrder.LITTLE_ENDIAN)
552-
val newAvgBuffer =
553-
ByteBuffer.allocate(currentDim * 4).order(ByteOrder.LITTLE_ENDIAN)
554-
for (_ <- 0 until currentDim) {
555-
// getFloat() will auto-increment the buffer's current position by 4
556-
val leftAvg = currentAvgBuffer.getFloat()
557-
val rightAvg = inputAvgBuffer.getFloat()
558-
newAvgBuffer.putFloat(
559-
(leftAvg * currentCount) / newCount.toFloat +
560-
(rightAvg * inputCount) / newCount.toFloat
561-
)
579+
var i = 0
580+
var idx = 0
581+
while (i < currentDim) {
582+
val leftAvg = avgBuffer.getFloat(idx)
583+
val rightAvg = inputAvgBuffer.getFloat(idx)
584+
avgBuffer.putFloat(idx, leftAvg * leftWeight + rightAvg * rightWeight)
585+
i += 1
586+
idx += 4 // 4 bytes per float
562587
}
563-
buffer.update(mutableAggBufferOffset + avgIndex, newAvgBuffer.array())
564-
buffer.setLong(mutableAggBufferOffset + countIndex, newCount)
588+
buffer.setLong(countOffset, newCount)
565589
}
566590
}
567591

568592
override def eval(buffer: InternalRow): Any = {
569-
val count = buffer.getLong(mutableAggBufferOffset + countIndex)
593+
val countOffset = mutableAggBufferOffset + countIndex
594+
val count = buffer.getLong(countOffset)
570595
if (count == 0L) {
571596
null
572597
} else {
573598
val dim = buffer.getInt(mutableAggBufferOffset + dimIndex)
574599
val avgBytes = buffer.getBinary(mutableAggBufferOffset + avgIndex)
575600
val avgBuffer = ByteBuffer.wrap(avgBytes).order(ByteOrder.LITTLE_ENDIAN)
576601
val result = new Array[Float](dim)
577-
for (i <- 0 until dim) {
602+
var i = 0
603+
while (i < dim) {
578604
result(i) = avgBuffer.getFloat()
605+
i += 1
579606
}
580607
ArrayData.toArrayData(result)
581608
}
@@ -600,6 +627,10 @@ case class VectorAvg(
600627
group = "vector_funcs"
601628
)
602629
// scalastyle:on line.size.limit
630+
// Note: This implementation uses single-precision floating-point arithmetic (Float).
631+
// Precision loss is expected for very large aggregates due to:
632+
// 1. Accumulated rounding errors when summing many values
633+
// 2. Loss of significance when adding small values to large accumulated sums
603634
case class VectorSum(
604635
child: Expression,
605636
mutableAggBufferOffset: Int = 0,
@@ -671,6 +702,8 @@ case class VectorSum(
671702
): ImperativeAggregate =
672703
copy(inputAggBufferOffset = newInputAggBufferOffset)
673704

705+
private lazy val inputContainsNull = child.dataType.asInstanceOf[ArrayType].containsNull
706+
674707
override def initialize(buffer: InternalRow): Unit = {
675708
buffer.update(mutableAggBufferOffset + sumIndex, null)
676709
buffer.update(mutableAggBufferOffset + dimIndex, null)
@@ -686,23 +719,33 @@ case class VectorSum(
686719
val inputLen = inputArray.numElements()
687720

688721
// Check for NULL elements in input vector - skip if any NULL element found
689-
for (i <- 0 until inputLen) {
690-
if (inputArray.isNullAt(i)) {
691-
return
722+
// Only check if the array type can contain nulls
723+
if (inputContainsNull) {
724+
var i = 0
725+
while (i < inputLen) {
726+
if (inputArray.isNullAt(i)) {
727+
return
728+
}
729+
i += 1
692730
}
693731
}
694732

695-
if (buffer.isNullAt(mutableAggBufferOffset + sumIndex)) {
733+
val sumOffset = mutableAggBufferOffset + sumIndex
734+
val dimOffset = mutableAggBufferOffset + dimIndex
735+
736+
if (buffer.isNullAt(sumOffset)) {
696737
// First valid vector - just copy it as the initial sum
697738
val byteBuffer =
698739
ByteBuffer.allocate(inputLen * 4).order(ByteOrder.LITTLE_ENDIAN)
699-
for (i <- 0 until inputLen) {
740+
var i = 0
741+
while (i < inputLen) {
700742
byteBuffer.putFloat(inputArray.getFloat(i))
743+
i += 1
701744
}
702-
buffer.update(mutableAggBufferOffset + sumIndex, byteBuffer.array())
703-
buffer.setInt(mutableAggBufferOffset + dimIndex, inputLen)
745+
buffer.update(sumOffset, byteBuffer.array())
746+
buffer.setInt(dimOffset, inputLen)
704747
} else {
705-
val currentDim = buffer.getInt(mutableAggBufferOffset + dimIndex)
748+
val currentDim = buffer.getInt(dimOffset)
706749

707750
// Empty array case - if current is empty and input is empty, keep empty
708751
if (currentDim == 0 && inputLen == 0) {
@@ -719,34 +762,39 @@ case class VectorSum(
719762
}
720763

721764
// Update sum: new_sum = old_sum + new_value
722-
val currentSumBytes = buffer.getBinary(mutableAggBufferOffset + sumIndex)
723-
val currentSumBuffer =
765+
val currentSumBytes = buffer.getBinary(sumOffset)
766+
// reuse the buffer without reallocation
767+
val sumBuffer =
724768
ByteBuffer.wrap(currentSumBytes).order(ByteOrder.LITTLE_ENDIAN)
725-
val newSumBuffer =
726-
ByteBuffer.allocate(currentDim * 4).order(ByteOrder.LITTLE_ENDIAN)
727-
for (i <- 0 until currentDim) {
728-
newSumBuffer.putFloat(
729-
currentSumBuffer.getFloat() + inputArray.getFloat(i)
730-
)
769+
var i = 0
770+
var idx = 0
771+
while (i < currentDim) {
772+
sumBuffer.putFloat(idx, sumBuffer.getFloat(idx) + inputArray.getFloat(i))
773+
i += 1
774+
idx += 4 // 4 bytes per float
731775
}
732-
buffer.update(mutableAggBufferOffset + sumIndex, newSumBuffer.array())
733776
}
734777
}
735778

736779
override def merge(buffer: InternalRow, inputBuffer: InternalRow): Unit = {
737-
if (inputBuffer.isNullAt(inputAggBufferOffset + sumIndex)) {
780+
val sumOffset = mutableAggBufferOffset + sumIndex
781+
val dimOffset = mutableAggBufferOffset + dimIndex
782+
val inputSumOffset = inputAggBufferOffset + sumIndex
783+
val inputDimOffset = inputAggBufferOffset + dimIndex
784+
785+
if (inputBuffer.isNullAt(inputSumOffset)) {
738786
return
739787
}
740788

741-
val inputSumBytes = inputBuffer.getBinary(inputAggBufferOffset + sumIndex)
742-
val inputDim = inputBuffer.getInt(inputAggBufferOffset + dimIndex)
789+
val inputSumBytes = inputBuffer.getBinary(inputSumOffset)
790+
val inputDim = inputBuffer.getInt(inputDimOffset)
743791

744-
if (buffer.isNullAt(mutableAggBufferOffset + sumIndex)) {
792+
if (buffer.isNullAt(sumOffset)) {
745793
// Copy input buffer to current buffer
746-
buffer.update(mutableAggBufferOffset + sumIndex, inputSumBytes.clone())
747-
buffer.setInt(mutableAggBufferOffset + dimIndex, inputDim)
794+
buffer.update(sumOffset, inputSumBytes.clone())
795+
buffer.setInt(dimOffset, inputDim)
748796
} else {
749-
val currentDim = buffer.getInt(mutableAggBufferOffset + dimIndex)
797+
val currentDim = buffer.getInt(dimOffset)
750798

751799
// Empty array case
752800
if (currentDim == 0 && inputDim == 0) {
@@ -763,32 +811,35 @@ case class VectorSum(
763811
}
764812

765813
// Merge sums: combined_sum = left_sum + right_sum
766-
val currentSumBytes = buffer.getBinary(mutableAggBufferOffset + sumIndex)
767-
val currentSumBuffer =
814+
val currentSumBytes = buffer.getBinary(sumOffset)
815+
// reuse the buffer without reallocation
816+
val sumBuffer =
768817
ByteBuffer.wrap(currentSumBytes).order(ByteOrder.LITTLE_ENDIAN)
769818
val inputSumBuffer =
770819
ByteBuffer.wrap(inputSumBytes).order(ByteOrder.LITTLE_ENDIAN)
771-
val newSumBuffer =
772-
ByteBuffer.allocate(currentDim * 4).order(ByteOrder.LITTLE_ENDIAN)
773-
for (_ <- 0 until currentDim) {
774-
newSumBuffer.putFloat(
775-
currentSumBuffer.getFloat() + inputSumBuffer.getFloat()
776-
)
820+
var i = 0
821+
var idx = 0
822+
while (i < currentDim) {
823+
sumBuffer.putFloat(idx, sumBuffer.getFloat(idx) + inputSumBuffer.getFloat(idx))
824+
i += 1
825+
idx += 4 // 4 bytes per float
777826
}
778-
buffer.update(mutableAggBufferOffset + sumIndex, newSumBuffer.array())
779827
}
780828
}
781829

782830
override def eval(buffer: InternalRow): Any = {
783-
if (buffer.isNullAt(mutableAggBufferOffset + sumIndex)) {
831+
val sumOffset = mutableAggBufferOffset + sumIndex
832+
if (buffer.isNullAt(sumOffset)) {
784833
null
785834
} else {
786835
val dim = buffer.getInt(mutableAggBufferOffset + dimIndex)
787-
val sumBytes = buffer.getBinary(mutableAggBufferOffset + sumIndex)
836+
val sumBytes = buffer.getBinary(sumOffset)
788837
val sumBuffer = ByteBuffer.wrap(sumBytes).order(ByteOrder.LITTLE_ENDIAN)
789838
val result = new Array[Float](dim)
790-
for (i <- 0 until dim) {
839+
var i = 0
840+
while (i < dim) {
791841
result(i) = sumBuffer.getFloat()
842+
i += 1
792843
}
793844
ArrayData.toArrayData(result)
794845
}

0 commit comments

Comments
 (0)