|
12 | 12 | namespace Symfony\AI\Platform\Bridge\Venice; |
13 | 13 |
|
14 | 14 | use Symfony\AI\Platform\Capability; |
15 | | -use Symfony\AI\Platform\Model; |
16 | | -use Symfony\AI\Platform\ModelCatalog\AbstractModelCatalog; |
| 15 | +use Symfony\AI\Platform\Exception\InvalidArgumentException; |
| 16 | +use Symfony\AI\Platform\ModelCatalog\ModelCatalogInterface; |
| 17 | +use Symfony\Contracts\HttpClient\HttpClientInterface; |
17 | 18 |
|
18 | 19 | /** |
19 | 20 | * @author Guillaume Loulier <personal@guillaumeloulier.fr> |
20 | 21 | */ |
21 | | -final class ModelCatalog extends AbstractModelCatalog |
| 22 | +final class ModelCatalog implements ModelCatalogInterface |
22 | 23 | { |
23 | | - /** |
24 | | - * @param array<string, array{class: class-string<Model>, capabilities: list<Capability>}> $additionalModels |
25 | | - */ |
26 | | - public function __construct(array $additionalModels = []) |
| 24 | + public function __construct( |
| 25 | + private readonly HttpClientInterface $httpClient, |
| 26 | + ) { |
| 27 | + } |
| 28 | + |
| 29 | + public function getModel(string $modelName): Venice |
27 | 30 | { |
28 | | - $defaultModels = [ |
29 | | - 'venice-uncensored' => [ |
| 31 | + $models = $this->getModels(); |
| 32 | + |
| 33 | + if (!\array_key_exists($modelName, $models)) { |
| 34 | + throw new InvalidArgumentException(\sprintf('The model "%s" cannot be retrieved from the API.', $modelName)); |
| 35 | + } |
| 36 | + |
| 37 | + if ([] === $models[$modelName]['capabilities']) { |
| 38 | + throw new InvalidArgumentException(\sprintf('The model "%s" is not supported, please check the Venice API.', $modelName)); |
| 39 | + } |
| 40 | + |
| 41 | + return new Venice($modelName, $models[$modelName]['capabilities']); |
| 42 | + } |
| 43 | + |
| 44 | + public function getModels(): array |
| 45 | + { |
| 46 | + $results = $this->httpClient->request('GET', 'models', [ |
| 47 | + 'query' => [ |
| 48 | + 'type' => 'all', |
| 49 | + ], |
| 50 | + ]); |
| 51 | + |
| 52 | + $models = $results->toArray(); |
| 53 | + |
| 54 | + if ([] === $models['data']) { |
| 55 | + return []; |
| 56 | + } |
| 57 | + |
| 58 | + $payload = static fn (array $model): array => match ($model['type']) { |
| 59 | + 'asr' => [ |
30 | 60 | 'class' => Venice::class, |
31 | 61 | 'capabilities' => [ |
32 | | - Capability::INPUT_MESSAGES, |
| 62 | + Capability::SPEECH_RECOGNITION, |
| 63 | + Capability::INPUT_TEXT, |
33 | 64 | ], |
34 | 65 | ], |
35 | | - 'tts-kokoro' => [ |
| 66 | + 'embedding' => [ |
36 | 67 | 'class' => Venice::class, |
37 | 68 | 'capabilities' => [ |
38 | | - Capability::TEXT_TO_SPEECH, |
| 69 | + Capability::EMBEDDINGS, |
| 70 | + Capability::INPUT_TEXT, |
39 | 71 | ], |
40 | 72 | ], |
41 | | - 'nvidia/parakeet-tdt-0.6b-v3' => [ |
| 73 | + 'text' => [ |
42 | 74 | 'class' => Venice::class, |
43 | 75 | 'capabilities' => [ |
44 | | - Capability::SPEECH_TO_TEXT, |
| 76 | + Capability::INPUT_TEXT, |
| 77 | + Capability::INPUT_MESSAGES, |
45 | 78 | ], |
46 | 79 | ], |
47 | | - 'openai/whisper-large-v3' => [ |
| 80 | + 'tts' => [ |
48 | 81 | 'class' => Venice::class, |
49 | 82 | 'capabilities' => [ |
50 | | - Capability::SPEECH_TO_TEXT, |
| 83 | + Capability::TEXT_TO_SPEECH, |
| 84 | + Capability::INPUT_TEXT, |
51 | 85 | ], |
52 | 86 | ], |
53 | | - 'text-embedding-bge-m3' => [ |
| 87 | + 'video' => [ |
54 | 88 | 'class' => Venice::class, |
55 | 89 | 'capabilities' => [ |
56 | | - Capability::EMBEDDINGS, |
| 90 | + Capability::IMAGE_TO_VIDEO, |
| 91 | + Capability::INPUT_IMAGE, |
57 | 92 | ], |
58 | 93 | ], |
59 | | - ]; |
| 94 | + }; |
60 | 95 |
|
61 | | - $this->models = [ |
62 | | - ...$defaultModels, |
63 | | - ...$additionalModels, |
64 | | - ]; |
| 96 | + return array_combine( |
| 97 | + array_map(static fn (array $model): string => $model['id'], $models['data']), |
| 98 | + array_map(static fn (array $model): array => $payload($model), $models['data']), |
| 99 | + ); |
65 | 100 | } |
66 | 101 | } |
0 commit comments