Skip to content

Commit 5d52f5f

Browse files
committed
feat: add role filtering to message history get_recent method (#349)
Add role parameter to get_recent() and get_relevant() methods in both MessageHistory and SemanticMessageHistory classes to enable filtering messages by role type. Features: - Support single role filtering: role="system" - Support multiple role filtering: role=["system", "user"] - Valid roles: "system", "user", "llm", "tool" - Backward compatible: role=None returns all messages - Works with existing parameters (top_k, session_tag, raw, etc.) - Comprehensive validation with clear error messages The implementation maintains full backward compatibility while enabling users to retrieve only specific message types like system prompts.
1 parent 7607554 commit 5d52f5f

File tree

4 files changed

+515
-8
lines changed

4 files changed

+515
-8
lines changed

redisvl/extensions/message_history/base_history.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ def get_recent(
6060
as_text: bool = False,
6161
raw: bool = False,
6262
session_tag: Optional[str] = None,
63+
role: Optional[Union[str, List[str]]] = None,
6364
) -> Union[List[str], List[Dict[str, str]]]:
6465
"""Retrieve the recent conversation history in sequential order.
6566
@@ -72,13 +73,17 @@ def get_recent(
7273
prompt and response
7374
session_tag (str): Tag to be added to entries to link to a specific
7475
conversation session. Defaults to instance ULID.
76+
role (Optional[Union[str, List[str]]]): Filter messages by role(s).
77+
Can be a single role string ("system", "user", "llm", "tool") or
78+
a list of roles. If None, all roles are returned.
7579
7680
Returns:
7781
Union[str, List[str]]: A single string transcription of the messages
7882
or list of strings if as_text is false.
7983
8084
Raises:
81-
ValueError: If top_k is not an integer greater than or equal to 0.
85+
ValueError: If top_k is not an integer greater than or equal to 0,
86+
or if role contains invalid values.
8287
"""
8388
raise NotImplementedError
8489

redisvl/extensions/message_history/message_history.py

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ def get_recent(
119119
as_text: bool = False,
120120
raw: bool = False,
121121
session_tag: Optional[str] = None,
122+
role: Optional[Union[str, List[str]]] = None,
122123
) -> Union[List[str], List[Dict[str, str]]]:
123124
"""Retrieve the recent message history in sequential order.
124125
@@ -130,17 +131,45 @@ def get_recent(
130131
prompt and response.
131132
session_tag (Optional[str]): Tag of the entries linked to a specific
132133
conversation session. Defaults to instance ULID.
134+
role (Optional[Union[str, List[str]]]): Filter messages by role(s).
135+
Can be a single role string ("system", "user", "llm", "tool") or
136+
a list of roles. If None, all roles are returned.
133137
134138
Returns:
135139
Union[str, List[str]]: A single string transcription of the messages
136140
or list of strings if as_text is false.
137141
138142
Raises:
139-
ValueError: if top_k is not an integer greater than or equal to 0.
143+
ValueError: if top_k is not an integer greater than or equal to 0,
144+
or if role contains invalid values.
140145
"""
141146
if type(top_k) != int or top_k < 0:
142147
raise ValueError("top_k must be an integer greater than or equal to 0")
143148

149+
# Validate and process role parameter
150+
if role is not None:
151+
valid_roles = {"system", "user", "llm", "tool"}
152+
153+
# Handle single role string
154+
if isinstance(role, str):
155+
if role not in valid_roles:
156+
raise ValueError(
157+
f"Invalid role '{role}'. Valid roles are: {valid_roles}"
158+
)
159+
roles_to_filter = [role]
160+
# Handle list of roles
161+
elif isinstance(role, list):
162+
if not role: # Empty list
163+
raise ValueError("roles cannot be empty")
164+
for r in role:
165+
if r not in valid_roles:
166+
raise ValueError(
167+
f"Invalid role '{r}'. Valid roles are: {valid_roles}"
168+
)
169+
roles_to_filter = role
170+
else:
171+
raise ValueError("role must be a string or list of strings")
172+
144173
return_fields = [
145174
ID_FIELD_NAME,
146175
SESSION_FIELD_NAME,
@@ -157,8 +186,22 @@ def get_recent(
157186
else self._default_session_filter
158187
)
159188

189+
# Combine session filter with role filter if provided
190+
filter_expression = session_filter
191+
if role is not None:
192+
if len(roles_to_filter) == 1:
193+
role_filter = Tag(ROLE_FIELD_NAME) == roles_to_filter[0]
194+
else:
195+
# Multiple roles - use OR logic
196+
role_filters = [Tag(ROLE_FIELD_NAME) == r for r in roles_to_filter]
197+
role_filter = role_filters[0]
198+
for rf in role_filters[1:]:
199+
role_filter = role_filter | rf
200+
201+
filter_expression = session_filter & role_filter
202+
160203
query = FilterQuery(
161-
filter_expression=session_filter,
204+
filter_expression=filter_expression,
162205
return_fields=return_fields,
163206
num_results=top_k,
164207
)

redisvl/extensions/message_history/semantic_history.py

Lines changed: 91 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,7 @@ def get_relevant(
173173
session_tag: Optional[str] = None,
174174
raw: bool = False,
175175
distance_threshold: Optional[float] = None,
176+
role: Optional[Union[str, List[str]]] = None,
176177
) -> Union[List[str], List[Dict[str, str]]]:
177178
"""Searches the message history for information semantically related to
178179
the specified prompt.
@@ -195,18 +196,46 @@ def get_relevant(
195196
if no relevant context is found.
196197
raw (bool): Whether to return the full Redis hash entry or just the
197198
message.
199+
role (Optional[Union[str, List[str]]]): Filter messages by role(s).
200+
Can be a single role string ("system", "user", "llm", "tool") or
201+
a list of roles. If None, all roles are returned.
198202
199203
Returns:
200204
Union[List[str], List[Dict[str,str]]: Either a list of strings, or a
201205
list of prompts and responses in JSON containing the most relevant.
202206
203-
Raises ValueError: if top_k is not an integer greater or equal to 0.
207+
Raises ValueError: if top_k is not an integer greater or equal to 0,
208+
or if role contains invalid values.
204209
"""
205210
if type(top_k) != int or top_k < 0:
206211
raise ValueError("top_k must be an integer greater than or equal to -1")
207212
if top_k == 0:
208213
return []
209214

215+
# Validate and process role parameter
216+
if role is not None:
217+
valid_roles = {"system", "user", "llm", "tool"}
218+
219+
# Handle single role string
220+
if isinstance(role, str):
221+
if role not in valid_roles:
222+
raise ValueError(
223+
f"Invalid role '{role}'. Valid roles are: {valid_roles}"
224+
)
225+
roles_to_filter = [role]
226+
# Handle list of roles
227+
elif isinstance(role, list):
228+
if not role: # Empty list
229+
raise ValueError("roles cannot be empty")
230+
for r in role:
231+
if r not in valid_roles:
232+
raise ValueError(
233+
f"Invalid role '{r}'. Valid roles are: {valid_roles}"
234+
)
235+
roles_to_filter = role
236+
else:
237+
raise ValueError("role must be a string or list of strings")
238+
210239
# override distance threshold
211240
distance_threshold = distance_threshold or self._distance_threshold
212241

@@ -225,21 +254,35 @@ def get_relevant(
225254
else self._default_session_filter
226255
)
227256

257+
# Combine session filter with role filter if provided
258+
filter_expression = session_filter
259+
if role is not None:
260+
if len(roles_to_filter) == 1:
261+
role_filter = Tag(ROLE_FIELD_NAME) == roles_to_filter[0]
262+
else:
263+
# Multiple roles - use OR logic
264+
role_filters = [Tag(ROLE_FIELD_NAME) == r for r in roles_to_filter]
265+
role_filter = role_filters[0]
266+
for rf in role_filters[1:]:
267+
role_filter = role_filter | rf
268+
269+
filter_expression = session_filter & role_filter
270+
228271
query = RangeQuery(
229272
vector=self._vectorizer.embed(prompt),
230273
vector_field_name=MESSAGE_VECTOR_FIELD_NAME,
231274
return_fields=return_fields,
232275
distance_threshold=distance_threshold,
233276
num_results=top_k,
234277
return_score=True,
235-
filter_expression=session_filter,
278+
filter_expression=filter_expression,
236279
dtype=self._vectorizer.dtype,
237280
)
238281
messages = self._index.query(query)
239282

240283
# if we don't find semantic matches fallback to returning recent context
241284
if not messages and fall_back:
242-
return self.get_recent(as_text=as_text, top_k=top_k, raw=raw)
285+
return self.get_recent(as_text=as_text, top_k=top_k, raw=raw, role=role)
243286
if raw:
244287
return messages
245288
return self._format_context(messages, as_text)
@@ -250,6 +293,7 @@ def get_recent(
250293
as_text: bool = False,
251294
raw: bool = False,
252295
session_tag: Optional[str] = None,
296+
role: Optional[Union[str, List[str]]] = None,
253297
) -> Union[List[str], List[Dict[str, str]]]:
254298
"""Retrieve the recent message history in sequential order.
255299
@@ -261,17 +305,45 @@ def get_recent(
261305
prompt and response
262306
session_tag (Optional[str]): Tag of the entries linked to a specific
263307
conversation session. Defaults to instance ULID.
308+
role (Optional[Union[str, List[str]]]): Filter messages by role(s).
309+
Can be a single role string ("system", "user", "llm", "tool") or
310+
a list of roles. If None, all roles are returned.
264311
265312
Returns:
266313
Union[str, List[str]]: A single string transcription of the session
267314
or list of strings if as_text is false.
268315
269316
Raises:
270-
ValueError: if top_k is not an integer greater than or equal to 0.
317+
ValueError: if top_k is not an integer greater than or equal to 0,
318+
or if role contains invalid values.
271319
"""
272320
if type(top_k) != int or top_k < 0:
273321
raise ValueError("top_k must be an integer greater than or equal to 0")
274322

323+
# Validate and process role parameter
324+
if role is not None:
325+
valid_roles = {"system", "user", "llm", "tool"}
326+
327+
# Handle single role string
328+
if isinstance(role, str):
329+
if role not in valid_roles:
330+
raise ValueError(
331+
f"Invalid role '{role}'. Valid roles are: {valid_roles}"
332+
)
333+
roles_to_filter = [role]
334+
# Handle list of roles
335+
elif isinstance(role, list):
336+
if not role: # Empty list
337+
raise ValueError("roles cannot be empty")
338+
for r in role:
339+
if r not in valid_roles:
340+
raise ValueError(
341+
f"Invalid role '{r}'. Valid roles are: {valid_roles}"
342+
)
343+
roles_to_filter = role
344+
else:
345+
raise ValueError("role must be a string or list of strings")
346+
275347
return_fields = [
276348
ID_FIELD_NAME,
277349
SESSION_FIELD_NAME,
@@ -288,8 +360,22 @@ def get_recent(
288360
else self._default_session_filter
289361
)
290362

363+
# Combine session filter with role filter if provided
364+
filter_expression = session_filter
365+
if role is not None:
366+
if len(roles_to_filter) == 1:
367+
role_filter = Tag(ROLE_FIELD_NAME) == roles_to_filter[0]
368+
else:
369+
# Multiple roles - use OR logic
370+
role_filters = [Tag(ROLE_FIELD_NAME) == r for r in roles_to_filter]
371+
role_filter = role_filters[0]
372+
for rf in role_filters[1:]:
373+
role_filter = role_filter | rf
374+
375+
filter_expression = session_filter & role_filter
376+
291377
query = FilterQuery(
292-
filter_expression=session_filter,
378+
filter_expression=filter_expression,
293379
return_fields=return_fields,
294380
num_results=top_k,
295381
)

0 commit comments

Comments
 (0)