Index in-progress flows to avoid linear search (#58146)
Co-authored-by: Steven Looman <steven.looman@gmail.com>
This commit is contained in:
@@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
|
||||
import abc
|
||||
import asyncio
|
||||
from collections.abc import Mapping
|
||||
from collections.abc import Iterable, Mapping
|
||||
from types import MappingProxyType
|
||||
from typing import Any, TypedDict
|
||||
import uuid
|
||||
@@ -78,6 +78,23 @@ class FlowResult(TypedDict, total=False):
|
||||
options: Mapping[str, Any]
|
||||
|
||||
|
||||
@callback
|
||||
def _async_flow_handler_to_flow_result(
|
||||
flows: Iterable[FlowHandler], include_uninitialized: bool
|
||||
) -> list[FlowResult]:
|
||||
"""Convert a list of FlowHandler to a partial FlowResult that can be serialized."""
|
||||
return [
|
||||
{
|
||||
"flow_id": flow.flow_id,
|
||||
"handler": flow.handler,
|
||||
"context": flow.context,
|
||||
"step_id": flow.cur_step["step_id"] if flow.cur_step else None,
|
||||
}
|
||||
for flow in flows
|
||||
if include_uninitialized or flow.cur_step is not None
|
||||
]
|
||||
|
||||
|
||||
class FlowManager(abc.ABC):
|
||||
"""Manage all the flows that are in progress."""
|
||||
|
||||
@@ -89,7 +106,8 @@ class FlowManager(abc.ABC):
|
||||
self.hass = hass
|
||||
self._initializing: dict[str, list[asyncio.Future]] = {}
|
||||
self._initialize_tasks: dict[str, list[asyncio.Task]] = {}
|
||||
self._progress: dict[str, Any] = {}
|
||||
self._progress: dict[str, FlowHandler] = {}
|
||||
self._handler_progress_index: dict[str, set[str]] = {}
|
||||
|
||||
async def async_wait_init_flow_finish(self, handler: str) -> None:
|
||||
"""Wait till all flows in progress are initialized."""
|
||||
@@ -127,24 +145,39 @@ class FlowManager(abc.ABC):
|
||||
"""Check if an existing matching flow is in progress with the same handler, context, and data."""
|
||||
return any(
|
||||
flow
|
||||
for flow in self._progress.values()
|
||||
if flow.handler == handler
|
||||
and flow.context["source"] == context["source"]
|
||||
and flow.init_data == data
|
||||
for flow in self._async_progress_by_handler(handler)
|
||||
if flow.context["source"] == context["source"] and flow.init_data == data
|
||||
)
|
||||
|
||||
@callback
|
||||
def async_get(self, flow_id: str) -> FlowResult | None:
|
||||
"""Return a flow in progress as a partial FlowResult."""
|
||||
if (flow := self._progress.get(flow_id)) is None:
|
||||
raise UnknownFlow
|
||||
return _async_flow_handler_to_flow_result([flow], False)[0]
|
||||
|
||||
@callback
|
||||
def async_progress(self, include_uninitialized: bool = False) -> list[FlowResult]:
|
||||
"""Return the flows in progress."""
|
||||
"""Return the flows in progress as a partial FlowResult."""
|
||||
return _async_flow_handler_to_flow_result(
|
||||
self._progress.values(), include_uninitialized
|
||||
)
|
||||
|
||||
@callback
|
||||
def async_progress_by_handler(
|
||||
self, handler: str, include_uninitialized: bool = False
|
||||
) -> list[FlowResult]:
|
||||
"""Return the flows in progress by handler as a partial FlowResult."""
|
||||
return _async_flow_handler_to_flow_result(
|
||||
self._async_progress_by_handler(handler), include_uninitialized
|
||||
)
|
||||
|
||||
@callback
|
||||
def _async_progress_by_handler(self, handler: str) -> list[FlowHandler]:
|
||||
"""Return the flows in progress by handler."""
|
||||
return [
|
||||
{
|
||||
"flow_id": flow.flow_id,
|
||||
"handler": flow.handler,
|
||||
"context": flow.context,
|
||||
"step_id": flow.cur_step["step_id"] if flow.cur_step else None,
|
||||
}
|
||||
for flow in self._progress.values()
|
||||
if include_uninitialized or flow.cur_step is not None
|
||||
self._progress[flow_id]
|
||||
for flow_id in self._handler_progress_index.get(handler, {})
|
||||
]
|
||||
|
||||
async def async_init(
|
||||
@@ -187,7 +220,7 @@ class FlowManager(abc.ABC):
|
||||
flow.flow_id = uuid.uuid4().hex
|
||||
flow.context = context
|
||||
flow.init_data = data
|
||||
self._progress[flow.flow_id] = flow
|
||||
self._async_add_flow_progress(flow)
|
||||
result = await self._async_handle_step(flow, flow.init_step, data, init_done)
|
||||
return flow, result
|
||||
|
||||
@@ -205,6 +238,7 @@ class FlowManager(abc.ABC):
|
||||
raise UnknownFlow
|
||||
|
||||
cur_step = flow.cur_step
|
||||
assert cur_step is not None
|
||||
|
||||
if cur_step.get("data_schema") is not None and user_input is not None:
|
||||
user_input = cur_step["data_schema"](user_input)
|
||||
@@ -245,8 +279,24 @@ class FlowManager(abc.ABC):
|
||||
@callback
|
||||
def async_abort(self, flow_id: str) -> None:
|
||||
"""Abort a flow."""
|
||||
if self._progress.pop(flow_id, None) is None:
|
||||
self._async_remove_flow_progress(flow_id)
|
||||
|
||||
@callback
|
||||
def _async_add_flow_progress(self, flow: FlowHandler) -> None:
|
||||
"""Add a flow to in progress."""
|
||||
self._progress[flow.flow_id] = flow
|
||||
self._handler_progress_index.setdefault(flow.handler, set()).add(flow.flow_id)
|
||||
|
||||
@callback
|
||||
def _async_remove_flow_progress(self, flow_id: str) -> None:
|
||||
"""Remove a flow from in progress."""
|
||||
flow = self._progress.pop(flow_id, None)
|
||||
if flow is None:
|
||||
raise UnknownFlow
|
||||
handler = flow.handler
|
||||
self._handler_progress_index[handler].remove(flow.flow_id)
|
||||
if not self._handler_progress_index[handler]:
|
||||
del self._handler_progress_index[handler]
|
||||
|
||||
async def _async_handle_step(
|
||||
self,
|
||||
@@ -259,7 +309,7 @@ class FlowManager(abc.ABC):
|
||||
method = f"async_step_{step_id}"
|
||||
|
||||
if not hasattr(flow, method):
|
||||
self._progress.pop(flow.flow_id)
|
||||
self._async_remove_flow_progress(flow.flow_id)
|
||||
if step_done:
|
||||
step_done.set_result(None)
|
||||
raise UnknownStep(
|
||||
@@ -310,7 +360,7 @@ class FlowManager(abc.ABC):
|
||||
return result
|
||||
|
||||
# Abort and Success results both finish the flow
|
||||
self._progress.pop(flow.flow_id)
|
||||
self._async_remove_flow_progress(flow.flow_id)
|
||||
|
||||
return result
|
||||
|
||||
@@ -319,7 +369,7 @@ class FlowHandler:
|
||||
"""Handle the configuration flow of a component."""
|
||||
|
||||
# Set by flow manager
|
||||
cur_step: dict[str, str] | None = None
|
||||
cur_step: dict[str, Any] | None = None
|
||||
|
||||
# While not purely typed, it makes typehinting more useful for us
|
||||
# and removes the need for constant None checks or asserts.
|
||||
|
||||
Reference in New Issue
Block a user