Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.

Commit 479b04d

Browse files
authored
Merge pull request #315 from jhrozek/copilot_fim_pipeline
Copilot chats are sent through an input pipeline
2 parents 11515a5 + 68bb0b2 commit 479b04d

File tree

7 files changed

+167
-41
lines changed

7 files changed

+167
-41
lines changed

src/codegate/llm_utils/extractor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ async def extract_packages(
2424
model: str = None,
2525
base_url: Optional[str] = None,
2626
api_key: Optional[str] = None,
27-
extra_headers: Optional[Dict[str, str]] = None
27+
extra_headers: Optional[Dict[str, str]] = None,
2828
) -> List[str]:
2929
"""Extract package names from the given content."""
3030
system_prompt = Config.get_config().prompts.lookup_packages
@@ -51,7 +51,7 @@ async def extract_ecosystem(
5151
model: str = None,
5252
base_url: Optional[str] = None,
5353
api_key: Optional[str] = None,
54-
extra_headers: Optional[Dict[str, str]] = None
54+
extra_headers: Optional[Dict[str, str]] = None,
5555
) -> List[str]:
5656
"""Extract ecosystem from the given content."""
5757
system_prompt = Config.get_config().prompts.lookup_ecosystem

src/codegate/llm_utils/llmclient.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ async def _complete_litellm(
137137
temperature=request["temperature"],
138138
base_url=base_url,
139139
response_format=request["response_format"],
140-
extra_headers=extra_headers
140+
extra_headers=extra_headers,
141141
)
142142
content = response["choices"][0]["message"]["content"]
143143

src/codegate/pipeline/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,7 @@ async def process_request(
224224
model: str,
225225
api_key: Optional[str] = None,
226226
api_base: Optional[str] = None,
227-
extra_headers: Optional[Dict[str, str]] = None
227+
extra_headers: Optional[Dict[str, str]] = None,
228228
) -> PipelineResult:
229229
"""Process a request through all pipeline steps"""
230230
self.context.sensitive = PipelineSensitiveData(
@@ -273,7 +273,7 @@ async def process_request(
273273
model: str,
274274
api_key: Optional[str] = None,
275275
api_base: Optional[str] = None,
276-
extra_headers: Optional[Dict[str, str]] = None
276+
extra_headers: Optional[Dict[str, str]] = None,
277277
) -> PipelineResult:
278278
"""Create a new pipeline instance and process the request"""
279279
instance = self.create_instance()

src/codegate/pipeline/codegate_context_retriever/codegate.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ async def __lookup_packages(self, user_query: str, context: PipelineContext):
6666
model=context.sensitive.model,
6767
api_key=context.sensitive.api_key,
6868
base_url=context.sensitive.api_base,
69-
extra_headers=context.metadata.get('extra_headers', None),
69+
extra_headers=context.metadata.get("extra_headers", None),
7070
)
7171

7272
logger.info(f"Packages in user query: {packages}")
@@ -80,7 +80,7 @@ async def __lookup_ecosystem(self, user_query: str, context: PipelineContext):
8080
model=context.sensitive.model,
8181
api_key=context.sensitive.api_key,
8282
base_url=context.sensitive.api_base,
83-
extra_headers=context.metadata.get('extra_headers', None),
83+
extra_headers=context.metadata.get("extra_headers", None),
8484
)
8585

8686
logger.info(f"Ecosystem in user query: {ecosystem}")

