Skip to content

Refactor pyerverse Association Logic #10397

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 23 additions & 8 deletions pylint/pyreverse/diagrams.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,7 @@ def extract_relationships(self) -> None:
obj.attrs = self.get_attrs(node)
obj.methods = self.get_methods(node)
obj.shape = "class"

# inheritance link
for par_node in node.ancestors(recurs=False):
try:
Expand All @@ -234,24 +235,38 @@ def extract_relationships(self) -> None:
except KeyError:
continue

# associations & aggregations links
for name, values in list(node.aggregations_type.items()):
# Track processed attributes to avoid duplicates
processed_attrs = set()

# Composition links
for name, values in list(node.compositions_type.items()):
for value in values:
self.assign_association_relationship(
value, obj, name, "aggregation"
value, obj, name, "composition"
)
processed_attrs.add(name)

# Aggregation links
for name, values in list(node.aggregations_type.items()):
if name not in processed_attrs:
for value in values:
self.assign_association_relationship(
value, obj, name, "aggregation"
)
processed_attrs.add(name)

# Association links
associations = node.associations_type.copy()

for name, values in node.locals_type.items():
if name not in associations:
associations[name] = values

for name, values in associations.items():
for value in values:
self.assign_association_relationship(
value, obj, name, "association"
)
if name not in processed_attrs:
for value in values:
self.assign_association_relationship(
value, obj, name, "association"
)

