|
9 | 9 |
|
10 | 10 | package org.elasticsearch.index.mapper.vectors; |
11 | 11 |
|
| 12 | +import org.apache.lucene.index.FloatVectorValues; |
12 | 13 | import org.apache.lucene.index.KnnVectorValues; |
13 | 14 | import org.apache.lucene.index.NumericDocValues; |
| 15 | +import org.apache.lucene.search.VectorScorer; |
14 | 16 | import org.elasticsearch.test.ESTestCase; |
15 | 17 |
|
16 | 18 | import java.io.IOException; |
@@ -114,4 +116,182 @@ public long cost() { |
114 | 116 | }; |
115 | 117 | } |
116 | 118 |
|
| 119 | + public void testOrdToDocWithSparseVectors() throws IOException { |
| 120 | + // This test simulates a real-world scenario where some documents don't have vector fields |
| 121 | + // After force merge, the ord (ordinal) and docId mapping becomes crucial |
| 122 | + |
| 123 | + // Simulate a scenario where we have 6 documents, but only documents 1, 3, 4 have vectors |
| 124 | + // Doc 0: no vector |
| 125 | + // Doc 1: vector [0.6, 0.8, 0.0, 0.0], magnitude 5.0 -> original [3.0, 4.0, 0.0, 0.0] |
| 126 | + // Doc 2: no vector |
| 127 | + // Doc 3: vector [1.0, 0.0, 0.0, 0.0], magnitude 2.0 -> original [2.0, 0.0, 0.0, 0.0] |
| 128 | + // Doc 4: vector [0.0, 0.0, 0.6, 0.8], magnitude 10.0 -> original [0.0, 0.0, 6.0, 8.0] |
| 129 | + // Doc 5: no vector |
| 130 | + |
| 131 | + // After merge, the vector ordinals will be 0, 1, 2 but they correspond to docIds 1, 3, 4 |
| 132 | + int totalDocs = 6; |
| 133 | + int[] docIdsWithVectors = {1, 3, 4}; // Document IDs that have vectors |
| 134 | + int numVectors = docIdsWithVectors.length; |
| 135 | + |
| 136 | + float[][] normalizedVectors = new float[numVectors][]; |
| 137 | + float[] magnitudes = new float[numVectors]; |
| 138 | + |
| 139 | + normalizedVectors[0] = new float[]{0.6f, 0.8f, 0.0f, 0.0f}; // Doc 1 |
| 140 | + magnitudes[0] = 5.0f; |
| 141 | + |
| 142 | + normalizedVectors[1] = new float[]{1.0f, 0.0f, 0.0f, 0.0f}; // Doc 3 |
| 143 | + magnitudes[1] = 2.0f; |
| 144 | + |
| 145 | + normalizedVectors[2] = new float[]{0.0f, 0.0f, 0.6f, 0.8f}; // Doc 4 |
| 146 | + magnitudes[2] = 10.0f; |
| 147 | + |
| 148 | + // Expected original vectors after denormalization |
| 149 | + float[][] expectedVectors = new float[numVectors][]; |
| 150 | + expectedVectors[0] = new float[]{3.0f, 4.0f, 0.0f, 0.0f}; // Doc 1 |
| 151 | + expectedVectors[1] = new float[]{2.0f, 0.0f, 0.0f, 0.0f}; // Doc 3 |
| 152 | + expectedVectors[2] = new float[]{0.0f, 0.0f, 6.0f, 8.0f}; // Doc 4 |
| 153 | + |
| 154 | + // Create a custom FloatVectorValues that simulates post-merge sparse vector scenario |
| 155 | + FloatVectorValues sparseVectorValues = new FloatVectorValues() { |
| 156 | + @Override |
| 157 | + public int dimension() { return 4; } |
| 158 | + |
| 159 | + @Override |
| 160 | + public int size() { return numVectors; } |
| 161 | + |
| 162 | + @Override |
| 163 | + public DocIndexIterator iterator() { |
| 164 | + return new DocIndexIterator() { |
| 165 | + private int index = -1; |
| 166 | + |
| 167 | + @Override |
| 168 | + public int docID() { return index; } |
| 169 | + |
| 170 | + @Override |
| 171 | + public int index() { return index; } |
| 172 | + |
| 173 | + @Override |
| 174 | + public int nextDoc() { return advance(index + 1); } |
| 175 | + |
| 176 | + @Override |
| 177 | + public int advance(int target) { |
| 178 | + if (target >= numVectors) return NO_MORE_DOCS; |
| 179 | + return index = target; |
| 180 | + } |
| 181 | + |
| 182 | + @Override |
| 183 | + public long cost() { return numVectors; } |
| 184 | + }; |
| 185 | + } |
| 186 | + |
| 187 | + @Override |
| 188 | + public FloatVectorValues copy() { throw new UnsupportedOperationException(); } |
| 189 | + |
| 190 | + @Override |
| 191 | + public VectorScorer scorer(float[] floats) { throw new UnsupportedOperationException(); } |
| 192 | + |
| 193 | + // This is the key method - it maps ordinals to actual document IDs |
| 194 | + @Override |
| 195 | + public int ordToDoc(int ord) { |
| 196 | + // ord 0 -> docId 1, ord 1 -> docId 3, ord 2 -> docId 4 |
| 197 | + return docIdsWithVectors[ord]; |
| 198 | + } |
| 199 | + |
| 200 | + @Override |
| 201 | + public float[] vectorValue(int ord) { |
| 202 | + return normalizedVectors[ord]; |
| 203 | + } |
| 204 | + }; |
| 205 | + |
| 206 | + // Create magnitudes that correspond to the actual document IDs |
| 207 | + NumericDocValues sparseMagnitudes = new NumericDocValues() { |
| 208 | + private int docId = -1; |
| 209 | + |
| 210 | + @Override |
| 211 | + public long longValue() { |
| 212 | + // Find which vector index corresponds to this docId |
| 213 | + for (int i = 0; i < docIdsWithVectors.length; i++) { |
| 214 | + if (docIdsWithVectors[i] == docId) { |
| 215 | + return Float.floatToRawIntBits(magnitudes[i]); |
| 216 | + } |
| 217 | + } |
| 218 | + return Float.floatToRawIntBits(1.0f); // Default magnitude |
| 219 | + } |
| 220 | + |
| 221 | + @Override |
| 222 | + public boolean advanceExact(int target) { |
| 223 | + docId = target; |
| 224 | + // Check if this docId has a vector |
| 225 | + for (int vectorDocId : docIdsWithVectors) { |
| 226 | + if (vectorDocId == target) { |
| 227 | + return true; |
| 228 | + } |
| 229 | + } |
| 230 | + return false; |
| 231 | + } |
| 232 | + |
| 233 | + @Override |
| 234 | + public int docID() { return docId; } |
| 235 | + |
| 236 | + @Override |
| 237 | + public int nextDoc() { return advance(docId + 1); } |
| 238 | + |
| 239 | + @Override |
| 240 | + public int advance(int target) { |
| 241 | + for (int vectorDocId : docIdsWithVectors) { |
| 242 | + if (vectorDocId >= target) { |
| 243 | + docId = vectorDocId; |
| 244 | + return docId; |
| 245 | + } |
| 246 | + } |
| 247 | + return NO_MORE_DOCS; |
| 248 | + } |
| 249 | + |
| 250 | + @Override |
| 251 | + public long cost() { return totalDocs; } |
| 252 | + }; |
| 253 | + |
| 254 | + // Test the fixed version (with ordToDoc) |
| 255 | + DenormalizedCosineFloatVectorValues vectorValues = new DenormalizedCosineFloatVectorValues( |
| 256 | + sparseVectorValues, |
| 257 | + sparseMagnitudes |
| 258 | + ); |
| 259 | + |
| 260 | + // Test that ordToDoc method properly maps ordinals to document IDs |
| 261 | + KnnVectorValues.DocIndexIterator iterator = vectorValues.iterator(); |
| 262 | + |
| 263 | + for (int ord = 0; ord < numVectors; ord++) { |
| 264 | + iterator.advance(ord); |
| 265 | + |
| 266 | + // Verify that ordToDoc works correctly |
| 267 | + int expectedDocId = docIdsWithVectors[ord]; |
| 268 | + int actualDocId = vectorValues.ordToDoc(ord); |
| 269 | + assertEquals( |
| 270 | + "ordToDoc should correctly map ord " + ord + " to docId " + expectedDocId, |
| 271 | + expectedDocId, |
| 272 | + actualDocId |
| 273 | + ); |
| 274 | + |
| 275 | + // Get the denormalized vector - this relies on ordToDoc working correctly |
| 276 | + float[] actualVector = vectorValues.vectorValue(iterator.index()); |
| 277 | + float actualMagnitude = vectorValues.magnitude(); |
| 278 | + |
| 279 | + // Verify the denormalized vector is correct |
| 280 | + assertArrayEquals( |
| 281 | + "Vector at ord " + ord + " (docId " + expectedDocId + ") should be correctly denormalized", |
| 282 | + expectedVectors[ord], |
| 283 | + actualVector, |
| 284 | + 1e-6f |
| 285 | + ); |
| 286 | + |
| 287 | + // Verify the magnitude is correct |
| 288 | + assertEquals( |
| 289 | + "Magnitude at ord " + ord + " (docId " + expectedDocId + ") should be correct", |
| 290 | + magnitudes[ord], |
| 291 | + actualMagnitude, |
| 292 | + 1e-6f |
| 293 | + ); |
| 294 | + } |
| 295 | + } |
| 296 | + |
117 | 297 | } |
0 commit comments