@@ -173,6 +173,7 @@ def get_relevant(
173173 session_tag : Optional [str ] = None ,
174174 raw : bool = False ,
175175 distance_threshold : Optional [float ] = None ,
176+ role : Optional [Union [str , List [str ]]] = None ,
176177 ) -> Union [List [str ], List [Dict [str , str ]]]:
177178 """Searches the message history for information semantically related to
178179 the specified prompt.
@@ -195,18 +196,46 @@ def get_relevant(
195196 if no relevant context is found.
196197 raw (bool): Whether to return the full Redis hash entry or just the
197198 message.
199+ role (Optional[Union[str, List[str]]]): Filter messages by role(s).
200+ Can be a single role string ("system", "user", "llm", "tool") or
201+ a list of roles. If None, all roles are returned.
198202
199203 Returns:
200204 Union[List[str], List[Dict[str,str]]: Either a list of strings, or a
201205 list of prompts and responses in JSON containing the most relevant.
202206
203- Raises ValueError: if top_k is not an integer greater or equal to 0.
207+ Raises ValueError: if top_k is not an integer greater or equal to 0,
208+ or if role contains invalid values.
204209 """
205210 if type (top_k ) != int or top_k < 0 :
206211 raise ValueError ("top_k must be an integer greater than or equal to -1" )
207212 if top_k == 0 :
208213 return []
209214
215+ # Validate and process role parameter
216+ if role is not None :
217+ valid_roles = {"system" , "user" , "llm" , "tool" }
218+
219+ # Handle single role string
220+ if isinstance (role , str ):
221+ if role not in valid_roles :
222+ raise ValueError (
223+ f"Invalid role '{ role } '. Valid roles are: { valid_roles } "
224+ )
225+ roles_to_filter = [role ]
226+ # Handle list of roles
227+ elif isinstance (role , list ):
228+ if not role : # Empty list
229+ raise ValueError ("roles cannot be empty" )
230+ for r in role :
231+ if r not in valid_roles :
232+ raise ValueError (
233+ f"Invalid role '{ r } '. Valid roles are: { valid_roles } "
234+ )
235+ roles_to_filter = role
236+ else :
237+ raise ValueError ("role must be a string or list of strings" )
238+
210239 # override distance threshold
211240 distance_threshold = distance_threshold or self ._distance_threshold
212241
@@ -225,21 +254,35 @@ def get_relevant(
225254 else self ._default_session_filter
226255 )
227256
257+ # Combine session filter with role filter if provided
258+ filter_expression = session_filter
259+ if role is not None :
260+ if len (roles_to_filter ) == 1 :
261+ role_filter = Tag (ROLE_FIELD_NAME ) == roles_to_filter [0 ]
262+ else :
263+ # Multiple roles - use OR logic
264+ role_filters = [Tag (ROLE_FIELD_NAME ) == r for r in roles_to_filter ]
265+ role_filter = role_filters [0 ]
266+ for rf in role_filters [1 :]:
267+ role_filter = role_filter | rf
268+
269+ filter_expression = session_filter & role_filter
270+
228271 query = RangeQuery (
229272 vector = self ._vectorizer .embed (prompt ),
230273 vector_field_name = MESSAGE_VECTOR_FIELD_NAME ,
231274 return_fields = return_fields ,
232275 distance_threshold = distance_threshold ,
233276 num_results = top_k ,
234277 return_score = True ,
235- filter_expression = session_filter ,
278+ filter_expression = filter_expression ,
236279 dtype = self ._vectorizer .dtype ,
237280 )
238281 messages = self ._index .query (query )
239282
240283 # if we don't find semantic matches fallback to returning recent context
241284 if not messages and fall_back :
242- return self .get_recent (as_text = as_text , top_k = top_k , raw = raw )
285+ return self .get_recent (as_text = as_text , top_k = top_k , raw = raw , role = role )
243286 if raw :
244287 return messages
245288 return self ._format_context (messages , as_text )
@@ -250,6 +293,7 @@ def get_recent(
250293 as_text : bool = False ,
251294 raw : bool = False ,
252295 session_tag : Optional [str ] = None ,
296+ role : Optional [Union [str , List [str ]]] = None ,
253297 ) -> Union [List [str ], List [Dict [str , str ]]]:
254298 """Retrieve the recent message history in sequential order.
255299
@@ -261,17 +305,45 @@ def get_recent(
261305 prompt and response
262306 session_tag (Optional[str]): Tag of the entries linked to a specific
263307 conversation session. Defaults to instance ULID.
308+ role (Optional[Union[str, List[str]]]): Filter messages by role(s).
309+ Can be a single role string ("system", "user", "llm", "tool") or
310+ a list of roles. If None, all roles are returned.
264311
265312 Returns:
266313 Union[str, List[str]]: A single string transcription of the session
267314 or list of strings if as_text is false.
268315
269316 Raises:
270- ValueError: if top_k is not an integer greater than or equal to 0.
317+ ValueError: if top_k is not an integer greater than or equal to 0,
318+ or if role contains invalid values.
271319 """
272320 if type (top_k ) != int or top_k < 0 :
273321 raise ValueError ("top_k must be an integer greater than or equal to 0" )
274322
323+ # Validate and process role parameter
324+ if role is not None :
325+ valid_roles = {"system" , "user" , "llm" , "tool" }
326+
327+ # Handle single role string
328+ if isinstance (role , str ):
329+ if role not in valid_roles :
330+ raise ValueError (
331+ f"Invalid role '{ role } '. Valid roles are: { valid_roles } "
332+ )
333+ roles_to_filter = [role ]
334+ # Handle list of roles
335+ elif isinstance (role , list ):
336+ if not role : # Empty list
337+ raise ValueError ("roles cannot be empty" )
338+ for r in role :
339+ if r not in valid_roles :
340+ raise ValueError (
341+ f"Invalid role '{ r } '. Valid roles are: { valid_roles } "
342+ )
343+ roles_to_filter = role
344+ else :
345+ raise ValueError ("role must be a string or list of strings" )
346+
275347 return_fields = [
276348 ID_FIELD_NAME ,
277349 SESSION_FIELD_NAME ,
@@ -288,8 +360,22 @@ def get_recent(
288360 else self ._default_session_filter
289361 )
290362
363+ # Combine session filter with role filter if provided
364+ filter_expression = session_filter
365+ if role is not None :
366+ if len (roles_to_filter ) == 1 :
367+ role_filter = Tag (ROLE_FIELD_NAME ) == roles_to_filter [0 ]
368+ else :
369+ # Multiple roles - use OR logic
370+ role_filters = [Tag (ROLE_FIELD_NAME ) == r for r in roles_to_filter ]
371+ role_filter = role_filters [0 ]
372+ for rf in role_filters [1 :]:
373+ role_filter = role_filter | rf
374+
375+ filter_expression = session_filter & role_filter
376+
291377 query = FilterQuery (
292- filter_expression = session_filter ,
378+ filter_expression = filter_expression ,
293379 return_fields = return_fields ,
294380 num_results = top_k ,
295381 )
0 commit comments