Skip to content

Commit aa76ca0

Browse files
committed
comments
1 parent 20b62a4 commit aa76ca0

File tree

1 file changed

+57
-73
lines changed

1 file changed

+57
-73
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/vectorExpressions.scala

Lines changed: 57 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,6 @@
1717

1818
package org.apache.spark.sql.catalyst.expressions
1919

20-
import java.nio.{ByteBuffer, ByteOrder}
21-
2220
import org.apache.spark.sql.catalyst.InternalRow
2321
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
2422
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
@@ -37,6 +35,7 @@ import org.apache.spark.sql.types.{
3735
StringType,
3836
StructType
3937
}
38+
import org.apache.spark.unsafe.Platform
4039

4140
// scalastyle:off line.size.limit
4241
@ExpressionDescription(
@@ -345,10 +344,10 @@ case class VectorNormalize(vector: Expression, degree: Expression)
345344
}
346345

347346
// Base trait for vector aggregate functions (vector_avg, vector_sum).
348-
// Provides a unified aggregate buffer schema: (current: BINARY, count: LONG)
349-
// - current: BINARY representation of the running vector (sum or average)
347+
// Provides a unified aggregate buffer schema: (acc: BINARY, count: LONG)
348+
// - acc: BINARY representation of the running vector (sum or average)
350349
// - count: number of valid vectors seen so far
351-
// - dimension is inferred from current.length / 4 (4 bytes per float)
350+
// - dimension is inferred from acc.length / 4 (4 bytes per float)
352351
// Subclasses only need to implement the element-wise update and merge logic.
353352
trait VectorAggregateBase extends ImperativeAggregate
354353
with UnaryLike[Expression]
@@ -375,17 +374,17 @@ trait VectorAggregateBase extends ImperativeAggregate
375374
}
376375
}
377376

378-
// Aggregate buffer schema: (current: BINARY, count: LONG)
379-
private lazy val currentAttr = AttributeReference(
380-
"current",
377+
// Aggregate buffer schema: (acc: BINARY, count: LONG)
378+
private lazy val accAttr = AttributeReference(
379+
"acc",
381380
BinaryType,
382381
nullable = true
383382
)()
384383
private lazy val countAttr =
385384
AttributeReference("count", LongType, nullable = false)()
386385

387386
override def aggBufferAttributes: Seq[AttributeReference] =
388-
Seq(currentAttr, countAttr)
387+
Seq(accAttr, countAttr)
389388

390389
override def aggBufferSchema: StructType =
391390
DataTypeUtils.fromAttributes(aggBufferAttributes)
@@ -394,33 +393,33 @@ trait VectorAggregateBase extends ImperativeAggregate
394393
aggBufferAttributes.map(_.newInstance())
395394

396395
// Buffer indices
397-
protected val currentIndex = 0
396+
protected val accIndex = 0
398397
protected val countIndex = 1
399398

400399
protected lazy val inputContainsNull =
401400
child.dataType.asInstanceOf[ArrayType].containsNull
402401

403402
override def initialize(buffer: InternalRow): Unit = {
404-
buffer.update(mutableAggBufferOffset + currentIndex, null)
403+
buffer.update(mutableAggBufferOffset + accIndex, null)
405404
buffer.setLong(mutableAggBufferOffset + countIndex, 0L)
406405
}
407406

408407
// Infer vector dimension from byte array length (4 bytes per float)
409408
protected def getDim(bytes: Array[Byte]): Int = bytes.length / 4
410409

411410
// Element-wise update for non-first vectors.
412-
// currentBuffer contains the running vector; update it in-place with inputArray.
411+
// accBytes contains the running vector; update it in-place with inputArray.
413412
protected def updateElements(
414-
currentBuffer: ByteBuffer,
413+
accBytes: Array[Byte],
415414
inputArray: ArrayData,
416415
dim: Int,
417416
newCount: Long): Unit
418417

419418
// Element-wise merge of two non-empty buffers.
420-
// currentBuffer contains the left running vector; update it in-place.
419+
// accBytes contains the left running vector; update it in-place.
421420
protected def mergeElements(
422-
currentBuffer: ByteBuffer,
423-
inputBuffer: ByteBuffer,
421+
accBytes: Array[Byte],
422+
inputBytes: Array[Byte],
424423
dim: Int,
425424
currentCount: Long,
426425
inputCount: Long,
@@ -447,95 +446,86 @@ trait VectorAggregateBase extends ImperativeAggregate
447446
}
448447
}
449448

450-
val currentOffset = mutableAggBufferOffset + currentIndex
449+
val accOffset = mutableAggBufferOffset + accIndex
451450
val countOffset = mutableAggBufferOffset + countIndex
452451

453452
val currentCount = buffer.getLong(countOffset)
454453

