Skip to content

Commit 43cbac2

Browse files
committed
feat: Add ability to edit provider, API key, and endpoint in existing integrations
- Extended OptionsFlowHandler with two-step configuration flow - Step 1: Select provider (OpenAI, Anthropic, DeepSeek, Gemini) - Step 2: Configure API key, endpoint, model, and other settings - Auto-reload integration on options change - When switching providers, show appropriate default endpoint and model - Updated translations for all 8 languages
1 parent 986c78d commit 43cbac2

File tree

10 files changed

+330
-86
lines changed

10 files changed

+330
-86
lines changed

custom_components/ha_text_ai/__init__.py

Lines changed: 34 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -312,21 +312,35 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
312312
_LOGGER.debug(f"Setting up HA Text AI entry: {entry.data}")
313313

314314
try:
315-
if CONF_API_PROVIDER not in entry.data:
315+
# Get provider from data or options (options takes precedence)
316+
config = {**entry.data, **entry.options}
317+
api_provider = config.get(CONF_API_PROVIDER)
318+
319+
if not api_provider:
316320
_LOGGER.error("API provider not specified")
317321
raise ConfigEntryNotReady("API provider is required")
318322

319-
# Get configuration (merge data with options to apply any runtime changes)
320-
config = {**entry.data, **entry.options}
321323
session = aiohttp_client.async_get_clientsession(hass)
322-
api_provider = config.get(CONF_API_PROVIDER)
323-
model = config.get(CONF_MODEL, DEFAULT_MODEL)
324-
endpoint = config.get(
325-
CONF_API_ENDPOINT,
326-
DEFAULT_OPENAI_ENDPOINT if api_provider == API_PROVIDER_OPENAI
327-
else DEFAULT_ANTHROPIC_ENDPOINT
328-
).rstrip('/')
329-
api_key = entry.data[CONF_API_KEY] # API key stays in data, not in options
324+
325+
# Get default endpoint based on provider
326+
default_endpoint = {
327+
API_PROVIDER_OPENAI: DEFAULT_OPENAI_ENDPOINT,
328+
API_PROVIDER_ANTHROPIC: DEFAULT_ANTHROPIC_ENDPOINT,
329+
API_PROVIDER_DEEPSEEK: DEFAULT_DEEPSEEK_ENDPOINT,
330+
API_PROVIDER_GEMINI: DEFAULT_GEMINI_ENDPOINT,
331+
}.get(api_provider, DEFAULT_OPENAI_ENDPOINT)
332+
333+
# Get default model based on provider
334+
default_model = (
335+
DEFAULT_DEEPSEEK_MODEL if api_provider == API_PROVIDER_DEEPSEEK else
336+
DEFAULT_GEMINI_MODEL if api_provider == API_PROVIDER_GEMINI else
337+
DEFAULT_MODEL
338+
)
339+
340+
model = config.get(CONF_MODEL, default_model)
341+
endpoint = config.get(CONF_API_ENDPOINT, default_endpoint).rstrip('/')
342+
# API key can now be updated via options
343+
api_key = config.get(CONF_API_KEY, entry.data.get(CONF_API_KEY))
330344
instance_name = entry.data.get(CONF_NAME, entry.entry_id)
331345
request_interval = config.get(CONF_REQUEST_INTERVAL, DEFAULT_REQUEST_INTERVAL)
332346
api_timeout = config.get(CONF_API_TIMEOUT, DEFAULT_API_TIMEOUT)
@@ -386,6 +400,9 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
386400
# Set up platforms
387401
await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS)
388402

403+
# Register update listener for options changes
404+
entry.async_on_unload(entry.add_update_listener(async_update_options))
405+
389406
_LOGGER.debug(f"Setup completed for {instance_name}")
390407

391408
return True
@@ -394,6 +411,12 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
394411
_LOGGER.exception(f"Error setting up HA Text AI: {err}")
395412
raise
396413

414+
async def async_update_options(hass: HomeAssistant, entry: ConfigEntry) -> None:
415+
"""Handle options update - reload the config entry."""
416+
_LOGGER.info("Options updated for %s, reloading integration", entry.title)
417+
await hass.config_entries.async_reload(entry.entry_id)
418+
419+
397420
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
398421
"""Unload a config entry."""
399422
try:

custom_components/ha_text_ai/config_flow.py

Lines changed: 207 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -479,74 +479,223 @@ def async_get_options_flow(config_entry: config_entries.ConfigEntry) -> config_e
479479
class OptionsFlowHandler(config_entries.OptionsFlow):
480480
"""Handle options flow."""
481481

