Skip to content

Commit 3d2a9e9

Browse files
committed
WIP on precommit
1 parent 9890ae2 commit 3d2a9e9

File tree

6 files changed

+134
-31
lines changed

6 files changed

+134
-31
lines changed

.pre-commit-config.yaml

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
repos:
2+
- repo: https://github.com/astral-sh/ruff-pre-commit
3+
rev: v0.3.2 # Use the latest version
4+
hooks:
5+
# Run the linter
6+
- id: ruff
7+
args: [--fix]
8+
# Run the formatter
9+
- id: ruff-format
10+
11+
- repo: https://github.com/pre-commit/pre-commit-hooks
12+
rev: v4.5.0
13+
hooks:
14+
- id: trailing-whitespace
15+
- id: end-of-file-fixer
16+
- id: check-yaml
17+
- id: check-added-large-files

models.py

Lines changed: 35 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
1-
from typing import Dict, List, Optional, Union, Any
2-
import numpy as np
3-
from pydantic import BaseModel, Field
4-
from openai import AsyncOpenAI, ChatCompletion
5-
import os
61
import logging
2+
import os
73
from enum import Enum
8-
import asyncio
9-
from redis.asyncio import ConnectionPool, Redis
4+
from typing import Any
5+
106
import anthropic
7+
import numpy as np
8+
from openai import AsyncOpenAI
9+
from pydantic import BaseModel, Field
10+
1111

1212
# Setup logging
1313
logger = logging.getLogger(__name__)
@@ -23,16 +23,16 @@ class MemoryMessage(BaseModel):
2323
class MemoryMessagesAndContext(BaseModel):
2424
"""Request payload for adding messages to memory"""
2525

26-
messages: List[MemoryMessage]
27-
context: Optional[str] = None
26+
messages: list[MemoryMessage]
27+
context: str | None = None
2828

2929

3030
class MemoryResponse(BaseModel):
3131
"""Response containing messages and context"""
3232

33-
messages: List[MemoryMessage]
34-
context: Optional[str] = None
35-
tokens: Optional[int] = None
33+
messages: list[MemoryMessage]
34+
context: str | None = None
35+
tokens: int | None = None
3636

3737

3838
class SearchPayload(BaseModel):
@@ -59,27 +59,27 @@ class RedisearchResult(BaseModel):
5959
role: str
6060
content: str
6161
dist: float
62-
63-
62+
63+
6464
class SearchResults(BaseModel):
6565
"""Results from a redisearch query"""
6666

67-
docs: List[RedisearchResult]
67+
docs: list[RedisearchResult]
6868
total: int
6969

7070

7171
class NamespaceQuery(BaseModel):
7272
"""Query parameters for namespace"""
7373

74-
namespace: Optional[str] = None
74+
namespace: str | None = None
7575

7676

7777
class GetSessionsQuery(BaseModel):
7878
"""Query parameters for getting sessions"""
7979

8080
page: int = Field(default=1)
8181
size: int = Field(default=20)
82-
namespace: Optional[str] = None
82+
namespace: str | None = None
8383

8484

8585
class ModelProvider(str, Enum):
@@ -260,7 +260,7 @@ def get_model_config(model_name: str) -> ModelConfig:
260260
class ChatResponse:
261261
"""Unified wrapper for chat responses from different providers"""
262262

263-
def __init__(self, choices: List[Any], usage: Dict[str, int]):
263+
def __init__(self, choices: list[Any], usage: dict[str, int]):
264264
self.choices = choices or []
265265
self.usage = usage or {"total_tokens": 0}
266266

@@ -319,7 +319,7 @@ async def create_chat_completion(self, model: str, prompt: str) -> ChatResponse:
319319
logger.error(f"Error creating chat completion with Anthropic: {e}")
320320
raise
321321

