Skip to content

Commit 0fe079c

Browse files
feat: add option to keep tool messages (#323)
Adds an option to the Toolbox/ChainProcessor to keep tool messages by avoiding to clone the original messageBag Fixes #321 Co-authored-by: Philip Heimböck <[email protected]>
1 parent b95023e commit 0fe079c

File tree

3 files changed

+89
-1
lines changed

3 files changed

+89
-1
lines changed

README.md

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,34 @@ $eventDispatcher->addListener(ToolCallsExecuted::class, function (ToolCallsExecu
347347
});
348348
```
349349

350+
#### Keeping Tool Messages
351+
352+
Sometimes you might wish to keep the tool messages (`AssistantMessage` containing the `toolCalls` and `ToolCallMessage` containing the response) in the context.
353+
Enable the `keepToolMessages` flag of the toolbox' `ChainProcessor` to ensure those messages will be added to your `MessageBag`.
354+
355+
```php
356+
use PhpLlm\LlmChain\Chain\Toolbox\ChainProcessor;
357+
use PhpLlm\LlmChain\Chain\Toolbox\Toolbox;
358+
359+
// Platform & LLM instantiation
360+
$messages = new MessageBag(
361+
Message::forSystem(<<<PROMPT
362+
Please answer all user questions only using the similary_search tool. Do not add information and if you cannot
363+
find an answer, say so.
364+
PROMPT),
365+
Message::ofUser('...') // The user's question.
366+
);
367+
368+
$yourTool = new YourTool();
369+
370+
$toolbox = Toolbox::create($yourTool);
371+
$toolProcessor = new ChainProcessor($toolbox, keepToolMessages: true);
372+
373+
$chain = new Chain($platform, $llm, inputProcessor: [$toolProcessor], outputProcessor: [$toolProcessor]);
374+
$response = $chain->call($messages);
375+
// $messages will now include the tool messages
376+
```
377+
350378
#### Code Examples (with built-in tools)
351379
352380
1. [Brave Tool](examples/toolbox/brave.php)

src/Chain/Toolbox/ChainProcessor.php

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ public function __construct(
3333
private readonly ToolboxInterface $toolbox,
3434
private readonly ToolResultConverter $resultConverter = new ToolResultConverter(),
3535
private readonly ?EventDispatcherInterface $eventDispatcher = null,
36+
private readonly bool $keepToolMessages = false,
3637
) {
3738
}
3839

@@ -86,7 +87,7 @@ private function isFlatStringArray(array $tools): bool
8687
private function handleToolCallsCallback(Output $output): \Closure
8788
{
8889
return function (ToolCallResponse $response, ?AssistantMessage $streamedAssistantResponse = null) use ($output): ResponseInterface {
89-
$messages = clone $output->messages;
90+
$messages = $this->keepToolMessages ? $output->messages : clone $output->messages;
9091

9192
if (null !== $streamedAssistantResponse && '' !== $streamedAssistantResponse->content) {
9293
$messages->add($streamedAssistantResponse);

tests/Chain/Toolbox/ChainProcessorTest.php

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,19 @@
44

55
namespace PhpLlm\LlmChain\Tests\Chain\Toolbox;
66

7+
use PhpLlm\LlmChain\Chain\ChainInterface;
78
use PhpLlm\LlmChain\Chain\Exception\MissingModelSupportException;
89
use PhpLlm\LlmChain\Chain\Input;
10+
use PhpLlm\LlmChain\Chain\Output;
911
use PhpLlm\LlmChain\Chain\Toolbox\ChainProcessor;
1012
use PhpLlm\LlmChain\Chain\Toolbox\ToolboxInterface;
1113
use PhpLlm\LlmChain\Platform\Capability;
14+
use PhpLlm\LlmChain\Platform\Message\AssistantMessage;
1215
use PhpLlm\LlmChain\Platform\Message\MessageBag;
16+
use PhpLlm\LlmChain\Platform\Message\ToolCallMessage;
1317
use PhpLlm\LlmChain\Platform\Model;
18+
use PhpLlm\LlmChain\Platform\Response\ToolCall;
19+
use PhpLlm\LlmChain\Platform\Response\ToolCallResponse;
1420
use PhpLlm\LlmChain\Platform\Tool\ExecutionReference;
1521
use PhpLlm\LlmChain\Platform\Tool\Tool;
1622
use PHPUnit\Framework\Attributes\CoversClass;
@@ -20,7 +26,10 @@
2026

2127
#[CoversClass(ChainProcessor::class)]
2228
#[UsesClass(Input::class)]
29+
#[UsesClass(Output::class)]
2330
#[UsesClass(Tool::class)]
31+
#[UsesClass(ToolCall::class)]
32+
#[UsesClass(ToolCallResponse::class)]
2433
#[UsesClass(ExecutionReference::class)]
2534
#[UsesClass(MessageBag::class)]
2635
#[UsesClass(MissingModelSupportException::class)]
@@ -87,4 +96,54 @@ public function processInputWithUnsupportedToolCallingWillThrowException(): void
8796

8897
$chainProcessor->processInput($input);
8998
}
99+
100+
#[Test]
101+
public function processOutputWithToolCallResponseKeepingMessages(): void
102+
{
103+
$toolbox = $this->createMock(ToolboxInterface::class);
104+
$toolbox->expects($this->once())->method('execute')->willReturn('Test response');
105+
106+
$model = new Model('gpt-4', [Capability::TOOL_CALLING]);
107+
108+
$messageBag = new MessageBag();
109+
110+
$response = new ToolCallResponse(new ToolCall('id1', 'tool1', ['arg1' => 'value1']));
111+
112+
$chain = $this->createStub(ChainInterface::class);
113+
114+
$chainProcessor = new ChainProcessor($toolbox, keepToolMessages: true);
115+
$chainProcessor->setChain($chain);
116+
117+
$output = new Output($model, $response, $messageBag, []);
118+
119+
$chainProcessor->processOutput($output);
120+
121+
self::assertCount(2, $messageBag);
122+
self::assertInstanceOf(AssistantMessage::class, $messageBag->getMessages()[0]);
123+
self::assertInstanceOf(ToolCallMessage::class, $messageBag->getMessages()[1]);
124+
}
125+
126+
#[Test]
127+
public function processOutputWithToolCallResponseForgettingMessages(): void
128+
{
129+
$toolbox = $this->createMock(ToolboxInterface::class);
130+
$toolbox->expects($this->once())->method('execute')->willReturn('Test response');
131+
132+
$model = new Model('gpt-4', [Capability::TOOL_CALLING]);
133+
134+
$messageBag = new MessageBag();
135+
136+
$response = new ToolCallResponse(new ToolCall('id1', 'tool1', ['arg1' => 'value1']));
137+
138+
$chain = $this->createStub(ChainInterface::class);
139+
140+
$chainProcessor = new ChainProcessor($toolbox, keepToolMessages: false);
141+
$chainProcessor->setChain($chain);
142+
143+
$output = new Output($model, $response, $messageBag, []);
144+
145+
$chainProcessor->processOutput($output);
146+
147+
self::assertCount(0, $messageBag);
148+
}
90149
}

0 commit comments

Comments
 (0)