Skip to content

Commit cfb6e2d

Browse files
update tests, docs, and formatting/linting
1 parent 9dabfbd commit cfb6e2d

File tree

10 files changed

+185
-1760
lines changed

10 files changed

+185
-1760
lines changed

docs/user_guide/01_getting_started.ipynb

Lines changed: 136 additions & 100 deletions
Large diffs are not rendered by default.

docs/user_guide/data_validation.ipynb

Lines changed: 0 additions & 1102 deletions
This file was deleted.

redisvl/index/index.py

Lines changed: 11 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ def from_yaml(cls, schema_path: str, **kwargs):
163163
164164
from redisvl.index import SearchIndex
165165
166-
index = SearchIndex.from_yaml("schemas/schema.yaml")
166+
index = SearchIndex.from_yaml("schemas/schema.yaml", redis_url="redis://localhost:6379")
167167
"""
168168
schema = IndexSchema.from_yaml(schema_path)
169169
return cls(schema=schema, **kwargs)
@@ -191,7 +191,7 @@ def from_dict(cls, schema_dict: Dict[str, Any], **kwargs):
191191
"fields": [
192192
{"name": "doc-id", "type": "tag"}
193193
]
194-
})
194+
}, redis_url="redis://localhost:6379")
195195
196196
"""
197197
schema = IndexSchema.from_dict(schema_dict)
@@ -235,10 +235,14 @@ class SearchIndex(BaseSearchIndex):
235235
from redisvl.index import SearchIndex
236236
237237
# initialize the index object with schema from file
238-
index = SearchIndex.from_yaml("schemas/schema.yaml", redis_url="redis://localhost:6379")
238+
index = SearchIndex.from_yaml(
239+
"schemas/schema.yaml",
240+
redis_url="redis://localhost:6379",
241+
validate_on_load=True
242+
)
239243
240244
# create the index
241-
index.create(overwrite=True)
245+
index.create(overwrite=True, drop=False)
242246
243247
# data is an iterable of dictionaries
244248
index.load(data)
@@ -387,11 +391,6 @@ def connect(self, redis_url: Optional[str] = None, **kwargs):
387391
ValueError: If the Redis URL is not provided nor accessible
388392
through the `REDIS_URL` environment variable.
389393
ModuleNotFoundError: If required Redis modules are not installed.
390-
391-
.. code-block:: python
392-
393-
index.connect(redis_url="redis://localhost:6379")
394-
395394
"""
396395
self.__redis_client = RedisConnectionFactory.get_redis_connection(
397396
redis_url=redis_url, **kwargs
@@ -411,16 +410,6 @@ def set_client(self, redis_client: redis.Redis, **kwargs):
411410
412411
Raises:
413412
TypeError: If the provided client is not valid.
414-
415-
.. code-block:: python
416-
417-
import redis
418-
from redisvl.index import SearchIndex
419-
420-
client = redis.Redis.from_url("redis://localhost:6379")
421-
index = SearchIndex.from_yaml("schemas/schema.yaml")
422-
index.set_client(client)
423-
424413
"""
425414
RedisConnectionFactory.validate_sync_redis(redis_client)
426415
self.__redis_client = redis_client
@@ -799,11 +788,12 @@ class AsyncSearchIndex(BaseSearchIndex):
799788
# initialize the index object with schema from file
800789
index = AsyncSearchIndex.from_yaml(
801790
"schemas/schema.yaml",
802-
redis_url="redis://localhost:6379"
791+
redis_url="redis://localhost:6379",
792+
validate_on_load=True
803793
)
804794
805795
# create the index
806-
await index.create(overwrite=True)
796+
await index.create(overwrite=True, drop=False)
807797
808798
# data is an iterable of dictionaries
809799
await index.load(data)

redisvl/index/storage.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ async def _aget(client: AsyncRedis, key: str) -> Dict[str, Any]:
143143
"""
144144
raise NotImplementedError
145145

146-
def validate(self, obj: Dict[str, Any]) -> Dict[str, Any]:
146+
def _validate(self, obj: Dict[str, Any]) -> Dict[str, Any]:
147147
"""
148148
Validate an object against the schema using Pydantic-based validation.
149149
@@ -161,7 +161,7 @@ def validate(self, obj: Dict[str, Any]) -> Dict[str, Any]:
161161

