Skip to content

Commit 558514d

Browse files
committed
made some changes and added geminiai
1 parent dc6880a commit 558514d

5 files changed

Lines changed: 158 additions & 1 deletion

File tree

src/memu/app/settings.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,15 @@ def set_provider_defaults(self) -> "LLMConfig":
135135
self.api_key = "XAI_API_KEY"
136136
if self.chat_model == "gpt-4o-mini":
137137
self.chat_model = "grok-2-latest"
138+
elif self.provider == "gemini":
139+
if self.base_url == "https://api.openai.com/v1":
140+
self.base_url = "https://generativelanguage.googleapis.com/v1beta/openai"
141+
if self.api_key == "OPENAI_API_KEY":
142+
self.api_key = "GEMINI_API_KEY"
143+
if self.chat_model == "gpt-4o-mini":
144+
self.chat_model = "gemini-2.0-flash"
145+
if self.embed_model == "text-embedding-3-small":
146+
self.embed_model = "gemini-embedding-001"
138147
return self
139148

140149

src/memu/llm/backends/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
from memu.llm.backends.base import LLMBackend
22
from memu.llm.backends.doubao import DoubaoLLMBackend
3+
from memu.llm.backends.gemini import GeminiLLMBackend
34
from memu.llm.backends.grok import GrokBackend
45
from memu.llm.backends.openai import OpenAILLMBackend
56
from memu.llm.backends.openrouter import OpenRouterLLMBackend
67

7-
__all__ = ["DoubaoLLMBackend", "GrokBackend", "LLMBackend", "OpenAILLMBackend", "OpenRouterLLMBackend"]
8+
__all__ = ["DoubaoLLMBackend", "GeminiLLMBackend", "GrokBackend", "LLMBackend", "OpenAILLMBackend", "OpenRouterLLMBackend"]

src/memu/llm/backends/gemini.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from __future__ import annotations
2+
3+
from memu.llm.backends.openai import OpenAILLMBackend
4+
5+
6+
class GeminiLLMBackend(OpenAILLMBackend):
7+
"""Backend for Google Gemini via its OpenAI-compatible API endpoint."""
8+
9+
name = "gemini"
10+
# Gemini's OpenAI-compatible chat endpoint is the same as OpenAI's
11+
summary_endpoint = "/chat/completions"

src/memu/llm/http_client.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from memu.llm.backends.base import LLMBackend
1313
from memu.llm.backends.doubao import DoubaoLLMBackend
14+
from memu.llm.backends.gemini import GeminiLLMBackend
1415
from memu.llm.backends.grok import GrokBackend
1516
from memu.llm.backends.openai import OpenAILLMBackend
1617
from memu.llm.backends.openrouter import OpenRouterLLMBackend
@@ -72,6 +73,7 @@ def parse_embedding_response(self, data: dict[str, Any]) -> list[list[float]]:
7273
LLM_BACKENDS: dict[str, Callable[[], LLMBackend]] = {
7374
OpenAILLMBackend.name: OpenAILLMBackend,
7475
DoubaoLLMBackend.name: DoubaoLLMBackend,
76+
GeminiLLMBackend.name: GeminiLLMBackend,
7577
GrokBackend.name: GrokBackend,
7678
OpenRouterLLMBackend.name: OpenRouterLLMBackend,
7779
}
@@ -291,6 +293,7 @@ def _load_embedding_backend(self, provider: str) -> _EmbeddingBackend:
291293
_OpenAIEmbeddingBackend.name: _OpenAIEmbeddingBackend,
292294
_DoubaoEmbeddingBackend.name: _DoubaoEmbeddingBackend,
293295
"grok": _OpenAIEmbeddingBackend,
296+
"gemini": _OpenAIEmbeddingBackend,
294297
_OpenRouterEmbeddingBackend.name: _OpenRouterEmbeddingBackend,
295298
}
296299
factory = backends.get(provider)

