Skip to content

Introduce action as indicator to indicate what one wants to do with the model #300

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions src/platform/src/Action.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
<?php

/*
* This file is part of the Symfony package.
*
* (c) Fabien Potencier <[email protected]>
*
* For the full copyright and license information, please view the LICENSE
* file that was distributed with this source code.
*/

namespace Symfony\AI\Platform;

/**
* @author Joshua Behrens <[email protected]>
*/
enum Action: string
{
case CHAT = 'chat';
case CALCULATE_EMBEDDINGS = 'embeddings';
case COMPLETE_CHAT = 'chat-completion';
}
14 changes: 12 additions & 2 deletions src/platform/src/Bridge/Albert/EmbeddingsModelClient.php
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@

namespace Symfony\AI\Platform\Bridge\Albert;

use Symfony\AI\Platform\Action;
use Symfony\AI\Platform\Bridge\OpenAi\Embeddings;
use Symfony\AI\Platform\Exception\InvalidActionArgumentException;
use Symfony\AI\Platform\Exception\InvalidArgumentException;
use Symfony\AI\Platform\Model;
use Symfony\AI\Platform\ModelClientInterface;
Expand All @@ -33,13 +35,21 @@ public function __construct(
'' !== $baseUrl || throw new InvalidArgumentException('The base URL must not be empty.');
}

public function supports(Model $model): bool
public function supports(Model $model, Action $action): bool
{
if (Action::CALCULATE_EMBEDDINGS !== $action) {
return false;
}

return $model instanceof Embeddings;
}

