Move device info validation to device registry (#96465)

* Move device info validation to device registry

* Don't move DeviceInfo

* Fix type annotation

* Don't block adding device for unknown config entry

* Fix test

* Remove use of locals()

* Improve error message
This commit is contained in:
Erik Montnemery
2023-07-14 14:55:17 +02:00
committed by GitHub
parent 3b32dcb613
commit 614f3c6a15
6 changed files with 159 additions and 126 deletions

View File

@@ -6,13 +6,14 @@ from collections.abc import Coroutine, ValuesView
import logging
import time
from typing import TYPE_CHECKING, Any, TypeVar, cast
from urllib.parse import urlparse
import attr
from homeassistant.backports.enum import StrEnum
from homeassistant.const import EVENT_HOMEASSISTANT_STARTED, EVENT_HOMEASSISTANT_STOP
from homeassistant.core import Event, HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError, RequiredParameterMissing
from homeassistant.exceptions import HomeAssistantError
from homeassistant.util.json import format_unserializable_data
import homeassistant.util.uuid as uuid_util
@@ -26,6 +27,7 @@ if TYPE_CHECKING:
from homeassistant.config_entries import ConfigEntry
from . import entity_registry
from .entity import DeviceInfo
_LOGGER = logging.getLogger(__name__)
@@ -60,6 +62,39 @@ DISABLED_CONFIG_ENTRY = DeviceEntryDisabler.CONFIG_ENTRY.value
DISABLED_INTEGRATION = DeviceEntryDisabler.INTEGRATION.value
DISABLED_USER = DeviceEntryDisabler.USER.value
DEVICE_INFO_TYPES = {
# Device info is categorized by finding the first device info type which has all
# the keys of the device info. The link device info type must be kept first
# to make it preferred over primary.
"link": {
"connections",
"identifiers",
},
"primary": {
"configuration_url",
"connections",
"entry_type",
"hw_version",
"identifiers",
"manufacturer",
"model",
"name",
"suggested_area",
"sw_version",
"via_device",
},
"secondary": {
"connections",
"default_manufacturer",
"default_model",
"default_name",
# Used by Fritz
"via_device",
},
}
DEVICE_INFO_KEYS = set.union(*(itm for itm in DEVICE_INFO_TYPES.values()))
class DeviceEntryType(StrEnum):
"""Device entry type."""
@@ -67,6 +102,66 @@ class DeviceEntryType(StrEnum):
SERVICE = "service"
class DeviceInfoError(HomeAssistantError):
"""Raised when device info is invalid."""
def __init__(self, domain: str, device_info: DeviceInfo, message: str) -> None:
"""Initialize error."""
super().__init__(
f"Invalid device info {device_info} for '{domain}' config entry: {message}",
)
self.device_info = device_info
self.domain = domain
def _validate_device_info(
config_entry: ConfigEntry | None,
device_info: DeviceInfo,
) -> str:
"""Process a device info."""
keys = set(device_info)
# If no keys or not enough info to match up, abort
if not device_info.get("connections") and not device_info.get("identifiers"):
raise DeviceInfoError(
config_entry.domain if config_entry else "unknown",
device_info,
"device info must include at least one of identifiers or connections",
)
device_info_type: str | None = None
# Find the first device info type which has all keys in the device info
for possible_type, allowed_keys in DEVICE_INFO_TYPES.items():
if keys <= allowed_keys:
device_info_type = possible_type
break
if device_info_type is None:
raise DeviceInfoError(
config_entry.domain if config_entry else "unknown",
device_info,
(
"device info needs to either describe a device, "
"link to existing device or provide extra information."
),
)
if (config_url := device_info.get("configuration_url")) is not None:
if type(config_url) is not str or urlparse(config_url).scheme not in [
"http",
"https",
"homeassistant",
]:
raise DeviceInfoError(
config_entry.domain if config_entry else "unknown",
device_info,
f"invalid configuration_url '{config_url}'",
)
return device_info_type
@attr.s(slots=True, frozen=True)
class DeviceEntry:
"""Device Registry Entry."""
@@ -338,7 +433,7 @@ class DeviceRegistry:
*,
config_entry_id: str,
configuration_url: str | None | UndefinedType = UNDEFINED,
connections: set[tuple[str, str]] | None = None,
connections: set[tuple[str, str]] | None | UndefinedType = UNDEFINED,
default_manufacturer: str | None | UndefinedType = UNDEFINED,
default_model: str | None | UndefinedType = UNDEFINED,
default_name: str | None | UndefinedType = UNDEFINED,
@@ -346,22 +441,47 @@ class DeviceRegistry:
disabled_by: DeviceEntryDisabler | None | UndefinedType = UNDEFINED,
entry_type: DeviceEntryType | None | UndefinedType = UNDEFINED,
hw_version: str | None | UndefinedType = UNDEFINED,
identifiers: set[tuple[str, str]] | None = None,
identifiers: set[tuple[str, str]] | None | UndefinedType = UNDEFINED,
manufacturer: str | None | UndefinedType = UNDEFINED,
model: str | None | UndefinedType = UNDEFINED,
name: str | None | UndefinedType = UNDEFINED,
suggested_area: str | None | UndefinedType = UNDEFINED,
sw_version: str | None | UndefinedType = UNDEFINED,
via_device: tuple[str, str] | None = None,
via_device: tuple[str, str] | None | UndefinedType = UNDEFINED,
) -> DeviceEntry:
"""Get device. Create if it doesn't exist."""
if not identifiers and not connections:
raise RequiredParameterMissing(["identifiers", "connections"])
if identifiers is None:
# Reconstruct a DeviceInfo dict from the arguments.
# When we upgrade to Python 3.12, we can change this method to instead
# accept kwargs typed as a DeviceInfo dict (PEP 692)
device_info: DeviceInfo = {}
for key, val in (
("configuration_url", configuration_url),
("connections", connections),
("default_manufacturer", default_manufacturer),
("default_model", default_model),
("default_name", default_name),
("entry_type", entry_type),
("hw_version", hw_version),
("identifiers", identifiers),
("manufacturer", manufacturer),
("model", model),
("name", name),
("suggested_area", suggested_area),
("sw_version", sw_version),
("via_device", via_device),
):
if val is UNDEFINED:
continue
device_info[key] = val # type: ignore[literal-required]
config_entry = self.hass.config_entries.async_get_entry(config_entry_id)
device_info_type = _validate_device_info(config_entry, device_info)
if identifiers is None or identifiers is UNDEFINED:
identifiers = set()
if connections is None:
if connections is None or connections is UNDEFINED:
connections = set()
else:
connections = _normalize_connections(connections)
@@ -378,6 +498,13 @@ class DeviceRegistry:
config_entry_id, connections, identifiers
)
self.devices[device.id] = device
# If creating a new device, default to the config entry name
if (
device_info_type == "primary"
and (not name or name is UNDEFINED)
and config_entry
):
name = config_entry.title
if default_manufacturer is not UNDEFINED and device.manufacturer is None:
manufacturer = default_manufacturer
@@ -388,7 +515,7 @@ class DeviceRegistry:
if default_name is not UNDEFINED and device.name is None:
name = default_name
if via_device is not None:
if via_device is not None and via_device is not UNDEFINED:
via = self.async_get_device(identifiers={via_device})
via_device_id: str | UndefinedType = via.id if via else UNDEFINED
else: