Skip to content

Commit c3ef647

Browse files
Merge pull request emrgnt-cmplxty#146 from maks-ivanov/feature/polish-search
Feature/polish search
2 parents 3f2996a + 8c17944 commit c3ef647

File tree

3 files changed

+45
-22
lines changed

3 files changed

+45
-22
lines changed

automata/tools/search/local_types.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import re
22
from dataclasses import dataclass
33
from enum import Enum
4-
from typing import Any, Dict, List, Optional
4+
from os import PathLike
5+
from typing import Any, Dict, List, Optional, Union
56

67
from automata.tools.search.scip_pb2 import Descriptor as DescriptorProto
78

@@ -161,6 +162,7 @@ def __eq__(self, other):
161162

162163
@dataclass
163164
class SymbolReference:
165+
symbol: Symbol
164166
line_number: int
165167
roles: Dict[str, Any]
166168

@@ -179,3 +181,6 @@ def __eq__(self, other):
179181
elif isinstance(other, str):
180182
return self.path == other
181183
return False
184+
185+
186+
StrPath = Union[str, PathLike]

automata/tools/search/main.py

+21-12
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import os
22
from argparse import ArgumentParser
3+
from typing import Dict, cast
34

45
# from automata.tools.search.call_graph import CallGraph
6+
# from automata.tools.search.local_types import Descriptor
57
from automata.tools.search.symbol_converter import SymbolConverter
68
from automata.tools.search.symbol_graph import SymbolGraph
79
from automata.tools.search.symbol_parser import parse_uri_to_symbol
@@ -25,17 +27,17 @@
2527
# Dump all available files in the symbol graph
2628
print("-" * 200)
2729
print("Fetching all available files in SymbolGraph")
28-
file_nodes = symbol_graph.get_all_files()
29-
for file_node in file_nodes:
30-
print("File >> %s" % (file_node))
30+
files = symbol_graph.get_all_files()
31+
for file in files:
32+
print("File Path >> %s" % (file.path))
3133
print("-" * 200)
3234

33-
# Dump all available symbols at the test path
35+
# Dump all available symbols defined along test path
3436
print("-" * 200)
35-
print("Fetching all available symbols along %s" % (test_path))
37+
print("Fetching all defined symbols along %s" % (test_path))
3638
available_symbols = symbol_graph.get_defined_symbols_along_path(test_path)
3739
for symbol in available_symbols:
38-
print("Available Symbol >> %s" % (symbol))
40+
print("Defined Symbol >> %s" % (symbol))
3941

4042
print("-" * 200)
4143

@@ -48,23 +50,30 @@
4850
# Find references of the test symbol
4951
print("-" * 200)
5052
print("Searching for references of the symbol %s" % (test_symbol))
51-
search_result = symbol_searcher.process_query("type:symbol %s" % (test_symbol.uri))
52-
print("References: ", search_result)
53+
search_result_0: Dict = cast(
54+
Dict, symbol_searcher.process_query("type:symbol %s" % (test_symbol.uri))
55+
)
56+
for file_path in search_result_0.keys():
57+
print("File Path >> %s" % (file_path))
58+
for reference in search_result_0[file_path]:
59+
print("Reference >> %s" % (reference))
5360
print("-" * 200)
5461

5562
# Find source code for the test symbol
5663
print("-" * 200)
5764
print("Searching for source code for symbol %s" % (test_symbol))
58-
search_result = symbol_searcher.process_query("type:source %s" % (test_symbol.uri))
59-
print("Source Code: ", search_result)
65+
search_result_1: str = cast(
66+
str, symbol_searcher.process_query("type:source %s" % (test_symbol.uri))
67+
)
68+
print("Source Code: ", search_result_1)
6069
print("-" * 200)
6170

