Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update tplink config to include aes keys #125685

Merged
merged 4 commits into from
Sep 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 84 additions & 34 deletions homeassistant/components/tplink/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from homeassistant.const import (
CONF_ALIAS,
CONF_AUTHENTICATION,
CONF_DEVICE,
CONF_HOST,
CONF_MAC,
CONF_MODEL,
Expand All @@ -44,8 +45,12 @@
from homeassistant.helpers.typing import ConfigType

from .const import (
CONF_AES_KEYS,
CONF_CONFIG_ENTRY_MINOR_VERSION,
CONF_CONNECTION_PARAMETERS,
CONF_CREDENTIALS_HASH,
CONF_DEVICE_CONFIG,
CONF_USES_HTTP,
CONNECT_TIMEOUT,
DISCOVERY_TIMEOUT,
DOMAIN,
Expand Down Expand Up @@ -85,9 +90,7 @@ def async_trigger_discovery(
CONF_ALIAS: device.alias or mac_alias(device.mac),
CONF_HOST: device.host,
CONF_MAC: formatted_mac,
CONF_DEVICE_CONFIG: device.config.to_dict(
exclude_credentials=True,
),
CONF_DEVICE: device,
},
)

Expand Down Expand Up @@ -136,25 +139,27 @@ async def async_setup_entry(hass: HomeAssistant, entry: TPLinkConfigEntry) -> bo
host: str = entry.data[CONF_HOST]
credentials = await get_credentials(hass)
entry_credentials_hash = entry.data.get(CONF_CREDENTIALS_HASH)
entry_use_http = entry.data.get(CONF_USES_HTTP, False)
entry_aes_keys = entry.data.get(CONF_AES_KEYS)

config: DeviceConfig | None = None
if config_dict := entry.data.get(CONF_DEVICE_CONFIG):
conn_params: Device.ConnectionParameters | None = None
if conn_params_dict := entry.data.get(CONF_CONNECTION_PARAMETERS):
try:
config = DeviceConfig.from_dict(config_dict)
conn_params = Device.ConnectionParameters.from_dict(conn_params_dict)
except KasaException:
_LOGGER.warning(
"Invalid connection type dict for %s: %s", host, config_dict
"Invalid connection parameters dict for %s: %s", host, conn_params_dict
)

if not config:
config = DeviceConfig(host)
else:
config.host = host

config.timeout = CONNECT_TIMEOUT
if config.uses_http is True:
config.http_client = create_async_tplink_clientsession(hass)

client = create_async_tplink_clientsession(hass) if entry_use_http else None
config = DeviceConfig(
host,
timeout=CONNECT_TIMEOUT,
http_client=client,
aes_keys=entry_aes_keys,
)
if conn_params:
config.connection_type = conn_params
# If we have in memory credentials use them otherwise check for credentials_hash
if credentials:
config.credentials = credentials
Expand All @@ -173,14 +178,15 @@ async def async_setup_entry(hass: HomeAssistant, entry: TPLinkConfigEntry) -> bo
raise ConfigEntryNotReady from ex

device_credentials_hash = device.credentials_hash
device_config_dict = device.config.to_dict(exclude_credentials=True)
# Do not store the credentials hash inside the device_config
device_config_dict.pop(CONF_CREDENTIALS_HASH, None)

# We not need to update the connection parameters or the use_http here
# because if they were wrong we would have failed to connect.
# Discovery will update those if necessary.
updates: dict[str, Any] = {}
if device_credentials_hash and device_credentials_hash != entry_credentials_hash:
updates[CONF_CREDENTIALS_HASH] = device_credentials_hash
if device_config_dict != config_dict:
updates[CONF_DEVICE_CONFIG] = device_config_dict
if entry_aes_keys != device.config.aes_keys:
updates[CONF_AES_KEYS] = device.config.aes_keys
if entry.data.get(CONF_ALIAS) != device.alias:
updates[CONF_ALIAS] = device.alias
if entry.data.get(CONF_MODEL) != device.model:
Expand Down Expand Up @@ -307,12 +313,20 @@ def _device_id_is_mac_or_none(mac: str, device_ids: Iterable[str]) -> str | None

async def async_migrate_entry(hass: HomeAssistant, config_entry: ConfigEntry) -> bool:
"""Migrate old entry."""
version = config_entry.version
minor_version = config_entry.minor_version

_LOGGER.debug("Migrating from version %s.%s", version, minor_version)

if version == 1 and minor_version < 3:
entry_version = config_entry.version
entry_minor_version = config_entry.minor_version
# having a condition to check for the current version allows
# tests to be written per migration step.
config_flow_minor_version = CONF_CONFIG_ENTRY_MINOR_VERSION

new_minor_version = 3
if (
entry_version == 1
and entry_minor_version < new_minor_version <= config_flow_minor_version
):
_LOGGER.debug(
"Migrating from version %s.%s", entry_version, entry_minor_version
)
# Previously entities on child devices added themselves to the parent
# device and set their device id as identifiers along with mac
# as a connection which creates a single device entry linked by all
Expand Down Expand Up @@ -359,28 +373,64 @@ async def async_migrate_entry(hass: HomeAssistant, config_entry: ConfigEntry) ->
new_identifiers,
)

minor_version = 3
hass.config_entries.async_update_entry(config_entry, minor_version=3)
hass.config_entries.async_update_entry(
config_entry, minor_version=new_minor_version
)

_LOGGER.debug("Migration to version %s.%s complete", version, minor_version)
_LOGGER.debug(
"Migration to version %s.%s complete", entry_version, new_minor_version
)

if version == 1 and minor_version == 3:
new_minor_version = 4
if (
entry_version == 1
and entry_minor_version < new_minor_version <= config_flow_minor_version
):
# credentials_hash stored in the device_config should be moved to data.
updates: dict[str, Any] = {}
if config_dict := config_entry.data.get(CONF_DEVICE_CONFIG):
assert isinstance(config_dict, dict)
if credentials_hash := config_dict.pop(CONF_CREDENTIALS_HASH, None):
updates[CONF_CREDENTIALS_HASH] = credentials_hash
updates[CONF_DEVICE_CONFIG] = config_dict
minor_version = 4
hass.config_entries.async_update_entry(
config_entry,
data={
**config_entry.data,
**updates,
},
minor_version=minor_version,
minor_version=new_minor_version,
)
_LOGGER.debug(
"Migration to version %s.%s complete", entry_version, new_minor_version
)
_LOGGER.debug("Migration to version %s.%s complete", version, minor_version)

new_minor_version = 5
if (
entry_version == 1
and entry_minor_version < new_minor_version <= config_flow_minor_version
):
# complete device config no longer to be stored, only required
# attributes like connection parameters and aes_keys
updates = {}
entry_data = {
k: v for k, v in config_entry.data.items() if k != CONF_DEVICE_CONFIG
}
if config_dict := config_entry.data.get(CONF_DEVICE_CONFIG):
assert isinstance(config_dict, dict)
if connection_parameters := config_dict.get("connection_type"):
updates[CONF_CONNECTION_PARAMETERS] = connection_parameters
if (use_http := config_dict.get(CONF_USES_HTTP)) is not None:
updates[CONF_USES_HTTP] = use_http
hass.config_entries.async_update_entry(
config_entry,
data={
**entry_data,
**updates,
},
minor_version=new_minor_version,
)
_LOGGER.debug(
"Migration to version %s.%s complete", entry_version, new_minor_version
)
return True
77 changes: 42 additions & 35 deletions homeassistant/components/tplink/config_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,11 @@
set_credentials,
)
from .const import (
CONF_CONNECTION_TYPE,
CONF_AES_KEYS,
CONF_CONFIG_ENTRY_MINOR_VERSION,
CONF_CONNECTION_PARAMETERS,
CONF_CREDENTIALS_HASH,
CONF_DEVICE_CONFIG,
CONF_USES_HTTP,
CONNECT_TIMEOUT,
DOMAIN,
)
Expand All @@ -64,7 +66,7 @@ class TPLinkConfigFlow(ConfigFlow, domain=DOMAIN):
"""Handle a config flow for tplink."""

VERSION = 1
MINOR_VERSION = 4
MINOR_VERSION = CONF_CONFIG_ENTRY_MINOR_VERSION
reauth_entry: ConfigEntry | None = None

