-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathutils.py
56 lines (41 loc) · 1.45 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
from functools import partialmethod
from itertools import chain
from typing import Literal, overload
from tqdm import tqdm
from tti_eval.common import EmbeddingDefinition
from tti_eval.constants import PROJECT_PATHS
def disable_tqdm():
tqdm.__init__ = partialmethod(tqdm.__init__, disable=True)
def enable_tqdm():
tqdm.__init__ = partialmethod(tqdm.__init__, disable=False)
@overload
def read_all_cached_embeddings(as_list: Literal[True]) -> list[EmbeddingDefinition]:
...
@overload
def read_all_cached_embeddings(as_list: Literal[False] = False) -> dict[str, list[EmbeddingDefinition]]:
...
def read_all_cached_embeddings(
as_list: bool = False,
) -> dict[str, list[EmbeddingDefinition]] | list[EmbeddingDefinition]:
"""
Reads existing embedding definitions from the cache directory.
Returns: a dictionary of <dataset, [embeddings]> where the list is over models.
"""
if not PROJECT_PATHS.EMBEDDINGS.exists():
return dict()
defs_dict = {
d.name: list(
{
EmbeddingDefinition(dataset=d.name, model=m.stem.rsplit("_", maxsplit=1)[0])
for m in d.iterdir()
if m.is_file() and m.suffix == ".npz"
}
)
for d in PROJECT_PATHS.EMBEDDINGS.iterdir()
if d.is_dir()
}
if as_list:
return list(chain(*defs_dict.values()))
return defs_dict
if __name__ == "__main__":
print(read_all_cached_embeddings())