From 27181f8fc08673d4b3997fc0f12276d0ee8bafac Mon Sep 17 00:00:00 2001 From: umang Date: Wed, 26 May 2021 02:11:35 -0700 Subject: [PATCH] add missing lib folder --- .gitignore | 4 +- lib/.gitignore | 108 +++++++++++++ lib/README.md | 2 + lib/base_model.py | 31 ++++ lib/base_trainer.py | 340 +++++++++++++++++++++++++++++++++++++++ lib/data/__init__.py | 0 lib/data/mnist.py | 35 ++++ lib/data/util.py | 75 +++++++++ lib/standard_nn.py | 117 ++++++++++++++ lib/utils/__init__.py | 0 lib/utils/logging.py | 36 +++++ lib/utils/math.py | 18 +++ lib/utils/optimizer.py | 50 ++++++ lib/utils/os.py | 323 +++++++++++++++++++++++++++++++++++++ lib/utils/samplers.py | 51 ++++++ lib/utils/torch_utils.py | 99 ++++++++++++ 16 files changed, 1287 insertions(+), 2 deletions(-) create mode 100644 lib/.gitignore create mode 100644 lib/README.md create mode 100644 lib/base_model.py create mode 100644 lib/base_trainer.py create mode 100644 lib/data/__init__.py create mode 100644 lib/data/mnist.py create mode 100644 lib/data/util.py create mode 100644 lib/standard_nn.py create mode 100644 lib/utils/__init__.py create mode 100644 lib/utils/logging.py create mode 100644 lib/utils/math.py create mode 100644 lib/utils/optimizer.py create mode 100644 lib/utils/os.py create mode 100644 lib/utils/samplers.py create mode 100644 lib/utils/torch_utils.py diff --git a/.gitignore b/.gitignore index adaccda..9462165 100644 --- a/.gitignore +++ b/.gitignore @@ -14,7 +14,7 @@ dist/ downloads/ eggs/ .eggs/ -lib/ + lib64/ parts/ sdist/ @@ -108,4 +108,4 @@ results .idea -result/ \ No newline at end of file +result/ diff --git a/lib/.gitignore b/lib/.gitignore new file mode 100644 index 0000000..1fa2818 --- /dev/null +++ b/lib/.gitignore @@ -0,0 +1,108 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# pyenv +.python-version + +# celery beat schedule file +celerybeat-schedule + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ + +.vscode +results +.idea diff --git a/lib/README.md b/lib/README.md new file mode 100644 index 0000000..05b31d2 --- /dev/null +++ b/lib/README.md @@ -0,0 +1,2 @@ +# train-lib +Common code for different projects diff --git a/lib/base_model.py b/lib/base_model.py new file mode 100644 index 0000000..6f30ed7 --- /dev/null +++ b/lib/base_model.py @@ -0,0 +1,31 @@ +""" base model""" +import logging + +import numpy as np +import torch.nn as nn + +logger = logging.getLogger() + + +class Base(nn.Module): + """ Base model with some util functions""" + + def stats(self, print_model=True): + # print network model and information about parameters + logger.info("Model info:::") + if print_model: + logger.info(self) + count = 0 + for i in self.parameters(): + count += np.prod(i.shape) + logger.info(f"Total parameters : {count}") + + def to(self, *args, **kwargs): + if kwargs.get("device"): + self.device = kwargs.get("device") + if len(args) > 0: + self.device = args[0] + return super().to(*args, **kwargs) + + def forward(self, x): + raise NotImplementedError() diff --git a/lib/base_trainer.py b/lib/base_trainer.py new file mode 100644 index 0000000..31e1d9f --- /dev/null +++ b/lib/base_trainer.py @@ -0,0 +1,340 @@ +"""trainer code""" +import copy +import logging +import os +from typing import List, Dict, Optional, Callable, Union + +import dill +import numpy as np +import torch +from torch.utils.tensorboard import SummaryWriter + +from lib.utils.logging import loss_logger_helper + +logger = logging.getLogger() + + +class Trainer: + # This is like skorch but instead of callbacks we use class functions (looks less magic) + # this is an evolving template + def __init__( + self, + model: torch.nn.Module, + optimizer: torch.optim, + scheduler: torch.optim.lr_scheduler, + result_dir: Optional[str], + statefile: Optional[str] = None, + log_every: int = 100, + save_strategy: Optional[List] = None, + patience: int = 20, + max_epoch: int = 100, + gradient_norm_clip=-1, + stopping_criteria_direction: str = "bigger", + stopping_criteria: Optional[Union[str, Callable]] = "accuracy", + evaluations=None, + **kwargs, + ): + """ + stopping_criteria : can be a function, string or none. If string it should match one + of the keys in aux_loss or should be loss, if none we don't invoke early stopping + """ + super().__init__() + + self.result_dir = result_dir + self.model = model + self.optimizer = optimizer + self.scheduler = scheduler + self.evaluations = evaluations + self.gradient_norm_clip = gradient_norm_clip + + # training state related params + self.epoch = 0 + self.step = 0 + self.best_criteria = None + self.best_epoch = -1 + + # config related param + self.log_every = log_every + self.save_strategy = save_strategy + self.patience = patience + self.max_epoch = max_epoch + self.stopping_criteria_direction = stopping_criteria_direction + self.stopping_criteria = stopping_criteria + + # TODO: should save config and see if things have changed? + if statefile is not None: + self.load(statefile) + + # init best model + self.best_model = self.model.state_dict() + + # logging stuff + if result_dir is not None: + # we do not need to purge. Purging can delete the validation result + self.summary_writer = SummaryWriter(log_dir=result_dir) + + def load(self, fname: str) -> Dict: + """ + fname: file name to load data from + """ + + data = torch.load(open(fname, "rb"), pickle_module=dill, map_location=self.model.device) + + if getattr(self, "model", None) and data.get("model") is not None: + state_dict = self.model.state_dict() + state_dict.update(data["model"]) + self.model.load_state_dict(state_dict) + + if getattr(self, "optimizer", None) and data.get("optimizer") is not None: + optimizer_dict = self.optimizer.state_dict() + optimizer_dict.update(data["optimizer"]) + self.optimizer.load_state_dict(optimizer_dict) + + if getattr(self, "scheduler", None) and data.get("scheduler") is not None: + scheduler_dict = self.scheduler.state_dict() + scheduler_dict.update(data["scheduler"]) + self.scheduler.load_state_dict(scheduler_dict) + + self.epoch = data["epoch"] + self.step = data["step"] + self.best_criteria = data["best_criteria"] + self.best_epoch = data["best_epoch"] + return data + + def save(self, fname: str, **kwargs): + """ + fname: file name to save to + kwargs: more arguments that we may want to save. + + By default we + - save, + - model, + - optimizer, + - epoch, + - step, + - best_criteria, + - best_epoch + """ + # NOTE: Best model is maintained but is saved automatically depending on save strategy, + # So that It could be loaded outside of the training process + kwargs.update({ + "model" : self.model.state_dict(), + "optimizer" : self.optimizer.state_dict(), + "epoch" : self.epoch, + "step" : self.step, + "best_criteria": self.best_criteria, + "best_epoch" : self.best_epoch, + }) + + if self.scheduler is not None: + kwargs.update({"scheduler": self.scheduler.state_dict()}) + + torch.save(kwargs, open(fname, "wb"), pickle_module=dill) + + # todo : allow to extract predictions + def run_iteration(self, batch, training: bool = True, reduce: bool = True): + """ + batch : batch of data, directly passed to model as is + training: if training set to true else false + reduce: whether to compute loss mean or return the raw vector form + """ + pred = self.model(batch) + loss, aux_loss = self.model.loss(pred, batch, reduce=reduce) + + if training: + loss.backward() + if self.gradient_norm_clip > 0: + torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.gradient_norm_clip) + self.optimizer.step() + self.optimizer.zero_grad() + + return loss, aux_loss + + def compute_criteria(self, loss, aux_loss): + stopping_criteria = self.stopping_criteria + if stopping_criteria is None: + return loss + + if callable(stopping_criteria): + return stopping_criteria(loss, aux_loss) + + if stopping_criteria == "loss": + return loss + + if aux_loss.get(stopping_criteria) is not None: + return aux_loss[stopping_criteria] + + raise Exception(f"{stopping_criteria} not found") + + def train_batch(self, batch, *args, **kwargs): + # This trains the batch + loss, aux_loss = self.run_iteration(batch, training=True, reduce=True) + loss_logger_helper(loss, aux_loss, writer=self.summary_writer, step=self.step, + epoch=self.epoch, + log_every=self.log_every, string="train") + + def train_epoch(self, train_loader, *args, **kwargs): + # This trains the epoch and also calls on batch begin and on batch end + # before and after calling train_batch respectively + self.model.train() + for i, batch in enumerate(train_loader): + self.on_batch_begin(i, batch, *args, **kwargs) + self.train_batch(batch, *args, **kwargs) + self.on_batch_end(i, batch, *args, **kwargs) + self.step += 1 + self.model.eval() + + def on_train_begin(self, train_loader, valid_loader, *args, **kwargs): + # this could be used to add things to class object like scheduler etc + if "init" in self.save_strategy: + if self.epoch == 0: + self.save(f"{self.result_dir}/init_model.pt") + + def on_epoch_begin(self, train_loader, valid_loader, *args, **kwargs): + # This is called when epoch begins + pass + + def on_batch_begin(self, epoch_step, batch, *args, **kwargs): + # This is called when batch begins + pass + + def on_train_end(self, train_loader, valid_loader, *args, **kwargs): + # Called when training finishes. For base trainer we just save the last model + if "last" in self.save_strategy: + logger.info("Saving the last model") + self.save(f"{self.result_dir}/last_model.pt") + + def on_epoch_end(self, train_loader, valid_loader, *args, **kwargs): + # called when epoch ends + # we call validation, scheduler here + # also check if we have a new best model and save model if needed + + # call validate + loss, aux_loss = self.validate(train_loader, valid_loader, *args, **kwargs) + loss_logger_helper(loss, aux_loss, writer=self.summary_writer, step=self.step, + epoch=self.epoch, log_every=self.log_every, string="val", + force_print=True) + + # do scheduler step + if self.scheduler is not None: + prev_lr = [group['lr'] for group in self.optimizer.param_groups] + if isinstance(self.scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): + criteria = self.compute_criteria(loss, aux_loss) + self.scheduler.step(criteria) + else: + self.scheduler.step() + new_lr = [group['lr'] for group in self.optimizer.param_groups] + + # if you don't pass a criteria, it won't be computed and best model won't be saved. + # on the contrary if you pass a stopping criteria, best model would be saved. + # You can pass a large patience to get rid of early stopping + if self.stopping_criteria is not None: + criteria = self.compute_criteria(loss, aux_loss) + + if ( + (self.best_criteria is None) + or ( + self.stopping_criteria_direction == "bigger" and self.best_criteria < criteria) + or ( + self.stopping_criteria_direction == "lower" and self.best_criteria > criteria) + ): + self.best_criteria = criteria + self.best_epoch = self.epoch + self.best_model = copy.deepcopy( + {k: v.cpu() for k, v in self.model.state_dict().items()}) + + if "best" in self.save_strategy: + logger.info(f"Saving best model at epoch {self.epoch}") + self.save(f"{self.result_dir}/best_model.pt") + + if "epoch" in self.save_strategy: + logger.info(f"Saving model at epoch {self.epoch}") + self.save(f"{self.result_dir}/{self.epoch}_model.pt") + + if "current" in self.save_strategy: + logger.info(f"Saving model at epoch {self.epoch}") + self.save(f"{self.result_dir}/current_model.pt") + + # logic to load best model on reduce lr + if self.scheduler is not None and not (all(a == b for (a, b) in zip(prev_lr, new_lr))): + if getattr(self.scheduler, 'load_on_reduce', None) == "best": + logger.info(f"Loading best model at epoch {self.epoch}") + # we want to preserve the scheduler + old_lrs = list(map(lambda x: x['lr'], self.optimizer.param_groups)) + old_scheduler_dict = copy.deepcopy(self.scheduler.state_dict()) + + best_model_path = None + if os.path.exists(f"{self.result_dir}/best_model.pt"): + best_model_path = f"{self.result_dir}/best_model.pt" + else: + d = "/".join(self.result_dir.split("/")[:-1]) + for directory in os.listdir(d): + if os.path.exists(f"{d}/{directory}/best_model.pt"): + best_model_path = self.load(f"{d}/{directory}/best_model.pt") + + if best_model_path is None: + raise FileNotFoundError( + f"Best Model not found in {self.result_dir}, please copy if it exists in " + f"other folder") + + self.load(best_model_path) + # override scheduler to keep old one and also keep reduced learning rates + self.scheduler.load_state_dict(old_scheduler_dict) + for idx, lr in enumerate(old_lrs): + self.optimizer.param_groups[idx]['lr'] = lr + logger.info(f"loaded best model and restarting from end of {self.epoch}") + + def on_batch_end(self, epoch_step, batch, *args, **kwargs): + # called after a batch is trained + pass + + def train(self, train_loader, valid_loader, *args, **kwargs): + + self.on_train_begin(train_loader, valid_loader, *args, **kwargs) + while self.epoch < self.max_epoch: + # NOTE: +1 here is more convenient, as now we don't need to do +1 before saving model + # If we don't do +1 before saving model, we will have to redo the last epoch + # So +1 here makes life easy, if we load model at end of e epoch, we will load model + # and start with e+1... smooth + self.epoch += 1 + self.on_epoch_begin(train_loader, valid_loader, *args, **kwargs) + logger.info(f"Starting epoch {self.epoch}") + self.train_epoch(train_loader, *args, **kwargs) + self.on_epoch_end(train_loader, valid_loader, *args, **kwargs) + + if self.epoch - self.best_epoch > self.patience: + logger.info(f"Patience reached stopping training after {self.epoch} epochs") + break + + self.on_train_end(train_loader, valid_loader, *args, **kwargs) + + def validate(self, train_loader, valid_loader, *args, **kwargs): + """ + we expect validate to return mean and other aux losses that we want to log + """ + losses = [] + aux_losses = {} + + self.model.eval() + with torch.no_grad(): + for i, batch in enumerate(valid_loader): + loss, aux_loss = self.run_iteration(batch, training=False, reduce=False) + losses.extend(loss.cpu().tolist()) + + if i == 0: + for k, v in aux_loss.items(): + # when we can't return sample wise statistics, we need to do this + if len(v.shape) == 0: + aux_losses[k] = [v.cpu().tolist()] + else: + aux_losses[k] = v.cpu().tolist() + else: + for k, v in aux_loss.items(): + if len(v.shape) == 0: + aux_losses[k].append(v.cpu().tolist()) + else: + aux_losses[k].extend(v.cpu().tolist()) + return np.mean(losses), {k: np.mean(v) for (k, v) in aux_losses.items()} + + def test(self, train_loader, test_loader, *args, **kwargs): + return self.validate(train_loader, test_loader, *args, **kwargs) diff --git a/lib/data/__init__.py b/lib/data/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/lib/data/mnist.py b/lib/data/mnist.py new file mode 100644 index 0000000..8a2c65f --- /dev/null +++ b/lib/data/mnist.py @@ -0,0 +1,35 @@ +import torch +from torchvision.datasets import MNIST + +from lib.data.util import DATA_FOLDER + + +class Mnist(MNIST): + def __init__(self, root=f"{DATA_FOLDER}/mnist", train=True, transform=None, + target_transform=None, download=True, init_transform=None, + init_target_transform=None, seed=None, fraction=1.0): + super().__init__(root, train=train, transform=transform, target_transform=target_transform, + download=download) + + if seed is not None: + rng_state = torch.get_rng_state() + torch.manual_seed(seed) + + N = len(self.data) + n = None + + if 0 < fraction < 1.0: + n = int(N * fraction) + elif N > fraction > 1: + n = int(fraction) + if n: + indices = torch.randperm(N)[:n] + self.data, self.targets = self.data[indices], self.targets[indices] + + if init_transform: + self.data = self.data = init_transform(self.data) + if init_target_transform: + self.targets = init_target_transform(self.targets) + + if seed is not None: + torch.set_rng_state(rng_state) diff --git a/lib/data/util.py b/lib/data/util.py new file mode 100644 index 0000000..3323646 --- /dev/null +++ b/lib/data/util.py @@ -0,0 +1,75 @@ +""" general utility function for data: mostly transformations """ + +import logging +import os +import random + +import numpy +import numpy as np +from PIL import ImageFilter + +logger = logging.getLogger() +DATA_FOLDER = os.getenv("DATA") if os.getenv("DATA") else "data" + + +def uniform_label_noise(p, labels, seed=None): + if seed is not None: + rng_state = numpy.random.get_state() + numpy.random.seed(seed) + + labels = numpy.array(labels.tolist()) + N = len(labels) + lst = numpy.unique(labels) + + # generate random labels + rnd_labels = numpy.random.choice(lst, size=N) + + flip = numpy.random.rand(N) <= p + labels = labels * (1 - flip) + rnd_labels * flip + + if seed is not None: + numpy.random.set_state(rng_state) + + return labels + + +class GaussianBlur(object): + """Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709""" + + def __init__(self, sigma=None): + if sigma is None: + sigma = [0.1, 2.0] + self.sigma = sigma + + def __call__(self, x): + sigma = random.uniform(self.sigma[0], self.sigma[1]) + x = x.filter(ImageFilter.GaussianBlur(radius=sigma)) + return x + + +def lines_to_np_array(lines): + return np.array([[int(i) for i in line.split()] for line in lines]) + + +def load_binary_mnist(): + with open(os.path.join(DATA_FOLDER, "binary-mnist", "binarized_mnist_train.amat")) as f: + lines = f.readlines() + train_data = lines_to_np_array(lines).astype("float32") + with open(os.path.join(DATA_FOLDER, "binary-mnist", "binarized_mnist_valid.amat")) as f: + lines = f.readlines() + validation_data = lines_to_np_array(lines).astype("float32") + with open(os.path.join(DATA_FOLDER, "binary-mnist", "binarized_mnist_test.amat")) as f: + lines = f.readlines() + test_data = lines_to_np_array(lines).astype("float32") + + return {"train": train_data, "valid": validation_data, "test": test_data} + + +def load_mnist(): + import gzip + import _pickle + + train, valid, test = _pickle.load( + gzip.open(os.path.join(DATA_FOLDER, "mnist", "mnist.pkl.gz")), encoding="latin1", + ) + return {"train": train, "valid": valid, "test": test} diff --git a/lib/standard_nn.py b/lib/standard_nn.py new file mode 100644 index 0000000..7b87f53 --- /dev/null +++ b/lib/standard_nn.py @@ -0,0 +1,117 @@ +""" +Utility to create simple sequential networks for classification or regression + +Create feed forward network with different `hidden_sizes` + +Create convolution networks with different `channels` (hidden_size), (2,2) max pooling +""" +import typing + +import numpy +import torch.nn as nn + +from lib.utils.torch_utils import Reshape, infer_shape + + +# TODO : extend to cover pooling sizes, strides etc for conv nets +def get_arch(input_shape: typing.Union[numpy.array, typing.List], output_size: int, + feed_forward: bool, hidden_sizes: typing.List[int], + kernel_size: typing.Union[typing.List[int], int] = 3, + non_linearity: typing.Union[typing.List[str], str, None] = "relu", + norm: typing.Union[typing.List[str], str, None] = None, + pooling: typing.Union[typing.List[str], str, None] = None) -> nn.Module: + + # general assertions + n_layers = len(hidden_sizes) + if n_layers > 0: + if isinstance(non_linearity, list): + assert len(non_linearity) == n_layers, "non linearity list is not same as hidden size" + non_linearities = non_linearity + else: + non_linearities = [non_linearity] * n_layers + + if isinstance(norm, list): + assert len(norm) == n_layers, "norm list is not same as hidden size" + norms = norm + else: + norms = [norm] * n_layers + else: + norms = [] + non_linearities = [] + + modules = [] + + if feed_forward: + modules.append(Reshape()) + insize = int(numpy.prod(input_shape)) + + for nl, no, outsize in zip(non_linearities, norms, hidden_sizes): + modules.append(nn.Linear(insize, outsize)) + + if nl == "relu": + modules.append(nn.ReLU()) + elif nl is None: + pass + else: + raise Exception(f"non-linearity {nl} not implemented") + + if no == "bn": + modules.append(nn.BatchNorm1d(outsize)) + elif no is None: + pass + else: + raise Exception(f"norm {no} is not implemented") + + insize = outsize + + modules.append(nn.Linear(insize, output_size)) + return {"net" : nn.Sequential(*modules)} + + # assertion specific to convolutions + assert n_layers >= 1, "Number of layers has to be more than 1 for convolution" + if isinstance(kernel_size, list): + assert len(kernel_size) == n_layers, "kernel size is not same as hidden size" + kernel_sizes = kernel_size + else: + kernel_sizes = [kernel_size] * n_layers + + if isinstance(pooling, list): + assert len(pooling) == n_layers, "pooling size is not same as hidden size" + poolings = pooling + else: + poolings = [pooling] * n_layers + + # convolutional layer with 3x3 convolutions + inchannel = input_shape[0] + for nl, no, outchannel, k, p in zip(non_linearities, norms, hidden_sizes, kernel_sizes, + poolings): + modules.append(nn.Conv2d(inchannel, outchannel, kernel_size=k)) + + if nl == "relu": + modules.append(nn.ReLU()) + elif nl is None: + pass + else: + raise Exception(f"non-linearity {nl} is not implemented") + + if no == "bn": + modules.append(nn.BatchNorm2d(outchannel)) + elif no is None: + pass + else: + raise Exception(f"norm {no} is not implemented") + + if p == "max_pool": + modules.append(nn.MaxPool2d(2)) + + elif p is None: + pass + else: + raise Exception(f"pooling {p} is not implemented") + + inchannel = outchannel + + output_shape = infer_shape(nn.Sequential(*modules).to("cpu"), input_shape) + modules.append(Reshape()) + modules.append(nn.Linear(int(numpy.prod(output_shape)), output_size)) + return {"net" : nn.Sequential(*modules)} diff --git a/lib/utils/__init__.py b/lib/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/lib/utils/logging.py b/lib/utils/logging.py new file mode 100644 index 0000000..f97268e --- /dev/null +++ b/lib/utils/logging.py @@ -0,0 +1,36 @@ +""" logging related functionality """ + + +import logging + +from torch.utils.tensorboard import SummaryWriter + +logger = logging.getLogger() + + +def print_verbose(string, verbose): + if verbose: + print(string) + + +def loss_logger_helper( + loss, aux_loss, writer: SummaryWriter, step: int, epoch: int, log_every: int, + string: str = "train", force_print: bool = False, new_line: bool = False +): + # write to tensorboard at every step but only print at log step or when force_print is passed + writer.add_scalar(f"{string}/loss", loss, step) + for k, v in aux_loss.items(): + writer.add_scalar(f"{string}/" + k, v, step) + + if step % log_every == 0 or force_print: + logger.info(f"{string}/loss: ({step}/{epoch}) {loss}") + + if force_print: + if new_line: + for k, v in aux_loss.items(): + logger.info(f"{string}/{k}:{v} ") + else: + str_ = "" + for k, v in aux_loss.items(): + str_ += f"{string}/{k}:{v} " + logger.info(f"{str_}") diff --git a/lib/utils/math.py b/lib/utils/math.py new file mode 100644 index 0000000..25fe860 --- /dev/null +++ b/lib/utils/math.py @@ -0,0 +1,18 @@ +""" Mathematical formulae for different expressions""" + +import torch + +from .torch_utils import EPSILON + + +def echo_mi(f, s): + N = s.shape[0] + s = s.view(N, -1) + return -torch.log(torch.abs(s) + EPSILON).sum(dim=1) + + +def get_echo_clip_factor(num_samples): + max_fx = 1 + d_max = num_samples + + return (2 ** (-23) / max_fx) ** (1.0 / d_max) diff --git a/lib/utils/optimizer.py b/lib/utils/optimizer.py new file mode 100644 index 0000000..811ed40 --- /dev/null +++ b/lib/utils/optimizer.py @@ -0,0 +1,50 @@ +from torch import optim, nn + + +def get_optimizer_scheduler(model, optimizer="adam", lr=1e-3, opt_params=None, scheduler=None, + scheduler_params=None): + """ + scheduler_params: + load_on_reduce : best/last/None (if best we load the best model in training so far) + (for this to work, you should save the best model during training) + """ + if scheduler_params is None: + scheduler_params = {} + if opt_params is None: + opt_params = {} + + if isinstance(model, nn.Module): + params = model.parameters() + else: + params = model + if optimizer == "adam": + optimizer = optim.Adam(params, lr=lr, weight_decay=opt_params["weight_decay"]) + elif optimizer == "sgd": + optimizer = optim.SGD(params, lr=lr, weight_decay=opt_params["weight_decay"], + momentum=opt_params["momentum"], nesterov=True) + else: + raise Exception(f"{optimizer} not implemented") + + if scheduler == "step": + scheduler = optim.lr_scheduler.StepLR(optimizer, gamma=scheduler_params["gamma"], + step_size=scheduler_params["step_size"]) + elif scheduler == "multi_step": + scheduler = optim.lr_scheduler.MultiStepLR(optimizer, gamma=scheduler_params["gamma"], + milestones=scheduler_params["milestones"]) + elif scheduler == "cosine": + scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=scheduler_params["T_max"]) + elif scheduler == "reduce_on_plateau": + scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, + mode=scheduler_params["mode"], + patience=scheduler_params["patience"], + factor=scheduler_params["gamma"], + min_lr=1e-7, verbose=True, + threshold=1e-7) + elif scheduler is None: + scheduler = None + else: + raise Exception(f"{scheduler} is not implemented") + + if scheduler_params.get("load_on_reduce") is not None: + setattr(scheduler, "load_on_reduce", scheduler_params.get("load_on_reduce")) + return optimizer, scheduler diff --git a/lib/utils/os.py b/lib/utils/os.py new file mode 100644 index 0000000..c5e2fb0 --- /dev/null +++ b/lib/utils/os.py @@ -0,0 +1,323 @@ +""" general utility functions""" +import argparse +import importlib +import json +import logging +import os +import random +import re +import shutil +import sys +import typing +from argparse import ArgumentParser +from collections.abc import MutableMapping + +import numpy +import torch +from box import Box + +logger = logging.getLogger() + + +def listorstr(inp): + if len(inp) == 1: + return try_cast(inp[0]) + + for i, val in enumerate(inp): + inp[i] = try_cast(val) + return inp + + +def try_cast(text): + """ try to cast to int or float if possible, else return the text itself""" + result = try_int(text, None) + if result is not None: + return result + + result = try_float(text, None) + if result is not None: + return result + + return text + + +def try_float(text, default: typing.Optional[int] = 0.0): + result = default + try: + result = float(text) + except Exception as _: + pass + return result + + +def try_int(text, default: typing.Optional[int] = 0): + result = default + try: + result = int(text) + except Exception as _: + pass + return result + + +def parse_args(parser: ArgumentParser) -> Box: + # get defaults + defaults = {} + # taken from parser_known_args code + # add any action defaults that aren't present + for action in parser._actions: + if action.dest is not argparse.SUPPRESS: + if action.default is not argparse.SUPPRESS: + defaults[action.dest] = action.default + + # add any parser defaults that aren't present + for dest in parser._defaults: + defaults[dest] = parser._defaults[dest] + + # check if there is config & read config + args = parser.parse_args() + if vars(args).get("config") is not None: + # load a .py config + configFile = args.config + spec = importlib.util.spec_from_file_location("config", configFile) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + config = module.config + # merge config and override defaults + defaults.update({k: v for k, v in config.items()}) + + # override defaults with command line params + # this will get rid of defaults and only read command line args + parser._defaults = {} + parser._actions = {} + args = parser.parse_args() + defaults.update({k: v for k, v in vars(args).items()}) + + return boxify_dict(defaults) + + +def boxify_dict(config): + """ + this takes a flat dictionary and break it into sub-dictionaries based on "." seperation + a = {"model.a": 1, "model.b" : 2, "alpha" : 3} will return Box({"model" : {"a" :1, + "b" : 2}, alpha:3}) + a = {"model.a": 1, "model.b" : 2, "model" : 3} will throw error + """ + new_config = {} + # iterate over keys and split on "." + for key in config: + if "." in key: + temp_config = new_config + for k in key.split(".")[:-1]: + # create non-existent keys as dictionary recursively + if temp_config.get(k) is None: + temp_config[k] = {} + elif not isinstance(temp_config.get(k), dict): + raise TypeError(f"Key '{k}' has values as well as child") + temp_config = temp_config[k] + temp_config[key.split(".")[-1]] = config[key] + else: + if new_config.get(key) is None: + new_config[key] = config[key] + else: + raise TypeError(f"Key '{key}' has values as well as child") + + return Box(new_config) + + +# https://stackoverflow.com/questions/6027558/flatten-nested-dictionaries-compressing-keys +def flatten(d, parent_key='', sep='.'): + items = [] + for k, v in d.items(): + new_key = parent_key + sep + k if parent_key else k + if isinstance(v, MutableMapping): + items.extend(flatten(v, new_key, sep=sep).items()) + else: + items.append((new_key, v)) + return Box(dict(items)) + + +def str2bool(v: typing.Union[bool, str, int]) -> bool: + if isinstance(v, bool): + return v + if v.lower() in ("yes", "true", "t", "y", "1", 1): + return True + if v.lower() in ("no", "false", "f", "n", "0", 0): + return False + raise TypeError("Boolean value expected.") + + +def safe_isdir(dir_name): + return os.path.exists(dir_name) and os.path.isdir(dir_name) + + +def safe_makedirs(dir_name): + try: + os.makedirs(dir_name) + except OSError as e: + print(e) + + +def jsonize(x: object) -> typing.Union[str, dict]: + try: + temp = json.dumps(x) + return temp + except Exception as e: + return {} + + +def copy_code(folder_to_copy, out_folder, replace=False): + logger.info(f"copying {folder_to_copy} to {out_folder}") + + if os.path.exists(out_folder): + if not os.path.isdir(out_folder): + logger.error(f"{out_folder} is not a directory") + sys.exit() + else: + logger.info(f"Not deleting existing result folder: {out_folder}") + else: + os.makedirs(out_folder) + + # replace / with _ + folder_name = f'{out_folder}/{re.sub("/", "_", folder_to_copy)}' + + # create a new copy if something already exists + if not replace: + i = 1 + temp = folder_name + while os.path.exists(temp): + temp = f"{folder_name}_{i}" + i += 1 + folder_name = temp + else: + if os.path.exists(folder_name): + if os.path.isdir(folder_name): + shutil.rmtree(folder_name) + else: + raise FileExistsError("There is a file with same name as folder") + + logger.info(f"Copying {folder_to_copy} to {folder_name}") + shutil.copytree(folder_to_copy, folder_name) + + +def get_state_params(wandb_use, run_id, result_folder, statefile): + """This searches for model and run id in result folder + The logic is as follows + + if we are not given run_id there are four cases: + - we want to restart the wandb run but too lazy to look up run-id or/and statefile + - we want a new wandb fresh run + - we are not using wandb at all and need to restart + - we are not using wandb and need a fresh run + + Case 1/3: + - If we want to restart the run, we expect the result_folder name to end with + /run_. + - In this case, if we are using wandb then we need to go inside wandb folder, list all + directory and pick up run id and (or) statefile + - If we are not using wandb we just look for model inside the run_ folder and + return statefile, run id as none + + case 2/4: + if not 1/3, it is case 2/4 + + This is expected to be a fail safe script. i.e any of run_id or statefile may not be specified + and relies on whims of the user _-_ + """ + # if not resume get run number and create result_folder/run_{run_num} + # if someone is resuming we expect them to give the exact folder name upto run num. + + # this part of code searches for run_id i.e will work only if we are using wandb + if run_id is None: + # if result folder if of type folder/run_, then search for current checkpoint and + # run-id else we will just create a new run with run_ + + regex = r"^.*/?run_[0-9]+/?$" + if re.match(regex, result_folder): + + # search for checkpoint and run-id if using wandb + if wandb_use: + # search in wandb folder if it exists else we want a new run + if os.path.exists(f"{result_folder}/wandb/"): + # case 1 + for folder in sorted(os.listdir(f"{result_folder}/wandb/"), reverse=True): + # assume run_<##> will have only single run, + # also no other crap in this folder + if os.path.exists(f"{result_folder}/wandb/{folder}/current_model.pt"): + run_id = folder.split("-")[-1] + logger.info(f"using run id {run_id}") + # we are done break out of for loop + break + else: + # case 3 + # if not using wandb search within run_ directory + logger.info(f"not using wandb") + if os.path.exists(f"{result_folder}/current_model.pt"): + statefile = f"{result_folder}/current_model.pt" + logger.info(f"using statefile {statefile}") + else: + # just start a new run + pass + else: + # trailing is not run_; that means user wants a new fresh run + # so we give a fresh run and create a new folder + # case 2/4 + last_run_num = max( + [0] + [try_int(i[-4:]) for i in os.listdir(result_folder)]) + 1 + result_folder = f"{result_folder}/run_{last_run_num:04d}" + logger.info(f"Creating new run with {result_folder}") + safe_makedirs(result_folder) + + # search for last checkpoint in case --statefile is none and we are resuming + if run_id is not None and statefile is None: + folders = sorted(os.listdir(f"{result_folder}/wandb"), reverse=True) + for folder in folders: + if run_id in folder: + # check for current_model.pt + if os.path.exists(f"{result_folder}/wandb/{folder}/current_model.pt"): + statefile = f"{result_folder}/wandb/{folder}/current_model.pt" + logger.info(f"Using state file {statefile} and run id {run_id}") + break + if statefile is None: + raise Exception("Did not find statefile, exiting!!") + return statefile, run_id, result_folder + + +if __name__ == "__main__": + # test boxify_dict + a = {"model.a": 1, "m odel.b": 2, "alpha": 3} + print(boxify_dict(a)) + + try: + a = {"model.a": 1, "model.b": 2, "model": 3} + print(boxify_dict(a)) + except Exception as e: + print(e) + + try: + a = {"model": 4, "model.a": 1, "model.b": 2, "model": 3} + print(boxify_dict(a)) + except Exception as e: + print(e) + + try: + a = {"model.a": 1, "model": 4, "model.b": 2, "model": 3} + print(boxify_dict(a)) + except Exception as e: + print(e) + + try: + a = {"model": {"attr1": 1, "attr2": {"attr_attr_3": 3}}, "train": 10} + print(flatten(a)) + except Exception as e: + print(e) + + +def set_seed(seed): + if isinstance(seed, list): + torch_seed, numpy_seed, random_seed = seed + else: + torch_seed, numpy_seed, random_seed = seed, seed, seed + + torch.manual_seed(torch_seed) + numpy.random.seed(numpy_seed) + random.seed(random_seed) diff --git a/lib/utils/samplers.py b/lib/utils/samplers.py new file mode 100644 index 0000000..45badb8 --- /dev/null +++ b/lib/utils/samplers.py @@ -0,0 +1,51 @@ +""" torch samplers for different distributions""" + +import numpy as np +import torch +from scipy.linalg import circulant + + +def sample_gaussian(mean, sigma, tril_sigma=False): + noise = torch.randn_like(mean) + + # we getting sigma + if tril_sigma: + z_sample = torch.bmm(sigma, noise.unsqueeze(dim=2)).squeeze() + mean + else: + z_sample = noise * sigma + mean + return z_sample + + +def sample_echo(f, s, m=None, replace=False, pop=True): + """ + f, s : are the outputs of encoder (shape : [B, Z] for f) + s is shape [B, Z] or [B, Z, Z] + tril_sigma: if we have s as diagonal matrix or lt matrix + m : number of samples to consider to generate noise when replace + is true (default to batch_size) + replace : sampling with replacement or not (if sampling with + replacement, pop is not considered) + pop: If true, remove the sample to which noise is being added + detach_noise_grad : detach gradient of noise or not + """ + batch_size, z_size = f.shape[0], f.shape[1:] + + # get indices + if not replace: + indices = circulant(np.arange(batch_size)) + if pop: + # just ignore the first column + indices = indices[:, 1:] + for i in indices: + np.random.shuffle(i) + else: + m = batch_size if m is None else m + indices = np.random.choice(batch_size, size=(batch_size, m), replace=True) + + f_arr = f[indices.reshape(-1)].view(indices.shape + z_size) + s_arr = s[indices.reshape(-1)].view(indices.shape + z_size) + + epsilon = f_arr[:, 0] + torch.sum(f_arr[:, 1:] * torch.cumprod(s_arr[:, :-1], dim=1), dim=1) + + z_sample = f + s * epsilon + return z_sample diff --git a/lib/utils/torch_utils.py b/lib/utils/torch_utils.py new file mode 100644 index 0000000..0849dd7 --- /dev/null +++ b/lib/utils/torch_utils.py @@ -0,0 +1,99 @@ +""" things torch should have but it doesn't""" +import logging + +import torch +import torch.nn as nn +from torch.autograd import Function + +logger = logging.getLogger() +EPSILON = 1e-8 + + +# reset seed +def reset_seed(): + while True: + try: + torch.seed() + except RuntimeError as _: + logger.error("Error generating seed") + else: + break + + +class Reshape(nn.Module): + """ + Reshape module that reshapes any input to (batch_size, ...shape) + by default it does flattening but you can pass any shape. + """ + + def __init__(self, shape=(-1,)): + super().__init__() + self.shape = shape + + def forward(self, x): + batch_size = x.shape[0] + return x.view((batch_size,) + self.shape) + + def extra_repr(self): + return f"shape={self.shape}" + + +class Offset(torch.nn.Module): + def __init__(self, offset, net): + super().__init__() + self.offset = nn.Parameter(offset, requires_grad=False) + self.net = net + + def forward(self, *args): + batch_size = args[0].shape[0] + return self.offset.expand((batch_size, -1, -1, -1)) + 1e-8 # + self.net(*args) + + +def batch_eye(N, D, device="cpu"): + x = torch.eye(D, device=device) + x = x.unsqueeze(0) + x = x.repeat(N, 1, 1) + return x + + +def batch_eye_like(tensor): + assert len(tensor.shape) == 3 and tensor.shape[1] == tensor.shape[2] + N = tensor.shape[0] + D = tensor.shape[1] + return batch_eye(N, D, device=tensor.device) + + +class _RevGrad(Function): + @staticmethod + def forward(ctx, input_): + ctx.save_for_backward(input_) + output = input_ + return output + + @staticmethod + def backward(ctx, grad_output): + grad_input = None + if ctx.needs_input_grad[0]: + grad_input = -grad_output + return grad_input + + +revgrad = _RevGrad.apply + + +class RevGrad(nn.Module): + def __init__(self, *args, **kwargs): + """ + A gradient reversal layer. + This layer has no parameters, and simply reverses the gradient + in the backward pass. + """ + super().__init__(*args, **kwargs) + + def forward(self, input_): + return revgrad(input_) + + +def infer_shape(net, input_shape): + x = torch.rand((2,) + input_shape) + return net(x).shape[1:]