@@ -479,74 +479,223 @@ def async_get_options_flow(config_entry: config_entries.ConfigEntry) -> config_e
479479class OptionsFlowHandler (config_entries .OptionsFlow ):
480480 """Handle options flow."""
481481
482- async def async_step_init (self , user_input : Optional [Dict [str , Any ]] = None ) -> FlowResult :
483- """Manage the options."""
484- if user_input is not None :
485- return self .async_create_entry (title = "" , data = user_input )
486-
487- current_data = {** self .config_entry .data , ** self .config_entry .options }
488- provider = current_data .get (CONF_API_PROVIDER )
482+ def __init__ (self ) -> None :
483+ """Initialize options flow."""
484+ self ._errors = {}
485+ self ._selected_provider = None
489486
490- default_model = (
487+ def _get_default_endpoint (self , provider : str ) -> str :
488+ """Get default endpoint for provider."""
489+ return {
490+ API_PROVIDER_OPENAI : DEFAULT_OPENAI_ENDPOINT ,
491+ API_PROVIDER_ANTHROPIC : DEFAULT_ANTHROPIC_ENDPOINT ,
492+ API_PROVIDER_DEEPSEEK : DEFAULT_DEEPSEEK_ENDPOINT ,
493+ API_PROVIDER_GEMINI : DEFAULT_GEMINI_ENDPOINT ,
494+ }.get (provider , DEFAULT_OPENAI_ENDPOINT )
495+
496+ def _get_default_model (self , provider : str ) -> str :
497+ """Get default model for provider."""
498+ return (
491499 DEFAULT_DEEPSEEK_MODEL if provider == API_PROVIDER_DEEPSEEK else
492500 DEFAULT_GEMINI_MODEL if provider == API_PROVIDER_GEMINI else
493501 DEFAULT_MODEL
494502 )
495503
504+ def _get_api_headers (self , api_key : str , provider : str ) -> Dict [str , str ]:
505+ """Get API headers based on provider."""
506+ if provider == API_PROVIDER_ANTHROPIC :
507+ return {
508+ "x-api-key" : api_key ,
509+ "anthropic-version" : "2023-06-01" ,
510+ "Content-Type" : "application/json"
511+ }
512+ return {
513+ "Authorization" : f"Bearer { api_key } " ,
514+ "Content-Type" : "application/json"
515+ }
516+
517+ async def _async_validate_api (self , provider : str , api_key : str , endpoint : str ) -> bool :
518+ """Validate API connection."""
519+ try :
520+ if not api_key :
521+ self ._errors ["base" ] = "invalid_auth"
522+ return False
523+
524+ # For Gemini, just check if API key is present
525+ if provider == API_PROVIDER_GEMINI :
526+ return True
527+
528+ session = async_get_clientsession (self .hass )
529+ headers = self ._get_api_headers (api_key , provider )
530+ endpoint = endpoint .rstrip ('/' )
531+
532+ check_url = (
533+ f"{ endpoint } /v1/models" if provider == API_PROVIDER_ANTHROPIC
534+ else f"{ endpoint } /models"
535+ )
536+
537+ async with session .get (check_url , headers = headers ) as response :
538+ if response .status == 401 :
539+ self ._errors ["base" ] = "invalid_auth"
540+ return False
541+ elif response .status not in [200 , 404 ]:
542+ self ._errors ["base" ] = "cannot_connect"
543+ return False
544+ return True
545+
546+ except Exception as err :
547+ _LOGGER .error ("API validation error: %s" , str (err ))
548+ self ._errors ["base" ] = "cannot_connect"
549+ return False
550+
551+ async def async_step_init (self , user_input : Optional [Dict [str , Any ]] = None ) -> FlowResult :
552+ """Handle provider selection step."""
553+ current_data = {** self .config_entry .data , ** self .config_entry .options }
554+ current_provider = current_data .get (CONF_API_PROVIDER , API_PROVIDER_OPENAI )
555+
556+ if user_input is not None :
557+ self ._selected_provider = user_input .get (CONF_API_PROVIDER , current_provider )
558+ return await self .async_step_settings ()
559+
496560 return self .async_show_form (
497561 step_id = "init" ,
498562 data_schema = vol .Schema ({
499- vol .Optional (
500- CONF_MODEL ,
501- default = current_data .get (CONF_MODEL , default_model )
502- ): str ,
503- vol .Optional (
504- CONF_TEMPERATURE ,
505- default = current_data .get (CONF_TEMPERATURE , DEFAULT_TEMPERATURE )
506- ): vol .All (
507- vol .Coerce (float ),
508- vol .Range (min = MIN_TEMPERATURE , max = MAX_TEMPERATURE )
509- ),
510- vol .Optional (
511- CONF_MAX_TOKENS ,
512- default = current_data .get (CONF_MAX_TOKENS , DEFAULT_MAX_TOKENS )
513- ): vol .All (
514- vol .Coerce (int ),
515- vol .Range (min = MIN_MAX_TOKENS , max = MAX_MAX_TOKENS )
516- ),
517- vol .Optional (
518- CONF_REQUEST_INTERVAL ,
519- default = current_data .get (CONF_REQUEST_INTERVAL , DEFAULT_REQUEST_INTERVAL )
520- ): vol .All (
521- vol .Coerce (float ),
522- vol .Range (min = MIN_REQUEST_INTERVAL )
523- ),
524- vol .Optional (
525- CONF_API_TIMEOUT ,
526- default = current_data .get (CONF_API_TIMEOUT , DEFAULT_API_TIMEOUT )
527- ): vol .All (
528- vol .Coerce (int ),
529- vol .Range (min = MIN_API_TIMEOUT , max = MAX_API_TIMEOUT )
530- ),
531- vol .Optional (
532- CONF_CONTEXT_MESSAGES ,
533- default = current_data .get (
534- CONF_CONTEXT_MESSAGES ,
535- DEFAULT_CONTEXT_MESSAGES
563+ vol .Required (
564+ CONF_API_PROVIDER ,
565+ default = current_provider
566+ ): selector .SelectSelector (
567+ selector .SelectSelectorConfig (
568+ options = API_PROVIDERS ,
569+ translation_key = "api_provider"
536570 )
537- ): vol .All (
538- vol .Coerce (int ),
539- vol .Range (min = 1 , max = 20 )
540571 ),
541- vol .Optional (
542- CONF_MAX_HISTORY_SIZE ,
543- default = current_data .get (
544- CONF_MAX_HISTORY_SIZE ,
545- DEFAULT_MAX_HISTORY
546- )
547- ): vol .All (
548- vol .Coerce (int ),
549- vol .Range (min = 1 , max = 100 )
572+ }),
573+ description_placeholders = {
574+ "current_provider" : current_provider
575+ }
576+ )
577+
578+ async def async_step_settings (self , user_input : Optional [Dict [str , Any ]] = None ) -> FlowResult :
579+ """Handle settings configuration step."""
580+ self ._errors = {}
581+ current_data = {** self .config_entry .data , ** self .config_entry .options }
582+ provider = self ._selected_provider or current_data .get (CONF_API_PROVIDER , API_PROVIDER_OPENAI )
583+
584+ # Determine if provider changed to show appropriate defaults
585+ provider_changed = provider != current_data .get (CONF_API_PROVIDER )
586+
587+ # Use new defaults if provider changed, otherwise use current values
588+ if provider_changed :
589+ default_endpoint = self ._get_default_endpoint (provider )
590+ default_model = self ._get_default_model (provider )
591+ else :
592+ default_endpoint = current_data .get (CONF_API_ENDPOINT , self ._get_default_endpoint (provider ))
593+ default_model = current_data .get (CONF_MODEL , self ._get_default_model (provider ))
594+
595+ if user_input is not None :
596+ # Validate API connection
597+ api_key = user_input .get (CONF_API_KEY , current_data .get (CONF_API_KEY , "" ))
598+ endpoint = user_input .get (CONF_API_ENDPOINT , default_endpoint )
599+
600+ if await self ._async_validate_api (provider , api_key , endpoint ):
601+ # Merge with provider selection
602+ final_data = {
603+ CONF_API_PROVIDER : provider ,
604+ ** user_input
605+ }
606+ return self .async_create_entry (title = "" , data = final_data )
607+
608+ # Show form again with errors
609+ return self .async_show_form (
610+ step_id = "settings" ,
611+ data_schema = self ._get_settings_schema (
612+ provider = provider ,
613+ current_data = current_data ,
614+ user_input = user_input ,
615+ default_endpoint = default_endpoint ,
616+ default_model = default_model ,
550617 ),
551- })
618+ errors = self ._errors
619+ )
620+
621+ return self .async_show_form (
622+ step_id = "settings" ,
623+ data_schema = self ._get_settings_schema (
624+ provider = provider ,
625+ current_data = current_data ,
626+ user_input = None ,
627+ default_endpoint = default_endpoint ,
628+ default_model = default_model ,
629+ ),
630+ description_placeholders = {
631+ "provider" : provider
632+ }
552633 )
634+
635+ def _get_settings_schema (
636+ self ,
637+ provider : str ,
638+ current_data : Dict [str , Any ],
639+ user_input : Optional [Dict [str , Any ]],
640+ default_endpoint : str ,
641+ default_model : str ,
642+ ) -> vol .Schema :
643+ """Build settings schema."""
644+ data = user_input or current_data
645+
646+ return vol .Schema ({
647+ vol .Required (
648+ CONF_API_KEY ,
649+ default = data .get (CONF_API_KEY , "" )
650+ ): str ,
651+ vol .Required (
652+ CONF_API_ENDPOINT ,
653+ default = data .get (CONF_API_ENDPOINT , default_endpoint )
654+ ): str ,
655+ vol .Required (
656+ CONF_MODEL ,
657+ default = data .get (CONF_MODEL , default_model )
658+ ): str ,
659+ vol .Optional (
660+ CONF_TEMPERATURE ,
661+ default = data .get (CONF_TEMPERATURE , DEFAULT_TEMPERATURE )
662+ ): vol .All (
663+ vol .Coerce (float ),
664+ vol .Range (min = MIN_TEMPERATURE , max = MAX_TEMPERATURE )
665+ ),
666+ vol .Optional (
667+ CONF_MAX_TOKENS ,
668+ default = data .get (CONF_MAX_TOKENS , DEFAULT_MAX_TOKENS )
669+ ): vol .All (
670+ vol .Coerce (int ),
671+ vol .Range (min = MIN_MAX_TOKENS , max = MAX_MAX_TOKENS )
672+ ),
673+ vol .Optional (
674+ CONF_REQUEST_INTERVAL ,
675+ default = data .get (CONF_REQUEST_INTERVAL , DEFAULT_REQUEST_INTERVAL )
676+ ): vol .All (
677+ vol .Coerce (float ),
678+ vol .Range (min = MIN_REQUEST_INTERVAL )
679+ ),
680+ vol .Optional (
681+ CONF_API_TIMEOUT ,
682+ default = data .get (CONF_API_TIMEOUT , DEFAULT_API_TIMEOUT )
683+ ): vol .All (
684+ vol .Coerce (int ),
685+ vol .Range (min = MIN_API_TIMEOUT , max = MAX_API_TIMEOUT )
686+ ),
687+ vol .Optional (
688+ CONF_CONTEXT_MESSAGES ,
689+ default = data .get (CONF_CONTEXT_MESSAGES , DEFAULT_CONTEXT_MESSAGES )
690+ ): vol .All (
691+ vol .Coerce (int ),
692+ vol .Range (min = 1 , max = 20 )
693+ ),
694+ vol .Optional (
695+ CONF_MAX_HISTORY_SIZE ,
696+ default = data .get (CONF_MAX_HISTORY_SIZE , DEFAULT_MAX_HISTORY )
697+ ): vol .All (
698+ vol .Coerce (int ),
699+ vol .Range (min = 1 , max = 100 )
700+ ),
701+ })
0 commit comments