6271
# Find exact matches for abbrievated test symbol
6372
print("-" * 200)
6473
abbv_test_symbol = "AutomataAgentConfig"
6574
print("Searching for exact matches of the filter %s" % (abbv_test_symbol))
66-
search_result = symbol_searcher.process_query("type:exact %s" % (abbv_test_symbol))
67-
print("Search result: ", search_result)
75+
search_result_2 = symbol_searcher.process_query("type:exact %s" % (abbv_test_symbol))
76+
print("Search result: ", search_result_2)
6877
print("-" * 200)
6978

7079
# Perform a find and replace on the test find symbol below

automata/tools/search/symbol_graph.py

+18-9
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import networkx as nx
55

6-
from automata.tools.search.local_types import File, Symbol, SymbolReference
6+
from automata.tools.search.local_types import File, StrPath, Symbol, SymbolReference
77
from automata.tools.search.scip_pb2 import Index, SymbolRole
88
from automata.tools.search.symbol_converter import SymbolConverter
99
from automata.tools.search.symbol_parser import parse_uri_to_symbol
@@ -13,7 +13,10 @@
1313

1414
class SymbolGraph:
1515
def __init__(
16-
self, index_path: str, symbol_converter: SymbolConverter, do_shortened_symbols: bool = True
16+
self,
17+
index_path: StrPath,
18+
symbol_converter: SymbolConverter,
19+
do_shortened_symbols: bool = True,
1720
):
1821
"""
1922
Initialize SymbolGraph with the path of an index protobuf file.
@@ -33,7 +36,11 @@ def get_all_files(self) -> List[File]:
3336
3437
:return: List of all file nodes.
3538
"""
36-
return [data for _, data in self._graph.nodes(data=True) if data.get("label") == "file"]
39+
return [
40+
data.get("file")
41+
for _, data in self._graph.nodes(data=True)
42+
if data.get("label") == "file"
43+
]
3744

3845
def get_all_defined_symbols(self) -> List[Symbol]:
3946
"""
@@ -54,9 +61,9 @@ def get_symbol_references(self, symbol: Symbol) -> Dict[str, List[SymbolReferenc
5461
:return: List of tuples (file, symbol details)
5562
"""
5663
search_results = [
57-
(file_path, symbol_reference)
58-
for file_path, symbol_reference, label in self._graph.out_edges(symbol, data=True)
59-
if label == "reference"
64+
(file_path, data.get("symbol_reference"))
65+
for file_path, _, data in self._graph.out_edges(symbol, data=True)
66+
if data.get("label") == "reference"
6067
]
6168
result_dict: Dict[str, List[SymbolReference]] = {}
6269

@@ -141,11 +148,12 @@ def _build_symbol_graph(self, index: Index) -> nx.MultiDiGraph:
141148
G = nx.MultiDiGraph()
142149
for document in index.documents:
143150
# Add File Vertices
144-
document_path = document.relative_path
151+
document_path: StrPath = document.relative_path
152+
145153
G.add_node(
146154
document_path,
147155
file=File(path=document.relative_path, occurrences=document.occurrences),
148-
label="file_path",
156+
label="file",
149157
)
150158

151159
for symbol_information in document.symbols:
@@ -167,6 +175,7 @@ def _build_symbol_graph(self, index: Index) -> nx.MultiDiGraph:
167175
occurrence_range = tuple(occurrence_information.range)
168176
occurrence_roles = self._get_symbol_roles_dict(occurrence_information.symbol_roles)
169177
occurrence_reference = SymbolReference(
178+
symbol=occurrence_symbol,
170179
line_number=occurrence_range[0],
171180
roles=occurrence_roles,
172181
)
@@ -195,7 +204,7 @@ def _get_symbol_roles_dict(role) -> Dict[str, bool]:
195204
return result
196205

197206
@staticmethod
198-
def _load_index_protobuf(path: str) -> Index:
207+
def _load_index_protobuf(path: StrPath) -> Index:
199208
"""
200209
Load an index from a protobuf file.
201210

0 commit comments

Comments
 (0)