482-
async def async_step_init(self, user_input: Optional[Dict[str, Any]] = None) -> FlowResult:
483-
"""Manage the options."""
484-
if user_input is not None:
485-
return self.async_create_entry(title="", data=user_input)
486-
487-
current_data = {**self.config_entry.data, **self.config_entry.options}
488-
provider = current_data.get(CONF_API_PROVIDER)
482+
def __init__(self) -> None:
483+
"""Initialize options flow."""
484+
self._errors = {}
485+
self._selected_provider = None
489486

490-
default_model = (
487+
def _get_default_endpoint(self, provider: str) -> str:
488+
"""Get default endpoint for provider."""
489+
return {
490+
API_PROVIDER_OPENAI: DEFAULT_OPENAI_ENDPOINT,
491+
API_PROVIDER_ANTHROPIC: DEFAULT_ANTHROPIC_ENDPOINT,
492+
API_PROVIDER_DEEPSEEK: DEFAULT_DEEPSEEK_ENDPOINT,
493+
API_PROVIDER_GEMINI: DEFAULT_GEMINI_ENDPOINT,
494+
}.get(provider, DEFAULT_OPENAI_ENDPOINT)
495+
496+
def _get_default_model(self, provider: str) -> str:
497+
"""Get default model for provider."""
498+
return (
491499
DEFAULT_DEEPSEEK_MODEL if provider == API_PROVIDER_DEEPSEEK else
492500
DEFAULT_GEMINI_MODEL if provider == API_PROVIDER_GEMINI else
493501
DEFAULT_MODEL
494502
)
495503

504+
def _get_api_headers(self, api_key: str, provider: str) -> Dict[str, str]:
505+
"""Get API headers based on provider."""
506+
if provider == API_PROVIDER_ANTHROPIC:
507+
return {
508+
"x-api-key": api_key,
509+
"anthropic-version": "2023-06-01",
510+
"Content-Type": "application/json"
511+
}
512+
return {
513+
"Authorization": f"Bearer {api_key}",
514+
"Content-Type": "application/json"
515+
}
516+
517+
async def _async_validate_api(self, provider: str, api_key: str, endpoint: str) -> bool:
518+
"""Validate API connection."""
519+
try:
520+
if not api_key:
521+
self._errors["base"] = "invalid_auth"
522+
return False
523+
524+
# For Gemini, just check if API key is present
525+
if provider == API_PROVIDER_GEMINI:
526+
return True
527+
528+
session = async_get_clientsession(self.hass)
529+
headers = self._get_api_headers(api_key, provider)
530+
endpoint = endpoint.rstrip('/')
531+
532+
check_url = (
533+
f"{endpoint}/v1/models" if provider == API_PROVIDER_ANTHROPIC
534+
else f"{endpoint}/models"
535+
)
536+
537+
async with session.get(check_url, headers=headers) as response:
538+
if response.status == 401:
539+
self._errors["base"] = "invalid_auth"
540+
return False
541+
elif response.status not in [200, 404]:
542+
self._errors["base"] = "cannot_connect"
543+
return False
544+
return True
545+
546+
except Exception as err:
547+
_LOGGER.error("API validation error: %s", str(err))
548+
self._errors["base"] = "cannot_connect"
549+
return False
550+
551+
async def async_step_init(self, user_input: Optional[Dict[str, Any]] = None) -> FlowResult:
552+
"""Handle provider selection step."""
553+
current_data = {**self.config_entry.data, **self.config_entry.options}
554+
current_provider = current_data.get(CONF_API_PROVIDER, API_PROVIDER_OPENAI)
555+
556+
if user_input is not None:
557+
self._selected_provider = user_input.get(CONF_API_PROVIDER, current_provider)
558+
return await self.async_step_settings()
559+
496560
return self.async_show_form(
497561
step_id="init",
498562
data_schema=vol.Schema({
499-
vol.Optional(
500-
CONF_MODEL,
501-
default=current_data.get(CONF_MODEL, default_model)
502-
): str,
503-
vol.Optional(
504-
CONF_TEMPERATURE,
505-
default=current_data.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE)
506-
): vol.All(
507-
vol.Coerce(float),
508-
vol.Range(min=MIN_TEMPERATURE, max=MAX_TEMPERATURE)
509-
),
510-
vol.Optional(
511-
CONF_MAX_TOKENS,
512-
default=current_data.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS)
513-
): vol.All(
514-
vol.Coerce(int),
515-
vol.Range(min=MIN_MAX_TOKENS, max=MAX_MAX_TOKENS)
516-
),
517-
vol.Optional(
518-
CONF_REQUEST_INTERVAL,
519-
default=current_data.get(CONF_REQUEST_INTERVAL, DEFAULT_REQUEST_INTERVAL)
520-
): vol.All(
521-
vol.Coerce(float),
522-
vol.Range(min=MIN_REQUEST_INTERVAL)
523-
),
524-
vol.Optional(
525-
CONF_API_TIMEOUT,
526-
default=current_data.get(CONF_API_TIMEOUT, DEFAULT_API_TIMEOUT)
527-
): vol.All(
528-
vol.Coerce(int),
529-
vol.Range(min=MIN_API_TIMEOUT, max=MAX_API_TIMEOUT)
530-
),
531-
vol.Optional(
532-
CONF_CONTEXT_MESSAGES,
533-
default=current_data.get(
534-
CONF_CONTEXT_MESSAGES,
535-
DEFAULT_CONTEXT_MESSAGES
563+
vol.Required(
564+
CONF_API_PROVIDER,
565+
default=current_provider
566+
): selector.SelectSelector(
567+
selector.SelectSelectorConfig(
568+
options=API_PROVIDERS,
569+
translation_key="api_provider"
536570
)
537-
): vol.All(
538-
vol.Coerce(int),
539-
vol.Range(min=1, max=20)
540571
),
541-
vol.Optional(
542-
CONF_MAX_HISTORY_SIZE,
543-
default=current_data.get(
544-
CONF_MAX_HISTORY_SIZE,
545-
DEFAULT_MAX_HISTORY
546-
)
547-
): vol.All(
548-
vol.Coerce(int),
549-
vol.Range(min=1, max=100)
572+
}),
573+
description_placeholders={
574+
"current_provider": current_provider
575+
}
576+
)
577+
578+
async def async_step_settings(self, user_input: Optional[Dict[str, Any]] = None) -> FlowResult:
579+
"""Handle settings configuration step."""
580+
self._errors = {}
581+
current_data = {**self.config_entry.data, **self.config_entry.options}
582+
provider = self._selected_provider or current_data.get(CONF_API_PROVIDER, API_PROVIDER_OPENAI)
583+
584+
# Determine if provider changed to show appropriate defaults
585+
provider_changed = provider != current_data.get(CONF_API_PROVIDER)
586+
587+
# Use new defaults if provider changed, otherwise use current values
588+
if provider_changed:
589+
default_endpoint = self._get_default_endpoint(provider)
590+
default_model = self._get_default_model(provider)
591+
else:
592+
default_endpoint = current_data.get(CONF_API_ENDPOINT, self._get_default_endpoint(provider))
593+
default_model = current_data.get(CONF_MODEL, self._get_default_model(provider))
594+
595+
if user_input is not None:
596+
# Validate API connection
597+
api_key = user_input.get(CONF_API_KEY, current_data.get(CONF_API_KEY, ""))
598+
endpoint = user_input.get(CONF_API_ENDPOINT, default_endpoint)
599+
600+
if await self._async_validate_api(provider, api_key, endpoint):
601+
# Merge with provider selection
602+
final_data = {
603+
CONF_API_PROVIDER: provider,
604+
**user_input
605+
}
606+
return self.async_create_entry(title="", data=final_data)
607+
608+
# Show form again with errors
609+
return self.async_show_form(
610+
step_id="settings",
611+
data_schema=self._get_settings_schema(
612+
provider=provider,
613+
current_data=current_data,
614+
user_input=user_input,
615+
default_endpoint=default_endpoint,
616+
default_model=default_model,
550617
),
551-
})
618+
errors=self._errors
619+
)
620+
621+
return self.async_show_form(
622+
step_id="settings",
623+
data_schema=self._get_settings_schema(
624+
provider=provider,
625+
current_data=current_data,
626+
user_input=None,
627+
default_endpoint=default_endpoint,
628+
default_model=default_model,
629+
),
630+
description_placeholders={
631+
"provider": provider
632+
}
552633
)
634+
635+
def _get_settings_schema(
636+
self,
637+
provider: str,
638+
current_data: Dict[str, Any],
639+
user_input: Optional[Dict[str, Any]],
640+
default_endpoint: str,
641+
default_model: str,
642+
) -> vol.Schema:
643+
"""Build settings schema."""
644+
data = user_input or current_data
645+
646+
return vol.Schema({
647+
vol.Required(
648+
CONF_API_KEY,
649+
default=data.get(CONF_API_KEY, "")
650+
): str,
651+
vol.Required(
652+
CONF_API_ENDPOINT,
653+
default=data.get(CONF_API_ENDPOINT, default_endpoint)
654+
): str,
655+
vol.Required(
656+
CONF_MODEL,
657+
default=data.get(CONF_MODEL, default_model)
658+
): str,
659+
vol.Optional(
660+
CONF_TEMPERATURE,
661+
default=data.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE)
662+
): vol.All(
663+
vol.Coerce(float),
664+
vol.Range(min=MIN_TEMPERATURE, max=MAX_TEMPERATURE)
665+
),
666+
vol.Optional(
667+
CONF_MAX_TOKENS,
668+
default=data.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS)
669+
): vol.All(
670+
vol.Coerce(int),
671+
vol.Range(min=MIN_MAX_TOKENS, max=MAX_MAX_TOKENS)
672+
),
673+
vol.Optional(
674+
CONF_REQUEST_INTERVAL,
675+
default=data.get(CONF_REQUEST_INTERVAL, DEFAULT_REQUEST_INTERVAL)
676+
): vol.All(
677+
vol.Coerce(float),
678+
vol.Range(min=MIN_REQUEST_INTERVAL)
679+
),
680+
vol.Optional(
681+
CONF_API_TIMEOUT,
682+
default=data.get(CONF_API_TIMEOUT, DEFAULT_API_TIMEOUT)
683+
): vol.All(
684+
vol.Coerce(int),
685+
vol.Range(min=MIN_API_TIMEOUT, max=MAX_API_TIMEOUT)
686+
),
687+
vol.Optional(
688+
CONF_CONTEXT_MESSAGES,
689+
default=data.get(CONF_CONTEXT_MESSAGES, DEFAULT_CONTEXT_MESSAGES)
690+
): vol.All(
691+
vol.Coerce(int),
692+
vol.Range(min=1, max=20)
693+
),
694+
vol.Optional(
695+
CONF_MAX_HISTORY_SIZE,
696+
default=data.get(CONF_MAX_HISTORY_SIZE, DEFAULT_MAX_HISTORY)
697+
): vol.All(
698+
vol.Coerce(int),
699+
vol.Range(min=1, max=100)
700+
),
701+
})