src/codegate/pipeline/secrets/secrets.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ async def process(
176176
PipelineResult containing the processed request and context with redaction metadata
177177
"""
178178

179-
if 'messages' not in request:
179+
if "messages" not in request:
180180
return PipelineResult(request=request, context=context)
181181

182182
secrets_manager = context.sensitive.manager

src/codegate/providers/copilot/pipeline.py

Lines changed: 53 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -44,13 +44,21 @@ def _request_id(headers: list[str]) -> str:
4444

4545
@staticmethod
4646
def _get_copilot_headers(headers: Dict[str, str]) -> Dict[str, str]:
47-
copilot_header_names = ['copilot-integration-id', 'editor-plugin-version', 'editor-version',
48-
'openai-intent', 'openai-organization', 'user-agent',
49-
'vscode-machineid', 'vscode-sessionid', 'x-github-api-version',
50-
'x-request-id']
47+
copilot_header_names = [
48+
"copilot-integration-id",
49+
"editor-plugin-version",
50+
"editor-version",
51+
"openai-intent",
52+
"openai-organization",
53+
"user-agent",
54+
"vscode-machineid",
55+
"vscode-sessionid",
56+
"x-github-api-version",
57+
"x-request-id",
58+
]
5159
copilot_headers = {}
5260
for a_name in copilot_header_names:
53-
copilot_headers[a_name] = headers.get(a_name, '')
61+
copilot_headers[a_name] = headers.get(a_name, "")
5462

5563
return copilot_headers
5664

@@ -59,15 +67,23 @@ async def process_body(self, headers: list[str], body: bytes) -> bytes:
5967
try:
6068
normalized_body = self.normalizer.normalize(body)
6169

70+
headers_dict = {}
71+
for header in headers:
72+
try:
73+
name, value = header.split(":", 1)
74+
headers_dict[name.strip().lower()] = value.strip()
75+
except ValueError:
76+
continue
77+
6278
pipeline = self.create_pipeline()
6379
result = await pipeline.process_request(
6480
request=normalized_body,
6581
provider=self.provider_name,
6682
prompt_id=self._request_id(headers),
6783
model=normalized_body.get("model", "gpt-4o-mini"),
68-
api_key = headers.get('authorization','').replace('Bearer ', ''),
69-
api_base = "https://" + headers.get('host', ''),
70-
extra_headers=CopilotPipeline._get_copilot_headers(headers)
84+
api_key=headers_dict.get("authorization", "").replace("Bearer ", ""),
85+
api_base="https://" + headers_dict.get("host", ""),
86+
extra_headers=CopilotPipeline._get_copilot_headers(headers_dict),
7187
)
7288

7389
if result.request:
@@ -101,14 +117,42 @@ def denormalize(self, request_from_pipeline: ChatCompletionRequest) -> bytes:
101117
return json.dumps(normalized_json_body).encode()
102118

103119

120+
class CopilotChatNormalizer:
121+
"""
122+
A custom normalizer for the chat format used by Copilot
123+
The requests are already in the OpenAI format, we just need
124+
to unmarshall them and marshall them back.
125+
"""
126+
127+
def normalize(self, body: bytes) -> ChatCompletionRequest:
128+
json_body = json.loads(body)
129+
return ChatCompletionRequest(**json_body)
130+
131+
def denormalize(self, request_from_pipeline: ChatCompletionRequest) -> bytes:
132+
return json.dumps(request_from_pipeline).encode()
133+
134+
104135
class CopilotFimPipeline(CopilotPipeline):
105136
"""
106137
A pipeline for the FIM format used by Copilot. Combines the normalizer for the FIM
107138
format and the FIM pipeline used by all providers.
108139
"""
109140

110141
def _create_normalizer(self):
111-
return CopilotFimNormalizer() # Uses your custom normalizer
142+
return CopilotFimNormalizer()
112143

113144
def create_pipeline(self):
114145
return self.pipeline_factory.create_fim_pipeline()
146+
147+
148+
class CopilotChatPipeline(CopilotPipeline):
149+
"""
150+
A pipeline for the Chat format used by Copilot. Combines the normalizer for the FIM
151+
format and the FIM pipeline used by all providers.
152+
"""
153+
154+
def _create_normalizer(self):
155+
return CopilotChatNormalizer()
156+
157+
def create_pipeline(self):
158+
return self.pipeline_factory.create_input_pipeline()

src/codegate/providers/copilot/provider.py

Lines changed: 106 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,11 @@
1313
from codegate.pipeline.factory import PipelineFactory
1414
from codegate.pipeline.secrets.manager import SecretsManager
1515
from codegate.providers.copilot.mapping import VALIDATED_ROUTES
16-
from codegate.providers.copilot.pipeline import CopilotFimPipeline
16+
from codegate.providers.copilot.pipeline import (
17+
CopilotChatPipeline,
18+
CopilotFimPipeline,
19+
CopilotPipeline,
20+
)
1721

1822
logger = structlog.get_logger("codegate")
1923

@@ -38,6 +42,61 @@ class HttpRequest:
3842
headers: List[str]
3943
original_path: str
4044
target: Optional[str] = None
45+
body: Optional[bytes] = None
46+
47+
def reconstruct(self) -> bytes:
48+
"""Reconstruct HTTP request from stored details"""
49+
headers = "\r\n".join(self.headers)
50+
request_line = f"{self.method} /{self.path} {self.version}\r\n"
51+
header_block = f"{request_line}{headers}\r\n\r\n"
52+
53+
# Convert header block to bytes and combine with body
54+
result = header_block.encode("utf-8")
55+
if self.body:
56+
result += self.body
57+
58+
return result
59+
60+
61+
def extract_path(full_path: str) -> str:
62+
"""Extract clean path from full URL or path string"""
63+
logger.debug(f"Extracting path from {full_path}")
64+
if full_path.startswith(("http://", "https://")):
65+
parsed = urlparse(full_path)
66+
path = parsed.path
67+
if parsed.query:
68+
path = f"{path}?{parsed.query}"
69+
return path.lstrip("/")
70+
return full_path.lstrip("/")
71+
72+
73+
def http_request_from_bytes(data: bytes) -> Optional[HttpRequest]:
74+
"""
75+
Parse HTTP request details from raw bytes data.
76+
TODO: Make safer by checking for valid HTTP request format, check
77+
if there is a method if there are headers, etc.
78+
"""
79+
if b"\r\n\r\n" not in data:
80+
return None
81+
82+
headers_end = data.index(b"\r\n\r\n")
83+
headers = data[:headers_end].split(b"\r\n")
84+
85+
request = headers[0].decode("utf-8")
86+
method, full_path, version = request.split(" ")
87+
88+
body_start = data.index(b"\r\n\r\n") + 4
89+
body = data[body_start:]
90+
91+
return HttpRequest(
92+
method=method,
93+
path=extract_path(full_path),
94+
version=version,
95+
headers=[header.decode("utf-8") for header in headers[1:]],
96+
original_path=full_path,
97+
target=full_path if method == "CONNECT" else None,
98+
body=body,
99+
)
41100

42101

43102
class CopilotProvider(asyncio.Protocol):
@@ -63,20 +122,26 @@ def __init__(self, loop: asyncio.AbstractEventLoop):
63122
self.pipeline_factory = PipelineFactory(SecretsManager())
64123
self.context_tracking: Optional[PipelineContext] = None
65124

66-
def _select_pipeline(self):
67-
if (
68-
self.request.method == "POST"
69-
and self.request.path == "v1/engines/copilot-codex/completions"
70-
):
125+
def _select_pipeline(self, method: str, path: str) -> Optional[CopilotPipeline]:
126+
if method == "POST" and path == "v1/engines/copilot-codex/completions":
71127
logger.debug("Selected CopilotFimStrategy")
72128
return CopilotFimPipeline(self.pipeline_factory)
129+
if method == "POST" and path == "chat/completions":
130+
logger.debug("Selected CopilotChatStrategy")
131+
return CopilotChatPipeline(self.pipeline_factory)
73132

74133
logger.debug("No pipeline strategy selected")
75134
return None
76135

77-
async def _body_through_pipeline(self, headers: list[str], body: bytes) -> bytes:
136+
async def _body_through_pipeline(
137+
self,
138+
method: str,
139+
path: str,
140+
headers: list[str],
141+
body: bytes,
142+
) -> bytes:
78143
logger.debug(f"Processing body through pipeline: {len(body)} bytes")
79-
strategy = self._select_pipeline()
144+
strategy = self._select_pipeline(method, path)
80145
if strategy is None:
81146
# if we didn't select any strategy that would change the request
82147
# let's just pass through the body as-is
@@ -89,7 +154,12 @@ async def _request_to_target(self, headers: list[str], body: bytes):
89154
).encode()
90155
logger.debug(f"Request Line: {request_line}")
91156

92-
body = await self._body_through_pipeline(headers, body)
157+
body = await self._body_through_pipeline(
158+
self.request.method,
159+
self.request.path,
160+
headers,
161+
body,
162+
)
93163

94164
for header in headers:
95165
if header.lower().startswith("content-length:"):
@@ -113,18 +183,6 @@ def connection_made(self, transport: asyncio.Transport) -> None:
113183
self.peername = transport.get_extra_info("peername")
114184
logger.debug(f"Client connected from {self.peername}")
115185

116-
@staticmethod
117-
def extract_path(full_path: str) -> str:
118-
"""Extract clean path from full URL or path string"""
119-
logger.debug(f"Extracting path from {full_path}")
120-
if full_path.startswith(("http://", "https://")):
121-
parsed = urlparse(full_path)
122-
path = parsed.path
123-
if parsed.query:
124-
path = f"{path}?{parsed.query}"
125-
return path.lstrip("/")
126-
return full_path.lstrip("/")
127-
128186
def get_headers_dict(self) -> Dict[str, str]:
129187
"""Convert raw headers to dictionary format"""
130188
headers_dict = {}
@@ -161,7 +219,7 @@ def parse_headers(self) -> bool:
161219

162220
self.request = HttpRequest(
163221
method=method,
164-
path=self.extract_path(full_path),
222+
path=extract_path(full_path),
165223
version=version,
166224
headers=[header.decode("utf-8") for header in headers[1:]],
167225
original_path=full_path,
@@ -179,9 +237,33 @@ def _check_buffer_size(self, new_data: bytes) -> bool:
179237
"""Check if adding new data would exceed buffer size limit"""
180238
return len(self.buffer) + len(new_data) <= MAX_BUFFER_SIZE
181239

182-
def _forward_data_to_target(self, data: bytes) -> None:
240+
async def _forward_data_through_pipeline(self, data: bytes) -> bytes:
241+
http_request = http_request_from_bytes(data)
242+
if not http_request:
243+
# we couldn't parse this into an HTTP request, so we just pass through
244+
return data
245+
246+
http_request.body = await self._body_through_pipeline(
247+
http_request.method,
248+
http_request.path,
249+
http_request.headers,
250+
http_request.body,
251+
)
252+
253+
for header in http_request.headers:
254+
if header.lower().startswith("content-length:"):
255+
http_request.headers.remove(header)
256+
break
257+
http_request.headers.append(f"Content-Length: {len(http_request.body)}")
258+
259+
pipeline_data = http_request.reconstruct()
260+
261+
return pipeline_data
262+
263+
async def _forward_data_to_target(self, data: bytes) -> None:
183264
"""Forward data to target if connection is established"""
184265
if self.target_transport and not self.target_transport.is_closing():
266+
data = await self._forward_data_through_pipeline(data)
185267
self.target_transport.write(data)
186268

187269
def data_received(self, data: bytes) -> None:
@@ -201,7 +283,7 @@ def data_received(self, data: bytes) -> None:
201283
else:
202284
asyncio.create_task(self.handle_http_request())
203285
else:
204-
self._forward_data_to_target(data)
286+
asyncio.create_task(self._forward_data_to_target(data))
205287

206288
except Exception as e:
207289
logger.error(f"Error processing received data: {e}")

0 commit comments

Comments
 (0)