Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.

Commit 606da49

Browse files
committed
Create the pipelines only once in the copilot provider
Since the copilot provider class instance is created once per connection, let's create the pipelines when establishing the connection and reuse them.
1 parent 8dbff63 commit 606da49

File tree

2 files changed

+32
-9
lines changed

2 files changed

+32
-9
lines changed

src/codegate/providers/copilot/pipeline.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ class CopilotPipeline(ABC):
2424

2525
def __init__(self, pipeline_factory: PipelineFactory):
2626
self.pipeline_factory = pipeline_factory
27+
self.instance = self._create_pipeline()
2728
self.normalizer = self._create_normalizer()
2829
self.provider_name = "openai"
2930

@@ -33,7 +34,7 @@ def _create_normalizer(self):
3334
pass
3435

3536
@abstractmethod
36-
def create_pipeline(self) -> SequentialPipelineProcessor:
37+
def _create_pipeline(self) -> SequentialPipelineProcessor:
3738
"""Each strategy defines which pipeline to create"""
3839
pass
3940

@@ -84,7 +85,11 @@ def _create_shortcut_response(result: PipelineResult, model: str) -> bytes:
8485
body = response.model_dump_json(exclude_none=True, exclude_unset=True).encode()
8586
return body
8687

87-
async def process_body(self, headers: list[str], body: bytes) -> Tuple[bytes, PipelineContext]:
88+
async def process_body(
89+
self,
90+
headers: list[str],
91+
body: bytes,
92+
) -> Tuple[bytes, PipelineContext | None]:
8893
"""Common processing logic for all strategies"""
8994
try:
9095
normalized_body = self.normalizer.normalize(body)
@@ -97,8 +102,7 @@ async def process_body(self, headers: list[str], body: bytes) -> Tuple[bytes, Pi
97102
except ValueError:
98103
continue
99104

100-
pipeline = self.create_pipeline()
101-
result = await pipeline.process_request(
105+
result = await self.instance.process_request(
102106
request=normalized_body,
103107
provider=self.provider_name,
104108
model=normalized_body.get("model", "gpt-4o-mini"),
@@ -168,10 +172,13 @@ class CopilotFimPipeline(CopilotPipeline):
168172
format and the FIM pipeline used by all providers.
169173
"""
170174

175+
def __init__(self, pipeline_factory: PipelineFactory):
176+
super().__init__(pipeline_factory)
177+
171178
def _create_normalizer(self):
172179
return CopilotFimNormalizer()
173180

174-
def create_pipeline(self) -> SequentialPipelineProcessor:
181+
def _create_pipeline(self) -> SequentialPipelineProcessor:
175182
return self.pipeline_factory.create_fim_pipeline()
176183

177184

@@ -181,8 +188,11 @@ class CopilotChatPipeline(CopilotPipeline):
181188
format and the FIM pipeline used by all providers.
182189
"""
183190

191+
def __init__(self, pipeline_factory: PipelineFactory):
192+
super().__init__(pipeline_factory)
193+
184194
def _create_normalizer(self):
185195
return CopilotChatNormalizer()
186196

187-
def create_pipeline(self) -> SequentialPipelineProcessor:
197+
def _create_pipeline(self) -> SequentialPipelineProcessor:
188198
return self.pipeline_factory.create_input_pipeline()

src/codegate/providers/copilot/provider.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -150,8 +150,16 @@ def __init__(self, loop: asyncio.AbstractEventLoop):
150150
self.cert_manager = TLSCertDomainManager(self.ca)
151151
self._closing = False
152152
self.pipeline_factory = PipelineFactory(SecretsManager())
153+
self.input_pipeline: Optional[CopilotPipeline] = None
154+
self.fim_pipeline: Optional[CopilotPipeline] = None
155+
# the context as provided by the pipeline
153156
self.context_tracking: Optional[PipelineContext] = None
154157

158+
def _ensure_pipelines(self):
159+
if not self.input_pipeline or not self.fim_pipeline:
160+
self.input_pipeline = CopilotChatPipeline(self.pipeline_factory)
161+
self.fim_pipeline = CopilotFimPipeline(self.pipeline_factory)
162+
155163
def _select_pipeline(self, method: str, path: str) -> Optional[CopilotPipeline]:
156164
if method != "POST":
157165
logger.debug("Not a POST request, no pipeline selected")
@@ -161,10 +169,10 @@ def _select_pipeline(self, method: str, path: str) -> Optional[CopilotPipeline]:
161169
if path == route.path:
162170
if route.pipeline_type == PipelineType.FIM:
163171
logger.debug("Selected FIM pipeline")
164-
return CopilotFimPipeline(self.pipeline_factory)
172+
return self.fim_pipeline
165173
elif route.pipeline_type == PipelineType.CHAT:
166174
logger.debug("Selected CHAT pipeline")
167-
return CopilotChatPipeline(self.pipeline_factory)
175+
return self.input_pipeline
168176

169177
logger.debug("No pipeline selected")
170178
return None
@@ -181,7 +189,6 @@ async def _body_through_pipeline(
181189
# if we didn't select any strategy that would change the request
182190
# let's just pass through the body as-is
183191
return body, None
184-
logger.debug(f"Processing body through pipeline: {len(body)} bytes")
185192
return await strategy.process_body(headers, body)
186193

187194
async def _request_to_target(self, headers: list[str], body: bytes):
@@ -288,6 +295,9 @@ async def _forward_data_through_pipeline(self, data: bytes) -> Union[HttpRequest
288295
http_request.headers,
289296
http_request.body,
290297
)
298+
# TODO: it's weird that we're overwriting the context.
299+
# Should we set the context once? Maybe when
300+
# creating the pipeline instance?
291301
self.context_tracking = context
292302

293303
if context and context.shortcut_response:
@@ -442,6 +452,7 @@ def data_received(self, data: bytes) -> None:
442452
if not self.headers_parsed:
443453
self.headers_parsed = self.parse_headers()
444454
if self.headers_parsed:
455+
self._ensure_pipelines()
445456
if self.request.method == "CONNECT":
446457
self.handle_connect()
447458
self.buffer.clear()
@@ -756,10 +767,12 @@ def connection_made(self, transport: asyncio.Transport) -> None:
756767

757768
def _ensure_output_processor(self) -> None:
758769
if self.proxy.context_tracking is None:
770+
logger.debug("No context tracking, no need to process pipeline")
759771
# No context tracking, no need to process pipeline
760772
return
761773

762774
if self.sse_processor is not None:
775+
logger.debug("Already initialized, no need to reinitialize")
763776
# Already initialized, no need to reinitialize
764777
return
765778

0 commit comments

Comments
 (0)