Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.

Commit 7731133

Browse files
fix: FIM not caching correctly non-python files
Closes: #405 The implemented cache was not working correctly because of the way the context is added for FIM requests in other programming languages which are not python. The context for the LLM is provided as single line comments. In Python this means lines which start with the character `#`. Other languages may have other starting sequence for single line comments, e.g. Javascript uses `//`. This PR changes the regex to detect the paths for other languages.
1 parent b71bfb5 commit 7731133

File tree

6 files changed

+331
-123
lines changed

6 files changed

+331
-123
lines changed

src/codegate/codegate_logging.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,10 @@ def _missing_(cls, value: str) -> Optional["LogFormat"]:
5050

5151
def add_origin(logger, log_method, event_dict):
5252
# Add 'origin' if it's bound to the logger but not explicitly in the event dict
53-
if 'origin' not in event_dict and hasattr(logger, '_context'):
54-
origin = logger._context.get('origin')
53+
if "origin" not in event_dict and hasattr(logger, "_context"):
54+
origin = logger._context.get("origin")
5555
if origin:
56-
event_dict['origin'] = origin
56+
event_dict["origin"] = origin
5757
return event_dict
5858

5959

src/codegate/db/connection.py

Lines changed: 15 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,5 @@
11
import asyncio
2-
import hashlib
32
import json
4-
import re
5-
from datetime import timedelta
63
from pathlib import Path
74
from typing import List, Optional
85

@@ -11,7 +8,7 @@
118
from sqlalchemy import text
129
from sqlalchemy.ext.asyncio import create_async_engine
1310

