Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.

Commit faee45f

Browse files
Implement muxing routing cache (#851)
* Implement muxing routing cache This allows us to have a parsed and in-memory representation of the routing rule engine. The intent is to reduce calls to the database. This is more expensive to maintain since we need to refresh the cache on every operation pertaining to models, endpoints, workspaces, and muxes themselves. Signed-off-by: Juan Antonio Osorio <[email protected]> * Add logic to repopulate mux cache in case a model changes. Signed-off-by: Juan Antonio Osorio <[email protected]> * Add TODO Signed-off-by: Juan Antonio Osorio <[email protected]> * Changes to integrate mux cache with mux router * Changed Lock primites to asyncio --------- Signed-off-by: Juan Antonio Osorio <[email protected]> Co-authored-by: Alejandro Ponce <[email protected]>
1 parent 16a3e13 commit faee45f

File tree

10 files changed

+295
-101
lines changed

10 files changed

+295
-101
lines changed

src/codegate/cli.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from codegate.providers.copilot.provider import CopilotProvider
2222
from codegate.server import init_app
2323
from codegate.storage.utils import restore_storage_backup
24+
from codegate.workspaces import crud as wscrud
2425

2526

2627
class UvicornServer:
@@ -341,6 +342,8 @@ def serve( # noqa: C901
341342

342343
registry = app.provider_registry
343344
loop.run_until_complete(provendcrud.initialize_provider_endpoints(registry))
345+
wsc = wscrud.WorkspaceCrud()
346+
loop.run_until_complete(wsc.initialize_mux_registry())
344347

345348
# Run the server
346349
try:

src/codegate/db/connection.py

Lines changed: 16 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
GetPromptWithOutputsRow,
2121
GetWorkspaceByNameConditions,
2222
MuxRule,
23-
MuxRuleProviderEndpoint,
2423
Output,
2524
Prompt,
2625
ProviderAuthMaterial,
@@ -711,6 +710,22 @@ async def get_provider_endpoint_by_id(self, provider_id: str) -> Optional[Provid
711710
)
712711
return provider[0] if provider else None
713712

713+
async def get_auth_material_by_provider_id(
714+
self, provider_id: str
715+
) -> Optional[ProviderAuthMaterial]:
716+
sql = text(
717+
"""
718+
SELECT id as provider_endpoint_id, auth_type, auth_blob
719+
FROM provider_endpoints
720+
WHERE id = :provider_endpoint_id
721+
"""
722+
)
723+
conditions = {"provider_endpoint_id": provider_id}
724+
auth_material = await self._exec_select_conditions_to_pydantic(
725+
ProviderAuthMaterial, sql, conditions, should_raise=True
726+
)
727+
return auth_material[0] if auth_material else None
728+
714729
async def get_provider_endpoints(self) -> List[ProviderEndpoint]:
715730
sql = text(
716731
"""
@@ -778,26 +793,6 @@ async def get_muxes_by_workspace(self, workspace_id: str) -> List[MuxRule]:
778793
)
779794
return muxes
780795

781-
async def get_muxes_with_provider_by_workspace(
782-
self, workspace_id: str
783-
) -> List[MuxRuleProviderEndpoint]:
784-
sql = text(
785-
"""
786-
SELECT m.id, m.provider_endpoint_id, m.provider_model_name, m.workspace_id,
787-
m.matcher_type, m.matcher_blob, m.priority, m.created_at, m.updated_at,
788-
pe.provider_type, pe.endpoint, pe.auth_type, pe.auth_blob
789-
FROM muxes m
790-
INNER JOIN provider_endpoints pe ON pe.id = m.provider_endpoint_id
791-
WHERE m.workspace_id = :workspace_id
792-
ORDER BY priority ASC
793-
"""
794-
)
795-
conditions = {"workspace_id": workspace_id}
796-
muxes = await self._exec_select_conditions_to_pydantic(
797-
MuxRuleProviderEndpoint, sql, conditions, should_raise=True
798-
)
799-
return muxes
800-
801796

802797
def init_db_sync(db_path: Optional[str] = None):
803798
"""DB will be initialized in the constructor in case it doesn't exist."""

src/codegate/db/models.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -199,19 +199,3 @@ class MuxRule(BaseModel):
199199
priority: int
200200
created_at: Optional[datetime.datetime] = None
201201
updated_at: Optional[datetime.datetime] = None
202-
203-
204-
class MuxRuleProviderEndpoint(BaseModel):
205-
id: str
206-
provider_endpoint_id: str
207-
provider_model_name: str
208-
workspace_id: str
209-
matcher_type: str
210-
matcher_blob: str
211-
priority: int
212-
created_at: Optional[datetime.datetime] = None
213-
updated_at: Optional[datetime.datetime] = None
214-
provider_type: ProviderType
215-
endpoint: str
216-
auth_type: str
217-
auth_blob: str

src/codegate/muxing/__init__.py

Whitespace-only changes.

src/codegate/mux/adapter.py renamed to src/codegate/muxing/adapter.py

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from ollama import ChatResponse
1111

1212
from codegate.db import models as db_models
13+
from codegate.muxing import rulematcher
1314
from codegate.providers.ollama.adapter import OLlamaToModel
1415

1516
logger = structlog.get_logger("codegate")
@@ -55,21 +56,17 @@ def _from_anthropic_to_openai(self, anthropic_body: dict) -> dict:
5556
del new_body["system"]
5657
return new_body
5758

58-
def _get_provider_formatted_url(
59-
self, mux_and_provider: db_models.MuxRuleProviderEndpoint
60-
) -> str:
59+
def _get_provider_formatted_url(self, model_route: rulematcher.ModelRoute) -> str:
6160
"""Get the provider formatted URL to use in base_url. Note this value comes from DB"""
62-
if mux_and_provider.provider_type == db_models.ProviderType.openai:
63-
return f"{mux_and_provider.endpoint}/v1"
64-
return mux_and_provider.endpoint
61+
if model_route.endpoint.provider_type == db_models.ProviderType.openai:
62+
return f"{model_route.endpoint.endpoint}/v1"
63+
return model_route.endpoint.endpoint
6564

66-
def _set_destination_info(
67-
self, data: dict, mux_and_provider: db_models.MuxRuleProviderEndpoint
68-
) -> dict:
65+
def _set_destination_info(self, data: dict, model_route: rulematcher.ModelRoute) -> dict:
6966
"""Set the destination provider info."""
7067
new_data = copy.deepcopy(data)
71-
new_data["model"] = mux_and_provider.provider_model_name
72-
new_data["base_url"] = self._get_provider_formatted_url(mux_and_provider)
68+
new_data["model"] = model_route.model.name
69+
new_data["base_url"] = self._get_provider_formatted_url(model_route)
7370
return new_data
7471

7572
def _identify_provider(self, data: dict) -> db_models.ProviderType:
@@ -79,22 +76,20 @@ def _identify_provider(self, data: dict) -> db_models.ProviderType:
7976
else:
8077
return db_models.ProviderType.openai
8178

82-
def map_body_to_dest(
83-
self, mux_and_provider: db_models.MuxRuleProviderEndpoint, data: dict
84-
) -> dict:
79+
def map_body_to_dest(self, model_route: rulematcher.ModelRoute, data: dict) -> dict:
8580
"""
8681
Map the body to the destination provider.
8782
8883
We only need to transform the body if the destination or origin provider is Anthropic.
8984
"""
9085
origin_prov = self._identify_provider(data)
91-
if mux_and_provider.provider_type == db_models.ProviderType.anthropic:
86+
if model_route.endpoint.provider_type == db_models.ProviderType.anthropic:
9287
if origin_prov != db_models.ProviderType.anthropic:
9388
data = self._from_openai_to_antrhopic(data)
9489
else:
9590
if origin_prov == db_models.ProviderType.anthropic:
9691
data = self._from_anthropic_to_openai(data)
97-
return self._set_destination_info(data, mux_and_provider)
92+
return self._set_destination_info(data, model_route)
9893

9994

10095
class StreamChunkFormatter:

src/codegate/mux/router.py renamed to src/codegate/muxing/router.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
import structlog
44
from fastapi import APIRouter, HTTPException, Request
55

6-
from codegate.mux.adapter import BodyAdapter, ResponseAdapter
6+
from codegate.muxing import rulematcher
7+
from codegate.muxing.adapter import BodyAdapter, ResponseAdapter
78
from codegate.providers.registry import ProviderRegistry
89
from codegate.workspaces.crud import WorkspaceCrud
910

@@ -53,27 +54,27 @@ async def route_to_dest_provider(
5354
body = await request.body()
5455
data = json.loads(body)
5556

57+
mux_registry = await rulematcher.get_muxing_rules_registry()
5658
try:
57-
active_ws_muxes = await self._ws_crud.get_active_workspace_muxes()
59+
# TODO: For future releases we will have to idenify a thing_to_match
60+
# and use our registry to get the correct muxes for the active workspace
61+
model_route = await mux_registry.get_match_for_active_workspace(thing_to_match=None)
5862
except Exception as e:
5963
logger.error(f"Error getting active workspace muxes: {e}")
6064
raise HTTPException(str(e))
6165

62-
# TODO: Implement the muxing logic here. For the moment we will assume
63-
# that we have a single mux, i.e. a single destination provider.
64-
if len(active_ws_muxes) == 0:
65-
raise HTTPException(status_code=404, detail="No active workspace muxes found")
66-
mux_and_provider = active_ws_muxes[0]
66+
if not model_route:
67+
raise HTTPException("No rule found for the active workspace", status_code=404)
6768

6869
# Parse the input data and map it to the destination provider format
6970
rest_of_path = self._ensure_path_starts_with_slash(rest_of_path)
70-
new_data = self._body_adapter.map_body_to_dest(mux_and_provider, data)
71-
provider = self._provider_registry.get_provider(mux_and_provider.provider_type)
72-
api_key = mux_and_provider.auth_blob
71+
new_data = self._body_adapter.map_body_to_dest(model_route, data)
72+
provider = self._provider_registry.get_provider(model_route.endpoint.provider_type)
73+
api_key = model_route.auth_material.auth_blob
7374

7475
# Send the request to the destination provider. It will run the pipeline
7576
response = await provider.process_request(new_data, api_key, rest_of_path)
7677
# Format the response to the client always using the OpenAI format
7778
return self._response_adapter.format_response_to_client(
78-
response, mux_and_provider.provider_type
79+
response, model_route.endpoint.provider_type
7980
)

src/codegate/muxing/rulematcher.py

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
import copy
2+
from abc import ABC, abstractmethod
3+
from asyncio import Lock
4+
from typing import List, Optional
5+
6+
from codegate.db import models as db_models
7+
8+
_muxrules_sgtn = None
9+
10+
_singleton_lock = Lock()
11+
12+
13+
async def get_muxing_rules_registry():
14+
"""Returns a singleton instance of the muxing rules registry."""
15+
16+
global _muxrules_sgtn
17+
18+
if _muxrules_sgtn is None:
19+
async with _singleton_lock:
20+
if _muxrules_sgtn is None:
21+
_muxrules_sgtn = MuxingRulesinWorkspaces()
22+
23+
return _muxrules_sgtn
24+
25+
26+
class ModelRoute:
27+
"""A route for a model."""
28+
29+
def __init__(
30+
self,
31+
model: db_models.ProviderModel,
32+
endpoint: db_models.ProviderEndpoint,
33+
auth_material: db_models.ProviderAuthMaterial,
34+
):
35+
self.model = model
36+
self.endpoint = endpoint
37+
self.auth_material = auth_material
38+
39+
40+
class MuxingRuleMatcher(ABC):
41+
"""Base class for matching muxing rules."""
42+
43+
def __init__(self, route: ModelRoute):
44+
self._route = route
45+
46+
@abstractmethod
47+
def match(self, thing_to_match) -> bool:
48+
"""Return True if the rule matches the thing_to_match."""
49+
pass
50+
51+
def destination(self) -> ModelRoute:
52+
"""Return the destination of the rule."""
53+
54+
return self._route
55+
56+
57+
class MuxingMatcherFactory:
58+
"""Factory for creating muxing matchers."""
59+
60+
@staticmethod
61+
def create(mux_rule: db_models.MuxRule, route: ModelRoute) -> MuxingRuleMatcher:
62+
"""Create a muxing matcher for the given endpoint and model."""
63+
64+
factory = {
65+
"catch_all": CatchAllMuxingRuleMatcher,
66+
}
67+
68+
try:
69+
return factory[mux_rule.matcher_type](route)
70+
except KeyError:
71+
raise ValueError(f"Unknown matcher type: {mux_rule.matcher_type}")
72+
73+
74+
class CatchAllMuxingRuleMatcher(MuxingRuleMatcher):
75+
"""A catch all muxing rule matcher."""
76+
77+
def match(self, thing_to_match) -> bool:
78+
return True
79+
80+
81+
class MuxingRulesinWorkspaces:
82+
"""A thread safe dictionary to store the muxing rules in workspaces."""
83+
84+
def __init__(self) -> None:
85+
super().__init__()
86+
self._lock = Lock()
87+
self._active_workspace = ""
88+
self._ws_rules = {}
89+
90+
async def get_ws_rules(self, workspace_name: str) -> List[MuxingRuleMatcher]:
91+
"""Get the rules for the given workspace."""
92+
async with self._lock:
93+
return copy.deepcopy(self._ws_rules.get(workspace_name, []))
94+
95+
async def set_ws_rules(self, workspace_name: str, rules: List[MuxingRuleMatcher]) -> None:
96+
"""Set the rules for the given workspace."""
97+
async with self._lock:
98+
self._ws_rules[workspace_name] = rules
99+
100+
async def delete_ws_rules(self, workspace_name: str) -> None:
101+
"""Delete the rules for the given workspace."""
102+
async with self._lock:
103+
del self._ws_rules[workspace_name]
104+
105+
async def set_active_workspace(self, workspace_name: str) -> None:
106+
"""Set the active workspace."""
107+
self._active_workspace = workspace_name
108+
109+
async def get_registries(self) -> List[str]:
110+
"""Get the list of workspaces."""
111+
async with self._lock:
112+
return list(self._ws_rules.keys())
113+
114+
async def get_match_for_active_workspace(self, thing_to_match) -> Optional[ModelRoute]:
115+
"""Get the first match for the given thing_to_match."""
116+
117+
# We iterate over all the rules and return the first match
118+
# Since we already do a deepcopy in __getitem__, we don't need to lock here
119+
try:
120+
rules = await self.get_ws_rules(self._active_workspace)
121+
for rule in rules:
122+
if rule.match(thing_to_match):
123+
return rule.destination()
124+
return None
125+
except KeyError:
126+
raise RuntimeError("No rules found for the active workspace")

src/codegate/providers/crud/crud.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from codegate.db.connection import DbReader, DbRecorder
1313
from codegate.providers.base import BaseProvider
1414
from codegate.providers.registry import ProviderRegistry, get_provider_registry
15+
from codegate.workspaces import crud as workspace_crud
1516

1617
logger = structlog.get_logger("codegate")
1718

@@ -32,6 +33,7 @@ class ProviderCrud:
3233
def __init__(self):
3334
self._db_reader = DbReader()
3435
self._db_writer = DbRecorder()
36+
self._ws_crud = workspace_crud.WorkspaceCrud()
3537

3638
async def list_endpoints(self) -> List[apimodelsv1.ProviderEndpoint]:
3739
"""List all the endpoints."""
@@ -176,6 +178,9 @@ async def update_endpoint(
176178
)
177179
)
178180

181+
# a model might have been deleted, let's repopulate the cache
182+
await self._ws_crud.repopulate_mux_cache()
183+
179184
return apimodelsv1.ProviderEndpoint.from_db_model(dbendpoint)
180185

181186
async def configure_auth_material(
@@ -208,6 +213,8 @@ async def delete_endpoint(self, provider_id: UUID):
208213

209214
await self._db_writer.delete_provider_endpoint(dbendpoint)
210215

216+
await self._ws_crud.repopulate_mux_cache()
217+
211218
async def models_by_provider(self, provider_id: UUID) -> List[apimodelsv1.ModelByProvider]:
212219
"""Get the models by provider."""
213220

src/codegate/server.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from codegate import __description__, __version__
1212
from codegate.api.v1 import v1
1313
from codegate.db.models import ProviderType
14-
from codegate.mux.router import MuxRouter
14+
from codegate.muxing.router import MuxRouter
1515
from codegate.pipeline.factory import PipelineFactory
1616
from codegate.providers.anthropic.provider import AnthropicProvider
1717
from codegate.providers.llamacpp.provider import LlamaCppProvider

0 commit comments

Comments
 (0)