def __init__(self) -> None:
Expand All @@ -87,46 +89,51 @@ async def async_step_integration_discovery(
return await self._async_handle_discovery(
discovery_info[CONF_HOST],
discovery_info[CONF_MAC],
discovery_info[CONF_DEVICE_CONFIG],
discovery_info[CONF_DEVICE],
)

@callback
def _get_config_updates(
self, entry: ConfigEntry, host: str, config: dict
self, entry: ConfigEntry, host: str, device: Device | None
) -> dict | None:
"""Return updates if the host or device config has changed."""
entry_data = entry.data
entry_config_dict = entry_data.get(CONF_DEVICE_CONFIG)
if entry_config_dict == config and entry_data[CONF_HOST] == host:
updates: dict[str, Any] = {}
new_connection_params = False
if entry_data[CONF_HOST] != host:
updates[CONF_HOST] = host
if device:
device_conn_params_dict = device.config.connection_type.to_dict()
entry_conn_params_dict = entry_data.get(CONF_CONNECTION_PARAMETERS)
if device_conn_params_dict != entry_conn_params_dict:
new_connection_params = True
updates[CONF_CONNECTION_PARAMETERS] = device_conn_params_dict
updates[CONF_USES_HTTP] = device.config.uses_http
if not updates:
return None
updates = {**entry.data, CONF_DEVICE_CONFIG: config, CONF_HOST: host}
updates = {**entry.data, **updates}
# If the connection parameters have changed the credentials_hash will be invalid.
if (
entry_config_dict
and isinstance(entry_config_dict, dict)
and entry_config_dict.get(CONF_CONNECTION_TYPE)
!= config.get(CONF_CONNECTION_TYPE)
):
if new_connection_params:
updates.pop(CONF_CREDENTIALS_HASH, None)
_LOGGER.debug(
"Connection type changed for %s from %s to: %s",
host,
entry_config_dict.get(CONF_CONNECTION_TYPE),
config.get(CONF_CONNECTION_TYPE),
entry_conn_params_dict,
device_conn_params_dict,
)
return updates

@callback
def _update_config_if_entry_in_setup_error(
self, entry: ConfigEntry, host: str, config: dict
self, entry: ConfigEntry, host: str, device: Device | None
) -> ConfigFlowResult | None:
"""If discovery encounters a device that is in SETUP_ERROR or SETUP_RETRY update the device config."""
if entry.state not in (
ConfigEntryState.SETUP_ERROR,
ConfigEntryState.SETUP_RETRY,
):
return None
if updates := self._get_config_updates(entry, host, config):
if updates := self._get_config_updates(entry, host, device):
return self.async_update_reload_and_abort(
entry,
data=updates,
Expand All @@ -135,19 +142,15 @@ def _update_config_if_entry_in_setup_error(
return None

async def _async_handle_discovery(
self, host: str, formatted_mac: str, config: dict | None = None
self, host: str, formatted_mac: str, device: Device | None = None
) -> ConfigFlowResult:
"""Handle any discovery."""
current_entry = await self.async_set_unique_id(
formatted_mac, raise_on_progress=False
)
if (
config
and current_entry
and (
result := self._update_config_if_entry_in_setup_error(
current_entry, host, config
)
if current_entry and (
result := self._update_config_if_entry_in_setup_error(
current_entry, host, device
)
):
return result
Expand All @@ -159,9 +162,13 @@ async def _async_handle_discovery(
return self.async_abort(reason="already_in_progress")
credentials = await get_credentials(self.hass)
try:
await self._async_try_discover_and_update(
host, credentials, raise_on_progress=True
)
if device:
self._discovered_device = device
await self._async_try_connect(device, credentials)
else:
await self._async_try_discover_and_update(
host, credentials, raise_on_progress=True
)
except AuthenticationError:
return await self.async_step_discovery_auth_confirm()
except KasaException:
Expand Down Expand Up @@ -381,14 +388,15 @@ def _async_create_entry_from_device(self, device: Device) -> ConfigFlowResult:
# This is only ever called after a successful device update so we know that
# the credential_hash is correct and should be saved.
self._abort_if_unique_id_configured(updates={CONF_HOST: device.host})
data = {
data: dict[str, Any] = {
CONF_HOST: device.host,
CONF_ALIAS: device.alias,
CONF_MODEL: device.model,
CONF_DEVICE_CONFIG: device.config.to_dict(
exclude_credentials=True,
),
CONF_CONNECTION_PARAMETERS: device.config.connection_type.to_dict(),
CONF_USES_HTTP: device.config.uses_http,
}
if device.config.aes_keys:
data[CONF_AES_KEYS] = device.config.aes_keys
if device.credentials_hash:
data[CONF_CREDENTIALS_HASH] = device.credentials_hash
return self.async_create_entry(
Expand Down Expand Up @@ -494,8 +502,7 @@ async def async_step_reauth_confirm(
placeholders["error"] = str(ex)
else:
await set_credentials(self.hass, username, password)
config = device.config.to_dict(exclude_credentials=True)
if updates := self._get_config_updates(reauth_entry, host, config):
if updates := self._get_config_updates(reauth_entry, host, device):
self.hass.config_entries.async_update_entry(
reauth_entry, data=updates
)
Expand Down
6 changes: 5 additions & 1 deletion homeassistant/components/tplink/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,11 @@

CONF_DEVICE_CONFIG: Final = "device_config"
CONF_CREDENTIALS_HASH: Final = "credentials_hash"
CONF_CONNECTION_TYPE: Final = "connection_type"
CONF_CONNECTION_PARAMETERS: Final = "connection_parameters"
CONF_USES_HTTP: Final = "uses_http"
CONF_AES_KEYS: Final = "aes_keys"

CONF_CONFIG_ENTRY_MINOR_VERSION: Final = 5

PLATFORMS: Final = [
Platform.BINARY_SENSOR,
Expand Down
Loading