tests/llm/test_gemini_provider.py

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
import os
2+
import unittest
3+
from unittest.mock import AsyncMock, MagicMock, patch
4+
5+
from memu.app.settings import LLMConfig
6+
from memu.llm.backends.gemini import GeminiLLMBackend
7+
from memu.llm.http_client import HTTPLLMClient, LLM_BACKENDS, _OpenAIEmbeddingBackend
8+
9+
10+
class TestGeminiSettings(unittest.TestCase):
11+
def test_settings_defaults(self):
12+
"""provider='gemini' should set Gemini-specific defaults."""
13+
config = LLMConfig(provider="gemini")
14+
self.assertEqual(config.base_url, "https://generativelanguage.googleapis.com/v1beta/openai")
15+
self.assertEqual(config.api_key, "GEMINI_API_KEY")
16+
self.assertEqual(config.chat_model, "gemini-2.0-flash")
17+
self.assertEqual(config.embed_model, "gemini-embedding-001")
18+
19+
def test_explicit_values_not_overridden(self):
20+
"""Explicit values should not be replaced by defaults."""
21+
config = LLMConfig(
22+
provider="gemini",
23+
chat_model="gemini-2.5-flash",
24+
embed_model="gemini-embedding-001",
25+
api_key="my-real-key",
26+
)
27+
self.assertEqual(config.chat_model, "gemini-2.5-flash")
28+
self.assertEqual(config.embed_model, "gemini-embedding-001")
29+
self.assertEqual(config.api_key, "my-real-key")
30+
31+
32+
class TestGeminiBackend(unittest.TestCase):
33+
def setUp(self):
34+
self.backend = GeminiLLMBackend()
35+
36+
def test_backend_registered(self):
37+
"""GeminiLLMBackend must be in the LLM_BACKENDS registry."""
38+
self.assertIn("gemini", LLM_BACKENDS)
39+
self.assertIsInstance(LLM_BACKENDS["gemini"](), GeminiLLMBackend)
40+
41+
def test_summary_endpoint(self):
42+
self.assertEqual(self.backend.summary_endpoint, "/chat/completions")
43+
44+
def test_build_summary_payload(self):
45+
payload = self.backend.build_summary_payload(
46+
text="Hello world",
47+
system_prompt="Be concise.",
48+
chat_model="gemini-2.0-flash",
49+
max_tokens=100,
50+
)
51+
self.assertEqual(payload["model"], "gemini-2.0-flash")
52+
self.assertEqual(payload["messages"][0]["role"], "system")
53+
self.assertEqual(payload["messages"][1]["content"], "Hello world")
54+
self.assertEqual(payload["max_tokens"], 100)
55+
56+
def test_parse_summary_response(self):
57+
dummy = {"choices": [{"message": {"content": "Gemini reply", "role": "assistant"}}]}
58+
result = self.backend.parse_summary_response(dummy)
59+
self.assertEqual(result, "Gemini reply")
60+
61+
def test_build_vision_payload(self):
62+
payload = self.backend.build_vision_payload(
63+
prompt="Describe this image",
64+
base64_image="abc123",
65+
mime_type="image/png",
66+
system_prompt=None,
67+
chat_model="gemini-2.0-flash",
68+
max_tokens=None,
69+
)
70+
self.assertEqual(payload["model"], "gemini-2.0-flash")
71+
content = payload["messages"][0]["content"]
72+
image_part = next(p for p in content if p["type"] == "image_url")
73+
self.assertIn("data:image/png;base64,abc123", image_part["image_url"]["url"])
74+
75+
76+
class TestGeminiHTTPClient(unittest.TestCase):
77+
def test_client_loads_gemini_backend(self):
78+
"""HTTPLLMClient with provider='gemini' should load GeminiLLMBackend."""
79+
client = HTTPLLMClient(
80+
base_url="https://generativelanguage.googleapis.com/v1beta/openai",
81+
api_key="fake-key",
82+
chat_model="gemini-2.0-flash",
83+
provider="gemini",
84+
embed_model="gemini-embedding-001",
85+
)
86+
self.assertIsInstance(client.backend, GeminiLLMBackend)
87+
self.assertIsInstance(client.embedding_backend, _OpenAIEmbeddingBackend)
88+
self.assertEqual(client.embed_model, "gemini-embedding-001")
89+
90+
def test_embedding_endpoint(self):
91+
client = HTTPLLMClient(
92+
base_url="https://generativelanguage.googleapis.com/v1beta/openai",
93+
api_key="fake-key",
94+
chat_model="gemini-2.0-flash",
95+
provider="gemini",
96+
embed_model="gemini-embedding-001",
97+
)
98+
self.assertEqual(client.embedding_endpoint, "embeddings")
99+
100+
101+
class TestGeminiLiveAPI(unittest.IsolatedAsyncioTestCase):
102+
"""Live tests — skipped if GEMINI_API_KEY is not set."""
103+
104+
def setUp(self):
105+
self.api_key = os.getenv("GEMINI_API_KEY")
106+
if not self.api_key:
107+
self.skipTest("GEMINI_API_KEY not set")
108+
self.client = HTTPLLMClient(
109+
base_url="https://generativelanguage.googleapis.com/v1beta/openai",
110+
api_key=self.api_key,
111+
chat_model="gemini-2.5-flash",
112+
provider="gemini",
113+
embed_model="gemini-embedding-001",
114+
)
115+
116+
async def test_chat(self):
117+
response, _ = await self.client.chat("Say hello in one word.")
118+
self.assertIsInstance(response, str)
119+
self.assertGreater(len(response), 0)
120+
121+
async def test_summarize(self):
122+
response, _ = await self.client.summarize("The sky is blue and the grass is green.")
123+
self.assertIsInstance(response, str)
124+
self.assertGreater(len(response), 0)
125+
126+
async def test_embed(self):
127+
vectors, _ = await self.client.embed(["Hello world", "Gemini embeddings"])
128+
self.assertEqual(len(vectors), 2)
129+
self.assertEqual(len(vectors[0]), 3072) # gemini-embedding-001 dimension
130+
131+
132+
if __name__ == "__main__":
133+
unittest.main()

0 commit comments

Comments
 (0)