Skip to content

Commit 24fe6d0

Browse files
feat: add EmbeddingGenerator for entity embedding creation
Implement EmbeddingGenerator class to generate and manage embeddings for any embeddable entity in the system. Key features: - Generate embeddings using configurable embedding models - Compute content and configuration hashes for deduplication - Automatic detection and marking of stale embeddings - Support for multiple entity types with dynamic model lookup - Proper error handling for invalid entity types and missing entities
1 parent fddfa45 commit 24fe6d0

File tree

1 file changed

+192
-0
lines changed
  • apps/backend/src/rhesis/backend/app/services/embedding

1 file changed

+192
-0
lines changed
Lines changed: 192 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,192 @@
1+
"""Embedding generator for creating embeddings from entities."""
2+
3+
import hashlib
4+
import json
5+
from typing import Any, Dict, List, Optional
6+
7+
from sqlalchemy.orm import Session
8+
9+
from rhesis.backend.app import models
10+
from rhesis.backend.app.models.embedding import EmbeddingStatus
11+
from rhesis.backend.app.models.model import Model
12+
from rhesis.backend.app.utils.crud_utils import get_item
13+
from rhesis.backend.logging import logger
14+
15+
16+
class EmbeddingGenerator:
17+
"""Generate embedding for any embeddable entity."""
18+
19+
def __init__(self, db: Session):
20+
self.db = db
21+
22+
def _get_entity(self, entity_id: str, entity_type: str, organization_id: str) -> Any:
23+
"""Get entity from database."""
24+
25+
try:
26+
model_class = getattr(models, entity_type)
27+
except AttributeError:
28+
raise ValueError(f"Entity type {entity_type} not found")
29+
30+
entity = get_item(self.db, model_class, entity_id, organization_id)
31+
if not entity:
32+
raise ValueError(f"Entity not found: {entity_id}")
33+
return entity
34+
35+
def _compute_hash(self, data: str | dict) -> str:
36+
"""Compute SHA-256 hash of input data."""
37+
38+
# Convert dict to stable string representation
39+
if isinstance(data, dict):
40+
data_str = json.dumps(data, sort_keys=True)
41+
else:
42+
data_str = data
43+
44+
return hashlib.sha256(data_str.encode('utf-8')).hexdigest()
45+
46+
def _generate_embedding_vector(
47+
self,
48+
searchable_text: str,
49+
provider: str,
50+
model_name: str,
51+
api_key: str,
52+
dimension: int,
53+
) -> List[float]:
54+
"""Generate embedding for a searchable text."""
55+
from rhesis.sdk.models.factory import EmbedderConfig, get_embedder
56+
57+
config = EmbedderConfig(
58+
provider=provider,
59+
model_name=model_name,
60+
api_key=api_key,
61+
dimensions=dimension,
62+
)
63+
try:
64+
embedder = get_embedder(config=config)
65+
except ValueError as e:
66+
raise ValueError(f"Failed to create embedder: {e}")
67+
68+
try:
69+
embedding = embedder.generate(searchable_text)
70+
except Exception as e:
71+
raise ValueError(f"Failed to generate embedding: {e}")
72+
73+
return embedding
74+
75+
def generate(
76+
self,
77+
entity_id: str,
78+
entity_type: str,
79+
organization_id: str,
80+
user_id: str,
81+
model_id: str,
82+
entity: Optional[Any] = None,
83+
) -> Dict[str, Any]:
84+
"""
85+
Generate embedding for any embeddable entity.
86+
87+
Args:
88+
entity_id: ID of the entity to embed
89+
entity_type: Type of entity (Test, Source, etc.)
90+
organization_id: Organization context
91+
user_id: User context
92+
model_id: ID of the embedding model to use
93+
entity: Optional entity object (avoids re-fetch if provided)
94+
95+
Returns:
96+
Dictionary with generation result
97+
"""
98+
# If entity object provided, use it (sync path -> no extra DB query)
99+
if not entity:
100+
entity = self._get_entity(entity_id, entity_type, organization_id)
101+
102+
if not hasattr(entity, "to_searchable_text"):
103+
raise ValueError(f"Entity {entity_type} does not support embedding")
104+
105+
# Fetch model to get all configuration
106+
model = self.db.query(Model).filter(Model.id == model_id).first()
107+
if not model:
108+
raise ValueError(f"Model not found: {model_id}")
109+
110+
# Extract model details
111+
provider = model.provider_type.type_value if model.provider_type else None
112+
model_name = model.model_name
113+
dimension = model.dimension
114+
115+
# Get searchable text from entity
116+
searchable_text = entity.to_searchable_text()
117+
118+
# Create configuration for this embedding
119+
config = {
120+
"provider": provider,
121+
"model_name": model_name,
122+
"dimension": dimension,
123+
"model_id": model_id,
124+
}
125+
126+
# Compute hashes for deduplication
127+
config_hash = self._compute_hash(config)
128+
text_hash = self._compute_hash(searchable_text)
129+
130+
# Check if embedding already exists (same text/config)
131+
existing_embedding = self.db.query(models.Embedding).filter(
132+
models.Embedding.entity_id == entity_id,
133+
models.Embedding.entity_type == entity_type,
134+
models.Embedding.organization_id == organization_id,
135+
models.Embedding.config_hash == config_hash,
136+
models.Embedding.text_hash == text_hash,
137+
models.Embedding.status == EmbeddingStatus.ACTIVE.value,
138+
).first()
139+
140+
if existing_embedding:
141+
logger.info(f"Embedding already exists for {entity_type}:{entity_id}")
142+
return {"status": "success", "embedding_id": str(existing_embedding.id)}
143+
144+
# Mark old embeddings as stale (different text/config)
145+
stale_count = (
146+
self.db.query(models.Embedding)
147+
.filter(
148+
models.Embedding.entity_id == entity_id,
149+
models.Embedding.entity_type == entity_type,
150+
models.Embedding.organization_id == organization_id,
151+
models.Embedding.status == EmbeddingStatus.ACTIVE.value,
152+
)
153+
.update({"status": EmbeddingStatus.STALE.value})
154+
)
155+
156+
if stale_count > 0:
157+
logger.info(f"Marked {stale_count} old embeddings as stale")
158+
159+
self.db.flush()
160+
161+
# Generate the embedding vector
162+
embedding_vector = self._generate_embedding_vector(
163+
searchable_text, provider, model_name, model.key, dimension
164+
)
165+
166+
# Create and store the embedding
167+
new_embedding = models.Embedding(
168+
entity_id=entity_id,
169+
entity_type=entity_type,
170+
model_id=model_id,
171+
embedding_config=config,
172+
config_hash=config_hash,
173+
searchable_text=searchable_text,
174+
text_hash=text_hash,
175+
organization_id=organization_id,
176+
user_id=user_id,
177+
status=EmbeddingStatus.ACTIVE.value,
178+
)
179+
180+
# Use the property setter which automatically selects the right column
181+
new_embedding.embedding = embedding_vector
182+
183+
self.db.add(new_embedding)
184+
self.db.commit()
185+
self.db.refresh(new_embedding)
186+
187+
logger.info(
188+
f"Successfully generated embedding for {entity_type}:{entity_id}, "
189+
f"dimension={dimension}"
190+
)
191+
192+
return {"status": "success", "embedding_id": str(new_embedding.id)}

0 commit comments

Comments
 (0)