diff --git a/src/strands/_identifier.py b/src/strands/_identifier.py index e8b12635c..e02d83473 100644 --- a/src/strands/_identifier.py +++ b/src/strands/_identifier.py @@ -9,6 +9,7 @@ class Identifier(enum.Enum): AGENT = "agent" SESSION = "session" + MESSAGE = "message" def validate(id_: str, type_: Identifier) -> str: diff --git a/src/strands/session/file_session_manager.py b/src/strands/session/file_session_manager.py index 9df86e17a..20d9f74e0 100644 --- a/src/strands/session/file_session_manager.py +++ b/src/strands/session/file_session_manager.py @@ -86,9 +86,16 @@ def _get_message_path(self, session_id: str, agent_id: str, message_id: int) -> message_id: Index of the message Returns: The filename for the message + + Raises: + ValueError: If message id contains path separators. """ + # Validate message_id to prevent path traversal + message_id_str = str(message_id) + message_id_str = _identifier.validate(message_id_str, _identifier.Identifier.MESSAGE) + agent_path = self._get_agent_path(session_id, agent_id) - return os.path.join(agent_path, "messages", f"{MESSAGE_PREFIX}{message_id}.json") + return os.path.join(agent_path, "messages", f"{MESSAGE_PREFIX}{message_id_str}.json") def _read_file(self, path: str) -> dict[str, Any]: """Read JSON file.""" diff --git a/tests/strands/session/test_file_session_manager.py b/tests/strands/session/test_file_session_manager.py index a89222b7e..a0236c420 100644 --- a/tests/strands/session/test_file_session_manager.py +++ b/tests/strands/session/test_file_session_manager.py @@ -390,3 +390,18 @@ def test__get_session_path_invalid_session_id(session_id, file_manager): def test__get_agent_path_invalid_agent_id(agent_id, file_manager): with pytest.raises(ValueError, match=f"agent_id={agent_id} | id cannot contain path separators"): file_manager._get_agent_path("session1", agent_id) + + +@pytest.mark.parametrize( + "message_id", + [ + "../../../secret", + "../../attack", + "../escape", + "path/traversal", + ], +) +def test__get_message_path_invalid_message_id(message_id, file_manager): + """Test that message_id with path traversal sequences raises ValueError.""" + with pytest.raises(ValueError, match=f"message_id={message_id} | id cannot contain path separators"): + file_manager._get_message_path("session1", "agent1", message_id)