diff --git a/tests/test_quirks.py b/tests/test_quirks.py index be6969b0c..d11f6e78b 100644 --- a/tests/test_quirks.py +++ b/tests/test_quirks.py @@ -1,5 +1,9 @@ import asyncio +import importlib.util import itertools +import pathlib +import pkgutil +import sys from typing import Final import pytest @@ -1031,3 +1035,116 @@ class TestQuirk(zigpy.quirks.CustomDevice): assert len(request_mock.mock_calls) == 3 assert all(c == request_mock.mock_calls[0] for c in request_mock.mock_calls) + + +def test_purge_custom_quirks(tmp_path: pathlib.Path, app_mock) -> None: + def load_quirks(): + for importer, modname, _ in pkgutil.walk_packages(path=[str(tmp_path)]): + spec = importer.find_spec(modname) + module = importlib.util.module_from_spec(spec) + sys.modules[modname] = module + spec.loader.exec_module(module) + + (tmp_path / "quirk1.py").write_text(""" +import zigpy.quirks +from zigpy.zcl.clusters.general import LevelControl +from zigpy.const import ( + SIG_ENDPOINTS, + SIG_EP_INPUT, + SIG_EP_OUTPUT, + SIG_EP_PROFILE, + SIG_EP_TYPE, + SIG_MODELS_INFO, +) + +class CustomLevel1(zigpy.quirks.CustomCluster, LevelControl): + pass + +class TestQuirk1(zigpy.quirks.CustomDevice): + signature = { + SIG_MODELS_INFO: (("manufacturer1", "model1"),), + SIG_ENDPOINTS: { + 1: { + SIG_EP_PROFILE: 255, + SIG_EP_TYPE: 255, + SIG_EP_INPUT: [3], + SIG_EP_OUTPUT: [6], + } + }, + } + + replacement = { + SIG_ENDPOINTS: { + 1: { + SIG_EP_PROFILE: 255, + SIG_EP_TYPE: 255, + SIG_EP_INPUT: [3, CustomLevel1], + SIG_EP_OUTPUT: [6], + } + }, + }""") + + (tmp_path / "quirk2.py").write_text(""" +import zigpy.quirks + +from zigpy.quirks.v2 import QuirkBuilder +from zigpy.zcl import ClusterType +from zigpy.zcl.clusters.general import LevelControl + +class CustomLevel2(zigpy.quirks.CustomCluster, LevelControl): + pass + +QuirkBuilder("manufacturer2", "model2").adds( + cluster=CustomLevel2, + cluster_type=ClusterType.Server, + endpoint_id=1, +).add_to_registry() +""") + + dev1 = zigpy.device.Device( + app_mock, t.EUI64.convert("11:11:11:11:11:11:11:11"), 0x1234 + ) + dev1.add_endpoint(1) + dev1[1].profile_id = 255 + dev1[1].device_type = 255 + dev1.model = "model1" + dev1.manufacturer = "manufacturer1" + dev1[1].add_input_cluster(3) + dev1[1].add_output_cluster(6) + + dev2 = zigpy.device.Device( + app_mock, t.EUI64.convert("22:22:22:22:22:22:22:22"), 0x5678 + ) + dev2.add_endpoint(1) + dev2[1].profile_id = 255 + dev2[1].device_type = 255 + dev2.model = "model2" + dev2.manufacturer = "manufacturer2" + dev2[1].add_input_cluster(3) + dev2[1].add_output_cluster(6) + + registry = zigpy.quirks.DEVICE_REGISTRY + + assert not registry._registry.get("manufacturer1", {}).get("model1", []) + assert not registry._registry_v2.get(("manufacturer2", "model2"), set()) + + load_quirks() + + assert registry._registry.get("manufacturer1", {}).get("model1", []) + assert registry._registry_v2.get(("manufacturer2", "model2"), set()) + + assert type(registry.get_device(dev1)).__name__ == "TestQuirk1" + assert registry.get_device(dev2).quirk_metadata.quirk_file.name == "quirk2.py" + + # Only quirks from the passed directory are purged so this is a no-op + registry.purge_custom_quirks(tmp_path / "some_other_dir") + assert registry._registry.get("manufacturer1", {}).get("model1", []) + assert registry._registry_v2.get(("manufacturer2", "model2"), set()) + + # Now we really remove them + registry.purge_custom_quirks(tmp_path) + assert not registry._registry.get("manufacturer1", {}).get("model1", []) + assert not registry._registry_v2.get(("manufacturer2", "model2"), set()) + + assert registry.get_device(dev1) is dev1 + assert registry.get_device(dev2) is dev2 diff --git a/tests/test_quirks_v2.py b/tests/test_quirks_v2.py index 5f07193a6..747123194 100644 --- a/tests/test_quirks_v2.py +++ b/tests/test_quirks_v2.py @@ -124,9 +124,10 @@ class AttributeDefs(BaseAttributeDefs): # pylint: disable=too-few-public-method assert quirked in registry # this would need to be updated if the line number of the call to QuirkBuilder # changes in this test in the future - assert quirked.quirk_metadata.quirk_location.endswith( - "zigpy/tests/test_quirks_v2.py]-line:103" + assert str(quirked.quirk_metadata.quirk_file).endswith( + "zigpy/tests/test_quirks_v2.py" ) + assert quirked.quirk_metadata.quirk_file_line == 103 ep = quirked.endpoints[1] diff --git a/zigpy/quirks/__init__.py b/zigpy/quirks/__init__.py index bb2ac1a3f..188c5a0e5 100644 --- a/zigpy/quirks/__init__.py +++ b/zigpy/quirks/__init__.py @@ -31,7 +31,7 @@ _LOGGER = logging.getLogger(__name__) -_DEVICE_REGISTRY = DeviceRegistry() +DEVICE_REGISTRY = _DEVICE_REGISTRY = DeviceRegistry() _uninitialized_device_message_handlers = [] diff --git a/zigpy/quirks/registry.py b/zigpy/quirks/registry.py index 2101012be..0ed3813bc 100644 --- a/zigpy/quirks/registry.py +++ b/zigpy/quirks/registry.py @@ -3,8 +3,10 @@ from __future__ import annotations import collections +import inspect import itertools import logging +import pathlib import typing from typing import TYPE_CHECKING @@ -36,6 +38,36 @@ def __init__(self, *args, **kwargs) -> None: collections.defaultdict(set) ) + def purge_custom_quirks(self, custom_quirks_root: pathlib.Path) -> None: + # If zhaquirks aren't being used, we can't tell if a quirk is custom or not + for model_registry in self._registry.values(): + for quirks in model_registry.values(): + to_remove = [] + + for quirk in quirks: + module = inspect.getmodule(quirk) + assert module is not None # All quirks should have modules + + quirk_module = pathlib.Path(module.__file__) + + if quirk_module.is_relative_to(custom_quirks_root): + to_remove.append(quirk) + + for quirk in to_remove: + _LOGGER.debug("Removing stale custom v1 quirk: %s", quirk) + quirks.remove(quirk) + + for registry in self._registry_v2.values(): + to_remove = [] + + for entry in registry: + if entry.quirk_file.is_relative_to(custom_quirks_root): + to_remove.append(entry) + + for entry in to_remove: + _LOGGER.debug("Removing stale custom v2 quirk: %s", entry) + registry.remove(entry) + def add_to_registry(self, custom_device: CustomDeviceType) -> None: """Add a device to the registry""" models_info = custom_device.signature.get(SIG_MODELS_INFO) diff --git a/zigpy/quirks/v2/__init__.py b/zigpy/quirks/v2/__init__.py index 044f78f5d..5616bb9d3 100644 --- a/zigpy/quirks/v2/__init__.py +++ b/zigpy/quirks/v2/__init__.py @@ -6,6 +6,7 @@ from enum import Enum import inspect import logging +import pathlib import typing from typing import TYPE_CHECKING, Any @@ -351,7 +352,8 @@ class ManufacturerModelMetadata: class QuirksV2RegistryEntry: """Quirks V2 registry entry.""" - quirk_location: str = attrs.field(default=None, eq=False) + quirk_file: str = attrs.field(default=None, eq=False) + quirk_file_line: int = attrs.field(default=None, eq=False) manufacturer_model_metadata: tuple[ManufacturerModelMetadata] = attrs.field( factory=tuple ) @@ -423,9 +425,8 @@ def __init__( stack: list[inspect.FrameInfo] = inspect.stack() caller: inspect.FrameInfo = stack[1] - self.quirk_location: str | None = ( - f"file[{caller.filename}]-line:{caller.lineno}" - ) + self.quirk_file = pathlib.Path(caller.filename) + self.quirk_file_line = caller.lineno self.also_applies_to(manufacturer, model) UNBUILT_QUIRK_BUILDERS.append(self) @@ -854,7 +855,8 @@ def add_to_registry(self) -> QuirksV2RegistryEntry: """Build the quirks v2 registry entry.""" quirk: QuirksV2RegistryEntry = QuirksV2RegistryEntry( # type: ignore[call-arg] manufacturer_model_metadata=tuple(self.manufacturer_model_metadata), - quirk_location=self.quirk_location, + quirk_file=self.quirk_file, + quirk_file_line=self.quirk_file_line, filters=tuple(self.filters), custom_device_class=self.custom_device_class, device_node_descriptor=self.device_node_descriptor,