14-
from codegate.config import Config
11+
from codegate.db.fim_cache import FimCache
1512
from codegate.db.models import Alert, Output, Prompt
1613
from codegate.db.queries import (
1714
AsyncQuerier,
@@ -22,7 +19,7 @@
2219

2320
logger = structlog.get_logger("codegate")
2421
alert_queue = asyncio.Queue()
25-
fim_entries = {}
22+
fim_cache = FimCache()
2623

2724

2825
class DbCodeGate:
@@ -183,47 +180,6 @@ async def record_alerts(self, alerts: List[Alert]) -> List[Alert]:
183180
logger.debug(f"Recorded alerts: {recorded_alerts}")
184181
return recorded_alerts
185182

186-
def _extract_request_message(self, request: str) -> Optional[dict]:
187-
"""Extract the user message from the FIM request"""
188-
try:
189-
parsed_request = json.loads(request)
190-
except Exception as e:
191-
logger.exception(f"Failed to extract request message: {request}", error=str(e))
192-
return None
193-
194-
messages = [message for message in parsed_request["messages"] if message["role"] == "user"]
195-
if len(messages) != 1:
196-
logger.warning(f"Expected one user message, found {len(messages)}.")
197-
return None
198-
199-
content_message = messages[0].get("content")
200-
return content_message
201-
202-
def _create_hash_key(self, message: str, provider: str) -> str:
203-
"""Creates a hash key from the message and includes the provider"""
204-
# Try to extract the path from the FIM message. The path is in FIM request in these formats:
205-
# folder/testing_file.py
206-
# Path: file3.py
207-
pattern = r"^#.*?\b([a-zA-Z0-9_\-\/]+\.\w+)\b"
208-
matches = re.findall(pattern, message, re.MULTILINE)
209-
# If no path is found, hash the entire prompt message.
210-
if not matches:
211-
logger.warning("No path found in messages. Creating hash cache from message.")
212-
message_to_hash = f"{message}-{provider}"
213-
else:
214-
# Copilot puts the path at the top of the file. Continue providers contain
215-
# several paths, the one in which the fim is triggered is the last one.
216-
if provider == "copilot":
217-
filepath = matches[0]
218-
else:
219-
filepath = matches[-1]
220-
message_to_hash = f"{filepath}-{provider}"
221-
222-
logger.debug(f"Message to hash: {message_to_hash}")
223-
hashed_content = hashlib.sha256(message_to_hash.encode("utf-8")).hexdigest()
224-
logger.debug(f"Hashed contnet: {hashed_content}")
225-
return hashed_content
226-
227183
def _should_record_context(self, context: Optional[PipelineContext]) -> bool:
228184
"""Check if the context should be recorded in DB"""
229185
if context is None or context.metadata.get("stored_in_db", False):
@@ -237,37 +193,22 @@ def _should_record_context(self, context: Optional[PipelineContext]) -> bool:
237193
if context.input_request.type != "fim":
238194
return True
239195

240-
# Couldn't process the user message. Skip creating a mapping entry.
241-
message = self._extract_request_message(context.input_request.request)
242-
if message is None:
243-
logger.warning(f"Couldn't read FIM message: {message}. Will not record to DB.")
244-
return False
245-
246-
hash_key = self._create_hash_key(message, context.input_request.provider)
247-
old_timestamp = fim_entries.get(hash_key, None)
248-
if old_timestamp is None:
249-
fim_entries[hash_key] = context.input_request.timestamp
250-
return True
196+
return fim_cache.could_store_fim_request(context)
251197

252-
elapsed_seconds = (context.input_request.timestamp - old_timestamp).total_seconds()
253-
if elapsed_seconds < Config.get_config().max_fim_hash_lifetime:
198+
async def record_context(self, context: Optional[PipelineContext]) -> None:
199+
try:
200+
if not self._should_record_context(context):
201+
return
202+
await self.record_request(context.input_request)
203+
await self.record_outputs(context.output_responses)
204+
await self.record_alerts(context.alerts_raised)
205+
context.metadata["stored_in_db"] = True
254206
logger.info(
255-
f"Skipping DB context recording. "
256-
f"Elapsed time since last FIM cache: {timedelta(seconds=elapsed_seconds)}."
207+
f"Recorded context in DB. Output chunks: {len(context.output_responses)}. "
208+
f"Alerts: {len(context.alerts_raised)}."
257209
)
258-
return False
259-
260-
async def record_context(self, context: Optional[PipelineContext]) -> None:
261-
if not self._should_record_context(context):
262-
return
263-
await self.record_request(context.input_request)
264-
await self.record_outputs(context.output_responses)
265-
await self.record_alerts(context.alerts_raised)
266-
context.metadata["stored_in_db"] = True
267-
logger.info(
268-
f"Recorded context in DB. Output chunks: {len(context.output_responses)}. "
269-
f"Alerts: {len(context.alerts_raised)}."
270-
)
210+
except Exception as e:
211+
logger.error(f"Failed to record context: {context}.", error=str(e))
271212

272213

273214
class DbReader(DbCodeGate):

src/codegate/db/fim_cache.py

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
import datetime
2+
import hashlib
3+
import json
4+
import re
5+
from typing import Dict, List, Optional
6+
7+
import structlog
8+
from pydantic import BaseModel
9+
10+
from codegate.config import Config
11+
from codegate.db.models import Alert
12+
from codegate.pipeline.base import AlertSeverity, PipelineContext
13+
14+
logger = structlog.get_logger("codegate")
15+
16+
17+
class CachedFim(BaseModel):
18+
19+
timestamp: datetime.datetime
20+
critical_alerts: List[Alert]
21+
22+
23+
class FimCache:
24+
25+
def __init__(self):
26+
self.cache: Dict[str, CachedFim] = {}
27+
28+
def _extract_message_from_fim_request(self, request: str) -> Optional[str]:
29+
"""Extract the user message from the FIM request"""
30+
try:
31+
parsed_request = json.loads(request)
32+
except Exception as e:
33+
logger.error(f"Failed to extract request message: {request}", error=str(e))
34+
return None
35+
36+
if not isinstance(parsed_request, dict):
37+
logger.warning(f"Expected a dictionary, got {type(parsed_request)}.")
38+
return None
39+
40+
messages = [
41+
message
42+
for message in parsed_request.get("messages", [])
43+
if isinstance(message, dict) and message.get("role", "") == "user"
44+
]
45+
if len(messages) != 1:
46+
logger.warning(f"Expected one user message, found {len(messages)}.")
47+
return None
48+
49+
content_message = messages[0].get("content")
50+
return content_message
51+
52+
def _match_filepath(self, message: str, provider: str) -> Optional[str]:
53+
# Try to extract the path from the FIM message. The path is in FIM request as a comment:
54+
# folder/testing_file.py
55+
# Path: file3.py
56+
# // Path: file3.js <-- Javascript
57+
pattern = r"^(#|//|<!--|--|%|;).*?\b([a-zA-Z0-9_\-\/]+\.\w+)\b"
58+
matches = re.findall(pattern, message, re.MULTILINE)
59+
# If no path is found, hash the entire prompt message.
60+
if not matches:
61+
return None
62+
63+
# Extract only the paths (2nd group from the match)
64+
paths = [match[1] for match in matches]
65+
66+
# Copilot puts the path at the top of the file. Continue providers contain
67+
# several paths, the one in which the fim is triggered is the last one.
68+
if provider == "copilot":
69+
return paths[0]
70+
else:
71+
return paths[-1]
72+
73+
def _calculate_hash_key(self, message: str, provider: str) -> str:
74+
"""Creates a hash key from the message and includes the provider"""
75+
filepath = self._match_filepath(message, provider)
76+
if filepath is None:
77+
logger.warning("No path found in messages. Creating hash key from message.")
78+
message_to_hash = f"{message}-{provider}"
79+
else:
80+
message_to_hash = f"{filepath}-{provider}"
81+
82+
logger.debug(f"Message to hash: {message_to_hash}")
83+
hashed_content = hashlib.sha256(message_to_hash.encode("utf-8")).hexdigest()
84+
logger.debug(f"Hashed contnet: {hashed_content}")
85+
return hashed_content
86+
87+
def _add_cache_entry(self, hash_key: str, context: PipelineContext):
88+
"""Add a new cache entry"""
89+
critical_alerts = [
90+
alert
91+
for alert in context.alerts_raised
92+
if alert.trigger_category == AlertSeverity.CRITICAL.value
93+
]
94+
new_cache = CachedFim(
95+
timestamp=context.input_request.timestamp, critical_alerts=critical_alerts
96+
)
97+
self.cache[hash_key] = new_cache
98+
logger.info(f"Added cache entry for hash key: {hash_key}")
99+
100+
def _are_new_alerts_present(self, context: PipelineContext, cached_entry: CachedFim) -> bool:
101+
"""Check if there are new alerts present"""
102+
new_critical_alerts = [
103+
alert
104+
for alert in context.alerts_raised
105+
if alert.trigger_category == AlertSeverity.CRITICAL.value
106+
]
107+
return len(new_critical_alerts) > len(cached_entry.critical_alerts)
108+
109+
def _is_cached_entry_old(self, context: PipelineContext, cached_entry: CachedFim) -> bool:
110+
"""Check if the cached entry is old"""
111+
elapsed_seconds = (context.input_request.timestamp - cached_entry.timestamp).total_seconds()
112+
return elapsed_seconds > Config.get_config().max_fim_hash_lifetime
113+
114+
def could_store_fim_request(self, context: PipelineContext):
115+
# Couldn't process the user message. Skip creating a mapping entry.
116+
message = self._extract_message_from_fim_request(context.input_request.request)
117+
if message is None:
118+
logger.warning(f"Couldn't read FIM message: {message}. Will not record to DB.")
119+
return False
120+
121+
hash_key = self._calculate_hash_key(message, context.input_request.provider)
122+
cached_entry = self.cache.get(hash_key, None)
123+
if cached_entry is None:
124+
self._add_cache_entry(hash_key, context)
125+
return True
126+
127+
if self._is_cached_entry_old(context, cached_entry):
128+
self._add_cache_entry(hash_key, context)
129+
return True
130+
131+
if self._are_new_alerts_present(context, cached_entry):
132+
self._add_cache_entry(hash_key, context)
133+
return True
134+
135+
logger.info(f"FIM entry already in cache: {hash_key}.")
136+
return False

src/codegate/providers/copilot/provider.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
import asyncio
22
import re
33
import ssl
4-
from src.codegate.codegate_logging import setup_logging
5-
import structlog
64
from dataclasses import dataclass
75
from typing import Dict, List, Optional, Tuple, Union
86
from urllib.parse import unquote, urljoin, urlparse
97

8+
import structlog
109
from litellm.types.utils import Delta, ModelResponse, StreamingChoices
1110

1211
from codegate.ca.codegate_ca import CertificateAuthority
@@ -22,6 +21,7 @@
2221
CopilotPipeline,
2322
)
2423
from codegate.providers.copilot.streaming import SSEProcessor
24+
from src.codegate.codegate_logging import setup_logging
2525

