|
11 | 11 | from typing import Any, Dict, List, Optional |
12 | 12 | from aiohttp import ClientSession, ClientTimeout |
13 | 13 | from async_timeout import timeout |
| 14 | +from datetime import datetime, timedelta |
14 | 15 |
|
15 | 16 | from homeassistant.core import HomeAssistant |
16 | 17 | from homeassistant.exceptions import HomeAssistantError |
@@ -250,89 +251,135 @@ async def _create_gemini_completion( |
250 | 251 | temperature: float, |
251 | 252 | max_tokens: int, |
252 | 253 | ) -> Dict[str, Any]: |
253 | | - """Create completion using Gemini API.""" |
254 | | - # Extract API key from headers (Bearer token) |
255 | | - api_key = self.headers.get("Authorization", "").replace("Bearer ", "") |
256 | | - url = f"{self.endpoint}/models/{model}:generateContent?key={api_key}" |
| 254 | + """Create completion using Gemini API with google-genai library. |
257 | 255 |
|
258 | | - # Convert messages to Gemini format |
259 | | - contents = [] |
260 | | - system_instruction = "" |
| 256 | + Args: |
| 257 | + model: The model name to use |
| 258 | + messages: List of message dictionaries with role and content |
| 259 | + temperature: Sampling temperature between 0.0 and 2.0 |
| 260 | + max_tokens: Maximum number of tokens to generate |
261 | 261 |
|
262 | | - # Process messages |
263 | | - for msg in messages: |
264 | | - if msg['role'] == 'system': |
265 | | - system_instruction += msg['content'] + "\n" |
266 | | - else: |
267 | | - # Convert role: 'user' stays 'user', anything else becomes 'model' |
268 | | - role = "user" if msg['role'] == 'user' else "model" |
269 | | - contents.append({ |
270 | | - "role": role, |
271 | | - "parts": [{"text": msg['content']}] |
272 | | - }) |
273 | | - |
274 | | - # Ensure contents starts with a user message if not empty |
275 | | - if contents and contents[0]["role"] != "user": |
276 | | - # Add a placeholder user message |
277 | | - contents.insert(0, { |
278 | | - "role": "user", |
279 | | - "parts": [{"text": "I need your assistance."}] |
280 | | - }) |
281 | | - |
282 | | - # Ensure contents is not empty |
283 | | - if not contents: |
284 | | - contents.append({ |
285 | | - "role": "user", |
286 | | - "parts": [{"text": "I need your assistance."}] |
287 | | - }) |
288 | | - |
289 | | - # Create payload with snake_case keys as required by Gemini API |
290 | | - payload = { |
291 | | - "contents": contents, |
292 | | - "generation_config": { # Changed from camelCase to snake_case |
293 | | - "temperature": temperature, |
294 | | - "max_output_tokens": max_tokens # Changed from camelCase to snake_case |
295 | | - } |
296 | | - } |
| 262 | + Returns: |
| 263 | + Dictionary with response content and token usage |
| 264 | + """ |
| 265 | + try: |
| 266 | + # Импортируем библиотеку в отдельном потоке, чтобы избежать блокировки event loop |
| 267 | + def import_genai(): |
| 268 | + from google import genai |
| 269 | + return genai |
297 | 270 |
|
298 | | - if system_instruction: |
299 | | - payload["system_instruction"] = { # Changed from camelCase to snake_case |
300 | | - "parts": [{"text": system_instruction.strip()}] |
301 | | - } |
| 271 | + genai = await asyncio.to_thread(import_genai) |
302 | 272 |
|
303 | | - try: |
304 | | - data = await self._make_request(url, payload) |
| 273 | + # Extract API key from headers (Bearer token) |
| 274 | + api_key = self.headers.get("Authorization", "").replace("Bearer ", "") |
305 | 275 |
|
306 | | - # Safely extract response data |
307 | | - candidates = data.get("candidates", []) |
308 | | - if not candidates: |
309 | | - raise HomeAssistantError("Gemini API returned no candidates") |
| 276 | + # Создаем клиент в отдельном потоке |
| 277 | + def create_client(): |
| 278 | + if self.endpoint and self.endpoint != "https://generativelanguage.googleapis.com/v1beta": |
| 279 | + return genai.Client(api_key=api_key, transport="rest", |
| 280 | + client_options={"api_endpoint": self.endpoint}) |
| 281 | + else: |
| 282 | + return genai.Client(api_key=api_key) |
310 | 283 |
|
311 | | - content = candidates[0].get("content", {}) |
312 | | - parts = content.get("parts", []) |
313 | | - if not parts: |
314 | | - raise HomeAssistantError("Gemini API response contains no content parts") |
| 284 | + client = await asyncio.to_thread(create_client) |
315 | 285 |
|
316 | | - answer_text = parts[0].get("text", "") |
| 286 | + # Process messages to extract system instruction and chat history |
| 287 | + system_instruction = "" |
| 288 | + contents = [] |
317 | 289 |
|
318 | | - # Safely extract usage data |
319 | | - usage = data.get("usageMetadata", {}) |
320 | | - prompt_tokens = usage.get("promptTokenCount", 0) |
321 | | - completion_tokens = usage.get("candidatesTokenCount", 0) |
322 | | - total_tokens = usage.get("totalTokenCount", prompt_tokens + completion_tokens) |
| 290 | + for msg in messages: |
| 291 | + if msg['role'] == 'system': |
| 292 | + system_instruction += msg['content'] + "\n" |
| 293 | + else: |
| 294 | + # For chat history, we need to convert to the format Gemini expects |
| 295 | + role = "user" if msg['role'] == 'user' else "model" |
| 296 | + contents.append({ |
| 297 | + "role": role, |
| 298 | + "parts": [{"text": msg['content']}] |
| 299 | + }) |
| 300 | + |
| 301 | + # Create configuration |
| 302 | + def create_config(): |
| 303 | + from google.genai import types |
| 304 | + config = types.GenerateContentConfig( |
| 305 | + temperature=temperature, |
| 306 | + max_output_tokens=max_tokens, |
| 307 | + ) |
| 308 | + |
| 309 | + # Add system instruction if present |
| 310 | + if system_instruction: |
| 311 | + config.system_instruction = system_instruction.strip() |
| 312 | + |
| 313 | + return config |
| 314 | + |
| 315 | + config = await asyncio.to_thread(create_config) |
| 316 | + |
| 317 | + # Выполняем запрос в отдельном потоке |
| 318 | + def generate_content(): |
| 319 | + # For single message without history, use generate_content |
| 320 | + if len(contents) <= 1: |
| 321 | + # If we have no content yet, create a simple prompt |
| 322 | + if not contents: |
| 323 | + prompt = "I need your assistance." |
| 324 | + else: |
| 325 | + prompt = contents[0]["parts"][0]["text"] |
| 326 | + |
| 327 | + return client.models.generate_content( |
| 328 | + model=model, |
| 329 | + contents=prompt, |
| 330 | + config=config |
| 331 | + ) |
| 332 | + else: |
| 333 | + # For multi-turn conversations, use chat |
| 334 | + chat = client.chats.create(model=model, config=config) |
| 335 | + |
| 336 | + # Send all messages in sequence |
| 337 | + for content in contents: |
| 338 | + if content["role"] == "user": |
| 339 | + response = chat.send_message(content["parts"][0]["text"]) |
| 340 | + # We don't send assistant messages as they're already part of the history |
| 341 | + |
| 342 | + return response |
| 343 | + |
| 344 | + response = await asyncio.to_thread(generate_content) |
| 345 | + |
| 346 | + # Extract response text |
| 347 | + def extract_response(): |
| 348 | + response_text = response.text if hasattr(response, 'text') else "" |
| 349 | + |
| 350 | + # Try to get token usage if available |
| 351 | + usage = {} |
| 352 | + if hasattr(response, 'usage_metadata'): |
| 353 | + usage = { |
| 354 | + "prompt_tokens": getattr(response.usage_metadata, 'prompt_token_count', 0), |
| 355 | + "completion_tokens": getattr(response.usage_metadata, 'candidates_token_count', 0), |
| 356 | + "total_tokens": getattr(response.usage_metadata, 'total_token_count', 0) |
| 357 | + } |
| 358 | + else: |
| 359 | + # Estimate token count as fallback |
| 360 | + usage = { |
| 361 | + "prompt_tokens": len(" ".join([m["content"] for m in messages]).split()) // 3, |
| 362 | + "completion_tokens": len(response_text.split()) // 3, |
| 363 | + "total_tokens": 0 # Will be calculated below |
| 364 | + } |
| 365 | + usage["total_tokens"] = usage["prompt_tokens"] + usage["completion_tokens"] |
| 366 | + |
| 367 | + return response_text, usage |
| 368 | + |
| 369 | + response_text, usage = await asyncio.to_thread(extract_response) |
323 | 370 |
|
324 | 371 | return { |
325 | 372 | "choices": [{ |
326 | 373 | "message": { |
327 | | - "content": answer_text |
| 374 | + "content": response_text |
328 | 375 | } |
329 | 376 | }], |
330 | | - "usage": { |
331 | | - "prompt_tokens": prompt_tokens, |
332 | | - "completion_tokens": completion_tokens, |
333 | | - "total_tokens": total_tokens |
334 | | - } |
| 377 | + "usage": usage |
335 | 378 | } |
| 379 | + |
| 380 | + except ImportError as e: |
| 381 | + _LOGGER.error(f"Google Gemini library not installed: {str(e)}") |
| 382 | + raise HomeAssistantError(f"Missing dependency: {str(e)}. Please install google-genai.") |
336 | 383 | except Exception as e: |
337 | 384 | _LOGGER.error(f"Gemini API error: {str(e)}") |
338 | 385 | raise HomeAssistantError(f"Gemini API error: {str(e)}") |
|
0 commit comments