diff --git a/code2seq/data/path_context.py b/code2seq/data/path_context.py index 4429d5e..0610f9e 100644 --- a/code2seq/data/path_context.py +++ b/code2seq/data/path_context.py @@ -1,37 +1,47 @@ from dataclasses import dataclass -from typing import Iterable, Tuple, Optional, Sequence +from typing import Iterable, Tuple, Optional, Sequence, List, cast import torch @dataclass class Path: - from_token: torch.Tensor # [max token parts] - path_node: torch.Tensor # [path length] - to_token: torch.Tensor # [max token parts] + from_token: List[int] # [max token parts] + path_node: List[int] # [path length] + to_token: List[int] # [max token parts] @dataclass class LabeledPathContext: - label: torch.Tensor # [max label parts] + label: List[int] # [max label parts] path_contexts: Sequence[Path] +def transpose(list_of_lists: List[List[int]]) -> List[List[int]]: + return [cast(List[int], it) for it in zip(*list_of_lists)] + + class BatchedLabeledPathContext: def __init__(self, all_samples: Sequence[Optional[LabeledPathContext]]): samples = [s for s in all_samples if s is not None] # [max label parts; batch size] - self.labels = torch.cat([s.label.unsqueeze(1) for s in samples], dim=1) + self.labels = torch.tensor(transpose([s.label for s in samples]), dtype=torch.long) # [batch size] self.contexts_per_label = torch.tensor([len(s.path_contexts) for s in samples]) # [max token parts; n contexts] - self.from_token = torch.cat([path.from_token.unsqueeze(1) for s in samples for path in s.path_contexts], dim=1) + self.from_token = torch.tensor( + transpose([path.from_token for s in samples for path in s.path_contexts]), dtype=torch.long + ) # [path length; n contexts] - self.path_nodes = torch.cat([path.path_node.unsqueeze(1) for s in samples for path in s.path_contexts], dim=1) + self.path_nodes = torch.tensor( + transpose([path.path_node for s in samples for path in s.path_contexts]), dtype=torch.long + ) # [max token parts; n contexts] - self.to_token = torch.cat([path.to_token.unsqueeze(1) for s in samples for path in s.path_contexts], dim=1) + self.to_token = torch.tensor( + transpose([path.to_token for s in samples for path in s.path_contexts]), dtype=torch.long + ) def __len__(self) -> int: return len(self.contexts_per_label) @@ -53,8 +63,8 @@ def move_to_device(self, device: torch.device): @dataclass class TypedPath(Path): - from_type: torch.Tensor # [max type parts] - to_type: torch.Tensor # [max type parts] + from_type: List[int] # [max type parts] + to_type: List[int] # [max type parts] @dataclass @@ -67,6 +77,10 @@ def __init__(self, all_samples: Sequence[Optional[LabeledTypedPathContext]]): super().__init__(all_samples) samples = [s for s in all_samples if s is not None] # [max type parts; n contexts] - self.from_type = torch.cat([path.from_type.unsqueeze(1) for s in samples for path in s.path_contexts], dim=1) + self.from_type = torch.tensor( + transpose([path.from_type for s in samples for path in s.path_contexts]), dtype=torch.long + ) # [max type parts; n contexts] - self.to_type = torch.cat([path.to_type.unsqueeze(1) for s in samples for path in s.path_contexts], dim=1) + self.to_type = torch.tensor( + transpose([path.to_type for s in samples for path in s.path_contexts]), dtype=torch.long + ) diff --git a/code2seq/data/path_context_dataset.py b/code2seq/data/path_context_dataset.py index 45b4472..62fc67f 100644 --- a/code2seq/data/path_context_dataset.py +++ b/code2seq/data/path_context_dataset.py @@ -2,7 +2,6 @@ from random import shuffle from typing import Dict, List, Optional -import torch from commode_utils.filesystem import get_lines_offsets, get_line_by_offset from omegaconf import DictConfig from torch.utils.data import Dataset @@ -63,34 +62,29 @@ def __getitem__(self, index) -> Optional[LabeledPathContext]: return LabeledPathContext(label, paths) @staticmethod - def tokenize_class(raw_class: str, vocab: Dict[str, int]) -> torch.Tensor: - return torch.tensor([vocab[raw_class]], dtype=torch.long) + def tokenize_class(raw_class: str, vocab: Dict[str, int]) -> List[int]: + return [vocab[raw_class]] @staticmethod - def tokenize_label(raw_label: str, vocab: Dict[str, int], max_parts: Optional[int]) -> torch.Tensor: + def tokenize_label(raw_label: str, vocab: Dict[str, int], max_parts: Optional[int]) -> List[int]: sublabels = raw_label.split(PathContextDataset._separator) max_parts = max_parts or len(sublabels) label_unk = vocab[Vocabulary.UNK] - label = torch.full((max_parts + 1,), vocab[Vocabulary.PAD], dtype=torch.long) - label[0] = vocab[Vocabulary.SOS] - sub_tokens_ids = [vocab.get(st, label_unk) for st in sublabels[:max_parts]] - label[1 : len(sub_tokens_ids) + 1] = torch.tensor(sub_tokens_ids) - + label = [vocab[Vocabulary.SOS]] + [vocab.get(st, label_unk) for st in sublabels[:max_parts]] if len(sublabels) < max_parts: - label[len(sublabels) + 1] = vocab[Vocabulary.EOS] - + label.append(vocab[Vocabulary.EOS]) + label += [vocab[Vocabulary.PAD]] * (max_parts + 1 - len(label)) return label @staticmethod - def tokenize_token(token: str, vocab: Dict[str, int], max_parts: Optional[int]) -> torch.Tensor: + def tokenize_token(token: str, vocab: Dict[str, int], max_parts: Optional[int]) -> List[int]: sub_tokens = token.split(PathContextDataset._separator) max_parts = max_parts or len(sub_tokens) token_unk = vocab[Vocabulary.UNK] - result = torch.full((max_parts,), vocab[Vocabulary.PAD], dtype=torch.long) - sub_tokens_ids = [vocab.get(st, token_unk) for st in sub_tokens[:max_parts]] - result[: len(sub_tokens_ids)] = torch.tensor(sub_tokens_ids) + result = [vocab.get(st, token_unk) for st in sub_tokens[:max_parts]] + result += [vocab[Vocabulary.PAD]] * (max_parts - len(result)) return result def _get_path(self, raw_path: List[str]) -> Path: diff --git a/setup.py b/setup.py index 8817d81..b3b979b 100644 --- a/setup.py +++ b/setup.py @@ -1,6 +1,6 @@ from setuptools import setup, find_packages -VERSION = "1.0.0" +VERSION = "1.0.1" with open("README.md") as readme_file: readme = readme_file.read() diff --git a/tests/test_tokenization.py b/tests/test_tokenization.py index 8a9d7f6..e2cd5c3 100644 --- a/tests/test_tokenization.py +++ b/tests/test_tokenization.py @@ -13,23 +13,23 @@ def test_tokenize_label(self): raw_label = "my|super|label" tokenized = PathContextDataset.tokenize_label(raw_label, self.vocab, 5) # my super - correct = torch.tensor([2, 4, 5, 1, 3, 0], dtype=torch.long) + correct = [2, 4, 5, 1, 3, 0] - torch.testing.assert_equal(tokenized, correct) + self.assertListEqual(tokenized, correct) def test_tokenize_class(self): raw_class = "super" tokenized = PathContextDataset.tokenize_class(raw_class, self.vocab) - correct = torch.tensor([5], dtype=torch.long) + correct = [5] - torch.testing.assert_equal(tokenized, correct) + self.assertListEqual(tokenized, correct) def test_tokenize_token(self): raw_token = "my|super|token" tokenized = PathContextDataset.tokenize_token(raw_token, self.vocab, 5) - correct = torch.tensor([4, 5, 1, 0, 0], dtype=torch.long) + correct = [4, 5, 1, 0, 0] - torch.testing.assert_equal(tokenized, correct) + self.assertListEqual(tokenized, correct) if __name__ == "__main__":