Add async_get_active_reauth_flows helper for config entries (#81881)

* Add `async_get_active_reauth_flows` helper for config entries

* Code review

* Code review + tests
This commit is contained in:
Aaron Bach
2022-11-09 15:36:50 -07:00
committed by GitHub
parent 0941ed076c
commit adf84b0c62
3 changed files with 53 additions and 22 deletions

View File

@@ -3,7 +3,7 @@ from __future__ import annotations
import asyncio
from collections import ChainMap
from collections.abc import Callable, Coroutine, Iterable, Mapping
from collections.abc import Callable, Coroutine, Generator, Iterable, Mapping
from contextvars import ContextVar
from enum import Enum
import functools
@@ -19,6 +19,7 @@ from .backports.enum import StrEnum
from .components import persistent_notification
from .const import EVENT_HOMEASSISTANT_STARTED, EVENT_HOMEASSISTANT_STOP, Platform
from .core import CALLBACK_TYPE, CoreState, Event, HomeAssistant, callback
from .data_entry_flow import FlowResult
from .exceptions import ConfigEntryAuthFailed, ConfigEntryNotReady, HomeAssistantError
from .helpers import device_registry, entity_registry, storage
from .helpers.dispatcher import async_dispatcher_connect, async_dispatcher_send
@@ -662,12 +663,7 @@ class ConfigEntry:
data: dict[str, Any] | None = None,
) -> None:
"""Start a reauth flow."""
if any(
flow
for flow in hass.config_entries.flow.async_progress_by_handler(self.domain)
if flow["context"].get("source") == SOURCE_REAUTH
and flow["context"].get("entry_id") == self.entry_id
):
if any(self.async_get_active_flows(hass, {SOURCE_REAUTH})):
# Reauth flow already in progress for this entry
return
@@ -685,6 +681,18 @@ class ConfigEntry:
)
)
@callback
def async_get_active_flows(
self, hass: HomeAssistant, sources: set[str]
) -> Generator[FlowResult, None, None]:
"""Get any active flows of certain sources for this entry."""
return (
flow
for flow in hass.config_entries.flow.async_progress_by_handler(self.domain)
if flow["context"].get("source") in sources
and flow["context"].get("entry_id") == self.entry_id
)
@callback
def async_create_task(
self, hass: HomeAssistant, target: Coroutine[Any, Any, _R]