|
33 | 33 | "organization": None, # OpenAI organization ID |
34 | 34 | "create_collection_if_missing": False, # Whether to create collection if it doesn't exist |
35 | 35 | "skip_if_exists": True, # Skip storing vCons that already exist in Milvus |
| 36 | + "index_type": "IVF_FLAT", # Vector index type: IVF_FLAT, IVF_SQ8, IVF_PQ, HNSW, ANNOY, etc. |
| 37 | + "metric_type": "L2", # Distance metric: L2 (Euclidean), IP (Inner Product), COSINE |
| 38 | + "nlist": 128, # Number of clusters for IVF indexes |
| 39 | + "m": 16, # HNSW parameter: number of edges per node |
| 40 | + "ef_construction": 200, # HNSW parameter: size of the dynamic candidate list during construction |
36 | 41 | } |
37 | 42 |
|
38 | 43 | def ensure_milvus_connection(host: str, port: str) -> bool: |
@@ -266,13 +271,14 @@ def extract_party_id(vcon: dict) -> str: |
266 | 271 | logger.debug("No usable party identifier found") |
267 | 272 | return "unknown_party" |
268 | 273 |
|
269 | | -def create_collection(collection_name: str, embedding_dim: int) -> Union[Collection, None]: |
| 274 | +def create_collection(collection_name: str, embedding_dim: int, opts: dict) -> Union[Collection, None]: |
270 | 275 | """ |
271 | 276 | Create a new Milvus collection for vCons. |
272 | 277 | |
273 | 278 | Args: |
274 | 279 | collection_name: Name for the new collection |
275 | 280 | embedding_dim: Dimension of the embedding vectors |
| 281 | + opts: Configuration options including index parameters |
276 | 282 | |
277 | 283 | Returns: |
278 | 284 | Collection or None: The created collection or None if failed |
@@ -301,15 +307,53 @@ def create_collection(collection_name: str, embedding_dim: int) -> Union[Collect |
301 | 307 | # Create collection |
302 | 308 | collection = Collection(name=collection_name, schema=schema) |
303 | 309 |
|
304 | | - # Create an IVF_FLAT index for fast vector search |
| 310 | + # Prepare index parameters based on the selected index type |
| 311 | + index_type = opts.get("index_type", "IVF_FLAT") |
| 312 | + metric_type = opts.get("metric_type", "L2") |
| 313 | + |
| 314 | + # Configure index parameters based on index type |
| 315 | + if index_type.startswith("IVF"): # IVF_FLAT, IVF_SQ8, IVF_PQ |
| 316 | + params = {"nlist": opts.get("nlist", 128)} |
| 317 | + |
| 318 | + # Additional params for IVF_PQ |
| 319 | + if index_type == "IVF_PQ": |
| 320 | + # For PQ, m is typically set to 8 or 12 |
| 321 | + params["m"] = opts.get("pq_m", 8) |
| 322 | + # nbits is typically 8 |
| 323 | + params["nbits"] = opts.get("pq_nbits", 8) |
| 324 | + |
| 325 | + elif index_type == "HNSW": |
| 326 | + params = { |
| 327 | + "M": opts.get("m", 16), # Number of edges per node |
| 328 | + "efConstruction": opts.get("ef_construction", 200) # Size of the dynamic candidate list during construction |
| 329 | + } |
| 330 | + |
| 331 | + elif index_type == "ANNOY": |
| 332 | + params = { |
| 333 | + "n_trees": opts.get("n_trees", 50) # Number of trees for ANNOY |
| 334 | + } |
| 335 | + |
| 336 | + elif index_type == "FLAT": |
| 337 | + # FLAT index doesn't need additional parameters |
| 338 | + params = {} |
| 339 | + |
| 340 | + else: |
| 341 | + # Default to IVF_FLAT if index type is not recognized |
| 342 | + logger.warning(f"Unrecognized index type {index_type}, defaulting to IVF_FLAT") |
| 343 | + index_type = "IVF_FLAT" |
| 344 | + params = {"nlist": opts.get("nlist", 128)} |
| 345 | + |
| 346 | + # Create the index |
305 | 347 | index_params = { |
306 | | - "metric_type": "L2", |
307 | | - "index_type": "IVF_FLAT", |
308 | | - "params": {"nlist": 128} |
| 348 | + "metric_type": metric_type, |
| 349 | + "index_type": index_type, |
| 350 | + "params": params |
309 | 351 | } |
| 352 | + |
| 353 | + logger.info(f"Creating index of type {index_type} with metric {metric_type}") |
310 | 354 | collection.create_index(field_name="embedding", index_params=index_params) |
311 | 355 |
|
312 | | - logger.info(f"Created collection '{collection_name}' successfully") |
| 356 | + logger.info(f"Created collection '{collection_name}' successfully with {index_type} index") |
313 | 357 | return collection |
314 | 358 | except Exception as e: |
315 | 359 | logger.error(f"Failed to create collection: {e}") |
@@ -368,7 +412,7 @@ def save(vcon_uuid: str, opts=default_options) -> None: |
368 | 412 | if not utility.has_collection(collection_name): |
369 | 413 | if opts["create_collection_if_missing"]: |
370 | 414 | logger.info(f"Collection {collection_name} does not exist, creating...") |
371 | | - collection = create_collection(collection_name, opts["embedding_dim"]) |
| 415 | + collection = create_collection(collection_name, opts["embedding_dim"], opts) |
372 | 416 | if not collection: |
373 | 417 | error_msg = f"Failed to create collection {collection_name}" |
374 | 418 | logger.error(error_msg) |
|
0 commit comments