322-
async def create_embedding(self, query_vec: List[str]) -> np.ndarray:
322+
async def create_embedding(self, query_vec: list[str]) -> np.ndarray:
323323
"""
324324
Create embeddings for the given texts
325325
Note: Anthropic doesn't offer an embedding API, so we'll use OpenAI's
@@ -345,22 +345,27 @@ def __init__(self, api_key: str | None = None, base_url: str | None = None):
345345

346346
if openai_api_base:
347347
self.completion_client = AsyncOpenAI(
348-
api_key=openai_api_key, base_url=openai_api_base
348+
api_key=openai_api_key,
349+
base_url=openai_api_base,
349350
)
350351
self.embedding_client = AsyncOpenAI(
351-
api_key=openai_api_key, base_url=openai_api_base
352+
api_key=openai_api_key,
353+
base_url=openai_api_base,
352354
)
353355
else:
354356
self.completion_client = AsyncOpenAI(api_key=openai_api_key)
355357
self.embedding_client = AsyncOpenAI(api_key=openai_api_key)
356358

357359
async def create_chat_completion(
358-
self, model: str, progressive_prompt: str
360+
self,
361+
model: str,
362+
progressive_prompt: str,
359363
) -> ChatResponse:
360364
"""Create a chat completion using the OpenAI API"""
361365
try:
362366
response = await self.completion_client.chat.completions.create(
363-
model=model, messages=[{"role": "user", "content": progressive_prompt}]
367+
model=model,
368+
messages=[{"role": "user", "content": progressive_prompt}],
364369
)
365370

366371
# Convert to unified format
@@ -380,7 +385,7 @@ async def create_chat_completion(
380385
logger.error(f"Error creating chat completion with OpenAI: {e}")
381386
raise
382387

383-
async def create_embedding(self, query_vec: List[str]) -> np.ndarray:
388+
async def create_embedding(self, query_vec: list[str]) -> np.ndarray:
384389
"""Create embeddings for the given texts"""
385390
try:
386391
embeddings = []
@@ -391,7 +396,8 @@ async def create_embedding(self, query_vec: List[str]) -> np.ndarray:
391396
for i in range(0, len(query_vec), batch_size):
392397
batch = query_vec[i : i + batch_size]
393398
response = await self.embedding_client.embeddings.create(
394-
model=embedding_model, input=batch
399+
model=embedding_model,
400+
input=batch,
395401
)
396402
batch_embeddings = [item.embedding for item in response.data]
397403
embeddings.extend(batch_embeddings)
@@ -408,13 +414,13 @@ class ModelClientFactory:
408414
@staticmethod
409415
async def get_client(
410416
model_name: str,
411-
) -> Union[OpenAIClientWrapper, AnthropicClientWrapper]:
417+
) -> OpenAIClientWrapper | AnthropicClientWrapper:
412418
"""Get the appropriate client for a model"""
413419
model_config = get_model_config(model_name)
414420

415421
if model_config.provider == ModelProvider.OPENAI:
416422
return OpenAIClientWrapper(api_key=os.environ.get("OPENAI_API_KEY"))
417-
elif model_config.provider == ModelProvider.ANTHROPIC:
423+
if model_config.provider == ModelProvider.ANTHROPIC:
418424
return AnthropicClientWrapper(api_key=os.environ.get("ANTHROPIC_API_KEY"))
419-
else:
420-
raise ValueError(f"Unsupported model provider: {model_config.provider}")
425+
426+
raise ValueError(f"Unsupported model provider: {model_config.provider}")

pyproject.toml

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
[tool.ruff]
2+
# Exclude a variety of commonly ignored directories
3+
exclude = [
4+
".git",
5+
".github",
6+
".pytest_cache",
7+
"__pycache__",
8+
"env",
9+
"venv",
10+
".venv",
11+
"*.egg-info",
12+
]
13+
14+
# Same as Black
15+
line-length = 88
16+
17+
# Assume Python 3.10
18+
target-version = "py310"
19+
20+
[tool.ruff.lint]
21+
# Enable various rules
22+
select = ["E", "F", "B", "I", "N", "UP", "C4", "PT", "RET", "SIM", "TID"]
23+
# Exclude COM812 which conflicts with the formatter
24+
ignore = ["COM812"]
25+
26+
# Allow unused variables when underscore-prefixed
27+
dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$"
28+
29+
# Fix code when possible
30+
fixable = ["ALL"]
31+
unfixable = []
32+
33+
[tool.ruff.lint.mccabe]
34+
# Flag functions with high cyclomatic complexity
35+
max-complexity = 10
36+
37+
[tool.ruff.lint.isort]
38+
# Group imports by type and organize them alphabetically
39+
known-first-party = ["redis-memory-server"]
40+
section-order = ["future", "standard-library", "third-party", "first-party", "local-folder"]
41+
combine-as-imports = true
42+
lines-after-imports = 2
43+
44+
[tool.ruff.lint.flake8-tidy-imports]
45+
ban-relative-imports = "all"
46+
47+
[tool.ruff.format]
48+
# Use double quotes for strings
49+
quote-style = "double"
50+
# Use spaces for indentation
51+
indent-style = "space"

requirements-dev.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
11
pytest
22
pytest-asyncio
33
testcontainers
4+
pre-commit
5+
ruff>=0.3.0

requirements.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,6 @@ python-multipart>=0.0.6
1010
tiktoken>=0.5.1
1111
structlog>=23.2.0
1212
async-timeout>=4.0.3
13-
httpx>=0.25.1
14-
numpy>=
13+
httpx>=0.25.1
14+
numpy>=2.2.3
1515
pydantic-settings>=2.8.1

setup-dev.sh

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
#!/bin/bash
2+
3+
# Exit on error
4+
set -e
5+
6+
echo "Setting up development environment..."
7+
8+
# Create a virtual environment if it doesn't exist
9+
if [ ! -d "env" ]; then
10+
echo "Creating virtual environment..."
11+
python -m venv env
12+
fi
13+
14+
# Activate the virtual environment
15+
source env/bin/activate
16+
17+
# Install dependencies
18+
echo "Installing development dependencies..."
19+
pip install -r requirements.txt
20+
pip install -r requirements-dev.txt
21+
22+
# Set up pre-commit hooks
23+
echo "Setting up pre-commit hooks..."
24+
pre-commit install
25+
26+
echo "Development environment setup complete!"
27+
echo "You can now run 'pre-commit run --all-files' to check all files in the repository."

0 commit comments

Comments
 (0)