import torch

from torchvision import datasets, transforms

from trainer import Trainer
from config import get_config
from utils import prepare_dirs
from data_loader import get_test_loader, get_train_valid_loader, VIEWPOINT_EXPS


torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False


def main(config):

    # ensure directories are setup
    prepare_dirs(config)

    # ensure reproducibility
    torch.manual_seed(config.random_seed)
    kwargs = {}
    if torch.cuda.is_available():
        torch.cuda.manual_seed(config.random_seed)
        kwargs = {'num_workers': 4, 'pin_memory': False}

    # instantiate data loaders
    if config.is_train:
        data_loader = get_train_valid_loader(
            config.data_dir, config.dataset, config.batch_size,
            config.random_seed, config.exp, config.valid_size,
            config.shuffle, **kwargs
        )
    else:
        data_loader = get_test_loader(
            config.data_dir, config.dataset, config.batch_size, config.exp, config.familiar,
            **kwargs
        )

    # instantiate trainer
    trainer = Trainer(config, data_loader)

    if config.is_train:
            trainer.train()
    else:
        if config.attack:
            trainer.test_attack()
        else:
            trainer.test()

if __name__ == '__main__':
    config, unparsed = get_config()
    main(config)