diff --git a/tests/test_quirks_v2.py b/tests/test_quirks_v2.py index 7c25795c5..5f07193a6 100644 --- a/tests/test_quirks_v2.py +++ b/tests/test_quirks_v2.py @@ -322,6 +322,48 @@ async def test_quirks_v2_with_node_descriptor(device_mock): assert quirked.node_desc == node_descriptor +async def test_quirks_v2_replace_occurrences(device_mock): + """Test adding a quirk that replaces all occurrences of a cluster.""" + registry = DeviceRegistry() + + device_mock[1].add_output_cluster(Identify.cluster_id) + + device_mock.add_endpoint(2) + device_mock[2].profile_id = 255 + device_mock[2].device_type = 255 + device_mock[2].add_input_cluster(Identify.cluster_id) + + device_mock.add_endpoint(3) + device_mock[3].profile_id = 255 + device_mock[3].device_type = 255 + device_mock[3].add_output_cluster(Identify.cluster_id) + + class CustomIdentifyCluster(CustomCluster, Identify): + """Custom identify cluster for testing quirks v2.""" + + ( + QuirkBuilder(device_mock.manufacturer, device_mock.model, registry=registry) + .replace_cluster_occurrences(CustomIdentifyCluster) + .add_to_registry() + ) + + quirked: CustomDeviceV2 = registry.get_device(device_mock) + assert isinstance(quirked, CustomDeviceV2) + + assert isinstance( + quirked.endpoints[1].in_clusters[Identify.cluster_id], CustomIdentifyCluster + ) + assert isinstance( + quirked.endpoints[1].out_clusters[Identify.cluster_id], CustomIdentifyCluster + ) + assert isinstance( + quirked.endpoints[2].in_clusters[Identify.cluster_id], CustomIdentifyCluster + ) + assert isinstance( + quirked.endpoints[3].out_clusters[Identify.cluster_id], CustomIdentifyCluster + ) + + async def test_quirks_v2_skip_configuration(device_mock): """Test adding a quirk that skips configuration to the registry.""" registry = DeviceRegistry() diff --git a/zigpy/quirks/v2/__init__.py b/zigpy/quirks/v2/__init__.py index 0a0f8d397..4051923ca 100644 --- a/zigpy/quirks/v2/__init__.py +++ b/zigpy/quirks/v2/__init__.py @@ -224,24 +224,30 @@ def __call__(self, device: CustomDeviceV2) -> None: class ReplaceClusterOccurrencesMetadata: """Replaces metadata for replacing all occurrences of a cluster on a device.""" - cluster_types: list[ClusterType] = attrs.field() + cluster_types: tuple[ClusterType] = attrs.field() cluster: type[Cluster | CustomCluster] = attrs.field() def __call__(self, device: CustomDeviceV2) -> None: """Process the replace.""" for endpoint in device.endpoints.values(): + if isinstance(endpoint, ZDO): + continue if ( ClusterType.Server in self.cluster_types and self.cluster.cluster_id in endpoint.in_clusters ): endpoint.in_clusters.pop(self.cluster.cluster_id) - endpoint.add_input_cluster(self.cluster.cluster_id, self.cluster) + endpoint.add_input_cluster( + self.cluster.cluster_id, self.cluster(endpoint) + ) if ( ClusterType.Client in self.cluster_types and self.cluster.cluster_id in endpoint.out_clusters ): endpoint.out_clusters.pop(self.cluster.cluster_id) - endpoint.add_output_cluster(self.cluster.cluster_id, self.cluster) + endpoint.add_output_cluster( + self.cluster.cluster_id, self.cluster(endpoint, is_server=False) + ) @attrs.define(frozen=True, kw_only=True, repr=True) @@ -599,7 +605,7 @@ def replace_cluster_occurrences( types.append(ClusterType.Client) self.replaces_cluster_occurrences_metadata.append( ReplaceClusterOccurrencesMetadata( # type: ignore[call-arg] - cluster_types=types, + cluster_types=tuple(types), cluster=replacement_cluster_class, ) )