From c50f2a57fac7a9e0e8e4da33821d18ce41547beb Mon Sep 17 00:00:00 2001 From: Egor Spirin Date: Mon, 15 Nov 2021 20:55:39 +0300 Subject: [PATCH 1/2] Update requirements, use rich progress bar instead of tqdm --- code2seq/code2class_wrapper.py | 2 -- code2seq/code2seq_wrapper.py | 2 -- code2seq/data/path_context_data_module.py | 17 +++++++------- .../data/typed_path_context_data_module.py | 6 ++--- code2seq/model/code2seq.py | 22 +++++++++++++------ code2seq/typed_code2seq_wrapper.py | 2 -- code2seq/utils/common.py | 2 +- code2seq/utils/train.py | 18 +++++++-------- config/code2seq-java-med.yaml | 2 +- config/code2seq-java-test.yaml | 1 - requirements.txt | 6 ++--- setup.py | 8 +++---- 12 files changed, 43 insertions(+), 45 deletions(-) diff --git a/code2seq/code2class_wrapper.py b/code2seq/code2class_wrapper.py index 0fcb0de..c2bc422 100644 --- a/code2seq/code2class_wrapper.py +++ b/code2seq/code2class_wrapper.py @@ -27,8 +27,6 @@ def train_code2class(config: DictConfig): # Load data module data_module = PathContextDataModule(config.data_folder, config.data, is_class=True) - data_module.prepare_data() - data_module.setup() # Load model code2class = Code2Class(config.model, config.optimizer, data_module.vocabulary) diff --git a/code2seq/code2seq_wrapper.py b/code2seq/code2seq_wrapper.py index aafb08f..9ef3505 100644 --- a/code2seq/code2seq_wrapper.py +++ b/code2seq/code2seq_wrapper.py @@ -27,8 +27,6 @@ def train_code2seq(config: DictConfig): # Load data module data_module = PathContextDataModule(config.data_folder, config.data) - data_module.prepare_data() - data_module.setup() # Load model code2seq = Code2Seq(config.model, config.optimizer, data_module.vocabulary, config.train.teacher_forcing) diff --git a/code2seq/data/path_context_data_module.py b/code2seq/data/path_context_data_module.py index 55ebb0b..da4d2c9 100644 --- a/code2seq/data/path_context_data_module.py +++ b/code2seq/data/path_context_data_module.py @@ -18,8 +18,6 @@ class PathContextDataModule(LightningDataModule): _val = "val" _test = "test" - _vocabulary: Optional[Vocabulary] = None - def __init__(self, data_dir: str, config: DictConfig, is_class: bool = False): super().__init__() self._config = config @@ -27,6 +25,8 @@ def __init__(self, data_dir: str, config: DictConfig, is_class: bool = False): self._name = basename(data_dir) self._is_class = is_class + self._vocabulary = self.setup_vocabulary() + @property def vocabulary(self) -> Vocabulary: if self._vocabulary is None: @@ -41,14 +41,12 @@ def prepare_data(self): raise ValueError(f"Config doesn't contain url for, can't download it automatically") download_dataset(self._config.url, self._data_dir, self._name) - def setup(self, stage: Optional[str] = None): - if not exists(join(self._data_dir, Vocabulary.vocab_filename)): + def setup_vocabulary(self) -> Vocabulary: + vocabulary_path = join(self._data_dir, Vocabulary.vocab_filename) + if not exists(vocabulary_path): print("Can't find vocabulary, collect it from train holdout") build_from_scratch(join(self._data_dir, f"{self._train}.c2s"), Vocabulary) - vocabulary_path = join(self._data_dir, Vocabulary.vocab_filename) - self._vocabulary = Vocabulary( - vocabulary_path, self._config.labels_count, self._config.tokens_count, self._is_class - ) + return Vocabulary(vocabulary_path, self._config.labels_count, self._config.tokens_count, self._is_class) @staticmethod def collate_wrapper(batch: List[Optional[LabeledPathContext]]) -> BatchedLabeledPathContext: @@ -88,6 +86,9 @@ def val_dataloader(self, *args, **kwargs) -> DataLoader: def test_dataloader(self, *args, **kwargs) -> DataLoader: return self._shared_dataloader(self._test) + def predict_dataloader(self, *args, **kwargs) -> DataLoader: + return self.test_dataloader(*args, **kwargs) + def transfer_batch_to_device( self, batch: BatchedLabeledPathContext, device: torch.device, dataloader_idx: int ) -> BatchedLabeledPathContext: diff --git a/code2seq/data/typed_path_context_data_module.py b/code2seq/data/typed_path_context_data_module.py index 724046c..bbf9fcf 100644 --- a/code2seq/data/typed_path_context_data_module.py +++ b/code2seq/data/typed_path_context_data_module.py @@ -11,7 +11,7 @@ class TypedPathContextDataModule(PathContextDataModule): - _vocabulary: Optional[TypedVocabulary] = None + _vocabulary: TypedVocabulary def __init__(self, data_dir: str, config: DictConfig): super().__init__(data_dir, config) @@ -27,12 +27,12 @@ def _create_dataset(self, holdout_file: str, random_context: bool) -> TypedPathC raise RuntimeError(f"Setup vocabulary before creating data loaders") return TypedPathContextDataset(holdout_file, self._config, self._vocabulary, random_context) - def setup(self, stage: Optional[str] = None): + def setup_vocabulary(self) -> TypedVocabulary: if not exists(join(self._data_dir, TypedVocabulary.vocab_filename)): print("Can't find vocabulary, collect it from train holdout") build_from_scratch(join(self._data_dir, f"{self._train}.c2s"), TypedVocabulary) vocabulary_path = join(self._data_dir, TypedVocabulary.vocab_filename) - self._vocabulary = TypedVocabulary( + return TypedVocabulary( vocabulary_path, self._config.labels_count, self._config.tokens_count, self._config.types_count ) diff --git a/code2seq/model/code2seq.py b/code2seq/model/code2seq.py index bc919a2..603760d 100644 --- a/code2seq/model/code2seq.py +++ b/code2seq/model/code2seq.py @@ -3,6 +3,7 @@ import torch from commode_utils.losses import SequenceCrossEntropyLoss from commode_utils.metrics import SequentialF1Score, ClassificationMetrics +from commode_utils.metrics.chrF import ChrF from commode_utils.modules import LSTMDecoderStep, Decoder from omegaconf import DictConfig from pytorch_lightning import LightningModule @@ -41,6 +42,10 @@ def __init__( f"{holdout}_f1": SequentialF1Score(pad_idx=self.__pad_idx, eos_idx=eos_idx, ignore_idx=ignore_idx) for holdout in ["train", "val", "test"] } + id2label = {v: k for k, v in vocabulary.label_to_id.items()} + metrics.update( + {f"{holdout}_chrf": ChrF(id2label, ignore_idx + [self.__pad_idx, eos_idx]) for holdout in ["val", "test"]} + ) self.__metrics = MetricCollection(metrics) self._encoder = self._get_encoder(model_config) @@ -102,18 +107,18 @@ def _shared_step(self, batch: BatchedLabeledPathContext, step: str) -> Dict: target_sequence = batch.labels if step == "train" else None # [seq length; batch size; vocab size] logits, _ = self.logits_from_batch(batch, target_sequence) - loss = self.__loss(logits[1:], batch.labels[1:]) + result = {f"{step}/loss": self.__loss(logits[1:], batch.labels[1:])} with torch.no_grad(): prediction = logits.argmax(-1) metric: ClassificationMetrics = self.__metrics[f"{step}_f1"](prediction, batch.labels) + result.update( + {f"{step}/f1": metric.f1_score, f"{step}/precision": metric.precision, f"{step}/recall": metric.recall} + ) + if step != "train": + result[f"{step}/chrf"] = self.__metrics[f"{step}_chrf"](prediction, batch.labels) - return { - f"{step}/loss": loss, - f"{step}/f1": metric.f1_score, - f"{step}/precision": metric.precision, - f"{step}/recall": metric.recall, - } + return result def training_step(self, batch: BatchedLabeledPathContext, batch_idx: int) -> Dict: # type: ignore result = self._shared_step(batch, "train") @@ -143,6 +148,9 @@ def _shared_epoch_end(self, step_outputs: EPOCH_OUTPUT, step: str): f"{step}/recall": metric.recall, } self.__metrics[f"{step}_f1"].reset() + if step != "train": + log[f"{step}/chrf"] = self.__metrics[f"{step}_chrf"].compute() + self.__metrics[f"{step}_chrf"].reset() self.log_dict(log, on_step=False, on_epoch=True) def training_epoch_end(self, step_outputs: EPOCH_OUTPUT): diff --git a/code2seq/typed_code2seq_wrapper.py b/code2seq/typed_code2seq_wrapper.py index 812a5c1..d250e92 100644 --- a/code2seq/typed_code2seq_wrapper.py +++ b/code2seq/typed_code2seq_wrapper.py @@ -27,8 +27,6 @@ def train_typed_code2seq(config: DictConfig): # Load data module data_module = TypedPathContextDataModule(config.data_folder, config.data) - data_module.prepare_data() - data_module.setup() # Load model typed_code2seq = TypedCode2Seq(config.model, config.optimizer, data_module.vocabulary, config.train.teacher_forcing) diff --git a/code2seq/utils/common.py b/code2seq/utils/common.py index 5869f35..d9bc5a3 100644 --- a/code2seq/utils/common.py +++ b/code2seq/utils/common.py @@ -4,7 +4,7 @@ def filter_warnings(): # "The dataloader does not have many workers which may be a bottleneck." filterwarnings("ignore", category=UserWarning, module="pytorch_lightning.utilities.distributed", lineno=50) - filterwarnings("ignore", category=UserWarning, module="pytorch_lightning.trainer.data_loading", lineno=105) + filterwarnings("ignore", category=UserWarning, module="pytorch_lightning.trainer.data_loading", lineno=110) # "Please also save or load the state of the optimizer when saving or loading the scheduler." filterwarnings("ignore", category=UserWarning, module="torch.optim.lr_scheduler", lineno=216) # save filterwarnings("ignore", category=UserWarning, module="torch.optim.lr_scheduler", lineno=234) # load diff --git a/code2seq/utils/train.py b/code2seq/utils/train.py index a88a86a..814b4ff 100644 --- a/code2seq/utils/train.py +++ b/code2seq/utils/train.py @@ -1,8 +1,10 @@ +from os.path import join + import torch from commode_utils.callback import PrintEpochResultCallback, ModelCheckpointWithUpload from omegaconf import DictConfig, OmegaConf from pytorch_lightning import seed_everything, Trainer, LightningModule, LightningDataModule -from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor +from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor, RichProgressBar from pytorch_lightning.loggers import WandbLogger @@ -21,7 +23,7 @@ def train(model: LightningModule, data_module: LightningDataModule, config: Dict # define model checkpoint callback checkpoint_callback = ModelCheckpointWithUpload( - dirpath=wandb_logger.experiment.dir, + dirpath=join(wandb_logger.experiment.dir, "checkpoints"), filename="{epoch:02d}-val_loss={val/loss:.4f}", monitor="val/loss", every_n_epochs=params.save_every_epoch, @@ -36,6 +38,8 @@ def train(model: LightningModule, data_module: LightningDataModule, config: Dict gpu = 1 if torch.cuda.is_available() else None # define learning rate logger lr_logger = LearningRateMonitor("step") + # define progress bar callback + progress_bar = RichProgressBar(refresh_rate_per_second=config.progress_bar_refresh_rate) trainer = Trainer( max_epochs=params.n_epochs, gradient_clip_val=params.clip_norm, @@ -44,15 +48,9 @@ def train(model: LightningModule, data_module: LightningDataModule, config: Dict log_every_n_steps=params.log_every_n_steps, logger=wandb_logger, gpus=gpu, - progress_bar_refresh_rate=config.progress_bar_refresh_rate, - callbacks=[ - lr_logger, - early_stopping_callback, - checkpoint_callback, - print_epoch_result_callback, - ], + callbacks=[lr_logger, early_stopping_callback, checkpoint_callback, print_epoch_result_callback, progress_bar], resume_from_checkpoint=config.get("checkpoint", None), ) trainer.fit(model=model, datamodule=data_module) - trainer.test() + trainer.test(datamodule=data_module, ckpt_path="best") diff --git a/config/code2seq-java-med.yaml b/config/code2seq-java-med.yaml index 6f575c0..c425b5b 100644 --- a/config/code2seq-java-med.yaml +++ b/config/code2seq-java-med.yaml @@ -28,7 +28,7 @@ data: random_context: true batch_size: 512 - test_batch_size: 768 + test_batch_size: 512 model: # Encoder diff --git a/config/code2seq-java-test.yaml b/config/code2seq-java-test.yaml index bdcde62..371c2c1 100644 --- a/config/code2seq-java-test.yaml +++ b/config/code2seq-java-test.yaml @@ -3,7 +3,6 @@ data_folder: ../data/code2seq/java-test checkpoint: null seed: 7 -# Training in notebooks (e.g. Google Colab) may crash with too small value progress_bar_refresh_rate: 1 print_config: true diff --git a/requirements.txt b/requirements.txt index 96fc28e..abd2722 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ torch==1.10.0 -pytorch-lightning==1.4.9 -torchmetrics==0.5.1 +pytorch-lightning==1.5.1 +torchmetrics==0.6.0 tqdm==4.62.3 wandb==0.12.6 omegaconf==2.1.1 -commode-utils==0.3.12 +commode-utils==0.4.0 diff --git a/setup.py b/setup.py index be092a4..133825d 100644 --- a/setup.py +++ b/setup.py @@ -6,13 +6,11 @@ readme = readme_file.read() install_requires = [ - "torch>=1.9.0", - "pytorch-lightning~=1.4.2", - "torchmetrics~=0.5.0", - "tqdm~=4.62.1", + "torch>=1.10.0", + "pytorch-lightning~=1.5.0", "wandb~=0.12.0", "omegaconf~=2.1.1", - "commode-utils>=0.3.8", + "commode-utils>=0.4.0", ] setup_args = dict( From 2753108bf70d97a6596613b8775ee915d691b2f1 Mon Sep 17 00:00:00 2001 From: Egor Spirin Date: Mon, 15 Nov 2021 20:56:26 +0300 Subject: [PATCH 2/2] Update version --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 133825d..f22a424 100644 --- a/setup.py +++ b/setup.py @@ -1,6 +1,6 @@ from setuptools import setup, find_packages -VERSION = "1.1.1" +VERSION = "1.2.0" with open("README.md") as readme_file: readme = readme_file.read()