Skip to content

Commit ebc92a2

Browse files
committed
Make sure provider is passed alongside with model instead of autocalculated
1 parent 9aff1c6 commit ebc92a2

File tree

8 files changed

+669
-112
lines changed

8 files changed

+669
-112
lines changed

config/test.exs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,3 +160,17 @@ config :junit_formatter,
160160
config :trento,
161161
:flaky_tests_detection,
162162
enabled?: System.get_env("WRITE_JUNIT") == "1"
163+
164+
config :trento, :ai,
165+
providers: [
166+
provider1: [
167+
models: [
168+
"model1"
169+
]
170+
],
171+
provider2: [
172+
models: [
173+
"model1"
174+
]
175+
]
176+
]

lib/trento/ai/llm_registry.ex

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,27 @@ defmodule Trento.AI.LLMRegistry do
4343
end)
4444
end
4545

46+
@doc """
47+
Checks if a given model is supported by a specific provider.
48+
"""
49+
@spec model_supported_by_provider?(bitstring(), atom()) :: boolean()
50+
def model_supported_by_provider?(model, provider) do
51+
model in get_provider_models(provider)
52+
end
53+
4654
@doc """
4755
Checks if a given model is supported by any provider.
4856
"""
4957
@spec model_supported?(bitstring()) :: boolean()
5058
def model_supported?(model), do: model in get_provider_models(:all)
5159

60+
@doc """
61+
Checks if a given provider is supported.
62+
"""
63+
@spec provider_supported?(atom()) :: boolean()
64+
def provider_supported?(provider) when is_atom(provider), do: provider in providers()
65+
66+
def provider_supported?(_), do: false
67+
5268
defp get_ai_providers_config, do: Keyword.get(ApplicationConfigLoader.load(), :providers, [])
5369
end

lib/trento/ai/user_configuration.ex

Lines changed: 38 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ defmodule Trento.AI.UserConfiguration do
1414
schema "ai_configurations" do
1515
field :model, :string
1616
field :provider, Ecto.Enum, values: LLMRegistry.providers()
17+
# field :provider, :string
1718
field :api_key, EncryptedBinary, redact: true
1819

1920
belongs_to :user, Trento.Users.User, primary_key: true
@@ -22,29 +23,53 @@ defmodule Trento.AI.UserConfiguration do
2223
end
2324

2425
def changeset(ai_configuration, attrs) do
25-
updated_attrs = maybe_update_provider(attrs)
26-
2726
ai_configuration
28-
|> cast(updated_attrs, [:user_id, :model, :provider, :api_key])
29-
|> validate_required([:user_id, :model, :api_key])
30-
|> validate_change(:model, &validate_model/2)
27+
|> cast(attrs, [:user_id, :model, :provider, :api_key])
28+
|> validate_required([:user_id, :model, :provider, :api_key])
29+
|> validate_change(:provider, &validate_provider/2)
30+
|> validate_model()
3131
|> unique_constraint(:user_id,
3232
name: :ai_configurations_pkey,
3333
message: "User already has a configuration"
3434
)
3535
|> foreign_key_constraint(:user_id, message: "User does not exist")
3636
end
3737

38-
defp maybe_update_provider(%{model: model} = attrs),
39-
do: Map.put(attrs, :provider, LLMRegistry.get_model_provider(model))
40-
41-
defp maybe_update_provider(attrs), do: attrs
42-
43-
defp validate_model(_model_field_atom, model) do
44-
if LLMRegistry.model_supported?(model) do
38+
defp validate_provider(_provider_field_atom, provider) do
39+
if LLMRegistry.provider_supported?(provider) do
4540
[]
4641
else
47-
[model: {"is not supported", validation: :ai_model_validity}]
42+
[provider: {"is not supported", validation: :ai_provider_validity}]
4843
end
4944
end
45+
46+
defp validate_model(%{errors: [provider: _]} = changeset), do: changeset
47+
48+
defp validate_model(changeset) do
49+
provider = get_field(changeset, :provider)
50+
model = get_field(changeset, :model)
51+
52+
changeset
53+
|> force_change(:model, model)
54+
|> force_change(:provider, provider)
55+
|> validate_change(:model, fn _model_atom, _model ->
56+
model_supported? = LLMRegistry.model_supported?(model)
57+
model_supported_by_provider? = LLMRegistry.model_supported_by_provider?(model, provider)
58+
59+
case {model_supported?, model_supported_by_provider?} do
60+
{true, true} ->
61+
[]
62+
63+
{true, false} ->
64+
[
65+
model:
66+
{"is not supported by the specified provider",
67+
validation: :ai_model_provider_mismatch}
68+
]
69+
70+
{false, _} ->
71+
[model: {"is not supported", validation: :ai_model_validity}]
72+
end
73+
end)
74+
end
5075
end

