@@ -265,71 +265,66 @@ CentroidSupplier createCentroidSupplier(IndexInput centroidsInput, int numCentro
265265 return new OffHeapCentroidSupplier (centroidsInput , numCentroids , fieldInfo );
266266 }
267267
268- static void writeCentroids (float [][] centroids , FieldInfo fieldInfo , float [] globalCentroid , IndexOutput centroidOutput )
269- throws IOException {
268+ @ Override
269+ void writeCentroids (
270+ FieldInfo fieldInfo ,
271+ CentroidSupplier centroidSupplier ,
272+ float [] globalCentroid ,
273+ long [] offsets ,
274+ IndexOutput centroidOutput
275+ ) throws IOException {
276+
270277 final OptimizedScalarQuantizer osq = new OptimizedScalarQuantizer (fieldInfo .getVectorSimilarityFunction ());
271278 int [] quantizedScratch = new int [fieldInfo .getVectorDimension ()];
272279 float [] centroidScratch = new float [fieldInfo .getVectorDimension ()];
273280 final byte [] quantized = new byte [fieldInfo .getVectorDimension ()];
274281 // TODO do we want to store these distances as well for future use?
275282 // TODO: sort centroids by global centroid (was doing so previously here)
276283 // TODO: sorting tanks recall possibly because centroids ordinals no longer are aligned
277- for (float [] centroid : centroids ) {
284+ for (int i = 0 ; i < centroidSupplier .size (); i ++) {
285+ float [] centroid = centroidSupplier .centroid (i );
278286 System .arraycopy (centroid , 0 , centroidScratch , 0 , centroid .length );
279287 OptimizedScalarQuantizer .QuantizationResult result = osq .scalarQuantize (
280288 centroidScratch ,
281289 quantizedScratch ,
282290 (byte ) 4 ,
283291 globalCentroid
284292 );
285- for (int i = 0 ; i < quantizedScratch .length ; i ++) {
286- quantized [i ] = (byte ) quantizedScratch [i ];
293+ for (int j = 0 ; j < quantizedScratch .length ; j ++) {
294+ quantized [j ] = (byte ) quantizedScratch [j ];
287295 }
288296 writeQuantizedValue (centroidOutput , quantized , result );
289297 }
290298 final ByteBuffer buffer = ByteBuffer .allocate (fieldInfo .getVectorDimension () * Float .BYTES ).order (ByteOrder .LITTLE_ENDIAN );
291- for (float [] centroid : centroids ) {
299+ for (int i = 0 ; i < centroidSupplier .size (); i ++) {
300+ float [] centroid = centroidSupplier .centroid (i );
292301 buffer .asFloatBuffer ().put (centroid );
302+ // write the centroids
293303 centroidOutput .writeBytes (buffer .array (), buffer .array ().length );
304+ // write the offset of this posting list
305+ centroidOutput .writeLong (offsets [i ]);
294306 }
295307 }
296308
297- @ Override
298- CentroidAssignments calculateAndWriteCentroids (
299- FieldInfo fieldInfo ,
300- FloatVectorValues floatVectorValues ,
301- IndexOutput centroidOutput ,
302- MergeState mergeState ,
303- float [] globalCentroid
304- ) throws IOException {
305- // TODO: take advantage of prior generated clusters from mergeState in the future
306- return calculateAndWriteCentroids (fieldInfo , floatVectorValues , centroidOutput , globalCentroid );
307- }
308-
309309 /**
310- * Calculate the centroids for the given field and write them to the given centroid output .
310+ * Calculate the centroids for the given field.
311311 * We use the {@link HierarchicalKMeans} algorithm to partition the space of all vectors across merging segments
312312 *
313313 * @param fieldInfo merging field info
314314 * @param floatVectorValues the float vector values to merge
315- * @param centroidOutput the centroid output
316315 * @param globalCentroid the global centroid, calculated by this method and used to quantize the centroids
317316 * @return the vector assignments, soar assignments, and if asked the centroids themselves that were computed
318317 * @throws IOException if an I/O error occurs
319318 */
320319 @ Override
321- CentroidAssignments calculateAndWriteCentroids (
322- FieldInfo fieldInfo ,
323- FloatVectorValues floatVectorValues ,
324- IndexOutput centroidOutput ,
325- float [] globalCentroid
326- ) throws IOException {
320+ CentroidAssignments calculateCentroids (FieldInfo fieldInfo , FloatVectorValues floatVectorValues , float [] globalCentroid )
321+ throws IOException {
327322
328323 long nanoTime = System .nanoTime ();
329324
330325 // TODO: consider hinting / bootstrapping hierarchical kmeans with the prior segments centroids
331- KMeansResult kMeansResult = new HierarchicalKMeans ( floatVectorValues . dimension ()). cluster (floatVectorValues , vectorPerCluster );
332- float [][] centroids = kMeansResult .centroids ();
326+ CentroidAssignments centroidAssignments = buildCentroidAssignments (floatVectorValues , vectorPerCluster );
327+ float [][] centroids = centroidAssignments .centroids ();
333328 // TODO: for flush we are doing this over the vectors and here centroids which seems duplicative
334329 // preliminary tests suggest recall is good using only centroids but need to do further evaluation
335330 // TODO: push this logic into vector util?
@@ -342,17 +337,15 @@ CentroidAssignments calculateAndWriteCentroids(
342337 globalCentroid [j ] /= centroids .length ;
343338 }
344339
345- // write centroids
346- writeCentroids (centroids , fieldInfo , globalCentroid , centroidOutput );
347-
348340 if (logger .isDebugEnabled ()) {
349341 logger .debug ("calculate centroids and assign vectors time ms: {}" , (System .nanoTime () - nanoTime ) / 1000000.0 );
350342 logger .debug ("final centroid count: {}" , centroids .length );
351343 }
352- return buildCentroidAssignments ( kMeansResult ) ;
344+ return centroidAssignments ;
353345 }
354346
355- static CentroidAssignments buildCentroidAssignments (KMeansResult kMeansResult ) {
347+ static CentroidAssignments buildCentroidAssignments (FloatVectorValues floatVectorValues , int vectorPerCluster ) throws IOException {
348+ KMeansResult kMeansResult = new HierarchicalKMeans (floatVectorValues .dimension ()).cluster (floatVectorValues , vectorPerCluster );
356349 float [][] centroids = kMeansResult .centroids ();
357350 int [] assignments = kMeansResult .assignments ();
358351 int [] soarAssignments = kMeansResult .soarAssignments ();
@@ -374,15 +367,13 @@ static class OffHeapCentroidSupplier implements CentroidSupplier {
374367 private final int numCentroids ;
375368 private final int dimension ;
376369 private final float [] scratch ;
377- private final long rawCentroidOffset ;
378370 private int currOrd = -1 ;
379371
380372 OffHeapCentroidSupplier (IndexInput centroidsInput , int numCentroids , FieldInfo info ) {
381373 this .centroidsInput = centroidsInput ;
382374 this .numCentroids = numCentroids ;
383375 this .dimension = info .getVectorDimension ();
384376 this .scratch = new float [dimension ];
385- this .rawCentroidOffset = (dimension + 3 * Float .BYTES + Short .BYTES ) * numCentroids ;
386377 }
387378
388379 @ Override
@@ -395,7 +386,7 @@ public float[] centroid(int centroidOrdinal) throws IOException {
395386 if (centroidOrdinal == currOrd ) {
396387 return scratch ;
397388 }
398- centroidsInput .seek (rawCentroidOffset + (long ) centroidOrdinal * dimension * Float .BYTES );
389+ centroidsInput .seek ((long ) centroidOrdinal * dimension * Float .BYTES );
399390 centroidsInput .readFloats (scratch , 0 , dimension );
400391 this .currOrd = centroidOrdinal ;
401392 return scratch ;
0 commit comments