|
1 |
| -import itertools |
2 | 1 | import pathlib
|
3 | 2 | import sys
|
4 | 3 | from typing import (
|
5 |
| - TYPE_CHECKING, |
6 |
| - Iterator, |
| 4 | + Generator, |
7 | 5 | List,
|
8 | 6 | Set,
|
9 | 7 | Tuple,
|
|
13 | 11 | from betterproto.lib.google.protobuf import (
|
14 | 12 | DescriptorProto,
|
15 | 13 | EnumDescriptorProto,
|
16 |
| - FieldDescriptorProto, |
17 | 14 | FileDescriptorProto,
|
18 | 15 | ServiceDescriptorProto,
|
19 | 16 | )
|
|
40 | 37 | )
|
41 | 38 |
|
42 | 39 |
|
43 |
| -if TYPE_CHECKING: |
44 |
| - from google.protobuf.descriptor import Descriptor |
45 |
| - |
46 |
| - |
47 | 40 | def traverse(
|
48 |
| - proto_file: FieldDescriptorProto, |
49 |
| -) -> "itertools.chain[Tuple[Union[str, EnumDescriptorProto], List[int]]]": |
| 41 | + proto_file: FileDescriptorProto, |
| 42 | +) -> Generator[ |
| 43 | + Tuple[Union[EnumDescriptorProto, DescriptorProto], List[int]], None, None |
| 44 | +]: |
50 | 45 | # Todo: Keep information about nested hierarchy
|
51 | 46 | def _traverse(
|
52 |
| - path: List[int], items: List["EnumDescriptorProto"], prefix="" |
53 |
| - ) -> Iterator[Tuple[Union[str, EnumDescriptorProto], List[int]]]: |
| 47 | + path: List[int], |
| 48 | + items: Union[List[EnumDescriptorProto], List[DescriptorProto]], |
| 49 | + prefix: str = "", |
| 50 | + ) -> Generator[ |
| 51 | + Tuple[Union[EnumDescriptorProto, DescriptorProto], List[int]], None, None |
| 52 | + ]: |
54 | 53 | for i, item in enumerate(items):
|
55 | 54 | # Adjust the name since we flatten the hierarchy.
|
56 | 55 | # Todo: don't change the name, but include full name in returned tuple
|
57 | 56 | item.name = next_prefix = f"{prefix}_{item.name}"
|
58 |
| - yield item, path + [i] |
| 57 | + yield item, [*path, i] |
59 | 58 |
|
60 | 59 | if isinstance(item, DescriptorProto):
|
61 |
| - for enum in item.enum_type: |
62 |
| - enum.name = f"{next_prefix}_{enum.name}" |
63 |
| - yield enum, path + [i, 4] |
| 60 | + # Get nested types. |
| 61 | + yield from _traverse([*path, i, 4], item.enum_type, next_prefix) |
| 62 | + yield from _traverse([*path, i, 3], item.nested_type, next_prefix) |
64 | 63 |
|
65 |
| - if item.nested_type: |
66 |
| - for n, p in _traverse(path + [i, 3], item.nested_type, next_prefix): |
67 |
| - yield n, p |
68 |
| - |
69 |
| - return itertools.chain( |
70 |
| - _traverse([5], proto_file.enum_type), _traverse([4], proto_file.message_type) |
71 |
| - ) |
| 64 | + yield from _traverse([5], proto_file.enum_type) |
| 65 | + yield from _traverse([4], proto_file.message_type) |
72 | 66 |
|
73 | 67 |
|
74 | 68 | def generate_code(request: CodeGeneratorRequest) -> CodeGeneratorResponse:
|
|
0 commit comments