Skip to content

Commit 6ad0439

Browse files
author
Nitin Kanukolanu
committed
Unit tests for basic SVS Vamana addition
1 parent b4d8b1f commit 6ad0439

File tree

2 files changed

+507
-0
lines changed

2 files changed

+507
-0
lines changed

tests/unit/test_fields.py

Lines changed: 214 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
GeoField,
1212
HNSWVectorField,
1313
NumericField,
14+
SVSVectorField,
1415
TagField,
1516
TextField,
1617
)
@@ -72,6 +73,24 @@ def create_hnsw_vector_field(**kwargs):
7273
return HNSWVectorField(**defaults)
7374

7475

76+
def create_svs_vector_field(**kwargs):
77+
defaults = {
78+
"name": "example_svsvectorfield",
79+
"attrs": {
80+
"dims": 128,
81+
"algorithm": "SVS-VAMANA",
82+
"datatype": "float32",
83+
"distance_metric": "cosine",
84+
"graph_max_degree": 40,
85+
"construction_window_size": 250,
86+
"search_window_size": 20,
87+
"epsilon": 0.01,
88+
},
89+
}
90+
defaults["attrs"].update(kwargs)
91+
return SVSVectorField(**defaults)
92+
93+
7594
# Tests for field schema creation and validation
7695
@pytest.mark.parametrize(
7796
"schema_func,field_class",
@@ -422,3 +441,198 @@ def test_field_factory_with_new_attributes():
422441
)
423442
assert isinstance(vector_field, FlatVectorField)
424443
assert vector_field.attrs.index_missing == True
444+
445+
446+
# ==================== SVS-VAMANA TESTS ====================
447+
448+
449+
def test_svs_vector_field_creation():
450+
"""Test basic SVS-VAMANA vector field creation."""
451+
svs_field = create_svs_vector_field()
452+
assert svs_field.name == "example_svsvectorfield"
453+
assert svs_field.attrs.algorithm == "SVS-VAMANA"
454+
assert svs_field.attrs.dims == 128
455+
assert svs_field.attrs.datatype.value == "FLOAT32"
456+
assert svs_field.attrs.distance_metric.value == "COSINE"
457+
assert svs_field.attrs.graph_max_degree == 40
458+
assert svs_field.attrs.construction_window_size == 250
459+
assert svs_field.attrs.search_window_size == 20
460+
assert svs_field.attrs.epsilon == 0.01
461+
462+
463+
def test_svs_vector_field_as_redis_field():
464+
"""Test SVS-VAMANA field conversion to Redis field."""
465+
svs_field = create_svs_vector_field()
466+
redis_field = svs_field.as_redis_field()
467+
468+
assert isinstance(redis_field, RedisVectorField)
469+
assert redis_field.name == "example_svsvectorfield"
470+
471+
# Check that SVS-VAMANA specific parameters are in args
472+
assert "GRAPH_MAX_DEGREE" in redis_field.args
473+
assert "CONSTRUCTION_WINDOW_SIZE" in redis_field.args
474+
assert "SEARCH_WINDOW_SIZE" in redis_field.args
475+
assert "EPSILON" in redis_field.args
476+
477+
478+
def test_svs_vector_field_default_params():
479+
"""Test SVS-VAMANA field with default parameters."""
480+
svs_field = SVSVectorField(
481+
name="test_vector",
482+
attrs={
483+
"dims": 768,
484+
"algorithm": "SVS-VAMANA",
485+
"datatype": "float32",
486+
"distance_metric": "cosine",
487+
},
488+
)
489+
490+
# Check defaults are applied
491+
assert svs_field.attrs.graph_max_degree == 40
492+
assert svs_field.attrs.construction_window_size == 250
493+
assert svs_field.attrs.search_window_size == 20
494+
assert svs_field.attrs.epsilon == 0.01
495+
assert svs_field.attrs.compression is None
496+
assert svs_field.attrs.reduce is None
497+
assert svs_field.attrs.training_threshold is None
498+
499+
500+
def test_svs_vector_field_with_custom_graph_params():
501+
"""Test SVS-VAMANA field with custom graph parameters."""
502+
svs_field = create_svs_vector_field(
503+
graph_max_degree=64,
504+
construction_window_size=500,
505+
search_window_size=40,
506+
epsilon=0.02,
507+
)
508+
509+
redis_field = svs_field.as_redis_field()
510+
511+
# Verify custom parameters are set
512+
assert redis_field.args[redis_field.args.index("GRAPH_MAX_DEGREE") + 1] == 64
513+
assert (
514+
redis_field.args[redis_field.args.index("CONSTRUCTION_WINDOW_SIZE") + 1] == 500
515+
)
516+
assert redis_field.args[redis_field.args.index("SEARCH_WINDOW_SIZE") + 1] == 40
517+
assert redis_field.args[redis_field.args.index("EPSILON") + 1] == 0.02
518+
519+
520+
def test_svs_vector_field_with_lvq4_compression():
521+
"""Test SVS-VAMANA field with LVQ4 compression."""
522+
svs_field = create_svs_vector_field(compression="LVQ4")
523+
redis_field = svs_field.as_redis_field()
524+
525+
assert "COMPRESSION" in redis_field.args
526+
assert redis_field.args[redis_field.args.index("COMPRESSION") + 1] == "LVQ4"
527+
528+
529+
def test_svs_vector_field_with_lvq8_compression():
530+
"""Test SVS-VAMANA field with LVQ8 compression."""
531+
svs_field = create_svs_vector_field(compression="LVQ8")
532+
redis_field = svs_field.as_redis_field()
533+
534+
assert "COMPRESSION" in redis_field.args
535+
assert redis_field.args[redis_field.args.index("COMPRESSION") + 1] == "LVQ8"
536+
537+
538+
def test_svs_vector_field_with_leanvec_compression():
539+
"""Test SVS-VAMANA field with LeanVec4x8 compression."""
540+
svs_field = create_svs_vector_field(compression="LeanVec4x8")
541+
redis_field = svs_field.as_redis_field()
542+
543+
assert "COMPRESSION" in redis_field.args
544+
assert redis_field.args[redis_field.args.index("COMPRESSION") + 1] == "LeanVec4x8"
545+
546+
547+
def test_svs_vector_field_with_leanvec_and_reduce():
548+
"""Test SVS-VAMANA field with LeanVec compression and reduce parameter."""
549+
svs_field = create_svs_vector_field(dims=768, compression="LeanVec4x8", reduce=384)
550+
redis_field = svs_field.as_redis_field()
551+
552+
assert "COMPRESSION" in redis_field.args
553+
assert redis_field.args[redis_field.args.index("COMPRESSION") + 1] == "LeanVec4x8"
554+
assert "REDUCE" in redis_field.args
555+
assert redis_field.args[redis_field.args.index("REDUCE") + 1] == 384
556+
557+
558+
def test_svs_vector_field_with_training_threshold():
559+
"""Test SVS-VAMANA field with training_threshold parameter."""
560+
svs_field = create_svs_vector_field(compression="LVQ4", training_threshold=10000)
561+
redis_field = svs_field.as_redis_field()
562+
563+
assert "TRAINING_THRESHOLD" in redis_field.args
564+
assert redis_field.args[redis_field.args.index("TRAINING_THRESHOLD") + 1] == 10000
565+
566+
567+
def test_svs_vector_field_reduce_with_lvq4_raises_error():
568+
"""Test that reduce parameter with LVQ4 compression raises ValueError."""
569+
with pytest.raises(
570+
ValueError, match="reduce parameter is only supported with LeanVec"
571+
):
572+
create_svs_vector_field(dims=768, compression="LVQ4", reduce=384)
573+
574+
575+
def test_svs_vector_field_reduce_with_lvq8_raises_error():
576+
"""Test that reduce parameter with LVQ8 compression raises ValueError."""
577+
with pytest.raises(
578+
ValueError, match="reduce parameter is only supported with LeanVec"
579+
):
580+
create_svs_vector_field(dims=768, compression="LVQ8", reduce=384)
581+
582+
583+
def test_svs_vector_field_reduce_without_compression_raises_error():
584+
"""Test that reduce parameter without compression raises ValueError."""
585+
with pytest.raises(ValueError, match="reduce parameter requires compression"):
586+
create_svs_vector_field(dims=768, reduce=384)
587+
588+
589+
def test_svs_vector_field_reduce_greater_than_dims_raises_error():
590+
"""Test that reduce >= dims raises ValueError."""
591+
with pytest.raises(ValueError, match="reduce.*must be less than dims"):
592+
create_svs_vector_field(dims=768, compression="LeanVec4x8", reduce=768)
593+
594+
595+
def test_svs_vector_field_reduce_equal_to_dims_raises_error():
596+
"""Test that reduce == dims raises ValueError."""
597+
with pytest.raises(ValueError, match="reduce.*must be less than dims"):
598+
create_svs_vector_field(dims=768, compression="LeanVec4x8", reduce=768)
599+
600+
601+
def test_svs_vector_field_invalid_datatype_raises_error():
602+
"""Test that invalid datatype (not float16/float32) raises ValueError."""
603+
with pytest.raises(Exception, match="SVS-VAMANA only supports FLOAT16 and FLOAT32"):
604+
create_svs_vector_field(datatype="float64")
605+
606+
607+
def test_svs_vector_field_float16_datatype():
608+
"""Test SVS-VAMANA field with float16 datatype."""
609+
svs_field = create_svs_vector_field(datatype="float16")
610+
redis_field = svs_field.as_redis_field()
611+
612+
assert "TYPE" in redis_field.args
613+
assert redis_field.args[redis_field.args.index("TYPE") + 1] == "FLOAT16"
614+
615+
616+
def test_svs_vector_field_all_compression_types():
617+
"""Test all valid compression types for SVS-VAMANA."""
618+
compression_types = ["LVQ4", "LVQ4x4", "LVQ4x8", "LVQ8", "LeanVec4x8", "LeanVec8x8"]
619+
620+
for compression in compression_types:
621+
svs_field = create_svs_vector_field(compression=compression)
622+
redis_field = svs_field.as_redis_field()
623+
624+
assert "COMPRESSION" in redis_field.args
625+
assert (
626+
redis_field.args[redis_field.args.index("COMPRESSION") + 1] == compression
627+
)
628+
629+
630+
def test_svs_vector_field_leanvec8x8_with_reduce():
631+
"""Test SVS-VAMANA field with LeanVec8x8 compression and reduce."""
632+
svs_field = create_svs_vector_field(dims=1024, compression="LeanVec8x8", reduce=512)
633+
redis_field = svs_field.as_redis_field()
634+
635+
assert "COMPRESSION" in redis_field.args
636+
assert redis_field.args[redis_field.args.index("COMPRESSION") + 1] == "LeanVec8x8"
637+
assert "REDUCE" in redis_field.args
638+
assert redis_field.args[redis_field.args.index("REDUCE") + 1] == 512

0 commit comments

Comments
 (0)