@@ -17,13 +17,14 @@ use vortex_dtype::{
1717use vortex_error:: {
1818 VortexError , VortexExpect , VortexResult , vortex_bail, vortex_err, vortex_panic,
1919} ;
20- use vortex_mask:: { AllOr , Mask } ;
20+ use vortex_mask:: { AllOr , Mask , MaskMut } ;
2121use vortex_scalar:: { PValue , Scalar } ;
2222use vortex_utils:: aliases:: hash_map:: HashMap ;
2323
2424use crate :: arrays:: PrimitiveArray ;
2525use crate :: compute:: { cast, filter, is_sorted, take} ;
2626use crate :: search_sorted:: { SearchResult , SearchSorted , SearchSortedSide } ;
27+ use crate :: validity:: Validity ;
2728use crate :: vtable:: ValidityHelper ;
2829use crate :: { Array , ArrayRef , IntoArray , ToCanonical } ;
2930
@@ -797,6 +798,43 @@ impl Patches {
797798 } ) )
798799 }
799800
801+ /// Apply patches to a mutable buffer and validity mask.
802+ ///
803+ /// This method applies the patch values to the given buffer at the positions specified by the
804+ /// patch indices. For non-null patch values, it updates the buffer and marks the position as
805+ /// valid. For null patch values, it marks the position as invalid.
806+ ///
807+ /// # Safety
808+ ///
809+ /// - All patch indices after offset adjustment must be valid indices into the buffer.
810+ /// - The buffer and validity mask must have the same length.
811+ pub unsafe fn apply_to_buffer < P : NativePType > ( & self , buffer : & mut [ P ] , validity : & mut MaskMut ) {
812+ let patch_indices = self . indices . to_primitive ( ) ;
813+ let patch_values = self . values . to_primitive ( ) ;
814+ let patches_validity = patch_values. validity ( ) ;
815+
816+ let patch_values_slice = patch_values. as_slice :: < P > ( ) ;
817+ match_each_unsigned_integer_ptype ! ( patch_indices. ptype( ) , |I | {
818+ let patch_indices_slice = patch_indices. as_slice:: <I >( ) ;
819+
820+ // SAFETY:
821+ // - `Patches` invariant guarantees indices are sorted and within array bounds.
822+ // - `patch_indices` and `patch_values` have equal length (from `Patches` invariant).
823+ // - `buffer` and `validity` have equal length (precondition).
824+ // - All patch indices are valid after offset adjustment (precondition).
825+ unsafe {
826+ apply_patches_to_buffer_inner(
827+ buffer,
828+ validity,
829+ patch_indices_slice,
830+ self . offset,
831+ patch_values_slice,
832+ patches_validity,
833+ ) ;
834+ }
835+ } ) ;
836+ }
837+
800838 pub fn map_values < F > ( self , f : F ) -> VortexResult < Self >
801839 where
802840 F : FnOnce ( ArrayRef ) -> VortexResult < ArrayRef > ,
@@ -821,6 +859,78 @@ impl Patches {
821859 }
822860}
823861
862+ /// Helper function to apply patches to a buffer.
863+ ///
864+ /// # Safety
865+ ///
866+ /// - All indices in `patch_indices` after subtracting `patch_offset` must be valid indices
867+ /// into both `buffer` and `validity`.
868+ /// - `patch_indices` must be sorted in ascending order.
869+ /// - `patch_indices` and `patch_values` must have the same length.
870+ /// - `buffer` and `validity` must have the same length.
871+ unsafe fn apply_patches_to_buffer_inner < P , I > (
872+ buffer : & mut [ P ] ,
873+ validity : & mut MaskMut ,
874+ patch_indices : & [ I ] ,
875+ patch_offset : usize ,
876+ patch_values : & [ P ] ,
877+ patches_validity : & Validity ,
878+ ) where
879+ P : NativePType ,
880+ I : UnsignedPType ,
881+ {
882+ debug_assert ! ( !patch_indices. is_empty( ) ) ;
883+ debug_assert_eq ! ( patch_indices. len( ) , patch_values. len( ) ) ;
884+ debug_assert_eq ! ( buffer. len( ) , validity. len( ) ) ;
885+
886+ match patches_validity {
887+ Validity :: NonNullable | Validity :: AllValid => {
888+ // All patch values are valid, apply them all.
889+ for ( & i, & value) in patch_indices. iter ( ) . zip_eq ( patch_values) {
890+ let index = i. as_ ( ) - patch_offset;
891+
892+ // SAFETY: `index` is valid because caller guarantees all patch indices are within
893+ // bounds after offset adjustment.
894+ unsafe {
895+ validity. set_unchecked ( index) ;
896+ }
897+ buffer[ index] = value;
898+ }
899+ }
900+ Validity :: AllInvalid => {
901+ // All patch values are null, just mark positions as invalid.
902+ for & i in patch_indices {
903+ let index = i. as_ ( ) - patch_offset;
904+
905+ // SAFETY: `index` is valid because caller guarantees all patch indices are within
906+ // bounds after offset adjustment.
907+ unsafe {
908+ validity. unset_unchecked ( index) ;
909+ }
910+ }
911+ }
912+ Validity :: Array ( array) => {
913+ // Some patch values may be null, check each one.
914+ let bool_array = array. to_bool ( ) ;
915+ let mask = bool_array. bit_buffer ( ) ;
916+ for ( patch_idx, ( & i, & value) ) in patch_indices. iter ( ) . zip_eq ( patch_values) . enumerate ( ) {
917+ let index = i. as_ ( ) - patch_offset;
918+
919+ // SAFETY: `index` and `patch_idx` are valid because caller guarantees all patch
920+ // indices are within bounds after offset adjustment.
921+ unsafe {
922+ if mask. value_unchecked ( patch_idx) {
923+ buffer[ index] = value;
924+ validity. set_unchecked ( index) ;
925+ } else {
926+ validity. unset_unchecked ( index) ;
927+ }
928+ }
929+ }
930+ }
931+ }
932+ }
933+
824934fn take_map < I : NativePType + Hash + Eq + TryFrom < usize > , T : NativePType > (
825935 indices : & [ I ] ,
826936 take_indices : PrimitiveArray ,
0 commit comments