2626
setup_logging()
2727
logger = structlog.get_logger("codegate").bind(origin="copilot_proxy")
@@ -206,7 +206,7 @@ async def _request_to_target(self, headers: list[str], body: bytes):
206206
logger.debug("=" * 40)
207207

208208
for i in range(0, len(body), CHUNK_SIZE):
209-
chunk = body[i: i + CHUNK_SIZE]
209+
chunk = body[i : i + CHUNK_SIZE]
210210
self.target_transport.write(chunk)
211211

212212
def connection_made(self, transport: asyncio.Transport) -> None:
@@ -269,9 +269,7 @@ def _check_buffer_size(self, new_data: bytes) -> bool:
269269
"""Check if adding new data would exceed buffer size limit"""
270270
return len(self.buffer) + len(new_data) <= MAX_BUFFER_SIZE
271271

272-
async def _forward_data_through_pipeline(
273-
self, data: bytes
274-
) -> Union[HttpRequest, HttpResponse]:
272+
async def _forward_data_through_pipeline(self, data: bytes) -> Union[HttpRequest, HttpResponse]:
275273
http_request = http_request_from_bytes(data)
276274
if not http_request:
277275
# we couldn't parse this into an HTTP request, so we just pass through
@@ -287,7 +285,7 @@ async def _forward_data_through_pipeline(
287285

288286
if context and context.shortcut_response:
289287
# Send shortcut response
290-
data_prefix = b'data:'
288+
data_prefix = b"data:"
291289
http_response = HttpResponse(
292290
http_request.version,
293291
200,
@@ -299,7 +297,7 @@ async def _forward_data_through_pipeline(
299297
"Content-Type: application/json",
300298
"Transfer-Encoding: chunked",
301299
],
302-
data_prefix + body
300+
data_prefix + body,
303301
)
304302
return http_response
305303

@@ -639,7 +637,7 @@ async def get_target_url(path: str) -> Optional[str]:
639637
# Check for prefix match
640638
for route in VALIDATED_ROUTES:
641639
# For prefix matches, keep the rest of the path
642-
remaining_path = path[len(route.path):]
640+
remaining_path = path[len(route.path) :]
643641
logger.debug(f"Remaining path: {remaining_path}")
644642
# Make sure we don't end up with double slashes
645643
if remaining_path and remaining_path.startswith("/"):
@@ -793,7 +791,7 @@ def data_received(self, data: bytes) -> None:
793791
self._proxy_transport_write(headers)
794792
logger.debug(f"Headers sent: {headers}")
795793

796-
data = data[header_end + 4:]
794+
data = data[header_end + 4 :]
797795

798796
self._process_chunk(data)
799797

tests/db/test_connection.py

Lines changed: 0 additions & 36 deletions
This file was deleted.

0 commit comments

Comments
 (0)