1717
1818package org .apache .spark .sql .catalyst .expressions
1919
20- import java .nio .{ByteBuffer , ByteOrder }
21-
2220import org .apache .spark .sql .catalyst .InternalRow
2321import org .apache .spark .sql .catalyst .analysis .TypeCheckResult
2422import org .apache .spark .sql .catalyst .analysis .TypeCheckResult .DataTypeMismatch
@@ -37,6 +35,7 @@ import org.apache.spark.sql.types.{
3735 StringType ,
3836 StructType
3937}
38+ import org .apache .spark .unsafe .Platform
4039
4140// scalastyle:off line.size.limit
4241@ ExpressionDescription (
@@ -345,10 +344,10 @@ case class VectorNormalize(vector: Expression, degree: Expression)
345344}
346345
347346// Base trait for vector aggregate functions (vector_avg, vector_sum).
348- // Provides a unified aggregate buffer schema: (current : BINARY, count: LONG)
349- // - current : BINARY representation of the running vector (sum or average)
347+ // Provides a unified aggregate buffer schema: (acc : BINARY, count: LONG)
348+ // - acc : BINARY representation of the running vector (sum or average)
350349// - count: number of valid vectors seen so far
351- // - dimension is inferred from current .length / 4 (4 bytes per float)
350+ // - dimension is inferred from acc .length / 4 (4 bytes per float)
352351// Subclasses only need to implement the element-wise update and merge logic.
353352trait VectorAggregateBase extends ImperativeAggregate
354353 with UnaryLike [Expression ]
@@ -375,17 +374,17 @@ trait VectorAggregateBase extends ImperativeAggregate
375374 }
376375 }
377376
378- // Aggregate buffer schema: (current : BINARY, count: LONG)
379- private lazy val currentAttr = AttributeReference (
380- " current " ,
377+ // Aggregate buffer schema: (acc : BINARY, count: LONG)
378+ private lazy val accAttr = AttributeReference (
379+ " acc " ,
381380 BinaryType ,
382381 nullable = true
383382 )()
384383 private lazy val countAttr =
385384 AttributeReference (" count" , LongType , nullable = false )()
386385
387386 override def aggBufferAttributes : Seq [AttributeReference ] =
388- Seq (currentAttr , countAttr)
387+ Seq (accAttr , countAttr)
389388
390389 override def aggBufferSchema : StructType =
391390 DataTypeUtils .fromAttributes(aggBufferAttributes)
@@ -394,33 +393,33 @@ trait VectorAggregateBase extends ImperativeAggregate
394393 aggBufferAttributes.map(_.newInstance())
395394
396395 // Buffer indices
397- protected val currentIndex = 0
396+ protected val accIndex = 0
398397 protected val countIndex = 1
399398
400399 protected lazy val inputContainsNull =
401400 child.dataType.asInstanceOf [ArrayType ].containsNull
402401
403402 override def initialize (buffer : InternalRow ): Unit = {
404- buffer.update(mutableAggBufferOffset + currentIndex , null )
403+ buffer.update(mutableAggBufferOffset + accIndex , null )
405404 buffer.setLong(mutableAggBufferOffset + countIndex, 0L )
406405 }
407406
408407 // Infer vector dimension from byte array length (4 bytes per float)
409408 protected def getDim (bytes : Array [Byte ]): Int = bytes.length / 4
410409
411410 // Element-wise update for non-first vectors.
412- // currentBuffer contains the running vector; update it in-place with inputArray.
411+ // accBytes contains the running vector; update it in-place with inputArray.
413412 protected def updateElements (
414- currentBuffer : ByteBuffer ,
413+ accBytes : Array [ Byte ] ,
415414 inputArray : ArrayData ,
416415 dim : Int ,
417416 newCount : Long ): Unit
418417
419418 // Element-wise merge of two non-empty buffers.
420- // currentBuffer contains the left running vector; update it in-place.
419+ // accBytes contains the left running vector; update it in-place.
421420 protected def mergeElements (
422- currentBuffer : ByteBuffer ,
423- inputBuffer : ByteBuffer ,
421+ accBytes : Array [ Byte ] ,
422+ inputBytes : Array [ Byte ] ,
424423 dim : Int ,
425424 currentCount : Long ,
426425 inputCount : Long ,
@@ -447,95 +446,86 @@ trait VectorAggregateBase extends ImperativeAggregate
447446 }
448447 }
449448
450- val currentOffset = mutableAggBufferOffset + currentIndex
449+ val accOffset = mutableAggBufferOffset + accIndex
451450 val countOffset = mutableAggBufferOffset + countIndex
452451
453452 val currentCount = buffer.getLong(countOffset)
454453
455454 if (currentCount == 0L ) {
456455 // First valid vector - just copy it
457- val byteBuffer =
458- ByteBuffer .allocate(inputLen * 4 ).order(ByteOrder .LITTLE_ENDIAN )
456+ val bytes = new Array [Byte ](inputLen * 4 )
459457 var i = 0
460458 while (i < inputLen) {
461- byteBuffer .putFloat(inputArray.getFloat(i))
459+ Platform .putFloat(bytes, Platform . BYTE_ARRAY_OFFSET + i.toLong * 4 , inputArray.getFloat(i))
462460 i += 1
463461 }
464- buffer.update(currentOffset, byteBuffer.array() )
462+ buffer.update(accOffset, bytes )
465463 buffer.setLong(countOffset, 1L )
466464 } else {
467- val currentBytes = buffer.getBinary(currentOffset )
468- val currentDim = getDim(currentBytes )
465+ val accBytes = buffer.getBinary(accOffset )
466+ val accDim = getDim(accBytes )
469467
470468 // Empty array case - if current is empty and input is empty, keep empty
471- if (currentDim == 0 && inputLen == 0 ) {
469+ if (accDim == 0 && inputLen == 0 ) {
472470 buffer.setLong(countOffset, currentCount + 1L )
473471 return
474472 }
475473
476474 // Dimension mismatch check
477- if (currentDim != inputLen) {
475+ if (accDim != inputLen) {
478476 throw QueryExecutionErrors .vectorDimensionMismatchError(
479477 prettyName,
480- currentDim ,
478+ accDim ,
481479 inputLen
482480 )
483481 }
484482
485483 val newCount = currentCount + 1L
486- // reuse the buffer without reallocation
487- val currentBuffer =
488- ByteBuffer .wrap(currentBytes).order(ByteOrder .LITTLE_ENDIAN )
489- updateElements(currentBuffer, inputArray, currentDim, newCount)
484+ updateElements(accBytes, inputArray, accDim, newCount)
490485 buffer.setLong(countOffset, newCount)
491486 }
492487 }
493488
494489 override def merge (buffer : InternalRow , inputBuffer : InternalRow ): Unit = {
495- val currentOffset = mutableAggBufferOffset + currentIndex
490+ val accOffset = mutableAggBufferOffset + accIndex
496491 val countOffset = mutableAggBufferOffset + countIndex
497- val inputCurrentOffset = inputAggBufferOffset + currentIndex
492+ val inputAccOffset = inputAggBufferOffset + accIndex
498493 val inputCountOffset = inputAggBufferOffset + countIndex
499494
500495 val inputCount = inputBuffer.getLong(inputCountOffset)
501496 if (inputCount == 0L ) {
502497 return
503498 }
504499
505- val inputCurrentBytes = inputBuffer.getBinary(inputCurrentOffset )
500+ val inputAccBytes = inputBuffer.getBinary(inputAccOffset )
506501 val currentCount = buffer.getLong(countOffset)
507502
508503 if (currentCount == 0L ) {
509504 // Copy input buffer to current buffer
510- buffer.update(currentOffset, inputCurrentBytes .clone())
505+ buffer.update(accOffset, inputAccBytes .clone())
511506 buffer.setLong(countOffset, inputCount)
512507 } else {
513- val currentBytes = buffer.getBinary(currentOffset )
514- val currentDim = getDim(currentBytes )
515- val inputDim = getDim(inputCurrentBytes )
508+ val accBytes = buffer.getBinary(accOffset )
509+ val accDim = getDim(accBytes )
510+ val inputDim = getDim(inputAccBytes )
516511
517512 // Empty array case
518- if (currentDim == 0 && inputDim == 0 ) {
513+ if (accDim == 0 && inputDim == 0 ) {
519514 buffer.setLong(countOffset, currentCount + inputCount)
520515 return
521516 }
522517
523518 // Dimension mismatch check
524- if (currentDim != inputDim) {
519+ if (accDim != inputDim) {
525520 throw QueryExecutionErrors .vectorDimensionMismatchError(
526521 prettyName,
527- currentDim ,
522+ accDim ,
528523 inputDim
529524 )
530525 }
531526
532527 val newCount = currentCount + inputCount
533- // reuse the buffer without reallocation
534- val currentBuf =
535- ByteBuffer .wrap(currentBytes).order(ByteOrder .LITTLE_ENDIAN )
536- val inputBuf =
537- ByteBuffer .wrap(inputCurrentBytes).order(ByteOrder .LITTLE_ENDIAN )
538- mergeElements(currentBuf, inputBuf, currentDim,
528+ mergeElements(accBytes, inputAccBytes, accDim,
539529 currentCount, inputCount, newCount)
540530 buffer.setLong(countOffset, newCount)
541531 }
@@ -546,14 +536,12 @@ trait VectorAggregateBase extends ImperativeAggregate
546536 if (count == 0L ) {
547537 null
548538 } else {
549- val currentBytes = buffer.getBinary(mutableAggBufferOffset + currentIndex)
550- val dim = getDim(currentBytes)
551- val currentBuffer =
552- ByteBuffer .wrap(currentBytes).order(ByteOrder .LITTLE_ENDIAN )
539+ val accBytes = buffer.getBinary(mutableAggBufferOffset + accIndex)
540+ val dim = getDim(accBytes)
553541 val result = new Array [Float ](dim)
554542 var i = 0
555543 while (i < dim) {
556- result(i) = currentBuffer .getFloat()
544+ result(i) = Platform .getFloat(accBytes, Platform . BYTE_ARRAY_OFFSET + i.toLong * 4 )
557545 i += 1
558546 }
559547 ArrayData .toArrayData(result)
@@ -601,26 +589,24 @@ case class VectorAvg(
601589 copy(inputAggBufferOffset = newInputAggBufferOffset)
602590
603591 override protected def updateElements (
604- currentBuffer : ByteBuffer ,
592+ accBytes : Array [ Byte ] ,
605593 inputArray : ArrayData ,
606594 dim : Int ,
607595 newCount : Long ): Unit = {
608596 // Update running average: new_avg = old_avg + (new_value - old_avg) / new_count
609597 val invCount = 1.0f / newCount
610598 var i = 0
611- var idx = 0
612599 while (i < dim) {
613- val oldAvg = currentBuffer.getFloat(idx)
614- val newVal = inputArray .getFloat(i )
615- currentBuffer .putFloat(idx, oldAvg + (newVal - oldAvg) * invCount)
600+ val off = Platform . BYTE_ARRAY_OFFSET + i.toLong * 4
601+ val oldAvg = Platform .getFloat(accBytes, off )
602+ Platform .putFloat(accBytes, off, oldAvg + (inputArray.getFloat(i) - oldAvg) * invCount)
616603 i += 1
617- idx += 4 // 4 bytes per float
618604 }
619605 }
620606
621607 override protected def mergeElements (
622- currentBuffer : ByteBuffer ,
623- inputBuffer : ByteBuffer ,
608+ accBytes : Array [ Byte ] ,
609+ inputBytes : Array [ Byte ] ,
624610 dim : Int ,
625611 currentCount : Long ,
626612 inputCount : Long ,
@@ -630,13 +616,12 @@ case class VectorAvg(
630616 val leftWeight = currentCount.toFloat / newCount
631617 val rightWeight = inputCount.toFloat / newCount
632618 var i = 0
633- var idx = 0
634619 while (i < dim) {
635- val leftAvg = currentBuffer.getFloat(idx)
636- val rightAvg = inputBuffer.getFloat(idx)
637- currentBuffer.putFloat(idx, leftAvg * leftWeight + rightAvg * rightWeight)
620+ val off = Platform .BYTE_ARRAY_OFFSET + i.toLong * 4
621+ val leftAvg = Platform .getFloat(accBytes, off)
622+ val rightAvg = Platform .getFloat(inputBytes, off)
623+ Platform .putFloat(accBytes, off, leftAvg * leftWeight + rightAvg * rightWeight)
638624 i += 1
639- idx += 4 // 4 bytes per float
640625 }
641626 }
642627
@@ -684,34 +669,33 @@ case class VectorSum(
684669 copy(inputAggBufferOffset = newInputAggBufferOffset)
685670
686671 override protected def updateElements (
687- currentBuffer : ByteBuffer ,
672+ accBytes : Array [ Byte ] ,
688673 inputArray : ArrayData ,
689674 dim : Int ,
690675 newCount : Long ): Unit = {
691676 // Update sum: new_sum = old_sum + new_value
692677 var i = 0
693- var idx = 0
694678 while (i < dim) {
695- currentBuffer.putFloat(idx, currentBuffer.getFloat(idx) + inputArray.getFloat(i))
679+ val off = Platform .BYTE_ARRAY_OFFSET + i.toLong * 4
680+ Platform .putFloat(accBytes, off, Platform .getFloat(accBytes, off) + inputArray.getFloat(i))
696681 i += 1
697- idx += 4 // 4 bytes per float
698682 }
699683 }
700684
701685 override protected def mergeElements (
702- currentBuffer : ByteBuffer ,
703- inputBuffer : ByteBuffer ,
686+ accBytes : Array [ Byte ] ,
687+ inputBytes : Array [ Byte ] ,
704688 dim : Int ,
705689 currentCount : Long ,
706690 inputCount : Long ,
707691 newCount : Long ): Unit = {
708692 // Merge sums: combined_sum = left_sum + right_sum
709693 var i = 0
710- var idx = 0
711694 while (i < dim) {
712- currentBuffer.putFloat(idx, currentBuffer.getFloat(idx) + inputBuffer.getFloat(idx))
695+ val off = Platform .BYTE_ARRAY_OFFSET + i.toLong * 4
696+ Platform .putFloat(accBytes, off,
697+ Platform .getFloat(accBytes, off) + Platform .getFloat(inputBytes, off))
713698 i += 1
714- idx += 4 // 4 bytes per float
715699 }
716700 }
717701
0 commit comments