Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 13 additions & 12 deletions opencontext/llm/llm_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ async def _openai_chat_completion_stream_async(self, messages: List[Dict[str, An

def _request_embedding(self, text: str, **kwargs) -> List[float]:
try:
if self.provider == LLMProvider.OPENAI.value:
if self.provider != LLMProvider.DOUBAO.value:
response = self.client.embeddings.create(model=self.model, input=[text])
embedding = response.data[0].embedding
else:
Expand Down Expand Up @@ -313,14 +313,15 @@ def _request_embedding(self, text: str, **kwargs) -> List[float]:

async def _request_embedding_async(self, text: str, **kwargs) -> List[float]:
try:
if self.provider == LLMProvider.OPENAI.value:
response = await self.async_client.embeddings.create(model=self.model, input=[text])
embedding = response.data[0].embedding
else:
if self.provider == LLMProvider.DOUBAO.value:
# Only ark has multimodal_embeddings
response = self.client.multimodal_embeddings.create(
model=self.model, input=[{"type": "text", "text": text}]
)
embedding = response.data.embedding
else:
response = await self.async_client.embeddings.create(model=self.model, input=[text])
embedding = response.data[0].embedding

# Record token usage
if hasattr(response, "usage") and response.usage:
Expand Down Expand Up @@ -476,20 +477,20 @@ def _extract_error_summary(error: Any) -> str:

elif self.llm_type == LLMType.EMBEDDING:
# Test with a simple text
if self.provider == LLMProvider.OPENAI.value:
response = self.client.embeddings.create(model=self.model, input=["test"])
if response.data and len(response.data) > 0 and response.data[0].embedding:
return True, "Embedding model validation successful"
else:
return False, "Embedding model returned empty response"
else:
if self.provider == LLMProvider.DOUBAO.value:
response = self.client.multimodal_embeddings.create(
model=self.model, input=[{"type": "text", "text": "test"}]
)
if response.data and response.data.embedding:
return True, "Embedding model validation successful"
else:
return False, "Embedding model returned empty response"
else:
response = self.client.embeddings.create(model=self.model, input=["test"])
if response.data and len(response.data) > 0 and response.data[0].embedding:
return True, "Embedding model validation successful"
else:
return False, "Embedding model returned empty response"
else:
return False, f"Unsupported LLM type: {self.llm_type}"

Expand Down