162162
def _preprocess_and_validate_objects(
163163
self,
164-
objects: List[Any],
164+
objects: Iterable[Any],
165165
id_field: Optional[str] = None,
166166
keys: Optional[Iterable[str]] = None,
167167
preprocess: Optional[Callable] = None,
@@ -201,7 +201,7 @@ def _preprocess_and_validate_objects(
201201

202202
# Schema validation if enabled
203203
if validate:
204-
processed_obj = self.validate(processed_obj)
204+
processed_obj = self._validate(processed_obj)
205205

206206
# Store valid object with its key for writing
207207
prepared_objects.append((key, processed_obj))
@@ -263,7 +263,7 @@ def write(
263263

264264
# Pass 1: Preprocess and validate all objects
265265
prepared_objects = self._preprocess_and_validate_objects(
266-
objects,
266+
list(objects), # Convert Iterable to List
267267
id_field=id_field,
268268
keys=keys,
269269
preprocess=preprocess,
@@ -342,7 +342,7 @@ async def awrite(
342342

343343
# Pass 1: Preprocess and validate all objects
344344
prepared_objects = self._preprocess_and_validate_objects(
345-
objects,
345+
list(objects), # Convert Iterable to List
346346
id_field=id_field,
347347
keys=keys,
348348
preprocess=preprocess,

redisvl/schema/validation.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import json
99
import re
1010
import warnings
11-
from typing import Any, Dict, List, Optional, Type, Union
11+
from typing import Any, Dict, List, Optional, Type, Union, cast
1212

1313
from pydantic import BaseModel, Field, field_validator
1414

@@ -53,7 +53,7 @@ def get_model_for_schema(cls, schema: IndexSchema) -> Type[BaseModel]:
5353
@classmethod
5454
def _map_field_to_pydantic_type(
5555
cls, field: BaseField, storage_type: StorageType
56-
) -> Type:
56+
) -> Type[Any]:
5757
"""
5858
Map Redis field types to appropriate Pydantic types.
5959
@@ -72,14 +72,17 @@ def _map_field_to_pydantic_type(
7272
elif field.type == FieldTypes.TAG:
7373
return str
7474
elif field.type == FieldTypes.NUMERIC:
75-
return Union[int, float]
75+
return Union[int, float] # type: ignore
7676
elif field.type == FieldTypes.GEO:
7777
return str
7878
elif field.type == FieldTypes.VECTOR:
7979
# For JSON storage, vectors are always lists
8080
if storage_type == StorageType.JSON:
8181
# For int data types, vectors must be ints, otherwise floats
82-
if field.attrs.datatype in (VectorDataType.INT8, VectorDataType.UINT8):
82+
if field.attrs.datatype in ( # type: ignore
83+
VectorDataType.INT8,
84+
VectorDataType.UINT8,
85+
):
8386
return List[int]
8487
return List[float]
8588
else:
@@ -103,8 +106,8 @@ def _create_model(cls, schema: IndexSchema) -> Type[BaseModel]:
103106
storage_type = schema.index.storage_type
104107

105108
# Create annotations dictionary for the dynamic model
106-
annotations = {}
107-
class_dict = {}
109+
annotations: Dict[str, Any] = {}
110+
class_dict: Dict[str, Any] = {}
108111

109112
# Build annotations and field metadata
110113
for field_name, field in schema.fields.items():
@@ -154,6 +157,8 @@ def _disallow_bool(cls, value):
154157

155158
# Register validators for VECTOR fields
156159
elif field.type == FieldTypes.VECTOR:
160+
dims = field.attrs.dims # type: ignore
161+
datatype = field.attrs.datatype # type: ignore
157162

158163
def make_vector_validator(
159164
fname: str, dims: int, datatype: VectorDataType
@@ -190,7 +195,7 @@ def _validate_vector(cls, value):
190195
return _validate_vector
191196

192197
class_dict[f"validate_{field_name}"] = make_vector_validator(
193-
field_name, field.attrs.dims, field.attrs.datatype
198+
field_name, dims, datatype
194199
)
195200

196201
# Create class dictionary with annotations and field metadata

tests/integration/test_async_search_index.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from redis import Redis as SyncRedis
55
from redis.asyncio import Redis
66

7-
from redisvl.exceptions import RedisSearchError
7+
from redisvl.exceptions import RedisSearchError, RedisVLError
88
from redisvl.index import AsyncSearchIndex
99
from redisvl.query import VectorQuery
1010
from redisvl.redis.utils import convert_bytes
@@ -267,7 +267,7 @@ async def test_search_index_load_preprocess(async_index):
267267
await async_index.create(overwrite=True, drop=True)
268268
data = [{"id": "1", "test": "foo"}]
269269

270-
async def preprocess(record):
270+
def preprocess(record):
271271
record["test"] = "bar"
272272
return record
273273

@@ -279,10 +279,10 @@ async def preprocess(record):
279279
== "bar"
280280
)
281281

282-
async def bad_preprocess(record):
282+
def bad_preprocess(record):
283283
return 1
284284

285-
with pytest.raises(ValueError):
285+
with pytest.raises(RedisVLError):
286286
await async_index.load(data, id_field="id", preprocess=bad_preprocess)
287287

288288

@@ -298,7 +298,7 @@ async def test_no_id_field(async_index):
298298
bad_data = [{"wrong_key": "1", "value": "test"}]
299299

300300
# catch missing / invalid id_field
301-
with pytest.raises(ValueError):
301+
with pytest.raises(RedisVLError):
302302
await async_index.load(bad_data, id_field="key")
303303

304304

tests/integration/test_search_index.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import pytest
44

5-
from redisvl.exceptions import RedisSearchError
5+
from redisvl.exceptions import RedisSearchError, RedisVLError
66
from redisvl.index import SearchIndex
77
from redisvl.query import VectorQuery
88
from redisvl.redis.utils import convert_bytes
@@ -265,7 +265,7 @@ def preprocess(record):
265265
def bad_preprocess(record):
266266
return 1
267267

268-
with pytest.raises(ValueError):
268+
with pytest.raises(RedisVLError):
269269
index.load(data, id_field="id", preprocess=bad_preprocess)
270270

271271

@@ -274,7 +274,7 @@ def test_no_id_field(index):
274274
bad_data = [{"wrong_key": "1", "value": "test"}]
275275

276276
# catch missing / invalid id_field
277-
with pytest.raises(ValueError):
277+
with pytest.raises(RedisVLError):
278278
index.load(bad_data, id_field="key")
279279

280280

0 commit comments

Comments
 (0)