custom_components/ha_text_ai/translations/de.json

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,13 +74,22 @@
7474
"options": {
7575
"step": {
7676
"init": {
77-
"title": "Instanzeinstellungen aktualisieren",
78-
"description": "Ändern Sie die Einstellungen für diese AI-Assistenteninstanz.",
77+
"title": "Anbieter auswählen",
78+
"description": "Wählen Sie den AI-Anbieter für diese Instanz. Die Integration wird nach dem Speichern der Änderungen neu geladen.",
7979
"data": {
80+
"api_provider": "API-Anbieter"
81+
}
82+
},
83+
"settings": {
84+
"title": "Verbindungs- und Modelleinstellungen",
85+
"description": "Konfigurieren Sie API-Anmeldeinformationen und Modellparameter. Änderungen werden nach dem Neuladen der Integration wirksam.",
86+
"data": {
87+
"api_key": "API-Schlüssel",
88+
"api_endpoint": "API-Endpunkt-URL",
8089
"model": "AI-Modell",
8190
"temperature": "Kreativität der Antwort (0-2)",
8291
"max_tokens": "Maximale Länge der Antwort (1-100000)",
83-
"request_interval": "Minimale Anfrageintervall (0,1-60 Sekunden)",
92+
"request_interval": "Minimales Anfrageintervall (0,1-60 Sekunden)",
8493
"api_timeout": "API-Anfrage Timeout in Sekunden (5-600)",
8594
"context_messages": "Anzahl der vorherigen Nachrichten, die im Kontext enthalten sein sollen (1-20)",
8695
"max_history_size": "Maximale Größe des Gesprächsverlaufs (1-100)"

custom_components/ha_text_ai/translations/en.json

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,9 +74,18 @@
7474
"options": {
7575
"step": {
7676
"init": {
77-
"title": "Update Instance Settings",
78-
"description": "Modify settings for this AI assistant instance.",
77+
"title": "Select Provider",
78+
"description": "Choose the AI provider for this instance. The integration will reload after saving changes.",
7979
"data": {
80+
"api_provider": "API Provider"
81+
}
82+
},
83+
"settings": {
84+
"title": "Connection & Model Settings",
85+
"description": "Configure API credentials and model parameters. Changes will take effect after the integration reloads.",
86+
"data": {
87+
"api_key": "API Key",
88+
"api_endpoint": "API Endpoint URL",
8089
"model": "AI model",
8190
"temperature": "Response creativity (0-2)",
8291
"max_tokens": "Maximum response length (1-100000)",

0 commit comments

Comments
 (0)