455454
if (currentCount == 0L) {
456455
// First valid vector - just copy it
457-
val byteBuffer =
458-
ByteBuffer.allocate(inputLen * 4).order(ByteOrder.LITTLE_ENDIAN)
456+
val bytes = new Array[Byte](inputLen * 4)
459457
var i = 0
460458
while (i < inputLen) {
461-
byteBuffer.putFloat(inputArray.getFloat(i))
459+
Platform.putFloat(bytes, Platform.BYTE_ARRAY_OFFSET + i.toLong * 4, inputArray.getFloat(i))
462460
i += 1
463461
}
464-
buffer.update(currentOffset, byteBuffer.array())
462+
buffer.update(accOffset, bytes)
465463
buffer.setLong(countOffset, 1L)
466464
} else {
467-
val currentBytes = buffer.getBinary(currentOffset)
468-
val currentDim = getDim(currentBytes)
465+
val accBytes = buffer.getBinary(accOffset)
466+
val accDim = getDim(accBytes)
469467

470468
// Empty array case - if current is empty and input is empty, keep empty
471-
if (currentDim == 0 && inputLen == 0) {
469+
if (accDim == 0 && inputLen == 0) {
472470
buffer.setLong(countOffset, currentCount + 1L)
473471
return
474472
}
475473

476474
// Dimension mismatch check
477-
if (currentDim != inputLen) {
475+
if (accDim != inputLen) {
478476
throw QueryExecutionErrors.vectorDimensionMismatchError(
479477
prettyName,
480-
currentDim,
478+
accDim,
481479
inputLen
482480
)
483481
}
484482

485483
val newCount = currentCount + 1L
486-
// reuse the buffer without reallocation
487-
val currentBuffer =
488-
ByteBuffer.wrap(currentBytes).order(ByteOrder.LITTLE_ENDIAN)
489-
updateElements(currentBuffer, inputArray, currentDim, newCount)
484+
updateElements(accBytes, inputArray, accDim, newCount)
490485
buffer.setLong(countOffset, newCount)
491486
}
492487
}
493488

494489
override def merge(buffer: InternalRow, inputBuffer: InternalRow): Unit = {
495-
val currentOffset = mutableAggBufferOffset + currentIndex
490+
val accOffset = mutableAggBufferOffset + accIndex
496491
val countOffset = mutableAggBufferOffset + countIndex
497-
val inputCurrentOffset = inputAggBufferOffset + currentIndex
492+
val inputAccOffset = inputAggBufferOffset + accIndex
498493
val inputCountOffset = inputAggBufferOffset + countIndex
499494

500495
val inputCount = inputBuffer.getLong(inputCountOffset)
501496
if (inputCount == 0L) {
502497
return
503498
}
504499

505-
val inputCurrentBytes = inputBuffer.getBinary(inputCurrentOffset)
500+
val inputAccBytes = inputBuffer.getBinary(inputAccOffset)
506501
val currentCount = buffer.getLong(countOffset)
507502

508503
if (currentCount == 0L) {
509504
// Copy input buffer to current buffer
510-
buffer.update(currentOffset, inputCurrentBytes.clone())
505+
buffer.update(accOffset, inputAccBytes.clone())
511506
buffer.setLong(countOffset, inputCount)
512507
} else {
513-
val currentBytes = buffer.getBinary(currentOffset)
514-
val currentDim = getDim(currentBytes)
515-
val inputDim = getDim(inputCurrentBytes)
508+
val accBytes = buffer.getBinary(accOffset)
509+
val accDim = getDim(accBytes)
510+
val inputDim = getDim(inputAccBytes)
516511

517512
// Empty array case
518-
if (currentDim == 0 && inputDim == 0) {
513+
if (accDim == 0 && inputDim == 0) {
519514
buffer.setLong(countOffset, currentCount + inputCount)
520515
return
521516
}
522517

523518
// Dimension mismatch check
524-
if (currentDim != inputDim) {
519+
if (accDim != inputDim) {
525520
throw QueryExecutionErrors.vectorDimensionMismatchError(
526521
prettyName,
527-
currentDim,
522+
accDim,
528523
inputDim
529524
)
530525
}
531526

532527
val newCount = currentCount + inputCount
533-
// reuse the buffer without reallocation
534-
val currentBuf =
535-
ByteBuffer.wrap(currentBytes).order(ByteOrder.LITTLE_ENDIAN)
536-
val inputBuf =
537-
ByteBuffer.wrap(inputCurrentBytes).order(ByteOrder.LITTLE_ENDIAN)
538-
mergeElements(currentBuf, inputBuf, currentDim,
528+
mergeElements(accBytes, inputAccBytes, accDim,
539529
currentCount, inputCount, newCount)
540530
buffer.setLong(countOffset, newCount)
541531
}
@@ -546,14 +536,12 @@ trait VectorAggregateBase extends ImperativeAggregate
546536
if (count == 0L) {
547537
null
548538
} else {
549-
val currentBytes = buffer.getBinary(mutableAggBufferOffset + currentIndex)
550-
val dim = getDim(currentBytes)
551-
val currentBuffer =
552-
ByteBuffer.wrap(currentBytes).order(ByteOrder.LITTLE_ENDIAN)
539+
val accBytes = buffer.getBinary(mutableAggBufferOffset + accIndex)
540+
val dim = getDim(accBytes)
553541
val result = new Array[Float](dim)
554542
var i = 0
555543
while (i < dim) {
556-
result(i) = currentBuffer.getFloat()
544+
result(i) = Platform.getFloat(accBytes, Platform.BYTE_ARRAY_OFFSET + i.toLong * 4)
557545
i += 1
558546
}
559547
ArrayData.toArrayData(result)
@@ -601,26 +589,24 @@ case class VectorAvg(
601589
copy(inputAggBufferOffset = newInputAggBufferOffset)
602590

603591
override protected def updateElements(
604-
currentBuffer: ByteBuffer,
592+
accBytes: Array[Byte],
605593
inputArray: ArrayData,
606594
dim: Int,
607595
newCount: Long): Unit = {
608596
// Update running average: new_avg = old_avg + (new_value - old_avg) / new_count
609597
val invCount = 1.0f / newCount
610598
var i = 0
611-
var idx = 0
612599
while (i < dim) {
613-
val oldAvg = currentBuffer.getFloat(idx)
614-
val newVal = inputArray.getFloat(i)
615-
currentBuffer.putFloat(idx, oldAvg + (newVal - oldAvg) * invCount)
600+
val off = Platform.BYTE_ARRAY_OFFSET + i.toLong * 4
601+
val oldAvg = Platform.getFloat(accBytes, off)
602+
Platform.putFloat(accBytes, off, oldAvg + (inputArray.getFloat(i) - oldAvg) * invCount)
616603
i += 1
617-
idx += 4 // 4 bytes per float
618604
}
619605
}
620606

621607
override protected def mergeElements(
622-
currentBuffer: ByteBuffer,
623-
inputBuffer: ByteBuffer,
608+
accBytes: Array[Byte],
609+
inputBytes: Array[Byte],
624610
dim: Int,
625611
currentCount: Long,
626612
inputCount: Long,
@@ -630,13 +616,12 @@ case class VectorAvg(
630616
val leftWeight = currentCount.toFloat / newCount
631617
val rightWeight = inputCount.toFloat / newCount
632618
var i = 0
633-
var idx = 0
634619
while (i < dim) {
635-
val leftAvg = currentBuffer.getFloat(idx)
636-
val rightAvg = inputBuffer.getFloat(idx)
637-
currentBuffer.putFloat(idx, leftAvg * leftWeight + rightAvg * rightWeight)
620+
val off = Platform.BYTE_ARRAY_OFFSET + i.toLong * 4
621+
val leftAvg = Platform.getFloat(accBytes, off)
622+
val rightAvg = Platform.getFloat(inputBytes, off)
623+
Platform.putFloat(accBytes, off, leftAvg * leftWeight + rightAvg * rightWeight)
638624
i += 1
639-
idx += 4 // 4 bytes per float
640625
}
641626
}
642627

@@ -684,34 +669,33 @@ case class VectorSum(
684669
copy(inputAggBufferOffset = newInputAggBufferOffset)
685670

686671
override protected def updateElements(
687-
currentBuffer: ByteBuffer,
672+
accBytes: Array[Byte],
688673
inputArray: ArrayData,
689674
dim: Int,
690675
newCount: Long): Unit = {
691676
// Update sum: new_sum = old_sum + new_value
692677
var i = 0
693-
var idx = 0
694678
while (i < dim) {
695-
currentBuffer.putFloat(idx, currentBuffer.getFloat(idx) + inputArray.getFloat(i))
679+
val off = Platform.BYTE_ARRAY_OFFSET + i.toLong * 4
680+
Platform.putFloat(accBytes, off, Platform.getFloat(accBytes, off) + inputArray.getFloat(i))
696681
i += 1
697-
idx += 4 // 4 bytes per float
698682
}
699683
}
700684

701685
override protected def mergeElements(
702-
currentBuffer: ByteBuffer,
703-
inputBuffer: ByteBuffer,
686+
accBytes: Array[Byte],
687+
inputBytes: Array[Byte],
704688
dim: Int,
705689
currentCount: Long,
706690
inputCount: Long,
707691
newCount: Long): Unit = {
708692
// Merge sums: combined_sum = left_sum + right_sum
709693
var i = 0
710-
var idx = 0
711694
while (i < dim) {
712-
currentBuffer.putFloat(idx, currentBuffer.getFloat(idx) + inputBuffer.getFloat(idx))
695+
val off = Platform.BYTE_ARRAY_OFFSET + i.toLong * 4
696+
Platform.putFloat(accBytes, off,
697+
Platform.getFloat(accBytes, off) + Platform.getFloat(inputBytes, off))
713698
i += 1
714-
idx += 4 // 4 bytes per float
715699
}
716700
}
717701

0 commit comments

Comments
 (0)