Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/strands/session/file_session_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,13 @@ def _get_message_path(self, session_id: str, agent_id: str, message_id: int) ->
message_id: Index of the message
Returns:
The filename for the message

Raises:
ValueError: If message_id is not an integer.
"""
if not isinstance(message_id, int):
raise ValueError(f"message_id=<{message_id}> | message id must be an integer")

agent_path = self._get_agent_path(session_id, agent_id)
return os.path.join(agent_path, "messages", f"{MESSAGE_PREFIX}{message_id}.json")

Expand Down
7 changes: 6 additions & 1 deletion src/strands/session/s3_session_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,11 +113,16 @@ def _get_message_path(self, session_id: str, agent_id: str, message_id: int) ->
session_id: ID of the session
agent_id: ID of the agent
message_id: Index of the message
**kwargs: Additional keyword arguments for future extensibility.

Returns:
The key for the message

Raises:
ValueError: If message_id is not an integer.
"""
if not isinstance(message_id, int):
raise ValueError(f"message_id=<{message_id}> | message id must be an integer")

agent_path = self._get_agent_path(session_id, agent_id)
return f"{agent_path}messages/{MESSAGE_PREFIX}{message_id}.json"

Expand Down
22 changes: 20 additions & 2 deletions tests/strands/session/test_file_session_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,14 +224,14 @@ def test_read_messages_with_new_agent(file_manager, sample_session, sample_agent
file_manager.create_session(sample_session)
file_manager.create_agent(sample_session.session_id, sample_agent)

result = file_manager.read_message(sample_session.session_id, sample_agent.agent_id, "nonexistent_message")
result = file_manager.read_message(sample_session.session_id, sample_agent.agent_id, 999)

assert result is None


def test_read_nonexistent_message(file_manager, sample_session, sample_agent):
"""Test reading a message that doesnt exist."""
result = file_manager.read_message(sample_session.session_id, sample_agent.agent_id, "nonexistent_message")
result = file_manager.read_message(sample_session.session_id, sample_agent.agent_id, 999)
assert result is None


Expand Down Expand Up @@ -390,3 +390,21 @@ def test__get_session_path_invalid_session_id(session_id, file_manager):
def test__get_agent_path_invalid_agent_id(agent_id, file_manager):
with pytest.raises(ValueError, match=f"agent_id={agent_id} | id cannot contain path separators"):
file_manager._get_agent_path("session1", agent_id)


@pytest.mark.parametrize(
"message_id",
[
"../../../secret",
"../../attack",
"../escape",
"path/traversal",
"not_an_int",
None,
[],
],
)
def test__get_message_path_invalid_message_id(message_id, file_manager):
"""Test that message_id that is not an integer raises ValueError."""
with pytest.raises(ValueError, match=r"message_id=<.*> \| message id must be an integer"):
file_manager._get_message_path("session1", "agent1", message_id)
20 changes: 19 additions & 1 deletion tests/strands/session/test_s3_session_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ def test_read_nonexistent_message(s3_manager, sample_session, sample_agent, samp
s3_manager.create_agent(sample_session.session_id, sample_agent)

# Read message
result = s3_manager.read_message(sample_session.session_id, sample_agent.agent_id, "nonexistent_message")
result = s3_manager.read_message(sample_session.session_id, sample_agent.agent_id, 999)

assert result is None

Expand Down Expand Up @@ -356,3 +356,21 @@ def test__get_session_path_invalid_session_id(session_id, s3_manager):
def test__get_agent_path_invalid_agent_id(agent_id, s3_manager):
with pytest.raises(ValueError, match=f"agent_id={agent_id} | id cannot contain path separators"):
s3_manager._get_agent_path("session1", agent_id)


@pytest.mark.parametrize(
"message_id",
[
"../../../secret",
"../../attack",
"../escape",
"path/traversal",
"not_an_int",
None,
[],
],
)
def test__get_message_path_invalid_message_id(message_id, s3_manager):
"""Test that message_id that is not an integer raises ValueError."""
with pytest.raises(ValueError, match=r"message_id=<.*> \| message id must be an integer"):
s3_manager._get_message_path("session1", "agent1", message_id)
Loading