Skip to content

Commit 58caf3f

Browse files
improve prompt helper multimodal support (#17831)
1 parent e157ebb commit 58caf3f

File tree

2 files changed

+20
-3
lines changed

2 files changed

+20
-3
lines changed

llama-index-core/llama_index/core/indices/prompt_helper.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
SelectorPromptTemplate,
2929
)
3030
from llama_index.core.prompts.prompt_utils import get_empty_prompt_txt
31-
from llama_index.core.prompts.utils import format_string
31+
from llama_index.core.prompts.utils import format_content_blocks
3232
from llama_index.core.schema import BaseComponent
3333
from llama_index.core.utilities.token_counting import TokenCounter
3434

@@ -198,9 +198,10 @@ def _get_available_chunk_size(
198198
for message in messages:
199199
partial_message = deepcopy(message)
200200

201+
# TODO: This does not count tokens in non-text blocks
201202
prompt_kwargs = prompt.kwargs or {}
202-
partial_message.content = format_string(
203-
partial_message.content or "", **prompt_kwargs
203+
partial_message.blocks = format_content_blocks(
204+
partial_message.blocks, **prompt_kwargs
204205
)
205206

206207
# add to list of partial messages

llama-index-core/llama_index/core/prompts/utils.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import re
33

44
from llama_index.core.base.llms.base import BaseLLM
5+
from llama_index.core.base.llms.types import ContentBlock, TextBlock
56

67

78
class SafeFormatter:
@@ -27,6 +28,21 @@ def format_string(string_to_format: str, **kwargs: str) -> str:
2728
return formatter.format(string_to_format)
2829

2930

31+
def format_content_blocks(
32+
content_blocks: List[ContentBlock], **kwargs: str
33+
) -> List[ContentBlock]:
34+
"""Format content blocks with kwargs."""
35+
formatter = SafeFormatter(format_dict=kwargs)
36+
formatted_blocks: List[ContentBlock] = []
37+
for block in content_blocks:
38+
if isinstance(block, TextBlock):
39+
formatted_blocks.append(TextBlock(text=formatter.format(block.text)))
40+
else:
41+
formatted_blocks.append(block)
42+
43+
return formatted_blocks
44+
45+
3046
def get_template_vars(template_str: str) -> List[str]:
3147
"""Get template variables from a template string."""
3248
variables = []

0 commit comments

Comments
 (0)