public function request(Model $model, array|string $payload, array $options = []): RawResultInterface
public function request(Model $model, Action $action, array|string $payload, array $options = []): RawResultInterface
{
if (Action::CALCULATE_EMBEDDINGS !== $action) {
return throw new InvalidActionArgumentException($model, $action, [Action::CALCULATE_EMBEDDINGS]);
}

return new RawHttpResult($this->httpClient->request('POST', \sprintf('%s/embeddings', $this->baseUrl), [
'auth_bearer' => $this->apiKey,
'json' => \is_array($payload) ? array_merge($payload, $options) : $payload,
Expand Down
14 changes: 12 additions & 2 deletions src/platform/src/Bridge/Albert/GptModelClient.php
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@

namespace Symfony\AI\Platform\Bridge\Albert;

use Symfony\AI\Platform\Action;
use Symfony\AI\Platform\Bridge\OpenAi\Gpt;
use Symfony\AI\Platform\Exception\InvalidActionArgumentException;
use Symfony\AI\Platform\Exception\InvalidArgumentException;
use Symfony\AI\Platform\Model;
use Symfony\AI\Platform\ModelClientInterface;
Expand All @@ -38,13 +40,21 @@ public function __construct(
'' !== $baseUrl || throw new InvalidArgumentException('The base URL must not be empty.');
}

public function supports(Model $model): bool
public function supports(Model $model, Action $action): bool
{
if (Action::COMPLETE_CHAT !== $action && Action::CHAT !== $action) {
return false;
}

return $model instanceof Gpt;
}

public function request(Model $model, array|string $payload, array $options = []): RawResultInterface
public function request(Model $model, Action $action, array|string $payload, array $options = []): RawResultInterface
{
if (Action::COMPLETE_CHAT !== $action && Action::CHAT !== $action) {
return throw new InvalidActionArgumentException($model, $action, [Action::CHAT, Action::COMPLETE_CHAT]);
}

return new RawHttpResult($this->httpClient->request('POST', \sprintf('%s/chat/completions', $this->baseUrl), [
'auth_bearer' => $this->apiKey,
'json' => \is_array($payload) ? array_merge($payload, $options) : $payload,
Expand Down
14 changes: 12 additions & 2 deletions src/platform/src/Bridge/Anthropic/ModelClient.php
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@

namespace Symfony\AI\Platform\Bridge\Anthropic;

use Symfony\AI\Platform\Action;
use Symfony\AI\Platform\Exception\InvalidActionArgumentException;
use Symfony\AI\Platform\Model;
use Symfony\AI\Platform\ModelClientInterface;
use Symfony\AI\Platform\Result\RawHttpResult;
Expand All @@ -32,13 +34,21 @@ public function __construct(
$this->httpClient = $httpClient instanceof EventSourceHttpClient ? $httpClient : new EventSourceHttpClient($httpClient);
}

public function supports(Model $model): bool
public function supports(Model $model, Action $action): bool
{
if (Action::CHAT !== $action) {
return false;
}

return $model instanceof Claude;
}

public function request(Model $model, array|string $payload, array $options = []): RawHttpResult
public function request(Model $model, Action $action, array|string $payload, array $options = []): RawHttpResult
{
if (Action::CHAT !== $action) {
return throw new InvalidActionArgumentException($model, $action, [Action::CHAT]);
}

if (isset($options['tools'])) {
$options['tool_choice'] = ['type' => 'auto'];
}
Expand Down
7 changes: 6 additions & 1 deletion src/platform/src/Bridge/Anthropic/ResultConverter.php
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

namespace Symfony\AI\Platform\Bridge\Anthropic;

use Symfony\AI\Platform\Action;
use Symfony\AI\Platform\Exception\RuntimeException;
use Symfony\AI\Platform\Model;
use Symfony\AI\Platform\Result\RawHttpResult;
Expand All @@ -31,8 +32,12 @@
*/
class ResultConverter implements ResultConverterInterface
{
public function supports(Model $model): bool
public function supports(Model $model, Action $action): bool
{
if (Action::COMPLETE_CHAT !== $action && Action::CHAT !== $action) {
return false;
}

return $model instanceof Claude;
}

Expand Down
14 changes: 12 additions & 2 deletions src/platform/src/Bridge/Azure/Meta/LlamaModelClient.php
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@

namespace Symfony\AI\Platform\Bridge\Azure\Meta;

use Symfony\AI\Platform\Action;
use Symfony\AI\Platform\Bridge\Meta\Llama;
use Symfony\AI\Platform\Exception\InvalidActionArgumentException;
use Symfony\AI\Platform\Model;
use Symfony\AI\Platform\ModelClientInterface;
use Symfony\AI\Platform\Result\RawHttpResult;
Expand All @@ -29,13 +31,21 @@ public function __construct(
) {
}

public function supports(Model $model): bool
public function supports(Model $model, Action $action): bool
{
if (Action::COMPLETE_CHAT !== $action && Action::CHAT !== $action) {
return false;
}

return $model instanceof Llama;
}

public function request(Model $model, array|string $payload, array $options = []): RawHttpResult
public function request(Model $model, Action $action, array|string $payload, array $options = []): RawHttpResult
{
if (Action::COMPLETE_CHAT !== $action && Action::CHAT !== $action) {
return throw new InvalidActionArgumentException($model, $action, [Action::CHAT, Action::COMPLETE_CHAT]);
}

$url = \sprintf('https://%s/chat/completions', $this->baseUrl);

return new RawHttpResult($this->httpClient->request('POST', $url, [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

namespace Symfony\AI\Platform\Bridge\Azure\Meta;

use Symfony\AI\Platform\Action;
use Symfony\AI\Platform\Bridge\Meta\Llama;
use Symfony\AI\Platform\Exception\RuntimeException;
use Symfony\AI\Platform\Model;
Expand All @@ -23,8 +24,12 @@
*/
final readonly class LlamaResultConverter implements ResultConverterInterface
{
public function supports(Model $model): bool
public function supports(Model $model, Action $action): bool
{
if (Action::COMPLETE_CHAT !== $action && Action::CHAT !== $action) {
return false;
}

return $model instanceof Llama;
}

Expand Down
14 changes: 12 additions & 2 deletions src/platform/src/Bridge/Azure/OpenAi/EmbeddingsModelClient.php
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@

namespace Symfony\AI\Platform\Bridge\Azure\OpenAi;

use Symfony\AI\Platform\Action;
use Symfony\AI\Platform\Bridge\OpenAi\Embeddings;
use Symfony\AI\Platform\Exception\InvalidActionArgumentException;
use Symfony\AI\Platform\Exception\InvalidArgumentException;
use Symfony\AI\Platform\Model;
use Symfony\AI\Platform\ModelClientInterface;
Expand Down Expand Up @@ -41,13 +43,21 @@ public function __construct(
'' !== $apiKey || throw new InvalidArgumentException('The API key must not be empty.');
}

public function supports(Model $model): bool
public function supports(Model $model, Action $action): bool
{
if (Action::CALCULATE_EMBEDDINGS !== $action) {
return false;
}

return $model instanceof Embeddings;
}

public function request(Model $model, array|string $payload, array $options = []): RawHttpResult
public function request(Model $model, Action $action, array|string $payload, array $options = []): RawHttpResult
{
if (Action::CALCULATE_EMBEDDINGS !== $action) {
return throw new InvalidActionArgumentException($model, $action, [Action::CALCULATE_EMBEDDINGS]);
}

$url = \sprintf('https://%s/openai/deployments/%s/embeddings', $this->baseUrl, $this->deployment);

return new RawHttpResult($this->httpClient->request('POST', $url, [
Expand Down
14 changes: 12 additions & 2 deletions src/platform/src/Bridge/Azure/OpenAi/GptModelClient.php
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@

namespace Symfony\AI\Platform\Bridge\Azure\OpenAi;

use Symfony\AI\Platform\Action;
use Symfony\AI\Platform\Bridge\OpenAi\Gpt;
use Symfony\AI\Platform\Exception\InvalidActionArgumentException;
use Symfony\AI\Platform\Exception\InvalidArgumentException;
use Symfony\AI\Platform\Model;
use Symfony\AI\Platform\ModelClientInterface;
Expand Down Expand Up @@ -41,13 +43,21 @@ public function __construct(
'' !== $apiKey || throw new InvalidArgumentException('The API key must not be empty.');
}

public function supports(Model $model): bool
public function supports(Model $model, Action $action): bool
{
if (Action::COMPLETE_CHAT !== $action && Action::CHAT !== $action) {
return false;
}

return $model instanceof Gpt;
}

public function request(Model $model, object|array|string $payload, array $options = []): RawHttpResult
public function request(Model $model, Action $action, object|array|string $payload, array $options = []): RawHttpResult
{
if (Action::COMPLETE_CHAT !== $action && Action::CHAT !== $action) {
return throw new InvalidActionArgumentException($model, $action, [Action::CHAT, Action::COMPLETE_CHAT]);
}

$url = \sprintf('https://%s/openai/deployments/%s/chat/completions', $this->baseUrl, $this->deployment);

return new RawHttpResult($this->httpClient->request('POST', $url, [
Expand Down
14 changes: 12 additions & 2 deletions src/platform/src/Bridge/Azure/OpenAi/WhisperModelClient.php
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@

namespace Symfony\AI\Platform\Bridge\Azure\OpenAi;

use Symfony\AI\Platform\Action;
use Symfony\AI\Platform\Bridge\OpenAi\Whisper;
use Symfony\AI\Platform\Bridge\OpenAi\Whisper\Task;
use Symfony\AI\Platform\Exception\InvalidActionArgumentException;
use Symfony\AI\Platform\Exception\InvalidArgumentException;
use Symfony\AI\Platform\Model;
use Symfony\AI\Platform\ModelClientInterface;
Expand Down Expand Up @@ -42,13 +44,21 @@ public function __construct(
'' !== $apiKey || throw new InvalidArgumentException('The API key must not be empty.');
}

public function supports(Model $model): bool
public function supports(Model $model, Action $action): bool
{
if (Action::CHAT !== $action) {
return false;
}

return $model instanceof Whisper;
}

public function request(Model $model, array|string $payload, array $options = []): RawHttpResult
public function request(Model $model, Action $action, array|string $payload, array $options = []): RawHttpResult
{
if (Action::CHAT !== $action) {
return throw new InvalidActionArgumentException($model, $action, [Action::CHAT]);
}

$task = $options['task'] ?? Task::TRANSCRIPTION;
$endpoint = Task::TRANSCRIPTION === $task ? 'transcriptions' : 'translations';
$url = \sprintf('https://%s/openai/deployments/%s/audio/%s', $this->baseUrl, $this->deployment, $endpoint);
Expand Down
14 changes: 12 additions & 2 deletions src/platform/src/Bridge/Bedrock/Anthropic/ClaudeModelClient.php
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@
use AsyncAws\BedrockRuntime\BedrockRuntimeClient;
use AsyncAws\BedrockRuntime\Input\InvokeModelRequest;
use AsyncAws\BedrockRuntime\Result\InvokeModelResponse;
use Symfony\AI\Platform\Action;
use Symfony\AI\Platform\Bridge\Anthropic\Claude;
use Symfony\AI\Platform\Bridge\Bedrock\RawBedrockResult;
use Symfony\AI\Platform\Exception\InvalidActionArgumentException;
use Symfony\AI\Platform\Exception\RuntimeException;
use Symfony\AI\Platform\Model;
use Symfony\AI\Platform\ModelClientInterface;
Expand All @@ -34,13 +36,21 @@ public function __construct(
) {
}

public function supports(Model $model): bool
public function supports(Model $model, Action $action): bool
{
if (Action::COMPLETE_CHAT !== $action && Action::CHAT !== $action) {
return false;
}

return $model instanceof Claude;
}

public function request(Model $model, array|string $payload, array $options = []): RawBedrockResult
public function request(Model $model, Action $action, array|string $payload, array $options = []): RawBedrockResult
{
if (Action::COMPLETE_CHAT !== $action && Action::CHAT !== $action) {
return throw new InvalidActionArgumentException($model, $action, [Action::CHAT, Action::COMPLETE_CHAT]);
}

unset($payload['model']);

if (isset($options['tools'])) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

namespace Symfony\AI\Platform\Bridge\Bedrock\Anthropic;

use Symfony\AI\Platform\Action;
use Symfony\AI\Platform\Bridge\Anthropic\Claude;
use Symfony\AI\Platform\Bridge\Bedrock\RawBedrockResult;
use Symfony\AI\Platform\Exception\RuntimeException;
Expand All @@ -26,8 +27,12 @@
*/
final readonly class ClaudeResultConverter implements ResultConverterInterface
{
public function supports(Model $model): bool
public function supports(Model $model, Action $action): bool
{
if (Action::COMPLETE_CHAT !== $action && Action::CHAT !== $action) {
return false;
}

return $model instanceof Claude;
}

Expand Down
14 changes: 12 additions & 2 deletions src/platform/src/Bridge/Bedrock/Meta/LlamaModelClient.php
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@

use AsyncAws\BedrockRuntime\BedrockRuntimeClient;
use AsyncAws\BedrockRuntime\Input\InvokeModelRequest;
use Symfony\AI\Platform\Action;
use Symfony\AI\Platform\Bridge\Bedrock\RawBedrockResult;
use Symfony\AI\Platform\Bridge\Meta\Llama;
use Symfony\AI\Platform\Exception\InvalidActionArgumentException;
use Symfony\AI\Platform\Model;
use Symfony\AI\Platform\ModelClientInterface;

Expand All @@ -28,13 +30,21 @@ public function __construct(
) {
}

public function supports(Model $model): bool
public function supports(Model $model, Action $action): bool
{
if (Action::COMPLETE_CHAT !== $action && Action::CHAT !== $action) {
return false;
}

return $model instanceof Llama;
}

public function request(Model $model, array|string $payload, array $options = []): RawBedrockResult
public function request(Model $model, Action $action, array|string $payload, array $options = []): RawBedrockResult
{
if (Action::COMPLETE_CHAT !== $action && Action::CHAT !== $action) {
return throw new InvalidActionArgumentException($model, $action, [Action::CHAT, Action::COMPLETE_CHAT]);
}

return new RawBedrockResult($this->bedrockRuntimeClient->invokeModel(new InvokeModelRequest([
'modelId' => $this->getModelId($model),
'contentType' => 'application/json',
Expand Down
Loading
Loading