Skip to content

Commit 0d52e99

Browse files
committed
Implement config entry overrides in entity_service_call
1 parent d1a7e64 commit 0d52e99

File tree

2 files changed

+93
-34
lines changed

2 files changed

+93
-34
lines changed

homeassistant/core.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
import asyncio
1010
from collections import UserDict, defaultdict
1111
from collections.abc import (
12-
Awaitable,
1312
Callable,
1413
Collection,
1514
Coroutine,
@@ -129,7 +128,7 @@
129128
type EntityServiceResponse = dict[str, ServiceResponse]
130129
type ConfigEntryServiceCallback = Callable[
131130
[ConfigEntry, set[Entity], ServiceCall],
132-
Awaitable[dict[str, EntityServiceResponse] | None],
131+
Coroutine[None, None, EntityServiceResponse | None],
133132
]
134133

135134

homeassistant/helpers/service.py

Lines changed: 92 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
HassJob,
3636
HassJobType,
3737
HomeAssistant,
38+
Service,
3839
ServiceCall,
3940
ServiceResponse,
4041
SupportsResponse,
@@ -728,11 +729,15 @@ def _filter_entities(
728729
entity_device_classes: Iterable[str | None] | None,
729730
required_features: Iterable[int] | None,
730731
referenced: target_helpers.SelectedEntities | None,
732+
service_obj: Service,
731733
domain: str,
732734
service: str,
733-
) -> list[Entity]:
735+
) -> tuple[list[Entity], dict[ConfigEntry, set[Entity]]]:
734736
"""Return a list of entities that pass availability, device class, and features."""
735737
filtered: list[Entity] = []
738+
per_entry_entities: dict[ConfigEntry, set[Entity]] = {
739+
ce: set() for ce in service_obj.overrides
740+
}
736741

737742
for entity in entity_candidates:
738743
if not entity.available:
@@ -754,10 +759,58 @@ def _filter_entities(
754759
if referenced and entity.entity_id in referenced.referenced:
755760
raise ServiceNotSupported(domain, service, entity.entity_id)
756761
continue
762+
ce = entity.platform.config_entry
763+
if ce in per_entry_entities:
764+
per_entry_entities[ce].add(entity)
765+
else:
766+
filtered.append(entity)
757767

758-
filtered.append(entity)
768+
return filtered, per_entry_entities
759769

760-
return filtered
770+
771+
async def _service_call_wrapper(
772+
*,
773+
hass: HomeAssistant,
774+
entities: set[Entity],
775+
handler: ConfigEntryServiceCallback | str | HassJob,
776+
config_entry: ConfigEntry | None = None,
777+
call: ServiceCall,
778+
data: dict | ServiceCall,
779+
) -> EntityServiceResponse:
780+
"""Execute a service call for a set of entities, either via normal handler or override.
781+
782+
Returns a dict mapping entities to ServiceResponse, None, or BaseException.
783+
All entities are included in the returned dict.
784+
785+
Raises:
786+
HomeAssistantError: If both or neither of `func` and `override_handler` are provided,
787+
or if `config_entry` is missing for an override.
788+
"""
789+
790+
if not entities:
791+
raise HomeAssistantError("No entities provided for service call")
792+
793+
gating_entity = next(iter(entities))
794+
if callable(handler):
795+
# Override callback path
796+
if config_entry is None:
797+
raise HomeAssistantError(
798+
"`config_entry` must be provided when using an override callback"
799+
)
800+
result: EntityServiceResponse | None = await gating_entity.async_request_call(
801+
handler(config_entry, entities, call)
802+
)
803+
if result is None:
804+
return {entity.entity_id: None for entity in entities}
805+
return {entity.entity_id: result.get(entity.entity_id) for entity in entities}
806+
# Normal entity service path
807+
if len(entities) != 1:
808+
raise HomeAssistantError("Normal service handler expects exactly one entity")
809+
810+
res: ServiceResponse | None = await gating_entity.async_request_call(
811+
_handle_entity_call(hass, gating_entity, handler, data, call.context)
812+
)
813+
return {gating_entity.entity_id: res}
761814

762815

763816
@bind_hass
@@ -826,53 +879,60 @@ async def entity_service_call(
826879
missing.discard(entity.entity_id)
827880
referenced.log_missing(missing, _LOGGER)
828881

829-
entities = _filter_entities(
882+
service_obj = hass.services.async_services_internal()[call.domain][call.service]
883+
entities, per_config_entities = _filter_entities(
830884
entity_candidates,
831885
entity_device_classes,
832886
required_features,
833887
referenced,
888+
service_obj,
834889
call.domain,
835890
call.service,
836891
)
837-
if not entities:
892+
if not entities and not any(per_config_entities.values()):
838893
if return_response:
839894
raise HomeAssistantError(
840895
"Service call requested response data but did not match any entities"
841896
)
842897
return None
898+
# Single entity optimization removed
843899

844-
if len(entities) == 1:
845-
# Single entity case avoids creating task
846-
entity = entities[0]
847-
single_response = await _handle_entity_call(
848-
hass, entity, func, data, call.context
900+
response_data: EntityServiceResponse = {}
901+
# For overrides: each config entry has a handler and set of entities
902+
override_coros = [
903+
_service_call_wrapper(
904+
hass=hass,
905+
entities=entities_set,
906+
handler=service_obj.overrides[ce],
907+
config_entry=ce,
908+
call=call,
909+
data=data,
849910
)
850-
if entity.should_poll:
851-
# Context expires if the turn on commands took a long time.
852-
# Set context again so it's there when we update
853-
entity.async_set_context(call.context)
854-
await entity.async_update_ha_state(True)
855-
return {entity.entity_id: single_response} if return_response else None
856-
857-
# Use asyncio.gather here to ensure the returned results
858-
# are in the same order as the entities list
859-
results: list[ServiceResponse | BaseException] = await asyncio.gather(
860-
*[
861-
entity.async_request_call(
862-
_handle_entity_call(hass, entity, func, data, call.context)
863-
)
864-
for entity in entities
865-
],
866-
return_exceptions=True,
911+
for ce, entities_set in per_config_entities.items()
912+
]
913+
914+
# For normal entities (not overridden)
915+
normal_coros = [
916+
_service_call_wrapper(
917+
hass=hass,
918+
entities={entity},
919+
handler=func,
920+
call=call,
921+
data=data,
922+
)
923+
for entity in entities
924+
]
925+
926+
all_results = await asyncio.gather(
927+
*override_coros, *normal_coros, return_exceptions=True
867928
)
868929

869-
response_data: EntityServiceResponse = {}
870-
for entity, result in zip(entities, results, strict=False):
930+
# Merge results into a single dict
931+
for result in all_results:
871932
if isinstance(result, BaseException):
872933
raise result from None
873-
response_data[entity.entity_id] = result
874-
875-
tasks: list[asyncio.Task[None]] = []
934+
response_data.update(result)
935+
tasks: list[asyncio.Task[None]] = []
876936

877937
for entity in entities:
878938
if not entity.should_poll:

0 commit comments

Comments
 (0)