Index in-progress flows to avoid linear search (#58146)

Co-authored-by: Steven Looman <steven.looman@gmail.com>
This commit is contained in:
J. Nick Koston
2021-10-22 07:19:49 -10:00
committed by GitHub
parent fa56be7cc0
commit 3b7dce8b95
11 changed files with 190 additions and 64 deletions

View File

@@ -3,7 +3,7 @@ from __future__ import annotations
import abc
import asyncio
from collections.abc import Mapping
from collections.abc import Iterable, Mapping
from types import MappingProxyType
from typing import Any, TypedDict
import uuid
@@ -78,6 +78,23 @@ class FlowResult(TypedDict, total=False):
options: Mapping[str, Any]
@callback
def _async_flow_handler_to_flow_result(
flows: Iterable[FlowHandler], include_uninitialized: bool
) -> list[FlowResult]:
"""Convert a list of FlowHandler to a partial FlowResult that can be serialized."""
return [
{
"flow_id": flow.flow_id,
"handler": flow.handler,
"context": flow.context,
"step_id": flow.cur_step["step_id"] if flow.cur_step else None,
}
for flow in flows
if include_uninitialized or flow.cur_step is not None
]
class FlowManager(abc.ABC):
"""Manage all the flows that are in progress."""
@@ -89,7 +106,8 @@ class FlowManager(abc.ABC):
self.hass = hass
self._initializing: dict[str, list[asyncio.Future]] = {}
self._initialize_tasks: dict[str, list[asyncio.Task]] = {}
self._progress: dict[str, Any] = {}
self._progress: dict[str, FlowHandler] = {}
self._handler_progress_index: dict[str, set[str]] = {}
async def async_wait_init_flow_finish(self, handler: str) -> None:
"""Wait till all flows in progress are initialized."""
@@ -127,24 +145,39 @@ class FlowManager(abc.ABC):
"""Check if an existing matching flow is in progress with the same handler, context, and data."""
return any(
flow
for flow in self._progress.values()
if flow.handler == handler
and flow.context["source"] == context["source"]
and flow.init_data == data
for flow in self._async_progress_by_handler(handler)
if flow.context["source"] == context["source"] and flow.init_data == data
)
@callback
def async_get(self, flow_id: str) -> FlowResult | None:
"""Return a flow in progress as a partial FlowResult."""
if (flow := self._progress.get(flow_id)) is None:
raise UnknownFlow
return _async_flow_handler_to_flow_result([flow], False)[0]
@callback
def async_progress(self, include_uninitialized: bool = False) -> list[FlowResult]:
"""Return the flows in progress."""
"""Return the flows in progress as a partial FlowResult."""
return _async_flow_handler_to_flow_result(
self._progress.values(), include_uninitialized
)
@callback
def async_progress_by_handler(
self, handler: str, include_uninitialized: bool = False
) -> list[FlowResult]:
"""Return the flows in progress by handler as a partial FlowResult."""
return _async_flow_handler_to_flow_result(
self._async_progress_by_handler(handler), include_uninitialized
)
@callback
def _async_progress_by_handler(self, handler: str) -> list[FlowHandler]:
"""Return the flows in progress by handler."""
return [
{
"flow_id": flow.flow_id,
"handler": flow.handler,
"context": flow.context,
"step_id": flow.cur_step["step_id"] if flow.cur_step else None,
}
for flow in self._progress.values()
if include_uninitialized or flow.cur_step is not None
self._progress[flow_id]
for flow_id in self._handler_progress_index.get(handler, {})
]
async def async_init(
@@ -187,7 +220,7 @@ class FlowManager(abc.ABC):
flow.flow_id = uuid.uuid4().hex
flow.context = context
flow.init_data = data
self._progress[flow.flow_id] = flow
self._async_add_flow_progress(flow)
result = await self._async_handle_step(flow, flow.init_step, data, init_done)
return flow, result
@@ -205,6 +238,7 @@ class FlowManager(abc.ABC):
raise UnknownFlow
cur_step = flow.cur_step
assert cur_step is not None
if cur_step.get("data_schema") is not None and user_input is not None:
user_input = cur_step["data_schema"](user_input)
@@ -245,8 +279,24 @@ class FlowManager(abc.ABC):
@callback
def async_abort(self, flow_id: str) -> None:
"""Abort a flow."""
if self._progress.pop(flow_id, None) is None:
self._async_remove_flow_progress(flow_id)
@callback
def _async_add_flow_progress(self, flow: FlowHandler) -> None:
"""Add a flow to in progress."""
self._progress[flow.flow_id] = flow
self._handler_progress_index.setdefault(flow.handler, set()).add(flow.flow_id)
@callback
def _async_remove_flow_progress(self, flow_id: str) -> None:
"""Remove a flow from in progress."""
flow = self._progress.pop(flow_id, None)
if flow is None:
raise UnknownFlow
handler = flow.handler
self._handler_progress_index[handler].remove(flow.flow_id)
if not self._handler_progress_index[handler]:
del self._handler_progress_index[handler]
async def _async_handle_step(
self,
@@ -259,7 +309,7 @@ class FlowManager(abc.ABC):
method = f"async_step_{step_id}"
if not hasattr(flow, method):
self._progress.pop(flow.flow_id)
self._async_remove_flow_progress(flow.flow_id)
if step_done:
step_done.set_result(None)
raise UnknownStep(
@@ -310,7 +360,7 @@ class FlowManager(abc.ABC):
return result
# Abort and Success results both finish the flow
self._progress.pop(flow.flow_id)
self._async_remove_flow_progress(flow.flow_id)
return result
@@ -319,7 +369,7 @@ class FlowHandler:
"""Handle the configuration flow of a component."""
# Set by flow manager
cur_step: dict[str, str] | None = None
cur_step: dict[str, Any] | None = None
# While not purely typed, it makes typehinting more useful for us
# and removes the need for constant None checks or asserts.