Skip to content

SageMakerAIModel provider flag for provider-specific formatting. #679

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 13 additions & 9 deletions src/strands/models/llamaapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,8 @@ def get_config(self) -> LlamaConfig:
"""
return self.config

def _format_request_message_content(self, content: ContentBlock) -> dict[str, Any]:
@staticmethod
def _format_request_message_content(content: ContentBlock) -> dict[str, Any]:
"""Format a LlamaAPI content block.

- NOTE: "reasoningContent" and "video" are not supported currently.
Expand Down Expand Up @@ -116,7 +117,8 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An

raise TypeError(f"content_type=<{next(iter(content))}> | unsupported type")

def _format_request_message_tool_call(self, tool_use: ToolUse) -> dict[str, Any]:
@staticmethod
def _format_request_message_tool_call(tool_use: ToolUse) -> dict[str, Any]:
"""Format a Llama API tool call.

Args:
Expand All @@ -133,7 +135,8 @@ def _format_request_message_tool_call(self, tool_use: ToolUse) -> dict[str, Any]
"id": tool_use["toolUseId"],
}

def _format_request_tool_message(self, tool_result: ToolResult) -> dict[str, Any]:
@staticmethod
def _format_request_tool_message(tool_result: ToolResult) -> dict[str, Any]:
"""Format a Llama API tool message.

Args:
Expand All @@ -153,10 +156,11 @@ def _format_request_tool_message(self, tool_result: ToolResult) -> dict[str, Any
return {
"role": "tool",
"tool_call_id": tool_result["toolUseId"],
"content": [self._format_request_message_content(content) for content in contents],
"content": [LlamaAPIModel._format_request_message_content(content) for content in contents],
}

def _format_request_messages(self, messages: Messages, system_prompt: Optional[str] = None) -> list[dict[str, Any]]:
@classmethod
def format_request_messages(cls, messages: Messages, system_prompt: Optional[str] = None) -> list[dict[str, Any]]:
"""Format a LlamaAPI compatible messages array.

Args:
Expand All @@ -174,17 +178,17 @@ def _format_request_messages(self, messages: Messages, system_prompt: Optional[s

formatted_contents: list[dict[str, Any]] | dict[str, Any] | str = ""
formatted_contents = [
self._format_request_message_content(content)
cls._format_request_message_content(content=content)
for content in contents
if not any(block_type in content for block_type in ["toolResult", "toolUse"])
]
formatted_tool_calls = [
self._format_request_message_tool_call(content["toolUse"])
cls._format_request_message_tool_call(tool_use=content["toolUse"])
for content in contents
if "toolUse" in content
]
formatted_tool_messages = [
self._format_request_tool_message(content["toolResult"])
cls._format_request_tool_message(tool_result=content["toolResult"])
for content in contents
if "toolResult" in content
]
Expand Down Expand Up @@ -220,7 +224,7 @@ def format_request(
format.
"""
request = {
"messages": self._format_request_messages(messages, system_prompt),
"messages": self.format_request_messages(messages, system_prompt),
"model": self.config["model_id"],
"stream": True,
"tools": [
Expand Down
20 changes: 12 additions & 8 deletions src/strands/models/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,8 @@ def get_config(self) -> MistralConfig:
"""
return self.config

def _format_request_message_content(self, content: ContentBlock) -> Union[str, dict[str, Any]]:
@staticmethod
def _format_request_message_content(content: ContentBlock) -> Union[str, dict[str, Any]]:
"""Format a Mistral content block.

Args:
Expand Down Expand Up @@ -141,7 +142,8 @@ def _format_request_message_content(self, content: ContentBlock) -> Union[str, d

raise TypeError(f"content_type=<{next(iter(content))}> | unsupported type")

def _format_request_message_tool_call(self, tool_use: ToolUse) -> dict[str, Any]:
@staticmethod
def _format_request_message_tool_call(tool_use: ToolUse) -> dict[str, Any]:
"""Format a Mistral tool call.

Args:
Expand All @@ -159,7 +161,8 @@ def _format_request_message_tool_call(self, tool_use: ToolUse) -> dict[str, Any]
"type": "function",
}

def _format_request_tool_message(self, tool_result: ToolResult) -> dict[str, Any]:
@staticmethod
def _format_request_tool_message(tool_result: ToolResult) -> dict[str, Any]:
"""Format a Mistral tool message.

Args:
Expand All @@ -184,7 +187,8 @@ def _format_request_tool_message(self, tool_result: ToolResult) -> dict[str, Any
"tool_call_id": tool_result["toolUseId"],
}

def _format_request_messages(self, messages: Messages, system_prompt: Optional[str] = None) -> list[dict[str, Any]]:
@classmethod
def format_request_messages(cls, messages: Messages, system_prompt: Optional[str] = None) -> list[dict[str, Any]]:
"""Format a Mistral compatible messages array.

Args:
Expand All @@ -209,13 +213,13 @@ def _format_request_messages(self, messages: Messages, system_prompt: Optional[s

for content in contents:
if "text" in content:
formatted_content = self._format_request_message_content(content)
formatted_content = cls._format_request_message_content(content)
if isinstance(formatted_content, str):
text_contents.append(formatted_content)
elif "toolUse" in content:
tool_calls.append(self._format_request_message_tool_call(content["toolUse"]))
tool_calls.append(cls._format_request_message_tool_call(content["toolUse"]))
elif "toolResult" in content:
tool_messages.append(self._format_request_tool_message(content["toolResult"]))
tool_messages.append(cls._format_request_tool_message(content["toolResult"]))

if text_contents or tool_calls:
formatted_message: dict[str, Any] = {
Expand Down Expand Up @@ -251,7 +255,7 @@ def format_request(
"""
request: dict[str, Any] = {
"model": self.config["model_id"],
"messages": self._format_request_messages(messages, system_prompt),
"messages": self.format_request_messages(messages, system_prompt),
}

if "max_tokens" in self.config:
Expand Down
Loading