Skip to content

Commit 5f7e4d5

Browse files
authored
Fix documentation for nested enums (#351)
1 parent 1aaf772 commit 5f7e4d5

File tree

3 files changed

+54
-23
lines changed

3 files changed

+54
-23
lines changed

src/betterproto/plugin/parser.py

+17-23
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
1-
import itertools
21
import pathlib
32
import sys
43
from typing import (
5-
TYPE_CHECKING,
6-
Iterator,
4+
Generator,
75
List,
86
Set,
97
Tuple,
@@ -13,7 +11,6 @@
1311
from betterproto.lib.google.protobuf import (
1412
DescriptorProto,
1513
EnumDescriptorProto,
16-
FieldDescriptorProto,
1714
FileDescriptorProto,
1815
ServiceDescriptorProto,
1916
)
@@ -40,35 +37,32 @@
4037
)
4138

4239

43-
if TYPE_CHECKING:
44-
from google.protobuf.descriptor import Descriptor
45-
46-
4740
def traverse(
48-
proto_file: FieldDescriptorProto,
49-
) -> "itertools.chain[Tuple[Union[str, EnumDescriptorProto], List[int]]]":
41+
proto_file: FileDescriptorProto,
42+
) -> Generator[
43+
Tuple[Union[EnumDescriptorProto, DescriptorProto], List[int]], None, None
44+
]:
5045
# Todo: Keep information about nested hierarchy
5146
def _traverse(
52-
path: List[int], items: List["EnumDescriptorProto"], prefix=""
53-
) -> Iterator[Tuple[Union[str, EnumDescriptorProto], List[int]]]:
47+
path: List[int],
48+
items: Union[List[EnumDescriptorProto], List[DescriptorProto]],
49+
prefix: str = "",
50+
) -> Generator[
51+
Tuple[Union[EnumDescriptorProto, DescriptorProto], List[int]], None, None
52+
]:
5453
for i, item in enumerate(items):
5554
# Adjust the name since we flatten the hierarchy.
5655
# Todo: don't change the name, but include full name in returned tuple
5756
item.name = next_prefix = f"{prefix}_{item.name}"
58-
yield item, path + [i]
57+
yield item, [*path, i]
5958

6059
if isinstance(item, DescriptorProto):
61-
for enum in item.enum_type:
62-
enum.name = f"{next_prefix}_{enum.name}"
63-
yield enum, path + [i, 4]
60+
# Get nested types.
61+
yield from _traverse([*path, i, 4], item.enum_type, next_prefix)
62+
yield from _traverse([*path, i, 3], item.nested_type, next_prefix)
6463

65-
if item.nested_type:
66-
for n, p in _traverse(path + [i, 3], item.nested_type, next_prefix):
67-
yield n, p
68-
69-
return itertools.chain(
70-
_traverse([5], proto_file.enum_type), _traverse([4], proto_file.message_type)
71-
)
64+
yield from _traverse([5], proto_file.enum_type)
65+
yield from _traverse([4], proto_file.message_type)
7266

7367

7468
def generate_code(request: CodeGeneratorRequest) -> CodeGeneratorResponse:

tests/inputs/nestedtwice/nestedtwice.proto

+12
Original file line numberDiff line numberDiff line change
@@ -2,27 +2,39 @@ syntax = "proto3";
22

33
package nestedtwice;
44

5+
/* Test doc. */
56
message Test {
7+
/* Top doc. */
68
message Top {
9+
/* Middle doc. */
710
message Middle {
11+
/* TopMiddleBottom doc.*/
812
message TopMiddleBottom {
13+
// TopMiddleBottom.a doc.
914
string a = 1;
1015
}
16+
/* EnumBottom doc. */
1117
enum EnumBottom{
18+
/* EnumBottom.A doc. */
1219
A = 0;
1320
B = 1;
1421
}
22+
/* Bottom doc. */
1523
message Bottom {
24+
/* Bottom.foo doc. */
1625
string foo = 1;
1726
}
1827
reserved 1;
28+
/* Middle.bottom doc. */
1929
repeated Bottom bottom = 2;
2030
repeated EnumBottom enumBottom=3;
2131
repeated TopMiddleBottom topMiddleBottom=4;
2232
bool bar = 5;
2333
}
34+
/* Top.name doc. */
2435
string name = 1;
2536
Middle middle = 2;
2637
}
38+
/* Test.top doc. */
2739
Top top = 1;
2840
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import pytest
2+
3+
from tests.output_betterproto.nestedtwice import (
4+
Test,
5+
TestTop,
6+
TestTopMiddle,
7+
TestTopMiddleBottom,
8+
TestTopMiddleEnumBottom,
9+
TestTopMiddleTopMiddleBottom,
10+
)
11+
12+
13+
@pytest.mark.parametrize(
14+
("cls", "expected_comment"),
15+
[
16+
(Test, "Test doc."),
17+
(TestTopMiddleEnumBottom, "EnumBottom doc."),
18+
(TestTop, "Top doc."),
19+
(TestTopMiddle, "Middle doc."),
20+
(TestTopMiddleTopMiddleBottom, "TopMiddleBottom doc."),
21+
(TestTopMiddleBottom, "Bottom doc."),
22+
],
23+
)
24+
def test_comment(cls, expected_comment):
25+
assert cls.__doc__ == expected_comment

0 commit comments

Comments
 (0)