lib/trento_web/openapi/v1/schema/ai.ex

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,12 @@ defmodule TrentoWeb.OpenApi.V1.Schema.AI do
4848
type: :object,
4949
additionalProperties: false,
5050
properties: %{
51+
provider: %Schema{
52+
type: :string,
53+
description: "AI provider.",
54+
nullable: false,
55+
example: "googleai"
56+
},
5157
model: %Schema{
5258
type: :string,
5359
description: "AI model.",
@@ -62,10 +68,11 @@ defmodule TrentoWeb.OpenApi.V1.Schema.AI do
6268
}
6369
},
6470
example: %{
71+
provider: "googleai",
6572
model: "gemini-2.0-flash",
6673
api_key: "AIza..."
6774
},
68-
required: [:model, :api_key]
75+
required: [:provider, :model, :api_key]
6976
},
7077
struct?: false
7178
)
@@ -82,14 +89,22 @@ defmodule TrentoWeb.OpenApi.V1.Schema.AI do
8289
additionalProperties: false,
8390
minProperties: 1,
8491
properties: %{
92+
provider: %Schema{
93+
type: :string,
94+
description: "AI provider.",
95+
nullable: false,
96+
example: "googleai"
97+
},
8598
model: %Schema{
8699
type: :string,
87100
description: "AI model.",
101+
nullable: false,
88102
example: "gemini-2.0-flash"
89103
},
90104
api_key: %Schema{
91105
type: :string,
92106
description: "AI API key.",
107+
nullable: false,
93108
example: "AIza..."
94109
}
95110
},

test/support/factory.ex

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1461,24 +1461,43 @@ defmodule Trento.Factory do
14611461
}
14621462
end
14631463

1464-
def random_ai_model_factory(_) do
1464+
def random_ai_provider_factory(_) do
14651465
Trento.AI.AICase.stub_config_loader()
14661466

1467-
:all
1467+
Enum.random(LLMRegistry.providers())
1468+
end
1469+
1470+
def random_ai_model_factory(attrs) do
1471+
Trento.AI.AICase.stub_config_loader()
1472+
1473+
attrs
1474+
|> Map.get(:provider, :all)
14681475
|> LLMRegistry.get_provider_models()
14691476
|> Enum.random()
14701477
end
14711478

14721479
def ai_user_configuration_factory(attrs) do
14731480
user_id = Map.get(attrs, :user_id, 1)
1474-
model = Map.get(attrs, :model, build(:random_ai_model))
1481+
provider = Map.get(attrs, :provider, build(:random_ai_provider))
1482+
model = Map.get(attrs, :model, build(:random_ai_model, provider: provider))
14751483

1476-
%UserConfiguration{}
1477-
|> UserConfiguration.changeset(%{
1484+
%UserConfiguration{
14781485
user_id: user_id,
1486+
provider: provider,
14791487
model: model,
14801488
api_key: Faker.String.base64(32)
1481-
})
1482-
|> Ecto.Changeset.apply_changes()
1489+
}
1490+
end
1491+
1492+
def ai_configuration_creation_params_factory(attrs) do
1493+
provider = Map.get(attrs, :provider, build(:random_ai_provider))
1494+
model = Map.get(attrs, :model, build(:random_ai_model, provider: provider))
1495+
api_key = Map.get(attrs, :api_key, Faker.String.base64(32))
1496+
1497+
%{
1498+
provider: provider,
1499+
model: model,
1500+
api_key: api_key
1501+
}
14831502
end
14841503
end

0 commit comments

Comments
 (0)