Skip to content

Commit 1e16309

Browse files
feat(llm): add Grok (xAI) integration (#179) (#236)
Co-authored-by: An Kaisen <51148505+ankaisen@users.noreply.github.com>
1 parent cba667a commit 1e16309

7 files changed

Lines changed: 193 additions & 1 deletion

File tree

docs/integrations/grok.md

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# Grok (xAI) Integration
2+
3+
MemU supports **Grok**, the AI model from xAI, as a first-class LLM provider.
4+
5+
## Prerequisites
6+
7+
1. **xAI Account:** You need an active account with [xAI](https://x.ai/).
8+
2. **API Key:** Obtain an API key from the [xAI Console](https://console.x.ai/).
9+
10+
## Configuration
11+
12+
To enable Grok, you need to set the `XAI_API_KEY` environment variable.
13+
14+
### Environment Variable
15+
16+
```bash
17+
export XAI_API_KEY="your-xai-api-key-here"
18+
```
19+
20+
PowerShell:
21+
22+
```powershell
23+
$env:XAI_API_KEY="your-xai-api-key-here"
24+
```
25+
26+
## Usage
27+
28+
To use Grok as your LLM provider, switch the `provider` setting to `grok`. This can be done in your configuration file or when initializing the application.
29+
30+
### Python Example
31+
32+
```python
33+
from memu.app.settings import LLMConfig
34+
35+
# Configure MemU to use Grok
36+
config = LLMConfig(
37+
provider="grok",
38+
# The default API key env var is XAI_API_KEY
39+
# The default model is grok-2-latest
40+
)
41+
42+
print(f"Using provider: {config.provider}")
43+
print(f"Base URL: {config.base_url}")
44+
print(f"Chat Model: {config.chat_model}")
45+
```
46+
47+
## Models Supported
48+
49+
We currently support the following Grok models:
50+
51+
* **grok-2-latest** (Default)
52+
53+
The integration automatically sets the base URL to `https://api.x.ai/v1`.

docs/providers/grok.md

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
# Grok (xAI) Provider
2+
3+
memU includes first-class support for [Grok](https://grok.x.ai/), allowing you to leverage xAI's powerful language models directly within your application.
4+
5+
## Prerequisites
6+
7+
To use this provider, you must have an active xAI account.
8+
9+
1. Navigate to the [xAI Console](https://console.x.ai/).
10+
2. Sign up or log in.
11+
3. Create a new **API Key** in the API Keys section.
12+
13+
## Configuration
14+
15+
The integration is designed to work out-of-the-box with minimal configuration.
16+
17+
### Environment Variables
18+
19+
Set the following environment variable in your `.env` file or system environment:
20+
21+
```bash
22+
GROK_API_KEY=xai-YOUR_API_KEY_HERE
23+
```
24+
25+
### Defaults
26+
27+
When you select the `grok` provider, memU automatically configures the following defaults:
28+
29+
* **Base URL**: `https://api.x.ai/v1`
30+
* **Model**: `grok-2-latest`
31+
32+
## Usage Example
33+
34+
You can enable the Grok provider by setting the `provider` field to `"grok"` in your application configuration.
35+
36+
### Using Python Configuration
37+
38+
```python
39+
from memu.app.settings import LLMConfig
40+
from memu.app.service import MemoryService
41+
42+
# Configure the LLM provider to use Grok
43+
llm_config = LLMConfig(provider="grok")
44+
45+
# Initialize the service
46+
service = MemoryService(llm_config=llm_config)
47+
print(f"Service initialized with model: {llm_config.chat_model}")
48+
# Output: Service initialized with model: grok-2-latest
49+
```
50+
51+
## Troubleshooting
52+
53+
### Connection Issues
54+
If you are unable to connect to the xAI API:
55+
1. Verify that your `GROK_API_KEY` is set correctly and has not expired.
56+
2. Ensure that the `base_url` is resolving to `https://api.x.ai/v1`. If you have manual overrides in your settings, they might be conflicting with the default.
57+
58+
### Model Availability
59+
If you receive a `404` or "Model not found" error, xAI may have updated their model names. You can override the model manually in the config if needed:
60+
61+
```python
62+
config = LLMConfig(
63+
provider="grok",
64+
chat_model="grok-beta" # Example override
65+
)
66+
```

src/memu/app/settings.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,18 @@ class LLMConfig(BaseModel):
114114
description="Maximum batch size for embedding API calls (used by SDK client backends).",
115115
)
116116

117+
@model_validator(mode="after")
118+
def set_provider_defaults(self) -> "LLMConfig":
119+
if self.provider == "grok":
120+
# If values match the OpenAI defaults, switch them to Grok defaults
121+
if self.base_url == "https://api.openai.com/v1":
122+
self.base_url = "https://api.x.ai/v1"
123+
if self.api_key == "OPENAI_API_KEY":
124+
self.api_key = "XAI_API_KEY"
125+
if self.chat_model == "gpt-4o-mini":
126+
self.chat_model = "grok-2-latest"
127+
return self
128+
117129

118130
class BlobConfig(BaseModel):
119131
provider: str = Field(default="local")

src/memu/llm/backends/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from memu.llm.backends.base import LLMBackend
22
from memu.llm.backends.doubao import DoubaoLLMBackend
3+
from memu.llm.backends.grok import GrokBackend
34
from memu.llm.backends.openai import OpenAILLMBackend
45
from memu.llm.backends.openrouter import OpenRouterLLMBackend
56

6-
__all__ = ["DoubaoLLMBackend", "LLMBackend", "OpenAILLMBackend", "OpenRouterLLMBackend"]
7+
__all__ = ["DoubaoLLMBackend", "GrokBackend", "LLMBackend", "OpenAILLMBackend", "OpenRouterLLMBackend"]

src/memu/llm/backends/grok.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from __future__ import annotations
2+
3+
from memu.llm.backends.openai import OpenAILLMBackend
4+
5+
6+
class GrokBackend(OpenAILLMBackend):
7+
"""Backend for Grok (xAI) LLM API."""
8+
9+
name = "grok"
10+
# Grok uses the same payload structure as OpenAI
11+
# We inherits build_summary_payload, parse_summary_response, etc.

src/memu/llm/http_client.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
from memu.llm.backends.base import LLMBackend
1212
from memu.llm.backends.doubao import DoubaoLLMBackend
13+
from memu.llm.backends.grok import GrokBackend
1314
from memu.llm.backends.openai import OpenAILLMBackend
1415
from memu.llm.backends.openrouter import OpenRouterLLMBackend
1516

@@ -66,6 +67,7 @@ def parse_embedding_response(self, data: dict[str, Any]) -> list[list[float]]:
6667
LLM_BACKENDS: dict[str, Callable[[], LLMBackend]] = {
6768
OpenAILLMBackend.name: OpenAILLMBackend,
6869
DoubaoLLMBackend.name: DoubaoLLMBackend,
70+
GrokBackend.name: GrokBackend,
6971
OpenRouterLLMBackend.name: OpenRouterLLMBackend,
7072
}
7173

@@ -244,6 +246,7 @@ def _load_embedding_backend(self, provider: str) -> _EmbeddingBackend:
244246
backends: dict[str, type[_EmbeddingBackend]] = {
245247
_OpenAIEmbeddingBackend.name: _OpenAIEmbeddingBackend,
246248
_DoubaoEmbeddingBackend.name: _DoubaoEmbeddingBackend,
249+
"grok": _OpenAIEmbeddingBackend,
247250
_OpenRouterEmbeddingBackend.name: _OpenRouterEmbeddingBackend,
248251
}
249252
factory = backends.get(provider)

tests/llm/test_grok_provider.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import unittest
2+
from unittest.mock import patch
3+
4+
from memu.app.settings import LLMConfig
5+
from memu.llm.backends.grok import GrokBackend
6+
from memu.llm.openai_sdk import OpenAISDKClient
7+
8+
9+
class TestGrokProvider(unittest.IsolatedAsyncioTestCase):
10+
def test_settings_defaults(self):
11+
"""Test that setting provider='grok' sets the correct defaults."""
12+
config = LLMConfig(provider="grok")
13+
self.assertEqual(config.base_url, "https://api.x.ai/v1")
14+
self.assertEqual(config.api_key, "XAI_API_KEY")
15+
self.assertEqual(config.chat_model, "grok-2-latest")
16+
17+
@patch("memu.llm.openai_sdk.AsyncOpenAI")
18+
async def test_client_initialization_with_grok_config(self, mock_async_openai):
19+
"""Test that OpenAISDKClient initializes with Grok base URL when configured."""
20+
# Setup config
21+
config = LLMConfig(provider="grok")
22+
23+
# Instantiate client with Grok config
24+
# We simulate what the application factory would do: pass the config values
25+
client = OpenAISDKClient(
26+
base_url=config.base_url,
27+
api_key="fake-key", # In real app, this would be os.getenv(config.api_key)
28+
chat_model=config.chat_model,
29+
embed_model=config.embed_model,
30+
)
31+
32+
# Assert AsyncOpenAI was called with the correct base_url
33+
mock_async_openai.assert_called_with(api_key="fake-key", base_url="https://api.x.ai/v1")
34+
35+
# Verify client attributes
36+
self.assertEqual(client.chat_model, "grok-2-latest")
37+
38+
def test_grok_backend_payload_parsing(self):
39+
"""Test that GrokBackend parses responses correctly (inherited from OpenAI)."""
40+
backend = GrokBackend()
41+
42+
# Simulate a typical OpenAI-compatible response
43+
dummy_response = {"choices": [{"message": {"content": "Grok response content", "role": "assistant"}}]}
44+
45+
result = backend.parse_summary_response(dummy_response)
46+
self.assertEqual(result, "Grok response content")

0 commit comments

Comments
 (0)