def assign_association_relationship(
self, value: astroid.NodeNG, obj: ClassEntity, name: str, type_relationship: str
Expand Down
8 changes: 7 additions & 1 deletion pylint/pyreverse/dot_printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,18 @@ class HTMLLabels(Enum):
# pylint: disable-next=consider-using-namedtuple-or-dataclass
ARROWS: dict[EdgeType, dict[str, str]] = {
EdgeType.INHERITS: {"arrowtail": "none", "arrowhead": "empty"},
EdgeType.ASSOCIATION: {
EdgeType.COMPOSITION: {
"fontcolor": "green",
"arrowtail": "none",
"arrowhead": "diamond",
"style": "solid",
},
EdgeType.ASSOCIATION: {
"fontcolor": "green",
"arrowtail": "none",
"arrowhead": "normal",
"style": "solid",
},
EdgeType.AGGREGATION: {
"fontcolor": "green",
"arrowtail": "none",
Expand Down
59 changes: 51 additions & 8 deletions pylint/pyreverse/inspector.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,9 @@ class Linker(IdGeneratorMixIn, utils.LocalsVisitor):

* aggregations_type
as instance_attrs_type but for aggregations relationships

* compositions_type
as instance_attrs_type but for compositions relationships
"""

def __init__(self, project: Project, tag: bool = False) -> None:
Expand All @@ -122,8 +125,14 @@ def __init__(self, project: Project, tag: bool = False) -> None:
self.tag = tag
# visited project
self.project = project
self.associations_handler = AggregationsHandler()
self.associations_handler.set_next(OtherAssociationsHandler())

# Chain: Composition β†’ Aggregation β†’ Association
self.associations_handler = CompositionsHandler()
aggregation_handler = AggregationsHandler()
association_handler = AssociationsHandler()

self.associations_handler.set_next(aggregation_handler)
aggregation_handler.set_next(association_handler)

def visit_project(self, node: Project) -> None:
"""Visit a pyreverse.utils.Project node.
Expand Down Expand Up @@ -167,6 +176,7 @@ def visit_classdef(self, node: nodes.ClassDef) -> None:
specializations.append(node)
baseobj.specializations = specializations
# resolve instance attributes
node.compositions_type = collections.defaultdict(list)
node.instance_attrs_type = collections.defaultdict(list)
node.aggregations_type = collections.defaultdict(list)
node.associations_type = collections.defaultdict(list)
Expand Down Expand Up @@ -327,28 +337,50 @@ def handle(self, node: nodes.AssignAttr, parent: nodes.ClassDef) -> None:
self._next_handler.handle(node, parent)


class CompositionsHandler(AbstractAssociationHandler):
"""Handle composition relationships where parent creates child objects."""

def handle(self, node: nodes.AssignAttr, parent: nodes.ClassDef) -> None:
if not isinstance(node.parent, (nodes.AnnAssign, nodes.Assign)):
super().handle(node, parent)
return

value = node.parent.value

# Composition: parent creates child (self.x = P())
if isinstance(value, nodes.Call):
current = set(parent.compositions_type[node.attrname])
parent.compositions_type[node.attrname] = list(
current | utils.infer_node(node)
)
return

# Not a composition, pass to next handler
super().handle(node, parent)


class AggregationsHandler(AbstractAssociationHandler):
"""Handle aggregation relationships where parent receives child objects."""

def handle(self, node: nodes.AssignAttr, parent: nodes.ClassDef) -> None:
# Check if we're not in an assignment context
if not isinstance(node.parent, (nodes.AnnAssign, nodes.Assign)):
super().handle(node, parent)
return

value = node.parent.value

# Handle direct name assignments
# Aggregation: parent receives child (self.x = x)
if isinstance(value, astroid.node_classes.Name):
current = set(parent.aggregations_type[node.attrname])
parent.aggregations_type[node.attrname] = list(
current | utils.infer_node(node)
)
return

# Handle comprehensions
# Aggregation: comprehensions (self.x = [P() for ...])
if isinstance(
value, (nodes.ListComp, nodes.DictComp, nodes.SetComp, nodes.GeneratorExp)
):
# Determine the type of the element in the comprehension
if isinstance(value, nodes.DictComp):
element_type = safe_infer(value.value)
else:
Expand All @@ -358,12 +390,23 @@ def handle(self, node: nodes.AssignAttr, parent: nodes.ClassDef) -> None:
parent.aggregations_type[node.attrname] = list(current | {element_type})
return

# Fallback to parent handler
# Type annotation only (x: P) defaults to aggregation
if isinstance(node.parent, nodes.AnnAssign) and node.parent.value is None:
current = set(parent.aggregations_type[node.attrname])
parent.aggregations_type[node.attrname] = list(
current | utils.infer_node(node)
)
return

# Not an aggregation, pass to next handler
super().handle(node, parent)


class OtherAssociationsHandler(AbstractAssociationHandler):
class AssociationsHandler(AbstractAssociationHandler):
"""Handle regular association relationships."""

def handle(self, node: nodes.AssignAttr, parent: nodes.ClassDef) -> None:
# Everything else is a regular association
current = set(parent.associations_type[node.attrname])
parent.associations_type[node.attrname] = list(current | utils.infer_node(node))

Expand Down
3 changes: 2 additions & 1 deletion pylint/pyreverse/mermaidjs_printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ class MermaidJSPrinter(Printer):
}
ARROWS: dict[EdgeType, str] = {
EdgeType.INHERITS: "--|>",
EdgeType.ASSOCIATION: "--*",
EdgeType.COMPOSITION: "--*",
EdgeType.ASSOCIATION: "-->",
EdgeType.AGGREGATION: "--o",
EdgeType.USES: "-->",
EdgeType.TYPE_DEPENDENCY: "..>",
Expand Down
3 changes: 2 additions & 1 deletion pylint/pyreverse/plantuml_printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ class PlantUmlPrinter(Printer):
}
ARROWS: dict[EdgeType, str] = {
EdgeType.INHERITS: "--|>",
EdgeType.ASSOCIATION: "--*",
EdgeType.ASSOCIATION: "-->",
EdgeType.COMPOSITION: "--*",
EdgeType.AGGREGATION: "--o",
EdgeType.USES: "-->",
EdgeType.TYPE_DEPENDENCY: "..>",
Expand Down
1 change: 1 addition & 0 deletions pylint/pyreverse/printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class NodeType(Enum):

class EdgeType(Enum):
INHERITS = "inherits"
COMPOSITION = "composition"
ASSOCIATION = "association"
AGGREGATION = "aggregation"
USES = "uses"
Expand Down
8 changes: 8 additions & 0 deletions pylint/pyreverse/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,14 @@ def write_classes(self, diagram: ClassDiagram) -> None:
label=rel.name,
type_=EdgeType.ASSOCIATION,
)
# generate compositions
for rel in diagram.get_relationships("composition"):
self.printer.emit_edge(
rel.from_object.fig_id,
rel.to_object.fig_id,
label=rel.name,
type_=EdgeType.COMPOSITION,
)
# generate aggregations
for rel in diagram.get_relationships("aggregation"):
if rel.to_object.fig_id in associations[rel.from_object.fig_id]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ classDiagram
}
class P {
}
P --* A : x
P --* C : x
P --> A : x
P --* D : x
P --* E : x
P --o B : x
P --o C : x
10 changes: 5 additions & 5 deletions tests/pyreverse/functional/class_diagrams/aggregation/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,24 @@ class P:
pass

class A:
x: P
x: P # just type hint, no ownership, soassociation

class B:
def __init__(self, x: P):
self.x = x
self.x = x # not instantiated, so aggregation

class C:
x: P

def __init__(self, x: P):
self.x = x
self.x = x # not instantiated, so aggregation

class D:
x: P

def __init__(self):
self.x = P()
self.x = P() # instantiated, so composition

class E:
def __init__(self):
self.x = P()
self.x = P() # instantiated, so composition
Loading