Add async_get_active_reauth_flows helper for config entries (#81881)
* Add `async_get_active_reauth_flows` helper for config entries * Code review * Code review + tests
This commit is contained in:
@@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections import ChainMap
|
||||
from collections.abc import Callable, Coroutine, Iterable, Mapping
|
||||
from collections.abc import Callable, Coroutine, Generator, Iterable, Mapping
|
||||
from contextvars import ContextVar
|
||||
from enum import Enum
|
||||
import functools
|
||||
@@ -19,6 +19,7 @@ from .backports.enum import StrEnum
|
||||
from .components import persistent_notification
|
||||
from .const import EVENT_HOMEASSISTANT_STARTED, EVENT_HOMEASSISTANT_STOP, Platform
|
||||
from .core import CALLBACK_TYPE, CoreState, Event, HomeAssistant, callback
|
||||
from .data_entry_flow import FlowResult
|
||||
from .exceptions import ConfigEntryAuthFailed, ConfigEntryNotReady, HomeAssistantError
|
||||
from .helpers import device_registry, entity_registry, storage
|
||||
from .helpers.dispatcher import async_dispatcher_connect, async_dispatcher_send
|
||||
@@ -662,12 +663,7 @@ class ConfigEntry:
|
||||
data: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""Start a reauth flow."""
|
||||
if any(
|
||||
flow
|
||||
for flow in hass.config_entries.flow.async_progress_by_handler(self.domain)
|
||||
if flow["context"].get("source") == SOURCE_REAUTH
|
||||
and flow["context"].get("entry_id") == self.entry_id
|
||||
):
|
||||
if any(self.async_get_active_flows(hass, {SOURCE_REAUTH})):
|
||||
# Reauth flow already in progress for this entry
|
||||
return
|
||||
|
||||
@@ -685,6 +681,18 @@ class ConfigEntry:
|
||||
)
|
||||
)
|
||||
|
||||
@callback
|
||||
def async_get_active_flows(
|
||||
self, hass: HomeAssistant, sources: set[str]
|
||||
) -> Generator[FlowResult, None, None]:
|
||||
"""Get any active flows of certain sources for this entry."""
|
||||
return (
|
||||
flow
|
||||
for flow in hass.config_entries.flow.async_progress_by_handler(self.domain)
|
||||
if flow["context"].get("source") in sources
|
||||
and flow["context"].get("entry_id") == self.entry_id
|
||||
)
|
||||
|
||||
@callback
|
||||
def async_create_task(
|
||||
self, hass: HomeAssistant, target: Coroutine[Any, Any, _R]
|
||||
|
||||
Reference in New Issue
Block a user