18
18
_singleton_lock = Lock ()
19
19
20
20
21
+ class MuxMatchingError (Exception ):
22
+ """An exception for muxing matching errors."""
23
+
24
+ pass
25
+
26
+
21
27
async def get_muxing_rules_registry ():
22
28
"""Returns a singleton instance of the muxing rules registry."""
23
29
@@ -48,9 +54,9 @@ def __init__(
48
54
class MuxingRuleMatcher (ABC ):
49
55
"""Base class for matching muxing rules."""
50
56
51
- def __init__ (self , route : ModelRoute , matcher_blob : str ):
57
+ def __init__ (self , route : ModelRoute , mux_rule : mux_models . MuxRule ):
52
58
self ._route = route
53
- self ._matcher_blob = matcher_blob
59
+ self ._mux_rule = mux_rule
54
60
55
61
@abstractmethod
56
62
def match (self , thing_to_match : mux_models .ThingToMatchMux ) -> bool :
@@ -67,18 +73,20 @@ class MuxingMatcherFactory:
67
73
"""Factory for creating muxing matchers."""
68
74
69
75
@staticmethod
70
- def create (mux_rule : db_models .MuxRule , route : ModelRoute ) -> MuxingRuleMatcher :
76
+ def create (db_mux_rule : db_models .MuxRule , route : ModelRoute ) -> MuxingRuleMatcher :
71
77
"""Create a muxing matcher for the given endpoint and model."""
72
78
73
79
factory : Dict [mux_models .MuxMatcherType , MuxingRuleMatcher ] = {
74
80
mux_models .MuxMatcherType .catch_all : CatchAllMuxingRuleMatcher ,
75
81
mux_models .MuxMatcherType .filename_match : FileMuxingRuleMatcher ,
76
- mux_models .MuxMatcherType .request_type_match : RequestTypeMuxingRuleMatcher ,
82
+ mux_models .MuxMatcherType .fim_filename : RequestTypeAndFileMuxingRuleMatcher ,
83
+ mux_models .MuxMatcherType .chat_filename : RequestTypeAndFileMuxingRuleMatcher ,
77
84
}
78
85
79
86
try :
80
87
# Initialize the MuxingRuleMatcher
81
- return factory [mux_rule .matcher_type ](route , mux_rule .matcher_blob )
88
+ mux_rule = mux_models .MuxRule .from_db_mux_rule (db_mux_rule )
89
+ return factory [mux_rule .matcher_type ](route , mux_rule )
82
90
except KeyError :
83
91
raise ValueError (f"Unknown matcher type: { mux_rule .matcher_type } " )
84
92
@@ -103,47 +111,63 @@ def _extract_request_filenames(self, detected_client: ClientType, data: dict) ->
103
111
return body_extractor .extract_unique_filenames (data )
104
112
except BodyCodeSnippetExtractorError as e :
105
113
logger .error (f"Error extracting filenames from request: { e } " )
106
- return set ()
114
+ raise MuxMatchingError ("Error extracting filenames from request" )
115
+
116
+ def _is_matcher_in_filenames (self , detected_client : ClientType , data : dict ) -> bool :
117
+ """
118
+ Check if the matcher is in the request filenames.
119
+ """
120
+ # Empty matcher_blob means we match everything
121
+ if not self ._mux_rule .matcher :
122
+ return True
123
+ filenames_to_match = self ._extract_request_filenames (detected_client , data )
124
+ # _mux_rule.matcher can be a filename or a file extension. We match if any of the filenames
125
+ # match the rule.
126
+ is_filename_match = any (
127
+ self ._mux_rule .matcher == filename or filename .endswith (self ._mux_rule .matcher )
128
+ for filename in filenames_to_match
129
+ )
130
+ return is_filename_match
107
131
108
132
def match (self , thing_to_match : mux_models .ThingToMatchMux ) -> bool :
109
133
"""
110
- Retun True if there is a filename in the request that matches the matcher_blob.
111
- The matcher_blob is either an extension (e.g. .py) or a filename (e.g. main.py).
134
+ Return True if the matcher is in one of the request filenames.
112
135
"""
113
- # If there is no matcher_blob, we don't match
114
- if not self ._matcher_blob :
115
- return False
116
- filenames_to_match = self ._extract_request_filenames (
136
+ is_rule_matched = self ._is_matcher_in_filenames (
117
137
thing_to_match .client_type , thing_to_match .body
118
138
)
119
- is_filename_match = any (self ._matcher_blob in filename for filename in filenames_to_match )
120
- if is_filename_match :
121
- logger .info (
122
- "Filename rule matched" , filenames = filenames_to_match , matcher = self ._matcher_blob
123
- )
124
- return is_filename_match
139
+ if is_rule_matched :
140
+ logger .info ("Filename rule matched" , matcher = self ._mux_rule .matcher )
141
+ return is_rule_matched
125
142
126
143
127
- class RequestTypeMuxingRuleMatcher (MuxingRuleMatcher ):
128
- """A catch all muxing rule matcher."""
144
+ class RequestTypeAndFileMuxingRuleMatcher (FileMuxingRuleMatcher ):
145
+ """A request type and file muxing rule matcher."""
146
+
147
+ def _is_request_type_match (self , is_fim_request : bool ) -> bool :
148
+ """
149
+ Check if the request type matches the MuxMatcherType.
150
+ """
151
+ incoming_request_type = "fim_filename" if is_fim_request else "chat_filename"
152
+ if incoming_request_type == self ._mux_rule .matcher_type :
153
+ return True
154
+ return False
129
155
130
156
def match (self , thing_to_match : mux_models .ThingToMatchMux ) -> bool :
131
157
"""
132
- Return True if the request type matches the matcher_blob.
133
- The matcher_blob is either "fim" or "chat" .
158
+ Return True if the matcher is in one of the request filenames and
159
+ if the request type matches the MuxMatcherType .
134
160
"""
135
- # If there is no matcher_blob, we don't match
136
- if not self ._matcher_blob :
137
- return False
138
- incoming_request_type = "fim" if thing_to_match .is_fim_request else "chat"
139
- is_request_type_match = self ._matcher_blob == incoming_request_type
140
- if is_request_type_match :
161
+ is_rule_matched = self ._is_matcher_in_filenames (
162
+ thing_to_match .client_type , thing_to_match .body
163
+ ) and self ._is_request_type_match (thing_to_match .is_fim_request )
164
+ if is_rule_matched :
141
165
logger .info (
142
- "Request type rule matched" ,
143
- matcher = self ._matcher_blob ,
144
- request_type = incoming_request_type ,
166
+ "Request type and rule matched" ,
167
+ matcher = self ._mux_rule . matcher ,
168
+ is_fim_request = thing_to_match . is_fim_request ,
145
169
)
146
- return is_request_type_match
170
+ return is_rule_matched
147
171
148
172
149
173
class MuxingRulesinWorkspaces :
0 commit comments