|
11 | 11 | GeoField, |
12 | 12 | HNSWVectorField, |
13 | 13 | NumericField, |
| 14 | + SVSVectorField, |
14 | 15 | TagField, |
15 | 16 | TextField, |
16 | 17 | ) |
@@ -72,6 +73,24 @@ def create_hnsw_vector_field(**kwargs): |
72 | 73 | return HNSWVectorField(**defaults) |
73 | 74 |
|
74 | 75 |
|
| 76 | +def create_svs_vector_field(**kwargs): |
| 77 | + defaults = { |
| 78 | + "name": "example_svsvectorfield", |
| 79 | + "attrs": { |
| 80 | + "dims": 128, |
| 81 | + "algorithm": "SVS-VAMANA", |
| 82 | + "datatype": "float32", |
| 83 | + "distance_metric": "cosine", |
| 84 | + "graph_max_degree": 40, |
| 85 | + "construction_window_size": 250, |
| 86 | + "search_window_size": 20, |
| 87 | + "epsilon": 0.01, |
| 88 | + }, |
| 89 | + } |
| 90 | + defaults["attrs"].update(kwargs) |
| 91 | + return SVSVectorField(**defaults) |
| 92 | + |
| 93 | + |
75 | 94 | # Tests for field schema creation and validation |
76 | 95 | @pytest.mark.parametrize( |
77 | 96 | "schema_func,field_class", |
@@ -422,3 +441,198 @@ def test_field_factory_with_new_attributes(): |
422 | 441 | ) |
423 | 442 | assert isinstance(vector_field, FlatVectorField) |
424 | 443 | assert vector_field.attrs.index_missing == True |
| 444 | + |
| 445 | + |
| 446 | +# ==================== SVS-VAMANA TESTS ==================== |
| 447 | + |
| 448 | + |
| 449 | +def test_svs_vector_field_creation(): |
| 450 | + """Test basic SVS-VAMANA vector field creation.""" |
| 451 | + svs_field = create_svs_vector_field() |
| 452 | + assert svs_field.name == "example_svsvectorfield" |
| 453 | + assert svs_field.attrs.algorithm == "SVS-VAMANA" |
| 454 | + assert svs_field.attrs.dims == 128 |
| 455 | + assert svs_field.attrs.datatype.value == "FLOAT32" |
| 456 | + assert svs_field.attrs.distance_metric.value == "COSINE" |
| 457 | + assert svs_field.attrs.graph_max_degree == 40 |
| 458 | + assert svs_field.attrs.construction_window_size == 250 |
| 459 | + assert svs_field.attrs.search_window_size == 20 |
| 460 | + assert svs_field.attrs.epsilon == 0.01 |
| 461 | + |
| 462 | + |
| 463 | +def test_svs_vector_field_as_redis_field(): |
| 464 | + """Test SVS-VAMANA field conversion to Redis field.""" |
| 465 | + svs_field = create_svs_vector_field() |
| 466 | + redis_field = svs_field.as_redis_field() |
| 467 | + |
| 468 | + assert isinstance(redis_field, RedisVectorField) |
| 469 | + assert redis_field.name == "example_svsvectorfield" |
| 470 | + |
| 471 | + # Check that SVS-VAMANA specific parameters are in args |
| 472 | + assert "GRAPH_MAX_DEGREE" in redis_field.args |
| 473 | + assert "CONSTRUCTION_WINDOW_SIZE" in redis_field.args |
| 474 | + assert "SEARCH_WINDOW_SIZE" in redis_field.args |
| 475 | + assert "EPSILON" in redis_field.args |
| 476 | + |
| 477 | + |
| 478 | +def test_svs_vector_field_default_params(): |
| 479 | + """Test SVS-VAMANA field with default parameters.""" |
| 480 | + svs_field = SVSVectorField( |
| 481 | + name="test_vector", |
| 482 | + attrs={ |
| 483 | + "dims": 768, |
| 484 | + "algorithm": "SVS-VAMANA", |
| 485 | + "datatype": "float32", |
| 486 | + "distance_metric": "cosine", |
| 487 | + }, |
| 488 | + ) |
| 489 | + |
| 490 | + # Check defaults are applied |
| 491 | + assert svs_field.attrs.graph_max_degree == 40 |
| 492 | + assert svs_field.attrs.construction_window_size == 250 |
| 493 | + assert svs_field.attrs.search_window_size == 20 |
| 494 | + assert svs_field.attrs.epsilon == 0.01 |
| 495 | + assert svs_field.attrs.compression is None |
| 496 | + assert svs_field.attrs.reduce is None |
| 497 | + assert svs_field.attrs.training_threshold is None |
| 498 | + |
| 499 | + |
| 500 | +def test_svs_vector_field_with_custom_graph_params(): |
| 501 | + """Test SVS-VAMANA field with custom graph parameters.""" |
| 502 | + svs_field = create_svs_vector_field( |
| 503 | + graph_max_degree=64, |
| 504 | + construction_window_size=500, |
| 505 | + search_window_size=40, |
| 506 | + epsilon=0.02, |
| 507 | + ) |
| 508 | + |
| 509 | + redis_field = svs_field.as_redis_field() |
| 510 | + |
| 511 | + # Verify custom parameters are set |
| 512 | + assert redis_field.args[redis_field.args.index("GRAPH_MAX_DEGREE") + 1] == 64 |
| 513 | + assert ( |
| 514 | + redis_field.args[redis_field.args.index("CONSTRUCTION_WINDOW_SIZE") + 1] == 500 |
| 515 | + ) |
| 516 | + assert redis_field.args[redis_field.args.index("SEARCH_WINDOW_SIZE") + 1] == 40 |
| 517 | + assert redis_field.args[redis_field.args.index("EPSILON") + 1] == 0.02 |
| 518 | + |
| 519 | + |
| 520 | +def test_svs_vector_field_with_lvq4_compression(): |
| 521 | + """Test SVS-VAMANA field with LVQ4 compression.""" |
| 522 | + svs_field = create_svs_vector_field(compression="LVQ4") |
| 523 | + redis_field = svs_field.as_redis_field() |
| 524 | + |
| 525 | + assert "COMPRESSION" in redis_field.args |
| 526 | + assert redis_field.args[redis_field.args.index("COMPRESSION") + 1] == "LVQ4" |
| 527 | + |
| 528 | + |
| 529 | +def test_svs_vector_field_with_lvq8_compression(): |
| 530 | + """Test SVS-VAMANA field with LVQ8 compression.""" |
| 531 | + svs_field = create_svs_vector_field(compression="LVQ8") |
| 532 | + redis_field = svs_field.as_redis_field() |
| 533 | + |
| 534 | + assert "COMPRESSION" in redis_field.args |
| 535 | + assert redis_field.args[redis_field.args.index("COMPRESSION") + 1] == "LVQ8" |
| 536 | + |
| 537 | + |
| 538 | +def test_svs_vector_field_with_leanvec_compression(): |
| 539 | + """Test SVS-VAMANA field with LeanVec4x8 compression.""" |
| 540 | + svs_field = create_svs_vector_field(compression="LeanVec4x8") |
| 541 | + redis_field = svs_field.as_redis_field() |
| 542 | + |
| 543 | + assert "COMPRESSION" in redis_field.args |
| 544 | + assert redis_field.args[redis_field.args.index("COMPRESSION") + 1] == "LeanVec4x8" |
| 545 | + |
| 546 | + |
| 547 | +def test_svs_vector_field_with_leanvec_and_reduce(): |
| 548 | + """Test SVS-VAMANA field with LeanVec compression and reduce parameter.""" |
| 549 | + svs_field = create_svs_vector_field(dims=768, compression="LeanVec4x8", reduce=384) |
| 550 | + redis_field = svs_field.as_redis_field() |
| 551 | + |
| 552 | + assert "COMPRESSION" in redis_field.args |
| 553 | + assert redis_field.args[redis_field.args.index("COMPRESSION") + 1] == "LeanVec4x8" |
| 554 | + assert "REDUCE" in redis_field.args |
| 555 | + assert redis_field.args[redis_field.args.index("REDUCE") + 1] == 384 |
| 556 | + |
| 557 | + |
| 558 | +def test_svs_vector_field_with_training_threshold(): |
| 559 | + """Test SVS-VAMANA field with training_threshold parameter.""" |
| 560 | + svs_field = create_svs_vector_field(compression="LVQ4", training_threshold=10000) |
| 561 | + redis_field = svs_field.as_redis_field() |
| 562 | + |
| 563 | + assert "TRAINING_THRESHOLD" in redis_field.args |
| 564 | + assert redis_field.args[redis_field.args.index("TRAINING_THRESHOLD") + 1] == 10000 |
| 565 | + |
| 566 | + |
| 567 | +def test_svs_vector_field_reduce_with_lvq4_raises_error(): |
| 568 | + """Test that reduce parameter with LVQ4 compression raises ValueError.""" |
| 569 | + with pytest.raises( |
| 570 | + ValueError, match="reduce parameter is only supported with LeanVec" |
| 571 | + ): |
| 572 | + create_svs_vector_field(dims=768, compression="LVQ4", reduce=384) |
| 573 | + |
| 574 | + |
| 575 | +def test_svs_vector_field_reduce_with_lvq8_raises_error(): |
| 576 | + """Test that reduce parameter with LVQ8 compression raises ValueError.""" |
| 577 | + with pytest.raises( |
| 578 | + ValueError, match="reduce parameter is only supported with LeanVec" |
| 579 | + ): |
| 580 | + create_svs_vector_field(dims=768, compression="LVQ8", reduce=384) |
| 581 | + |
| 582 | + |
| 583 | +def test_svs_vector_field_reduce_without_compression_raises_error(): |
| 584 | + """Test that reduce parameter without compression raises ValueError.""" |
| 585 | + with pytest.raises(ValueError, match="reduce parameter requires compression"): |
| 586 | + create_svs_vector_field(dims=768, reduce=384) |
| 587 | + |
| 588 | + |
| 589 | +def test_svs_vector_field_reduce_greater_than_dims_raises_error(): |
| 590 | + """Test that reduce >= dims raises ValueError.""" |
| 591 | + with pytest.raises(ValueError, match="reduce.*must be less than dims"): |
| 592 | + create_svs_vector_field(dims=768, compression="LeanVec4x8", reduce=768) |
| 593 | + |
| 594 | + |
| 595 | +def test_svs_vector_field_reduce_equal_to_dims_raises_error(): |
| 596 | + """Test that reduce == dims raises ValueError.""" |
| 597 | + with pytest.raises(ValueError, match="reduce.*must be less than dims"): |
| 598 | + create_svs_vector_field(dims=768, compression="LeanVec4x8", reduce=768) |
| 599 | + |
| 600 | + |
| 601 | +def test_svs_vector_field_invalid_datatype_raises_error(): |
| 602 | + """Test that invalid datatype (not float16/float32) raises ValueError.""" |
| 603 | + with pytest.raises(Exception, match="SVS-VAMANA only supports FLOAT16 and FLOAT32"): |
| 604 | + create_svs_vector_field(datatype="float64") |
| 605 | + |
| 606 | + |
| 607 | +def test_svs_vector_field_float16_datatype(): |
| 608 | + """Test SVS-VAMANA field with float16 datatype.""" |
| 609 | + svs_field = create_svs_vector_field(datatype="float16") |
| 610 | + redis_field = svs_field.as_redis_field() |
| 611 | + |
| 612 | + assert "TYPE" in redis_field.args |
| 613 | + assert redis_field.args[redis_field.args.index("TYPE") + 1] == "FLOAT16" |
| 614 | + |
| 615 | + |
| 616 | +def test_svs_vector_field_all_compression_types(): |
| 617 | + """Test all valid compression types for SVS-VAMANA.""" |
| 618 | + compression_types = ["LVQ4", "LVQ4x4", "LVQ4x8", "LVQ8", "LeanVec4x8", "LeanVec8x8"] |
| 619 | + |
| 620 | + for compression in compression_types: |
| 621 | + svs_field = create_svs_vector_field(compression=compression) |
| 622 | + redis_field = svs_field.as_redis_field() |
| 623 | + |
| 624 | + assert "COMPRESSION" in redis_field.args |
| 625 | + assert ( |
| 626 | + redis_field.args[redis_field.args.index("COMPRESSION") + 1] == compression |
| 627 | + ) |
| 628 | + |
| 629 | + |
| 630 | +def test_svs_vector_field_leanvec8x8_with_reduce(): |
| 631 | + """Test SVS-VAMANA field with LeanVec8x8 compression and reduce.""" |
| 632 | + svs_field = create_svs_vector_field(dims=1024, compression="LeanVec8x8", reduce=512) |
| 633 | + redis_field = svs_field.as_redis_field() |
| 634 | + |
| 635 | + assert "COMPRESSION" in redis_field.args |
| 636 | + assert redis_field.args[redis_field.args.index("COMPRESSION") + 1] == "LeanVec8x8" |
| 637 | + assert "REDUCE" in redis_field.args |
| 638 | + assert redis_field.args[redis_field.args.index("REDUCE") + 1] == 512 |
0 commit comments