-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathdatabase_loader.py
More file actions
358 lines (300 loc) · 12.4 KB
/
database_loader.py
File metadata and controls
358 lines (300 loc) · 12.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
"""Database loading and collection management utilities."""
from __future__ import annotations
import time
from pathlib import Path
from typing import Any, Callable, Mapping, Optional, Sequence
import weaviate
import weaviate.collections.classes.config as wvcc
from .database_registry import resolve_spec
from ..dataset import in_memory_dataset_loader
from ..utils import (
get_weaviate_client,
get_provider_headers,
load_config,
parse_embedding_model,
pretty_print_in_memory_document,
add_tag_to_name,
)
# Public API: used by external callers to build temporary collections with an explicit text embedding model.
def get_vector_config(embedding_model: Optional[str] = None) -> Any:
"""
Factory function to create vectorizer config based on provider.
Args:
embedding_model: Model string in format "provider/model" (e.g., "cohere/embed-4")
or just "model" for weaviate. If None, uses default weaviate.
Returns:
Vectorizer configuration object
"""
if not embedding_model:
return wvcc.Configure.Vectors.text2vec_weaviate(
vectorize_collection_name=False,
)
provider, model_name = parse_embedding_model(embedding_model)
if provider == "weaviate":
return wvcc.Configure.Vectors.text2vec_weaviate(
model=model_name,
vectorize_collection_name=False,
)
elif provider == "cohere":
return wvcc.Configure.Vectors.text2vec_cohere(model=model_name)
elif provider == "voyageai":
return wvcc.Configure.Vectors.text2vec_voyageai(model=model_name)
elif provider == "google":
return wvcc.Configure.Vectors.multi2vec_google_gemini(model=model_name)
else:
raise ValueError(
f"Unsupported embedding provider: '{provider}'. "
f"Supported providers: ['weaviate', 'cohere', 'voyageai', 'google']"
)
# Public API: used by embedding-comparison flows and scripts to create/populate tagged collections.
def create_collection_with_vector_config(
client: Optional[weaviate.WeaviateClient] = None,
dataset_name: Optional[str] = None,
tag: str = "Default",
embedding_model: Optional[str] = None,
weaviate_client: Optional[weaviate.WeaviateClient] = None,
) -> None:
"""
Create and populate a collection with a specified embedding model.
Used for embedding model comparison where temporary collections
are created with different models.
Args:
client: Connected Weaviate client
dataset_name: Name of the dataset to load
tag: Suffix to add to the collection name
embedding_model: Embedding model to use. If None, uses default.
weaviate_client: Backward-compatible alias for `client`.
"""
if client is None:
client = weaviate_client
elif weaviate_client is not None and weaviate_client is not client:
raise ValueError("Pass either `client` or `weaviate_client`, not both")
if client is None:
raise ValueError("A Weaviate client is required")
if dataset_name is None:
raise ValueError("dataset_name is required")
print(f"Loading dataset '{dataset_name}'...")
objects = _load_documents(dataset_name)
spec = resolve_spec(dataset_name)
alias_collection_name = spec.name_fn(dataset_name)
collection_name = add_tag_to_name(alias_collection_name, tag)
vector_config = get_vector_config(embedding_model)
model_info = f" with model {embedding_model}" if embedding_model else " with default model"
print(f"Creating collection '{collection_name}'{model_info}...")
_drop_and_create_collection(
client,
collection_name,
properties=spec.properties,
vector_config=vector_config,
recreate=True,
multi_tenancy_config=spec.multi_tenancy_config,
)
print(f"Populating collection with {len(objects)} objects...")
_batch_insert(
client,
collection=collection_name,
items=objects,
item_to_props=spec.item_to_props,
tenant_id_field=spec.tenant_id_field,
)
print(f"Collection '{collection_name}' ready!\n")
# Public API: primary entry point used by `scripts/populate-db.py` and package users.
def database_loader(recreate: bool = True, tag: str = "Default") -> None:
"""
Load dataset from config and populate Weaviate collection (or Engram if use_engram is set).
Args:
recreate: Whether to drop existing collection before creating
tag: Suffix to add to collection name
"""
config_path = Path(__file__).parent / "database_loader_config.yml"
config = load_config(config_path)
if config.get("use_engram", False):
from .engram_loader import engram_loader
engram_loader(
dataset_name=config.get("dataset_name"),
engram_base_url=config.get("engram_base_url"),
)
return
# Get provider headers for all configured embedding providers.
headers = _resolve_provider_headers(
embedding_providers=config.get("embedding_providers"),
embedding_models=[
config.get("text_embedding_model"),
config.get("image_embedding_model"),
*_coerce_model_list(config.get("text_embedding_models")),
*_coerce_model_list(config.get("image_embedding_models")),
],
)
client = get_weaviate_client(headers=headers)
try:
dataset_name: str = config["dataset_name"]
objects = _load_documents(dataset_name)
print("\033[92mFirst Document:\033[0m")
pretty_print_in_memory_document(objects[0])
spec = resolve_spec(dataset_name)
alias_collection_name = spec.name_fn(dataset_name)
collection_name = add_tag_to_name(alias_collection_name, tag)
print(f"\n\033[96mCreating collection '{collection_name}'...\033[0m")
_drop_and_create_collection(
client,
collection_name,
properties=spec.properties,
vector_config=spec.vector_config,
recreate=recreate,
multi_tenancy_config=spec.multi_tenancy_config,
)
# Manage alias
alias_info = client.alias.get(alias_name=alias_collection_name)
if alias_info is None:
client.alias.create(
alias_name=alias_collection_name,
target_collection=collection_name,
)
else:
client.alias.update(
alias_name=alias_collection_name,
new_target_collection=collection_name,
)
_batch_insert(
client,
collection=collection_name,
items=objects,
item_to_props=spec.item_to_props,
tenant_id_field=spec.tenant_id_field,
)
finally:
client.close()
# Private helper: low-level create/delete wrapper used internally to keep public loaders concise.
def _drop_and_create_collection(
client: weaviate.WeaviateClient,
name: str,
properties: Sequence[wvcc.Property],
vector_config: Any,
recreate: bool = True,
multi_tenancy_config: Any = None,
) -> None:
"""Drop (if exists) and create a Weaviate collection."""
if recreate and client.collections.exists(name):
client.collections.delete(name)
if not client.collections.exists(name):
create_kwargs: dict[str, Any] = {
"name": name,
"vector_config": vector_config,
"properties": list(properties),
}
if multi_tenancy_config is not None:
create_kwargs["multi_tenancy_config"] = multi_tenancy_config
client.collections.create(**create_kwargs)
# Private helper: internal batching utility shared by the public loaders above.
def _batch_insert(
client: weaviate.WeaviateClient,
collection: str,
items: Sequence[Mapping[str, Any]],
item_to_props: Callable[[Mapping[str, Any]], dict[str, Any]],
batch_size: int = 20,
verbose: bool = True,
tenant_id_field: Optional[str] = None,
) -> int:
"""
Insert items into a Weaviate collection in batches.
If tenant_id_field is set, each item's tenant is read from that field
and passed to add_object so the object is inserted into the correct tenant.
Returns the total number of items inserted.
"""
start = time.perf_counter()
total = 0
if verbose:
print(f"Inserting {len(items)} objects into collection '{collection}'...")
with client.batch.fixed_size(batch_size=batch_size) as batch:
for i, item in enumerate(items, start=1):
props = item_to_props(item)
add_kwargs: dict[str, Any] = {
"collection": collection,
"properties": props,
}
if tenant_id_field is not None:
add_kwargs["tenant"] = str(item[tenant_id_field])
batch.add_object(**add_kwargs)
if verbose and i % batch_size == 0:
elapsed = time.perf_counter() - start
rate = i / max(elapsed, 1e-9)
print(f"\033[92mInserted {i} objects ({elapsed:.1f}s, {rate:.1f} objs/s)\033[0m")
total = i
if verbose:
elapsed = time.perf_counter() - start
rate = total / max(elapsed, 1e-9)
print(f"Inserted {total} objects in {elapsed:.2f}s ({rate:.1f} objs/s)")
return total
# Private helper: internal header resolver so provider-key logic stays out of public APIs.
def _resolve_provider_headers(
embedding_models: Sequence[Optional[str]] = (),
embedding_providers: Optional[str | Sequence[str]] = None,
) -> dict[str, str]:
"""Resolve API headers for all configured embedding providers."""
headers: dict[str, str] = {}
for provider in _parse_embedding_providers(embedding_providers):
headers.update(get_provider_headers(provider))
for embedding_model in embedding_models:
if not embedding_model:
continue
provider, _ = parse_embedding_model(embedding_model)
headers.update(get_provider_headers(_normalize_provider_name(provider)))
return headers
def _parse_embedding_providers(
embedding_providers: Optional[str | Sequence[str]],
) -> list[str]:
"""Parse embedding providers from config values."""
if embedding_providers is None:
return []
if isinstance(embedding_providers, str):
raw = embedding_providers.strip()
if not raw or raw.lower() == "auto":
return []
return [
_normalize_provider_name(token)
for token in raw.split(",")
if token.strip()
]
providers: list[str] = []
for value in embedding_providers:
if not isinstance(value, str):
raise ValueError("embedding_providers entries must be strings.")
normalized = _normalize_provider_name(value)
if normalized == "auto":
continue
if normalized not in providers:
providers.append(normalized)
return providers
def _normalize_provider_name(provider: str) -> str:
"""Normalize provider aliases to canonical provider names."""
normalized = provider.strip().lower()
if normalized in {"voyage", "voyage-ai"}:
return "voyageai"
if normalized in {"gemini", "google-gemini", "google_gemini"}:
return "google"
return normalized
def _coerce_model_list(raw_models: Any) -> list[str]:
"""Normalize optional model config values to a list of non-empty strings."""
if raw_models is None:
return []
if isinstance(raw_models, str):
return [raw_models] if raw_models.strip() else []
if not isinstance(raw_models, Sequence):
raise ValueError("Embedding model lists must be a sequence of strings.")
models: list[str] = []
for value in raw_models:
if not isinstance(value, str):
raise ValueError("Embedding model list entries must be strings.")
if value.strip():
models.append(value)
return models
# Private helper: centralizes dataset validation for all public loader paths.
def _load_documents(dataset_name: str) -> Sequence[Mapping[str, Any]]:
"""Load in-memory docs for a dataset with explicit validation."""
loaded = in_memory_dataset_loader(dataset_name, corpus_only=True)
if loaded is None:
raise ValueError(f"Unsupported dataset_name: {dataset_name}")
objects, _ = loaded
if not objects:
raise ValueError(f"Dataset '{dataset_name}' returned zero documents")
return objects