From ed19236440ab927963a80e32e5fd8db369ff146e Mon Sep 17 00:00:00 2001 From: Sidd Karamcheti Date: Mon, 6 Mar 2023 17:38:46 -0800 Subject: [PATCH] Add Sth-Sth-v2 Preprocessing & XLA Pretraining Script (#10) * Update xpretrain documentation * Update README * Add Sth-Sth-v2 Preprocessing Pipeline * Add model initialization stub * Add full XLA pretraining pipeline * Add v1.0.0 with full preprocessing/XLA pretraining pipeline --- README.md | 4 +- examples/xla-reference/README.md | 2 +- examples/xla-reference/xpretrain.py | 833 ++++++++++++++++++++++- pyproject.toml | 6 +- voltron/conf/__init__.py | 3 + voltron/conf/accelerators.py | 52 ++ voltron/conf/datasets.py | 6 +- voltron/conf/models.py | 435 ++++++++++++ voltron/conf/tracking.py | 44 ++ voltron/preprocessing/stream_datasets.py | 833 +++++++++++++++++++++++ voltron/preprocessing/transforms.py | 2 +- voltron/util/checkpointing.py | 114 ++++ voltron/util/xla_logger.py | 306 +++++++++ 13 files changed, 2632 insertions(+), 8 deletions(-) create mode 100644 voltron/conf/accelerators.py create mode 100644 voltron/conf/models.py create mode 100644 voltron/conf/tracking.py create mode 100644 voltron/preprocessing/stream_datasets.py create mode 100644 voltron/util/checkpointing.py create mode 100644 voltron/util/xla_logger.py diff --git a/README.md b/README.md index 44f6117..f5b4822 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,5 @@
- Voltron Logo + Voltron Logo
@@ -83,7 +83,7 @@ our paper. ## API -![Voltron Framework](https://github.com/siddk/voltron-robotics/blob/main/docs/assets/voltron-framework.png) +![Voltron Framework](https://raw.githubusercontent.com/siddk/voltron-robotics/main/docs/assets/voltron-framework.png) The package `voltron` provides the following functionality for using and adapting existing representations: diff --git a/examples/xla-reference/README.md b/examples/xla-reference/README.md index d90d831..1ebeee8 100644 --- a/examples/xla-reference/README.md +++ b/examples/xla-reference/README.md @@ -9,4 +9,4 @@ To get things to work, we had to add some non-intuitive code to facilitate PyTor data parallel training pipeline). As a result, `xpretrain.py` is here mostly for documentation purposes, with a fully refactored version `pretrain.py` forthcoming. -We also include the original cloud preprocesssing script `xpreprocess.py` for completeness. +We also include the original cloud preprocessing script `xpreprocess.py` for completeness (this is more general). diff --git a/examples/xla-reference/xpretrain.py b/examples/xla-reference/xpretrain.py index c234e50..774a072 100644 --- a/examples/xla-reference/xpretrain.py +++ b/examples/xla-reference/xpretrain.py @@ -1,6 +1,835 @@ """ xpretrain.py -TODO :: Reference script for PyTorch XLA (TPU-based) pretraining on the non-Qualcomm version of Sth-Sth-v2; this is - mostly for completeness =>> the hope is that the regular `pretrain.py` script is more general and maintained. +(The `x` prefix indicates this is a script geared for XLA/TPU backends *only*)! + +Reference script for PyTorch XLA (TPU-based) pretraining on the non-Qualcomm version of Sth-Sth-v2; this is +mostly for completeness =>> the hope is that the regular `pretrain.py` script is more general and maintained. + +Focuses on multi-TPU (XLA) training --> but also supports single-core TPU training, as the default distributed mp.spawn +behavior just collapses into a single thread! Loads and preprocesses dataset, instantiates a model, and runs training. + +Run with `python xpretrain.py` (will by default use the configuration specified by `DEFAULTS` below). """ +import os +import re +import time +from collections import deque +from dataclasses import dataclass, field +from datetime import datetime +from pathlib import Path +from typing import Any, Dict, List, Optional + +import hydra +import jsonlines +import numpy as np +import torch +import torch_xla.core.xla_model as xm +import torch_xla.debug.metrics as met +import torch_xla.distributed.parallel_loader as parallel +import wandb +from hydra.core.config_store import ConfigStore +from omegaconf import MISSING, OmegaConf +from torch.utils.data import DataLoader +from tqdm import tqdm + +from voltron.conf import AcceleratorConfig, DatasetConfig, ModelConfig, TrackingConfig +from voltron.models import VMVP, VR3M, VRN3M, VCond, VDual, VGen +from voltron.overwatch import OverwatchRich +from voltron.preprocessing.stream_datasets import get_epoch_datasets +from voltron.util import set_global_seed +from voltron.util.checkpointing import CheckpointSaver +from voltron.util.distributed import ResumeableDistributedSampler +from voltron.util.xla_logger import ( + log_epoch_end_update, + log_vcond_train_update, + log_vdual_train_update, + log_vgen_train_update, + log_vmvp_train_update, + log_vr3m_train_update, + log_vrn3m_train_update, +) + +# Set Defaults (Hydra w/ Structured Configs) +DEFAULTS = [ + "_self_", + {"model": "v-cond"}, + {"dataset": "sth-sth-v2"}, + {"accelerator": "tpu-v3-8"}, + {"tracking": "voltron-tracking"}, + {"override hydra/job_logging": "overwatch_rich"}, +] + + +@dataclass +class PretrainConfig: + # fmt: off + defaults: List[Any] = field(default_factory=lambda: DEFAULTS) + hydra: Dict[str, Any] = field(default_factory=lambda: { + "run": {"dir": "./runs/train/${model.identifier}+dataset-${dataset.name}"} + }) + + # Command Line Arguments + run_id: Optional[str] = None # Run ID for Logging + seed: int = 21 # Random Seed (for reproducibility) + + # Resume / Debug Behavior + resume: bool = True # Whether to resume an existing run... + resume_epoch: Optional[int] = None # Epoch to resume (if auto-resuming)... + checkpoint_path: Optional[str] = None # Path to the specific checkpoint to load! + wandb_resume_id: Optional[str] = None # W&B Run ID for `resume` behavior... + + # Composable / Structured Arguments + model: ModelConfig = MISSING # Model architecture for pretraining + dataset: DatasetConfig = MISSING # List of datasets for pretraining + accelerator: AcceleratorConfig = MISSING # Accelerator configuration + tracking: TrackingConfig = MISSING # Run/experiment tracking configuration + # fmt: on + + +# Hydra Setup :: Retrieve ConfigStore (Singleton) & Register Components +cs = ConfigStore.instance() +cs.store(group="hydra/job_logging", name="overwatch_rich", node=OverwatchRich) # Annoying - configure logger for Hydra +cs.store(name="config", node=PretrainConfig) + + +# ruff: noqa: C901 +def xpretrain(cfg: PretrainConfig) -> None: + # Identify if `is_rank_zero` --> We only log from the rank zero process! + is_rank_zero = xm.is_master_ordinal(local=False) + xm.master_print("Voltron Training :: Assembling the Legendary Defender...") + + # Create Unique Run Name -- if `resume = True` we assume same "run_id" + run_id = cfg.run_id + if run_id is None: + run_id = run_dir = f"{cfg.model.identifier}+{cfg.dataset.name}-x{cfg.seed}" + cfg.run_id = run_id + else: + cfg.run_id = run_dir = run_id + + if is_rank_zero: + os.makedirs(run_dir, exist_ok=True) + + xm.master_print( + '\t=>> "If you get too worried about what could go wrong, you might miss a chance to do something great."' + ) + + # Set Randomness, get DataLoader worker initialization function (to ensure any random augmentations!) + worker_init_fn = set_global_seed(cfg.seed) + + # Model Initialization Logic + xm.master_print("Initializing Model and Placing on Different Devices...") + if cfg.model.arch == "v-mvp": + xm.master_print(f"Initializing MVP variant `{cfg.model.identifier}`") + model = VMVP( + resolution=cfg.dataset.resolution, + patch_size=cfg.model.patch_size, + encoder_depth=cfg.model.encoder_depth, + encoder_embed_dim=cfg.model.encoder_embed_dim, + encoder_n_heads=cfg.model.encoder_n_heads, + decoder_depth=cfg.model.decoder_depth, + decoder_embed_dim=cfg.model.decoder_embed_dim, + decoder_n_heads=cfg.model.decoder_n_heads, + optimizer=cfg.model.optimizer, + schedule=cfg.model.schedule, + base_lr=cfg.model.base_lr, + min_lr=cfg.model.min_lr, + effective_bsz=cfg.model.effective_bsz, + betas=cfg.model.betas, + weight_decay=cfg.model.weight_decay, + warmup_epochs=cfg.dataset.warmup_epochs, + max_epochs=cfg.dataset.max_epochs, + mlp_ratio=cfg.model.mlp_ratio, + norm_pixel_loss=cfg.model.norm_pixel_loss, + ) + + elif cfg.model.arch == "v-r3m": + xm.master_print(f"Initializing R3M (ViT) Variant `{cfg.model.identifier}`") + model = VR3M( + resolution=cfg.dataset.resolution, + patch_size=cfg.model.patch_size, + depth=cfg.model.depth, + embed_dim=cfg.model.embed_dim, + n_heads=cfg.model.n_heads, + language_model=cfg.model.language_model, + hf_cache=cfg.model.hf_cache, + language_dim=cfg.model.language_dim, + reward_dim=cfg.model.reward_dim, + n_negatives=cfg.model.n_negatives, + lang_reward_weight=cfg.model.lang_reward_weight, + tcn_weight=cfg.model.tcn_weight, + l1_weight=cfg.model.l1_weight, + l2_weight=cfg.model.l2_weight, + optimizer=cfg.model.optimizer, + schedule=cfg.model.schedule, + lr=cfg.model.lr, + min_lr=cfg.model.min_lr, + warmup_epochs=cfg.dataset.warmup_epochs, + max_epochs=cfg.dataset.max_epochs, + mlp_ratio=cfg.model.mlp_ratio, + ) + + elif cfg.model.arch == "v-rn3m": + xm.master_print(f"Intializing R3M (ResNet) Variant `{cfg.model.identifier}`") + model = VRN3M( + resolution=cfg.dataset.resolution, + fc_dim=cfg.model.fc_dim, + language_model=cfg.model.language_model, + hf_cache=cfg.model.hf_cache, + language_dim=cfg.model.language_dim, + reward_dim=cfg.model.reward_dim, + n_negatives=cfg.model.n_negatives, + lang_reward_weight=cfg.model.lang_reward_weight, + tcn_weight=cfg.model.tcn_weight, + l1_weight=cfg.model.l1_weight, + l2_weight=cfg.model.l2_weight, + optimizer=cfg.model.optimizer, + lr=cfg.model.lr, + ) + + elif cfg.model.arch == "v-cond": + xm.master_print(f"Initializing Voltron V-Cond variant `{cfg.model.identifier}`") + model = VCond( + resolution=cfg.dataset.resolution, + patch_size=cfg.model.patch_size, + encoder_depth=cfg.model.encoder_depth, + encoder_embed_dim=cfg.model.encoder_embed_dim, + encoder_n_heads=cfg.model.encoder_n_heads, + decoder_depth=cfg.model.decoder_depth, + decoder_embed_dim=cfg.model.decoder_embed_dim, + decoder_n_heads=cfg.model.decoder_n_heads, + language_model=cfg.model.language_model, + hf_cache=cfg.model.hf_cache, + language_dim=cfg.model.language_dim, + optimizer=cfg.model.optimizer, + schedule=cfg.model.schedule, + base_lr=cfg.model.base_lr, + min_lr=cfg.model.min_lr, + effective_bsz=cfg.model.effective_bsz, + betas=cfg.model.betas, + weight_decay=cfg.model.weight_decay, + warmup_epochs=cfg.dataset.warmup_epochs, + max_epochs=cfg.dataset.max_epochs, + mlp_ratio=cfg.model.mlp_ratio, + norm_pixel_loss=cfg.model.norm_pixel_loss, + ) + + elif cfg.model.arch == "v-dual": + xm.master_print(f"Initializing Voltron V-Dual variant `{cfg.model.identifier}`") + model = VDual( + resolution=cfg.dataset.resolution, + patch_size=cfg.model.patch_size, + encoder_depth=cfg.model.encoder_depth, + encoder_embed_dim=cfg.model.encoder_embed_dim, + encoder_n_heads=cfg.model.encoder_n_heads, + decoder_depth=cfg.model.decoder_depth, + decoder_embed_dim=cfg.model.decoder_embed_dim, + decoder_n_heads=cfg.model.decoder_n_heads, + language_model=cfg.model.language_model, + hf_cache=cfg.model.hf_cache, + language_dim=cfg.model.language_dim, + optimizer=cfg.model.optimizer, + schedule=cfg.model.schedule, + base_lr=cfg.model.base_lr, + min_lr=cfg.model.min_lr, + effective_bsz=cfg.model.effective_bsz, + betas=cfg.model.betas, + weight_decay=cfg.model.weight_decay, + warmup_epochs=cfg.dataset.warmup_epochs, + max_epochs=cfg.dataset.max_epochs, + mlp_ratio=cfg.model.mlp_ratio, + norm_pixel_loss=cfg.model.norm_pixel_loss, + ) + + elif cfg.model.arch == "v-gen": + xm.master_print(f"Initializing Voltron V-Gen variant `{cfg.model.identifier}`") + model = VGen( + resolution=cfg.dataset.resolution, + patch_size=cfg.model.patch_size, + encoder_depth=cfg.model.encoder_depth, + encoder_embed_dim=cfg.model.encoder_embed_dim, + encoder_n_heads=cfg.model.encoder_n_heads, + decoder_depth=cfg.model.decoder_depth, + decoder_embed_dim=cfg.model.decoder_embed_dim, + decoder_n_heads=cfg.model.decoder_n_heads, + language_model=cfg.model.language_model, + hf_cache=cfg.model.hf_cache, + language_dim=cfg.model.language_dim, + max_lang_len=cfg.dataset.max_lang_len, + vocab_size=cfg.model.vocab_size, + mae_weight=cfg.model.mae_weight, + lm_weight=cfg.model.lm_weight, + optimizer=cfg.model.optimizer, + schedule=cfg.model.schedule, + base_lr=cfg.model.base_lr, + min_lr=cfg.model.min_lr, + effective_bsz=cfg.model.effective_bsz, + betas=cfg.model.betas, + weight_decay=cfg.model.weight_decay, + warmup_epochs=cfg.dataset.warmup_epochs, + max_epochs=cfg.dataset.max_epochs, + mlp_ratio=cfg.model.mlp_ratio, + norm_pixel_loss=cfg.model.norm_pixel_loss, + ) + + else: + raise NotImplementedError(f"Model Architecture `{cfg.model.arch}` is not supported!") + + # We use gradient accumulation to honor the effective batch size specified... + assert cfg.model.effective_bsz % cfg.model.device_bsz == 0, "Device bsz must evenly divide effective bsz!" + accumulate_grad_batches = cfg.model.effective_bsz // cfg.model.device_bsz // xm.xrt_world_size() + xm.master_print( + f"Running `{cfg.model.identifier}` model w/ Effective Batch Size of `{cfg.model.effective_bsz}`, " + f"Per-Device Batch Size of `{cfg.model.device_bsz}`, " + f"Distributed World Size of `{xm.xrt_world_size()}` and `{accumulate_grad_batches}` Accumulation Steps" + ) + + # If Resuming =>> Load Model from Checkpoint + start_checkpoint, start_epoch, start_step = None, 0, 0 + if cfg.resume: + # **IMPORTANT**: We're making a few assumptions on resuming that should eventually become explicit checks: + # - `accumulate_grad_batches` is exactly the same when resuming; this means: + # + `cfg.model.effective_bsz`, `cfg.model.device_bsz`, & `cfg.accelerator.num_accelerators` are the same! + # - The Weights & Biases directory `run_dir/wandb` only contains a *single run* + # - The `param_groups` in `optimizer.state_dict()` are exactly the same across resumes! + # + This means that (and generally should be true for resuming altogether) the architecture is the same! + # - The `cfg.seed` should be the same (again, should generally be true...) + if cfg.checkpoint_path is None: + xm.master_print("Resuming :: Attempting to Automatically Load Checkpoint -- Searching!") + checkpoint_path = Path(run_dir) / "checkpoints" + if checkpoint_path.exists() and any(checkpoint_path.iterdir()): + # Parse out the latest "complete" epoch checkpoint, as well as any "local step" checkpoints... + checkpoints = list(checkpoint_path.iterdir()) + complete_checkpoint, complete_epoch = max( + [ + (c, int(re.search("epoch=(.+?)-train", c.name).group(1))) + for c in checkpoints + if "local-epoch=" not in str(c) + ], + key=lambda x: x[1], + ) + + # Case 1 :: We have "local step" checkpoints --> will always override any "full epoch" checkpoints... + local = [ + ( + c, + int(re.search("local-epoch=(.+?)-step", c.name).group(1)), + int(re.search("step=(.+?)[.-]", c.name).group(1)), + ) + for c in checkpoints + if "local-epoch=" in str(c) + ] + if len(local) > 0: + # Parse out (epoch, "highest" step) + assert no great "full epoch" checkpoint exists! + start_checkpoint, start_epoch, start_step = max(local, key=lambda x: x[1:]) + assert start_epoch == complete_epoch, "Epoch mismatch in `resume` from local_step!" + + # Case 2 :: Otherwise, we're just going to start with the last "complete" epoch... + else: + start_checkpoint, start_epoch = complete_checkpoint, complete_epoch + + else: + xm.master_print("No Checkpoints Found -- Starting Run from Scratch!") + + else: + xm.master_print(f"Resuming :: Loading from Checkpoint `{cfg.checkpoint_path}`...") + start_checkpoint = cfg.checkpoint_path + + # Actually Load the Checkpoint State! + if start_checkpoint is not None: + xm.master_print(f"Resuming :: Loading Model & Optimizer State Dictionaries from `{start_checkpoint}`") + checkpoint = torch.load(str(start_checkpoint)) + model_state_dict, optimizer_state_dict = checkpoint + model.load_state_dict(model_state_dict) + + # Logging / W&B Handling + if is_rank_zero: + xm.master_print("Initializing Weights & Biases + JSONL + Checkpoint Saver on Rank Zero ONLY...") + tags = None + if cfg.tracking.tags is None: + tags = [cfg.model.identifier, cfg.dataset.name, "pretraining"] + + # W&B Initialize & Log all Hyperparameters (Only on ordinal 0) + wandb_resume_id = None + if cfg.resume and cfg.wandb_resume_id is None: + xm.master_print("Resuming :: Attempting to Automatically Load W&B Resume ID -- Searching!") + wandb_path = Path("wandb") + if wandb_path.exists() and any((wandb_path / "latest-run").iterdir()): + # Parse out the unique resume_id from the `.wandb` file... + wandb_fns = [f.name for f in (wandb_path / "latest-run").iterdir() if str(f).endswith(".wandb")] + assert len(wandb_fns) == 1, f"There should only be 1 `.wandb` file... found {len(wandb_fns)}!" + + # Regex match on `run-{id}.wandb`... + wandb_resume_id = re.search("run-(.+?).wandb", wandb_fns[0]).group(1) + + # Otherwise, assert that we're starting from scratch! + else: + assert start_checkpoint is None, "Trying to restart a run from checkpoint without a valid W&B ID!" + + elif cfg.resume: + xm.master_print(f"Resuming :: Using Specified W&B Resume ID = `{cfg.wandb_resume_id}`") + wandb_resume_id = cfg.wandb_resume_id + + # Initialize Weights & Biases + xm.master_print(f"W&B Resume is {cfg.resume} w/ W&B Resume ID = {wandb_resume_id}!") + wandb.init( + project=cfg.tracking.project, + entity=cfg.tracking.entity, + config=cfg, + name=run_id, + dir=f"{os.getcwd()}" if cfg.tracking.dir is None else cfg.tracking.dir, + tags=tags, + notes=cfg.tracking.notes, + resume="allow" if start_checkpoint is not None else False, + id=wandb_resume_id, + # Weird line because PT-TPU VMs don't come with a working install of Tensorflow... + settings=wandb.Settings(_disable_stats=True), + ) + + # Initialize JSONL Logger (append only mode) --> last "global step" will always take precedence. + with jsonlines.open(f"{run_id}.jsonl", mode="a", sort_keys=True) as js_logger: + js_logger.write( + { + "run_id": run_id, + "start_time": datetime.now().strftime("%m-%d-%H:%M"), + "hparams": OmegaConf.to_container(cfg), + } + ) + + # Rank Zero Node will take time to spin up the loggers & checkpointer... might as well rendezvous? + xm.rendezvous("Logging...") + + # === Here Be Dragons === + # Time to handle device placement -- Note - example code doesn't specify device idx - why not? + # > https://github.com/pytorch/xla/blob/3c0d68da07702995a592ea70f27868cd76fa0755/test/test_train_mp_mnist.py#L114 + # > Results in printing [xla:0] and [xla:1] a bunch... no [xla:2-7]? This feels bad...? + # + # |=> Debugging Try: `xm.xla_device(n=xm.get_ordinal()) ---> hangs completely? + # +=> *ANSWER*: https://github.com/pytorch/xla/issues/2345#issuecomment-657114819 + # >> "Make no assumptions and don't try to build them manually..." + device = xm.xla_device() + model = model.train().to(device) + optimizer, update_lr = model.configure_optimizer() + global_step, train_losses, lrs, start_time, resume_time = 0, deque(maxlen=128), [], time.time(), 0 + + # If resuming (valid `start_checkpoint`) -- patch the optimizer state dictionary, and load! + if start_checkpoint is not None: + patched_optimizer_state_dict = { + "state": optimizer_state_dict, + "param_groups": optimizer.state_dict()["param_groups"], + } + optimizer.load_state_dict(patched_optimizer_state_dict) + + # Create step timing... + step_times, step_start_time = deque(maxlen=128), time.time() + + # Create Model/Architecture-Specific Trackers... + if cfg.model.arch == "v-mvp": + reconstruction_losses = deque(maxlen=128) + + elif cfg.model.arch in {"v-r3m", "v-rn3m"}: + tcn_losses, reward_losses, l1_losses, l2_losses = [deque(maxlen=128) for _ in range(4)] + tcn_accuracies, reward_accuracies = [deque(maxlen=128) for _ in range(2)] + + elif cfg.model.arch == "v-cond": + reconstruction_losses = deque(maxlen=128) + + elif cfg.model.arch == "v-dual": + reconstruction_losses = deque(maxlen=128) + zero_reconstruction, k_reconstruction = deque(maxlen=128), deque(maxlen=128) + + elif cfg.model.arch == "v-gen": + reconstruction_losses, lm_losses, lm_ppl = deque(maxlen=128), deque(maxlen=128), deque(maxlen=128) + zero_reconstruction, k_reconstruction = deque(maxlen=128), deque(maxlen=128) + + else: + raise NotImplementedError(f"Trackers for Model `{cfg.model.arch}` not implemented!") + + # 0th Checkpoint - Pull out optimizer state explicitly (`groups` are not serializable & can easily be replicated) + saver = CheckpointSaver(cfg.tracking.checkpoint_strategy, run_dir, cfg.accelerator.accelerator) + if start_checkpoint is None and start_epoch == 0: + xm.master_print("Saving 0th Epoch Checkpoint...") + saver.save( + epoch=0, is_local_step=False, model=model, optimizer=optimizer, duration=0, train_loss=None, val_loss=None + ) + + # Run on all processes --> retrieve "0th epoch" dataset! + # =>> Important, ensures data is locked across models, for the given epoch! + xm.master_print(f"Retrieving Dataset `{cfg.dataset.name}` prepared for `{cfg.model.arch}`!") + train_dataset, val_dataset = get_epoch_datasets( + 0, + cfg.dataset.name, + cfg.dataset.normalization, + cfg.model.arch, + cfg.dataset.stream, + cfg.dataset.artifact_path, + cfg.dataset.stream_prefix, + cfg.model.data_modality, + cfg.model.get("lang_dropout", None), + cfg.model.get("gen_ratio", None), + ) + + # Loading Datasets might take time... rendezvous to be safe + xm.rendezvous("Retrieved Datasets...") + + # Iterate through Epochs, Evaluating at the end of each Training Epoch! + # >> Everything in this loop should happen across all workers, except for the logging (ordinal 0)! + xm.master_print("Starting Training Loop...") + for epoch in range(start_epoch, cfg.dataset.max_epochs): + xm.master_print(f"\t[Epoch {epoch}] Building Distributed Sampler & DataLoaders...") + train_dataset.set_epoch(epoch) + + # ResumeableDistributedSampler operates at over *examples* --> start_step (full_batches) * effective_bsz + seen_examples = start_step * cfg.model.effective_bsz + train_sampler = ResumeableDistributedSampler( + seen_examples, + start_epoch, + train_dataset, + num_replicas=xm.xrt_world_size(), + rank=xm.get_ordinal(), + shuffle=True, + seed=cfg.seed, + ) + + # Set epoch appropriately for the `train_sampler` --> necessary to trigger "normal" logic! + train_sampler.set_epoch(epoch) + + train_dataloader = DataLoader( + train_dataset, + batch_size=cfg.model.device_bsz, + sampler=train_sampler, + shuffle=False if train_sampler else True, + num_workers=cfg.accelerator.num_workers, + drop_last=True, + worker_init_fn=worker_init_fn, + prefetch_factor=4, + ) + + # NOTE :: We're not sharding the Validation set --> *everybody* will run forward passes on the *same* data! + # > We will have to reduce_mesh() later... unclear why, but the torch_xla folks seem keen on it; might lead to + # > weird rendezvous/hang issues if Validation is big enough... + val_dataloader = DataLoader( + val_dataset, + batch_size=cfg.model.device_bsz, + shuffle=False, + num_workers=4, + drop_last=True, + worker_init_fn=worker_init_fn, + ) + + # Initializing the Dataloaders might take time depending on process... + xm.rendezvous("Initialized Dataloaders...") + + # Leverage the *special* API that handles synchronizing TPU cores across batches! + # > NOTE: This is super important! + xm.master_print("\tSetting up Parallel MpDeviceLoaders...") + train_device_loader = parallel.MpDeviceLoader(train_dataloader, device) + val_device_loader = parallel.MpDeviceLoader(val_dataloader, device) + + # Book-keeping & LR setting when `resuming` --> only do this on start_epoch! + if epoch == start_epoch: + if start_checkpoint is not None: + global_step = start_step + ((len(train_dataset) // cfg.model.effective_bsz) * start_epoch) + resume_time = int(re.search("-t=(.+?).pt", str(start_checkpoint)).group(1)) + lrs.append(update_lr(start_epoch, start_step / (len(train_dataset) // cfg.model.effective_bsz))) + else: + lrs.append(update_lr(start_epoch, 0)) + + # Iterate... + step_start_time = time.time() + with tqdm(total=len(train_device_loader) // accumulate_grad_batches, disable=not is_rank_zero) as progress: + for train_idx, batch in enumerate(train_device_loader): + if cfg.model.arch == "v-mvp": + # Run a forward pass through the MAE... other return vals are reconstructions (pixel norm) & mask + loss, _, _ = model(batch) + reconstruction_losses.append(loss) + + elif cfg.model.arch in {"v-r3m", "v-rn3m"}: + imgs, lang, lang_mask = batch + loss, tcn_loss, reward_loss, l1_loss, l2_loss, tcn_acc, rew_acc = model(imgs, lang, lang_mask) + + # Add to trackers + tcn_losses.append(tcn_loss) + reward_losses.append(reward_loss) + l1_losses.append(l1_loss) + l2_losses.append(l2_loss) + tcn_accuracies.append(tcn_acc) + reward_accuracies.append(rew_acc) + + elif cfg.model.arch == "v-cond": + img, lang, lang_mask = batch + loss, _, _ = model(img, lang, lang_mask) + reconstruction_losses.append(loss) + + elif cfg.model.arch == "v-dual": + imgs, lang, lang_mask = batch + loss, [zero_loss, k_loss] = model(imgs, lang, lang_mask) + + # Add to trackers + reconstruction_losses.append(loss) + zero_reconstruction.append(zero_loss) + k_reconstruction.append(k_loss) + + elif cfg.model.arch == "v-gen": + imgs, lang_con, lang_con_mask, lang_gen, lang_gen_mask, lang_gen_weight = batch + loss, reconstruction_loss, lm_loss, [zero_loss, k_loss] = model( + imgs, lang_con, lang_con_mask, lang_gen, lang_gen_mask, lang_gen_weight + ) + + # Add to trackers + reconstruction_losses.append(reconstruction_loss) + lm_losses.append(lm_loss) + lm_ppl.append(torch.exp(lm_loss)) + zero_reconstruction.append(zero_loss) + k_reconstruction.append(k_loss) + + else: + raise NotImplementedError(f"Forward Pass Logic for Model `{cfg.model.arch}` not implemented!") + + # Write Loss to Loggers (prior to accumulation normalization) + train_losses.append(loss) + + # Normalize loss to account for accumulation + loss = loss / accumulate_grad_batches + loss.backward() + + # Gradient Accumulation =>> Note: skip any errant batches at the end... + if (train_idx + 1) % accumulate_grad_batches == 0: + xm.optimizer_step(optimizer) # Note call to xm.optimizer_step() -- has implicit mark_step! + optimizer.zero_grad() + + # Add to `step_times` + step_times.append(time.time() - step_start_time) + + # Logging --> Because there is no guarantee processes will be in sync, we need a `closure` + # > Ref: https://pytorch.org/xla/release/1.11/index.html#torch_xla.core.xla_model.add_step_closure + if is_rank_zero and global_step % cfg.tracking.log_frequency == 0: + if cfg.model.arch == "v-mvp": + xm.add_step_closure( + log_vmvp_train_update, + args=( + epoch, + global_step, + run_id, + train_losses, + lrs[-1], + reconstruction_losses, + step_times, + ), + ) + + elif cfg.model.arch == "v-r3m": + xm.add_step_closure( + log_vr3m_train_update, + args=( + epoch, + global_step, + run_id, + train_losses, + lrs[-1], + tcn_losses, + reward_losses, + l1_losses, + l2_losses, + tcn_accuracies, + reward_accuracies, + step_times, + ), + ) + + elif cfg.model.arch == "v-rn3m": + xm.add_step_closure( + log_vrn3m_train_update, + args=( + epoch, + global_step, + run_id, + train_losses, + lrs[-1], + tcn_losses, + reward_losses, + l1_losses, + l2_losses, + tcn_accuracies, + reward_accuracies, + step_times, + ), + ) + + elif cfg.model.arch == "v-cond": + xm.add_step_closure( + log_vcond_train_update, + args=( + epoch, + global_step, + run_id, + train_losses, + lrs[-1], + reconstruction_losses, + step_times, + ), + ) + + elif cfg.model.arch == "v-dual": + xm.add_step_closure( + log_vdual_train_update, + args=( + epoch, + global_step, + run_id, + train_losses, + lrs[-1], + reconstruction_losses, + zero_reconstruction, + k_reconstruction, + step_times, + ), + ) + + elif cfg.model.arch == "v-gen": + xm.add_step_closure( + log_vgen_train_update, + args=( + epoch, + global_step, + run_id, + train_losses, + lrs[-1], + reconstruction_losses, + lm_losses, + lm_ppl, + zero_reconstruction, + k_reconstruction, + step_times, + ), + ) + + else: + raise NotImplementedError(f"Log Update for Model `{cfg.model.arch}` not implemented!") + + # Increment Global Step _after_ logging! + global_step += 1 + + # Save checkpoint subject to *local_step = (train_idx + 1) // accumulate_grad_batches* + saver.save( + epoch=epoch, + is_local_step=True, + model=model, + optimizer=optimizer, + duration=int(time.time() - start_time) + resume_time, + local_step=start_step + ((train_idx + 1) // accumulate_grad_batches), + ) + + # Update LR every `accumulation_steps` iterations... + lrs.append( + update_lr( + epoch, + (start_step + ((train_idx + 1) // accumulate_grad_batches)) + / (len(train_dataset) // cfg.model.effective_bsz), + ) + ) + + # Reset `step_start_time` + step_start_time = time.time() + + # Update `progress` each time we take a gradient step! + progress.update() + + # After each forward pass, mark a step, to compile XLA graph for a single forward pass! + # =>> Note :: this is important, with gradient accumulation, the graph can get massive otherwise! + xm.mark_step() + + else: + # Clear gradients and reset start step (regardless) at end of the loop + optimizer.zero_grad() + start_step = 0 + + # Redundant, but Synchronous Validation Epoch... + xm.master_print("Validating...") + val_losses = [] + with torch.no_grad(): + for batch in tqdm(val_device_loader, disable=not is_rank_zero): + if cfg.model.arch == "v-mvp": + loss, _, _ = model(batch) + elif cfg.model.arch in {"v-r3m", "v-rn3m"}: + imgs, lang, lang_mask = batch + loss, _, _, _, _, _, _ = model(imgs, lang, lang_mask) + elif cfg.model.arch == "v-cond": + img, lang, lang_mask = batch + loss, _, _ = model(img, lang, lang_mask) + elif cfg.model.arch == "v-dual": + imgs, lang, lang_mask = batch + loss, _ = model(imgs, lang, lang_mask) + elif cfg.model.arch == "v-gen": + imgs, lang_con, lang_con_mask, lang_gen, lang_gen_mask, lang_gen_weight = batch + loss, _, _, _ = model(imgs, lang_con, lang_con_mask, lang_gen, lang_gen_mask, lang_gen_weight) + else: + raise NotImplementedError(f"Forward Pass Logic for Model `{cfg.model.arch} not implemented!") + + # Just append to val_losses... + val_losses.append(loss) + + # Compute Val Loss & *mesh reduce* --> Why? :: the XLA people said so! + val_loss = torch.stack(val_losses).mean().item() + val_loss = xm.mesh_reduce("val_loss", val_loss, np.mean) # All replicas should just return the same thing? + + # Logging --> add another `closure` for end-of-epoch cleanup --> compute `duration` as well... + duration = int(time.time() - start_time) + resume_time + if is_rank_zero: + xm.add_step_closure( + log_epoch_end_update, + args=( + cfg.model.arch, + epoch, + global_step, + run_id, + duration, + train_losses, + val_loss, + lrs[-1], + step_times, + ), + ) + + # Save Checkpoint (at end of Epoch) + saver.save( + epoch=epoch + 1, + is_local_step=False, + model=model, + optimizer=optimizer, + duration=duration, + train_loss=train_losses[-1].item(), + val_loss=val_loss, + ) + + # Dump TPU Debugging Metrics... + if is_rank_zero: + with open("tpu-debug-metrics.log", "w") as f: + f.write(met.metrics_report()) + + # Exiting w/ Multiprocessing is a Nightmare... try to join? + xm.master_print("...and that's all, folks!") + xm.rendezvous("Cheers!") + + # Sleep for like 3 minutes... get W&B to finish syncing logs + wandb.finish() + time.sleep(150) + + +def mp_fn(_: int, cfg: PretrainConfig) -> None: + torch.set_default_tensor_type("torch.FloatTensor") + + # Let's Start Pretraining! + xpretrain(cfg) + + +@hydra.main(config_path=None, config_name="config") +def main(cfg: PretrainConfig) -> None: + import torch_xla.distributed.xla_multiprocessing as xmp + + # Call XMP Spawn w/ the Config as the sole argument... + xmp.spawn(mp_fn, args=(cfg,), nprocs=cfg.accelerator.num_accelerators, start_method="spawn") + + +if __name__ == "__main__": + main() diff --git a/pyproject.toml b/pyproject.toml index a7c4335..ca82c64 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,7 @@ authors = [ {name = "Siddharth Karamcheti", email="skaramcheti@cs.stanford.edu"} ] description = "Voltron: Language-Driven Representation Learning for Robotics." -version = "0.0.1" +version = "1.0.0" readme = "README.md" requires-python = ">=3.8" keywords = ["robotics", "representation learning", "natural language processing", "machine learning"] @@ -30,15 +30,19 @@ classifiers = [ dependencies = [ "av", "gdown", + "google-cloud-storage", "einops", "hurry.filesize", "hydra-core==1.1.1", # Lock Hydra =>> future versions break... + "jsonlines", "omegaconf==2.1.2", # Lock OmegaConf =>> future versions break... "opencv-python==4.2.0.32", # Lock OpenCV =>> just in case... + "rich", "torch", "torchvision", "torchaudio", "transformers", + "wandb", ] [project.optional-dependencies] diff --git a/voltron/conf/__init__.py b/voltron/conf/__init__.py index 1d7893d..39c3e78 100644 --- a/voltron/conf/__init__.py +++ b/voltron/conf/__init__.py @@ -1 +1,4 @@ +from .accelerators import AcceleratorConfig from .datasets import DatasetConfig +from .models import ModelConfig +from .tracking import TrackingConfig diff --git a/voltron/conf/accelerators.py b/voltron/conf/accelerators.py new file mode 100644 index 0000000..b4dbc87 --- /dev/null +++ b/voltron/conf/accelerators.py @@ -0,0 +1,52 @@ +""" +accelerator.py + +Base Hydra Structured Configs for defining various accelerator schemes. Uses a simple single inheritance structure. +""" +from dataclasses import dataclass + +from hydra.core.config_store import ConfigStore +from omegaconf import MISSING + + +@dataclass +class AcceleratorConfig: + accelerator: str = MISSING + num_accelerators: int = MISSING + num_workers: int = MISSING + + +@dataclass +class TPUv2OneConfig(AcceleratorConfig): + accelerator = "tpu" + num_accelerators = 1 + num_workers = 4 + + +@dataclass +class TPUv2EightConfig(AcceleratorConfig): + accelerator = "tpu" + num_accelerators = 8 + num_workers = 4 + + +@dataclass +class TPUv3OneConfig(AcceleratorConfig): + accelerator = "tpu" + num_accelerators = 1 + num_workers = 8 + + +@dataclass +class TPUv3EightConfig(AcceleratorConfig): + accelerator = "tpu" + num_accelerators = 8 + num_workers = 8 + + +# Create a configuration group `accelerator` and populate with the above... +cs = ConfigStore.instance() +cs.store(group="accelerator", name="tpu-v2-1", node=TPUv2OneConfig) +cs.store(group="accelerator", name="tpu-v2-8", node=TPUv2EightConfig) +cs.store(group="accelerator", name="tpu-v3-1", node=TPUv3OneConfig) +cs.store(group="accelerator", name="tpu-v3-8", node=TPUv3EightConfig) diff --git a/voltron/conf/datasets.py b/voltron/conf/datasets.py index 2aeebe4..49ecdbb 100644 --- a/voltron/conf/datasets.py +++ b/voltron/conf/datasets.py @@ -16,7 +16,11 @@ class DatasetConfig: name: str = MISSING path: str = MISSING - artifact_path: str = to_absolute_path("/mnt/home") + artifact_path: str = to_absolute_path("data/processed/sth-sth-v2") + + # Streaming Parameters (assumes fully preprocessed dataset lives at `stream_prefix/...`) + stream: bool = True + stream_prefix: str = "data/processed" # Dataset-Specific Parameters resolution: int = 224 diff --git a/voltron/conf/models.py b/voltron/conf/models.py new file mode 100644 index 0000000..da8625c --- /dev/null +++ b/voltron/conf/models.py @@ -0,0 +1,435 @@ +""" +models.py + +Base Hydra Structured Config for defining various pretraining model architectures and appropriate configurations. Uses a +simple single inheritance structure. +""" +from dataclasses import dataclass +from typing import Tuple + +from hydra.core.config_store import ConfigStore +from hydra.utils import to_absolute_path +from omegaconf import MISSING + + +@dataclass +class ModelConfig: + arch: str = MISSING + identifier: str = MISSING + + # Dataset Modality + data_modality: str = MISSING + + # Default Vision Transformer Configuration + patch_size: int = 16 + mlp_ratio: float = 4.0 + + # Number of examples one can safely fit on an accelerator w/ this model! + device_bsz: int = MISSING + + +# @Data-Locked Reproductions --- Encompasses MVP (MAE) + R3M + + +# MVP (Base Masked Autoencoder) +@dataclass +class MVPConfig(ModelConfig): + arch: str = "v-mvp" + identifier: str = MISSING + + # Dataset Modality + data_modality: str = "state" + + # Base MAE Parameters + mask_ratio: float = 0.75 + + # Architecture Parameters + encoder_depth: int = MISSING + encoder_embed_dim: int = MISSING + encoder_n_heads: int = MISSING + + decoder_depth: int = MISSING + decoder_embed_dim: int = MISSING + decoder_n_heads: int = MISSING + + # MAE Loss/Objective Configuration + norm_pixel_loss: bool = True + effective_bsz: int = 1024 + device_bsz: int = MISSING + + # Optimization Parameters + optimizer: str = "adamw" + schedule: str = "linear-warmup+cosine-decay" + base_lr: float = 1.5e-4 + min_lr: float = 0.0 + betas: Tuple[float, float] = (0.9, 0.95) + weight_decay: float = 0.05 + + +@dataclass +class MVPSmallConfig(MVPConfig): + identifier = "r-mvp" + + # Architecture Parameters -- should match ViT Small Architecture to the letter! + # Note: Small is defined in TIMM: + # > https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L683 + encoder_depth = 12 + encoder_embed_dim = 384 + encoder_n_heads = 6 + + decoder_depth = 6 + decoder_embed_dim = 192 + decoder_n_heads = 6 + + # Number of examples one can safely fit on an accelerator w/ this model! + # > TPU-v3: max of 128 per device. + device_bsz = 128 + + +# R3M Models --> Just different visual encoders, roughly following the above! +@dataclass +class R3MConfig(ModelConfig): + arch: str = "v-r3m" + identifier: str = MISSING + + # Dataset Modality + data_modality: str = "quintet+language" + + # ViT Architecture Parameters + depth: int = MISSING + embed_dim: int = MISSING + n_heads: int = MISSING + + # Effective Batch Size + effective_bsz: int = 1024 + + # Language Model Parameters + language_model: str = "distilbert-base-uncased" + hf_cache: str = to_absolute_path("data/hf-cache") + language_dim: int = 768 + vocab_size: int = 30522 + reward_dim: int = 1024 + + # Loss/Objective Configuration + lang_reward_weight: float = 1.0 + tcn_weight: float = 1.0 + l1_weight: float = 1e-5 + l2_weight: float = 1e-5 + n_negatives: int = 3 + device_bsz: int = MISSING + + # Optimization Parameters + optimizer: str = "adam" + schedule: str = "linear-warmup+cosine-decay" + lr: float = 1e-4 + min_lr: float = 0.0 + + +@dataclass +class R3MSmallConfig(R3MConfig): + identifier = "r-r3m-vit" + + # Architecture Parameters -- should match ViT Small Architecture to the letter! + # Note: Small is defined in TIMM: + # > https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L683 + depth = 12 + embed_dim = 384 + n_heads = 6 + + # Device Batch Size + device_bsz = 32 + + +# R3M -- ResNet50 Encoder (instead of ViT) +@dataclass +class ResNet3MConfig(ModelConfig): + arch: str = "v-rn3m" + identifier: str = MISSING + + # Dataset Modality + data_modality: str = "quintet+language" + + # Effective Batch Size + effective_bsz: int = 1024 + + # Architecture Parameters + fc_dim: int = MISSING + + # Language Model Parameters + language_model: str = "distilbert-base-uncased" + hf_cache: str = to_absolute_path("data/hf-cache") + language_dim: int = 768 + vocab_size: int = 30522 + reward_dim: int = 1024 + + # Loss/Objective Configuration + lang_reward_weight: float = 1.0 + tcn_weight: float = 1.0 + l1_weight: float = 1e-5 + l2_weight: float = 1e-5 + n_negatives: int = 3 + device_bsz: int = MISSING + + # Optimization Parameters + optimizer: str = "adam" + lr: float = 1e-4 + + +class RN3M50Config(ResNet3MConfig): + identifier = "r-r3m-rn50" + + # Architecture Parameters + fc_dim = 2048 + + # Device Batch Size + device_bsz = 32 + + +# @Voltron Models -- VCond, VDual, VGen + + +# VCond -- Single Frame + Language Conditioning +@dataclass +class VCondConfig(ModelConfig): + arch: str = "v-cond" + identifier: str = MISSING + + # Dataset Modality + data_modality: str = "state+language" + + # Base MAE Parameters + mask_ratio: float = 0.75 + + # Base Language Parameters --> full sentence dropout only... + language_model: str = "distilbert-base-uncased" + hf_cache: str = to_absolute_path("data/hf-cache") + language_dim: int = 768 + vocab_size: int = 30522 + lang_dropout: float = MISSING + + # Architecture Parameters + encoder_depth: int = MISSING + encoder_embed_dim: int = MISSING + encoder_n_heads: int = MISSING + + decoder_depth: int = MISSING + decoder_embed_dim: int = MISSING + decoder_n_heads: int = MISSING + + # MAE Loss/Objective Configuration + norm_pixel_loss: bool = True + effective_bsz: int = 1024 + device_bsz: int = MISSING + + # Optimization Parameters + optimizer: str = "adamw" + schedule: str = "linear-warmup+cosine-decay" + base_lr: float = 1.5e-4 + min_lr: float = 0.0 + betas: Tuple[float, float] = (0.9, 0.95) + weight_decay: float = 0.05 + + +@dataclass +class VCondSmallConfig(VCondConfig): + identifier = "v-cond" + + # No language dropout... + lang_dropout = 0.0 + + # Architecture Parameters -- should match ViT Small Architecture to the letter! + # Note: Small is defined in TIMM: + # > https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L683 + encoder_depth = 12 + encoder_embed_dim = 384 + encoder_n_heads = 6 + + decoder_depth = 6 + decoder_embed_dim = 192 + decoder_n_heads = 6 + + # Number of examples one can safely fit on an accelerator w/ this model! + # > TPU-v3: max of 128 per device + device_bsz = 128 + + +@dataclass +class VCondBaseConfig(VCondConfig): + identifier = "v-cond-base" + + # No language dropout... + lang_dropout = 0.0 + + # Architecture Parameters -- should match ViT Base Architecture to the letter! + # Note: Base is defined in TIMM & Original MAE Repository: + # > https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L723 + # > https://github.com/facebookresearch/mae/blob/main/models_mae.py#L223 + encoder_depth = 12 + encoder_embed_dim = 768 + encoder_n_heads = 12 + + decoder_depth = 8 + decoder_embed_dim = 512 + decoder_n_heads = 16 + + # Number of examples one can safely fit on an accelerator w/ this model! + # > TPU-v3: max of 128 per device! + device_bsz = 128 + + +# VDual - Dual Frame (0th Frame + Kth frame) + Language Conditioning +@dataclass +class VDualConfig(ModelConfig): + arch: str = "v-dual" + identifier: str = MISSING + + # Dataset Modality + data_modality: str = "state+ok" + + # Base MAE Parameters + mae_weight: float = 1.0 + mask_ratio: float = 0.75 + + # Base Language Parameters --> full sentence dropout only... + language_model: str = "distilbert-base-uncased" + hf_cache: str = to_absolute_path("data/hf-cache") + language_dim: int = 768 + vocab_size: int = 30522 + lang_dropout: float = MISSING + + # Architecture Parameters + encoder_depth: int = MISSING + encoder_embed_dim: int = MISSING + encoder_n_heads: int = MISSING + + decoder_depth: int = MISSING + decoder_embed_dim: int = MISSING + decoder_n_heads: int = MISSING + + # MAE Loss/Objective Configuration -- Cut effective batch size since we see 12-25x contexts per batch example! + norm_pixel_loss: bool = True + effective_bsz: int = 1024 + device_bsz: int = MISSING + + # Optimization Parameters + optimizer: str = "adamw" + schedule: str = "linear-warmup+cosine-decay" + base_lr: float = 1.5e-4 + min_lr: float = 0.0 + betas: Tuple[float, float] = (0.9, 0.95) + weight_decay: float = 0.05 + + +@dataclass +class VDualSmallConfig(VDualConfig): + identifier = "v-dual" + + # No language dropout... + lang_dropout = 0.0 + + # Architecture Parameters -- should match ViT Small Architecture to the letter! + # Note: Small is defined in TIMM: + # > https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L683 + encoder_depth = 12 + encoder_embed_dim = 384 + encoder_n_heads = 6 + + decoder_depth = 6 + decoder_embed_dim = 192 + decoder_n_heads = 6 + + # Number of examples one can safely fit on an accelerator w/ this model! + # > TPU-v3: max of 128 per device! + device_bsz = 128 + + +# VGen - Dual Frame with Language Conditioning AND Language Generation +@dataclass +class VGenConfig(ModelConfig): + arch: str = "v-gen" + identifier: str = MISSING + + # Dataset Modality + data_modality: str = "state+ok" + + # Base MAE & LM Parameters --> LM Weight is set such that mae & lang loss are ~same order of magnitude + mae_weight: float = 1.0 + lm_weight: float = 0.5 + mask_ratio: float = 0.75 + gen_ratio: float = MISSING + + # Base Language Parameters + language_model: str = "distilbert-base-uncased" + hf_cache: str = to_absolute_path("data/hf-cache") + language_dim: int = 768 + vocab_size: int = 30522 + + # Architecture Parameters + encoder_depth: int = MISSING + encoder_embed_dim: int = MISSING + encoder_n_heads: int = MISSING + + decoder_depth: int = MISSING + decoder_embed_dim: int = MISSING + decoder_n_heads: int = MISSING + + # MAE Loss/Objective Configuration -- Cut effective batch size since we see 12-25x contexts per batch example! + norm_pixel_loss: bool = True + effective_bsz: int = 1024 + device_bsz: int = MISSING + + # Optimization Parameters + optimizer: str = "adamw" + schedule: str = "linear-warmup+cosine-decay" + base_lr: float = 1.5e-4 + min_lr: float = 0.0 + betas: Tuple[float, float] = (0.9, 0.95) + weight_decay: float = 0.05 + + +@dataclass +class VGen50SmallConfig(VGenConfig): + identifier = "v-gen" + + # LM Parameters --> control % of examples that are for "language generation" (no conditioning) + gen_ratio = 0.50 + + # Architecture Parameters -- should match ViT Small Architecture to the letter! + # Note: Small is defined in TIMM: + # > https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L683 + encoder_depth = 12 + encoder_embed_dim = 384 + encoder_n_heads = 6 + + decoder_depth = 6 + decoder_embed_dim = 192 + decoder_n_heads = 6 + + # Number of examples one can safely fit on an accelerator w/ this model! + # > TPU-v3: max of 64 per device! + device_bsz = 64 + + +# Create a configuration group `model` and populate with the above... +cs = ConfigStore.instance() + +# === @Data-Locked Reproductions === + +# Image-Only MAE/MVP Architectures +cs.store(group="model", name="r-mvp", node=MVPSmallConfig) + +# R3M Architectures - ViT & ResNet50 +cs.store(group="model", name="r-r3m-vit", node=R3MSmallConfig) +cs.store(group="model", name="r-r3m-rn50", node=RN3M50Config) + +# === @Voltron === + +# VCond Architectures +cs.store(group="model", name="v-cond", node=VCondSmallConfig) +cs.store(group="model", name="v-cond-base", node=VCondBaseConfig) + +# VDual +cs.store(group="model", name="v-dual", node=VDualSmallConfig) + +# VGen +cs.store(group="model", name="v-gen", node=VGen50SmallConfig) diff --git a/voltron/conf/tracking.py b/voltron/conf/tracking.py new file mode 100644 index 0000000..39f317d --- /dev/null +++ b/voltron/conf/tracking.py @@ -0,0 +1,44 @@ +""" +tracking.py + +Base Hydra Structured Config for defining various run & experiment tracking configurations, e.g., via Weights & Biases. +Uses a simple single inheritance structure. +""" +from dataclasses import dataclass +from typing import List, Optional, Tuple + +from hydra.core.config_store import ConfigStore +from omegaconf import MISSING + + +@dataclass +class TrackingConfig: + # Generic Logging Frequency --> Matters more for XLA/TPUs... set this to be as large as you can stomach! + log_frequency: int = 100 + + # Checkpointing Strategy --> Save each epoch, keep most recent `idx[0]` checkpoints & *every* `idx[1]` checkpoints + # Additionally, save (locally) a checkpoint every `idx[2]` steps for the current epoch (-1). + checkpoint_strategy: Tuple[int, int, int] = (1, 1, 1500) + + # Weights & Biases Setup + project: str = "voltron-pretraining" + entity: str = "voltron-robotics" + + # Notes & Tags are at the discretion of the user... see below + notes: str = MISSING + tags: Optional[List[str]] = None + + # Directory to save W&B Metadata & Logs in General -- if None, defaults to `logs/` in the Hydra CWD + directory: Optional[str] = None + + +@dataclass +class VoltronTrackingConfig(TrackingConfig): + # Note: I really like using notes to keep track of things, so will crash unless specified with run. + # > For `tags` I like to populate based on other args in the script, so letting it remain None + notes: str = MISSING + + +# Create a configuration group `trackers` and populate with the above... +cs = ConfigStore.instance() +cs.store(group="tracking", name="voltron-tracking", node=VoltronTrackingConfig) diff --git a/voltron/preprocessing/stream_datasets.py b/voltron/preprocessing/stream_datasets.py new file mode 100644 index 0000000..7f1e05e --- /dev/null +++ b/voltron/preprocessing/stream_datasets.py @@ -0,0 +1,833 @@ +""" +stream_datasets.py + +Core PyTorch Datasets for the various "flavors" of data used by the various models under study. Crucially, each dataset +loads from the corresponding "batch" serialized files, that define the exact data to use. + +Notably, these serialized files control exactly what data is seen by *all* methods **across epochs.** Using them is +fairly critical to reproducibility & fair comparison. + +This specific file contains logic for a "streaming" DataLoader; data is fetched (within the dataloader, by each +worker) via an open connection over the network to a GCS bucket, materializing data as raw BytesIO objects fed to +PIL.Image constructors. +""" +import json +import os +from io import BytesIO +from pathlib import Path +from typing import Any, Optional, Tuple + +import numpy as np +import torch +from google.api_core.exceptions import NotFound +from google.auth.exceptions import TransportError +from google.cloud import storage +from google.resumable_media._helpers import _LOGGER +from PIL import Image, UnidentifiedImageError +from torch.utils.data import Dataset, get_worker_info +from torchvision.io import read_image +from torchvision.transforms import Compose +from torchvision.transforms.functional import pil_to_tensor + +from voltron.preprocessing.transforms import get_online_transform +from voltron.util.distributed import get_rank + +# NOTE --> IF STREAMING JPEGS, WE NEED TO USE PILLOW TO READ FILES (w/o extracting locally...) +# =>> Instead of `read_image(file)` assume we have "fname" and open fileobj (as BytesIO) -- remember to `seek(0)` +# +# > from PIL import Image +# > from torchvision.transforms.functional import pil_to_tensor +# > tensor = pil_to_tensor(Image.open(fileobj) +# |--> This returns a `torch.uint8` Tensor of shape [3, 224, 224] --> *verified* equivalent to `read_image` + +# Create Global GCS Client... +# =>> Multiprocessing.spawn() does not inherit from base shell --> need to set service account key... +# =>> TODO :: Figure out how to fetch num_accelerators & num_workers programatically... +os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = "/mnt/home/auth/gcp-auth.json" +N_CORES, BUCKETS = 8, [storage.Client().bucket("voltron-ANONYMIZED") for _ in range(8 * 8)] + +# Suppress Google Cloud Loggers +_LOGGER.propagate = False +storage.blob._logger.propagate = False + + +class PretrainDataset(Dataset): + def set_epoch(self, epoch: int) -> None: + self.epoch = epoch + + +class StateDataset(PretrainDataset): + def __init__( + self, + epoch: int, + index_path: Path, + img_transform: Compose, + stream: bool = False, + prefix: Optional[Path] = None, + is_val: bool = False, + do_retry: bool = True, + n_retries: int = 3, + ) -> None: + super().__init__() + self.index_path, self.stream, self.is_val, self.val_loaded = index_path, stream, is_val, False + self.epoch, self.transform, self.elements, self.prefix = epoch, img_transform, None, prefix + self.r = N_CORES * get_rank() + self.do_retry, self.n_retries = do_retry, n_retries + + # === Retrieve Epoch Batches (only call `set_epoch` inside __init__() === + self.set_epoch(self.epoch) + + def set_epoch(self, epoch: int) -> None: + # Not Streaming --> Read from local disk... + if not self.stream: + if self.is_val and not self.val_loaded: + with open(self.index_path / "state" / "validation-batches.json", "r") as f: + self.elements = json.load(f) + + # Set `val_loaded` and move on... + self.val_loaded = True + + elif not self.is_val: + with open(self.index_path / "state" / f"train-epoch={epoch}-batches.json", "r") as f: + self.elements = json.load(f) + + # Streaming --> Beam directly from Bucket + else: + if self.is_val and not self.val_loaded: + blob = BUCKETS[self.r].blob(str(self.prefix / "index" / "state" / "validation-batches.json")) + self.elements = json.loads(blob.download_as_string()) + + # `elements[i]["state"] currently maps to disk path... remove all but `parent/child.jpg` + for element in self.elements: + element["state"] = "/".join(element["state"].split("/")[-2:]) + + # Set `val_loaded` and move on... + self.val_loaded = True + + elif not self.is_val: + blob = BUCKETS[self.r].blob(str(self.prefix / "index" / "state" / f"train-epoch={epoch}-batches.json")) + self.elements = json.loads(blob.download_as_string()) + + # `elements[i]["state"]` currently maps to disk path... remove all but `parent/child.jpg` + for element in self.elements: + element["state"] = "/".join(element["state"].split("/")[-2:]) + + def __getitem__(self, index: int) -> torch.Tensor: + """Return single frame as torch Tensor.""" + if not self.stream: + return self.transform(read_image(self.elements[index]["state"])) + else: + # Multiplex w/ num_worker idx... + worker_info = get_worker_info() + r = (self.r + worker_info.id) if worker_info is not None else self.r + + # Streaming + Retry Logic (in case of a bad connection -- retry same file!) + frame_path = self.elements[index]["state"] + for _i in range(self.n_retries): + try: + # Stream JPEG contents into BytesIO (seek back to 0), then into PIL Image.open() + if self.is_val: + blob, fobj = BUCKETS[r].blob(str(self.prefix / "val" / frame_path)), BytesIO() + else: + blob, fobj = BUCKETS[r].blob(str(self.prefix / "train" / frame_path)), BytesIO() + + # Error Handling --> we've already run several dry-run/verification trials, but about ~0.004% of + # the time, we'll hit some sort of TCP/Transport error; this might even go up + # with multiple runs happening at the same time. + # + # To address this, we're adopting the simplest possible "retry" strategy that + # immediately tries to re-download the same file (and crashes if not possible). + # This ensures reproducibility, but puts some extra effort onto the user... + + # File download... + blob.download_to_file(fobj) + fobj.seek(0) + + # Image loading... + img_tensor = pil_to_tensor(Image.open(fobj)) + + # Return transformed image... + return self.transform(img_tensor) + + except (NotFound, TransportError, UnidentifiedImageError, OSError) as e: + # At the minimum --> print the broken file (obnoxiously!) + print(f"=>> BROKEN FILE :: {frame_path}") + if not self.do_retry: + raise e + else: + continue + + # If we've exhausted our retries --> raise an informative ValueError + raise ValueError(f"Failed to fix state `{self.elements[index]['state']}` w/ {self.n_retries} retries...") + + def __len__(self) -> int: + return len(self.elements) + + +class StateLanguageDataset(PretrainDataset): + def __init__( + self, + epoch: int, + index_path: Path, + img_transform: Compose, + lang_dropout: Optional[float] = None, + stream: bool = False, + prefix: Optional[Path] = None, + is_val: bool = False, + do_retry: bool = True, + n_retries: int = 3, + ) -> None: + super().__init__() + self.index_path, self.stream, self.is_val, self.val_loaded = index_path, stream, is_val, False + self.epoch, self.transform, self.elements, self.prefix = epoch, img_transform, None, prefix + self.lang_dropout = 0.0 if (lang_dropout is None or lang_dropout == 0) else lang_dropout + self.dropout_indices = set() + self.r = N_CORES * get_rank() + self.do_retry, self.n_retries = do_retry, n_retries + + # Load Language Index & Retrieve Epoch 0 Batches + language_path = "val-language-index.json" if self.is_val else "train-language-index.json" + if not self.stream: + with open(self.index_path / language_path, "r") as f: + self.language_index = json.load(f) + else: + blob = BUCKETS[self.r].blob(str(self.prefix / "index" / language_path)) + self.language_index = json.loads(blob.download_as_string()) + + # === Retrieve Epoch Batches (only call `set_epoch` inside __init__() === + self.set_epoch(self.epoch) + + def set_epoch(self, epoch: int) -> None: + # Not Streaming --> Read from local disk... + if not self.stream: + if self.is_val and not self.val_loaded: + with open(self.index_path / "state+language" / "validation-batches.json", "r") as f: + self.elements = json.load(f) + + # Assemble the set of dropout indices for the given epoch... + n_drop = int(self.lang_dropout * len(self.elements)) + self.dropout_indices = set(np.random.choice(len(self.elements), n_drop, replace=False)) + + # Set `val_loaded` and move on... + self.val_loaded = True + + elif not self.is_val: + with open(self.index_path / "state+language" / f"train-epoch={epoch}-batches.json", "r") as f: + self.elements = json.load(f) + + # Assemble the set of dropout indices for the given epoch... + n_drop = int(self.lang_dropout * len(self.elements)) + self.dropout_indices = set(np.random.choice(len(self.elements), n_drop, replace=False)) + + # Streaming --> Beam directly from Bucket + else: + if self.is_val and not self.val_loaded: + blob = BUCKETS[self.r].blob(str(self.prefix / "index" / "state+language" / "validation-batches.json")) + self.elements = json.loads(blob.download_as_string()) + + # `elements[i]["state"]` currently maps to disk path... remove all but `parent/child.jpg` + for element in self.elements: + element["state"] = "/".join(element["state"].split("/")[-2:]) + + # Assemble the set of dropout indices for the given epoch... + n_drop = int(self.lang_dropout * len(self.elements)) + self.dropout_indices = set(np.random.choice(len(self.elements), n_drop, replace=False)) + + # Set `val_loaded` and move on... + self.val_loaded = True + + elif not self.is_val: + blob = BUCKETS[self.r].blob( + str(self.prefix / "index" / "state+language" / f"train-epoch={epoch}-batches.json") + ) + self.elements = json.loads(blob.download_as_string()) + + # `elements[i]["state"] currently maps to disk path... remove all but `parent/child.jpg` + for element in self.elements: + element["state"] = "/".join(element["state"].split("/")[-2:]) + + # Assemble the set of dropout indices for the given epoch... + n_drop = int(self.lang_dropout * len(self.elements)) + self.dropout_indices = set(np.random.choice(len(self.elements), n_drop, replace=False)) + + def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Return the frame and language, decomposed as the input_ids, and attention_mask.""" + vid = self.elements[index]["vid"] + lang = torch.tensor(self.language_index[vid]["input_ids"], dtype=torch.int64) + lang_mask = torch.tensor(self.language_index[vid]["attention_mask"], dtype=torch.int64) + + # Dropout Language Check --> (Naive Zeroing leads to NaN --> just want the "CLS" token) + if index in self.dropout_indices: + # Initial language token is *always* = `101` --> last token always = `102` + lang[1:] *= 0 + lang_mask[1:] *= 0 + + # Retrieve Single Frame + if not self.stream: + img = self.transform(read_image(self.elements[index]["state"])) + return img, lang, lang_mask + else: + # Multiplex w/ num_worker idx... + worker_info = get_worker_info() + r = (self.r + worker_info.id) if worker_info is not None else self.r + + # Streaming + Retry Logic (in case of a bad connection -- retry same file!) + frame_path = self.elements[index]["state"] + for _i in range(self.n_retries): + try: + # Stream JPEG contents into BytesIO (seek back to 0), then into PIL Image.open() + if self.is_val: + blob, fobj = BUCKETS[r].blob(str(self.prefix / "val" / frame_path)), BytesIO() + else: + blob, fobj = BUCKETS[r].blob(str(self.prefix / "train" / frame_path)), BytesIO() + + # Error Handling --> we've already run several dry-run/verification trials, but about ~0.004% of + # the time, we'll hit some sort of TCP/Transport error; this might even go up + # with multiple runs happening at the same time. + # + # To address this, we're adopting the simplest possible "retry" strategy that + # immediately tries to re-download the same file (and crashes if not possible). + # This ensures reproducibility, but puts some extra effort onto the user... + + # File download... + blob.download_to_file(fobj) + fobj.seek(0) + + # Image loading... + img_tensor = pil_to_tensor(Image.open(fobj)) + + # Assemble transformed image and return... + img = self.transform(img_tensor) + return img, lang, lang_mask + + except (NotFound, TransportError, UnidentifiedImageError, OSError) as e: + # At the minimum --> print the broken file (obnoxiously!) + print(f"=>> BROKEN FILE :: {frame_path}") + if not self.do_retry: + raise e + else: + continue + + # If we've exhausted our retries --> raise an informative ValueError + raise ValueError(f"Failed to fix state `{self.elements[index]['state']}` w/ {self.n_retries} retries...") + + def __len__(self) -> int: + return len(self.elements) + + +class StateOKDataset(PretrainDataset): + def __init__( + self, + epoch: int, + index_path: Path, + img_transform: Compose, + lang_dropout: Optional[float] = None, + stream: bool = False, + prefix: Optional[Path] = None, + no_lang: bool = False, + is_val: bool = False, + do_retry: bool = True, + n_retries: int = 3, + ) -> None: + super().__init__() + self.index_path, self.stream, self.is_val, self.val_loaded = index_path, stream, is_val, False + self.epoch, self.transform, self.elements, self.prefix = epoch, img_transform, None, prefix + self.no_lang, self.lang_dropout = no_lang, 0.0 if (lang_dropout is None or lang_dropout == 0) else lang_dropout + self.dropout_indices = set() + self.r = N_CORES * get_rank() + self.do_retry, self.n_retries = do_retry, n_retries + + # Load Language Index & Retrieve Epoch 0 Batches + if not self.no_lang: + language_path = "val-language-index.json" if self.is_val else "train-language-index.json" + if not self.stream: + with open(self.index_path / language_path, "r") as f: + self.language_index = json.load(f) + else: + blob = BUCKETS[self.r].blob(str(self.prefix / "index" / language_path)) + self.language_index = json.loads(blob.download_as_string()) + + # === Retrieve Epoch Batches (only call `set_epoch` inside __init__() === + self.set_epoch(self.epoch) + + def set_epoch(self, epoch: int) -> None: + # Not Streaming --> Read from local disk... + if not self.stream: + if self.is_val and not self.val_loaded: + with open(self.index_path / "state+ok" / "validation-batches.json", "r") as f: + self.elements = json.load(f) + + # Assemble the set of dropout indices for the given epoch... + n_drop = int(self.lang_dropout * len(self.elements)) + self.dropout_indices = set(np.random.choice(len(self.elements), n_drop, replace=False)) + + # Set `val_loaded` and move on... + self.val_loaded = True + + elif not self.is_val: + with open(self.index_path / "state + ok" / f"train-epoch={epoch}-batches.json", "r") as f: + self.elements = json.load(f) + + # Assemble the set of dropout indices for the given epoch... + n_drop = int(self.lang_dropout * len(self.elements)) + self.dropout_indices = set(np.random.choice(len(self.elements), n_drop, replace=False)) + + # Streaming --> Beam directly from Bucket + else: + if self.is_val and not self.val_loaded: + blob = BUCKETS[self.r].blob(str(self.prefix / "index" / "state+ok" / "validation-batches.json")) + self.elements = json.loads(blob.download_as_string()) + + # `elements[i]["states"]` currently maps to disk path.. remove all but `parent/child.jpg` + for element in self.elements: + element["states"] = ["/".join(x.split("/")[-2:]) for x in element["states"]] + + # Assemble the set of dropout indices for the given epoch... + n_drop = int(self.lang_dropout * len(self.elements)) + self.dropout_indices = set(np.random.choice(len(self.elements), n_drop, replace=False)) + + # Set `val_loaded` and move on... + self.val_loaded = True + + elif not self.is_val: + blob = BUCKETS[self.r].blob( + str(self.prefix / "index" / "state+ok" / f"train-epoch={epoch}-batches.json") + ) + self.elements = json.loads(blob.download_as_string()) + + # `elements[i]["states"]` currently maps to disk path.. remove all but `parent/child.jpg` + for element in self.elements: + element["states"] = ["/".join(x.split("/")[-2:]) for x in element["states"]] + + # Assemble the set of dropout indices for the given epoch... + n_drop = int(self.lang_dropout * len(self.elements)) + self.dropout_indices = set(np.random.choice(len(self.elements), n_drop, replace=False)) + + # ruff: noqa: C901 + def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Return both states/frames and language, decomposed as the input_ids and attention_mask.""" + vid = self.elements[index]["vid"] + + # Fetch language if desired... + if not self.no_lang: + lang = torch.tensor(self.language_index[vid]["input_ids"], dtype=torch.int64) + lang_mask = torch.tensor(self.language_index[vid]["attention_mask"], dtype=torch.int64) + + # Dropout Language Check --> (Naive Zeroing leads to NaN --> just want the "CLS" token) + if index in self.dropout_indices: + # Initial language token is *always* = `101` --> last token always = `102` + lang[1:] *= 0 + lang_mask[1:] *= 0 + + # Retrieve Frames + if not self.stream: + imgs = self.elements[index]["states"] + imgs = torch.stack([self.transform(read_image(s)) for s in imgs]) + + # Return --> based on `self.no_lang` + if not self.no_lang: + return imgs, lang, lang_mask + else: + return imgs + + else: + # Multiplex w/ num_worker idx... + worker_info = get_worker_info() + r = (self.r + worker_info.id) if worker_info is not None else self.r + + # Streaming + Retry Logic (in case of a bad connection -- retry same files!) + frame_paths, current_frame = list(self.elements[index]["states"]), None + for _i in range(self.n_retries): + try: + # Stream JPEG contents into BytesIO (seek back to 0), then PIL Image.open() + imgs = [] + for _current_idx, current_frame in enumerate(frame_paths): + if self.is_val: + blob, fobj = BUCKETS[r].blob(str(self.prefix / "val" / current_frame)), BytesIO() + else: + blob, fobj = BUCKETS[r].blob(str(self.prefix / "train" / current_frame)), BytesIO() + + # Error Handling --> we've already run several dry-run/verification trials, but about ~0.004% of + # the time, we'll hit some sort of TCP/Transport error; this might even go up + # with multiple runs happening at the same time. + # + # To address this, we're adopting the simplest possible "retry" strategy that + # immediately tries to re-download the same file (crashes if not possible). + # This ensures reproducibility, but puts some extra effort onto the user... + + # File download... + blob.download_to_file(fobj) + fobj.seek(0) + + # Image loading... + img_tensor = pil_to_tensor(Image.open(fobj)) + imgs.append(self.transform(img_tensor)) + + # Stack... + assert len(imgs) == 2, "Something went awry with try/except in StateOK Dataset..." + imgs = torch.stack(imgs) + + # Return --> based on `self.no_lang` + if not self.no_lang: + return imgs, lang, lang_mask + else: + return imgs + + except (NotFound, TransportError, UnidentifiedImageError, OSError) as e: + # At the minimum --> print the broken file (obnoxiously!) + print(f"=>> BROKEN FILE :: {current_frame}") + if not self.do_retry: + raise e + else: + continue + + # If we've exhausted our retries --> raise an informative ValueError + raise ValueError(f"Failed to fix states `{self.elements[index]['states']}` w/ {self.n_retries} retries...") + + def __len__(self) -> int: + return len(self.elements) + + +class GenStateOKDataset(PretrainDataset): + def __init__( + self, + epoch: int, + index_path: Path, + img_transform: Compose, + gen_ratio: float, + stream: bool = False, + prefix: Optional[Path] = None, + is_val: bool = False, + do_retry: bool = True, + n_retries: int = 3, + ) -> None: + super().__init__() + self.index_path, self.stream, self.is_val, self.val_loaded = index_path, stream, is_val, False + self.epoch, self.transform, self.elements, self.prefix = epoch, img_transform, None, prefix + self.gen_ratio, self.gen_indices = gen_ratio, set() + self.r = N_CORES * get_rank() + self.do_retry, self.n_retries = do_retry, n_retries + + # Load Language Index & Retrieve Epoch 0 Batches + language_path = "val-language-index.json" if self.is_val else "train-language-index.json" + if not self.stream: + with open(self.index_path / language_path, "r") as f: + self.language_index = json.load(f) + else: + blob = BUCKETS[self.r].blob(str(self.prefix / "index" / language_path)) + self.language_index = json.loads(blob.download_as_string()) + + # === Retrieve Epoch Batches (only call `set_epoch` inside __init__() === + self.set_epoch(self.epoch) + + def set_epoch(self, epoch: int) -> None: + # Not Streaming --> Read from local disk... + if not self.stream: + if self.is_val and not self.val_loaded: + with open(self.index_path / "state+ok" / "validation-batches.json", "r") as f: + self.elements = json.load(f) + + # Assemble the set of dropout indices for the given epoch... + n_gen = int(self.gen_ratio * len(self.elements)) + self.gen_indices = set(np.random.choice(len(self.elements), n_gen, replace=False)) + + # Set `val_loaded` and move on... + self.val_loaded = True + + elif not self.is_val: + with open(self.index_path / "state+ok" / f"train-epoch={epoch}-batches.json", "r") as f: + self.elements = json.load(f) + + # Assemble the set of dropout indices for the given epoch... + n_gen = int(self.gen_ratio * len(self.elements)) + self.gen_indices = set(np.random.choice(len(self.elements), n_gen, replace=False)) + + # Streaming --> Beam directly from Bucket + else: + if self.is_val and not self.val_loaded: + blob = BUCKETS[self.r].blob(str(self.prefix / "index" / "state+ok" / "validation-batches.json")) + self.elements = json.loads(blob.download_as_string()) + + # `elements[i]["states"]` currently maps to disk path.. remove all but `parent/child.jpg` + for element in self.elements: + element["states"] = ["/".join(x.split("/")[-2:]) for x in element["states"]] + + # Assemble the set of dropout indices for the given epoch... + n_gen = int(self.gen_ratio * len(self.elements)) + self.gen_indices = set(np.random.choice(len(self.elements), n_gen, replace=False)) + + # Set `val_loaded` and move on... + self.val_loaded = True + + elif not self.is_val: + blob = BUCKETS[self.r].blob( + str(self.prefix / "index" / "state+ok" / f"train-epoch={epoch}-batches.json") + ) + self.elements = json.loads(blob.download_as_string()) + + # `elements[i]["states"]` currently maps to disk path.. remove all but `parent/child.jpg` + for element in self.elements: + element["states"] = ["/".join(x.split("/")[-2:]) for x in element["states"]] + + # Assemble the set of dropout indices for the given epoch... + n_gen = int(self.gen_ratio * len(self.elements)) + self.gen_indices = set(np.random.choice(len(self.elements), n_gen, replace=False)) + + def __getitem__( + self, index: int + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, float]: + """Return both states/frames, language to condition on, language to generate, decomposed as input_ids/mask.""" + vid = self.elements[index]["vid"] + + # Fetch language to condition on / generate... + lang_condition = torch.tensor(self.language_index[vid]["input_ids"], dtype=torch.int64) + lang_condition_mask = torch.tensor(self.language_index[vid]["attention_mask"], dtype=torch.int64) + lang_gen = torch.tensor(self.language_index[vid]["input_ids"], dtype=torch.int64) + lang_gen_mask = torch.tensor(self.language_index[vid]["attention_mask"], dtype=torch.int64) + + # Generate/Condition Check --> (Naive Zeroing leads to NaN --> just want the "CLS" token) + if index in self.gen_indices: + # If generating, just condition on the token (always the initial...), but generate everything! + lang_condition[1:] *= 0 + lang_condition_mask[1:] *= 0 + lang_gen_weight = 1 + + else: + # If conditioning, generate the token (dummy so things don't crash), but set weight to 0 + lang_gen[1:] *= 0 + lang_gen_mask[1:] *= 0 + lang_gen_weight = 0 + + # Retrieve Frames + if not self.stream: + imgs = self.elements[index]["states"] + imgs = torch.stack([self.transform(read_image(s)) for s in imgs]) + + # Return... + return imgs, lang_condition, lang_condition_mask, lang_gen, lang_gen_mask, lang_gen_weight + + else: + # Multiplex w/ num_worker idx... + worker_info = get_worker_info() + r = (self.r + worker_info.id) if worker_info is not None else self.r + + # Streaming + Retry Logic (in case of a bad connection -- retry same files!) + frame_paths, current_frame = list(self.elements[index]["states"]), None + for _i in range(self.n_retries): + try: + # Stream JPEG contents into BytesIO (seek back to 0), then PIL Image.open() + imgs = [] + for _current_idx, current_frame in enumerate(frame_paths): + if self.is_val: + blob, fobj = BUCKETS[r].blob(str(self.prefix / "val" / current_frame)), BytesIO() + else: + blob, fobj = BUCKETS[r].blob(str(self.prefix / "train" / current_frame)), BytesIO() + + # Error Handling --> we've already run several dry-run/verification trials, but about ~0.004% of + # the time, we'll hit some sort of TCP/Transport error; this might even go up + # with multiple runs happening at the same time. + # + # To address this, we're adopting the simplest possible "retry" strategy that + # immediately tries to re-download the same file (crashes if not possible). + # This ensures reproducibility, but puts some extra effort onto the user... + + # File download... + blob.download_to_file(fobj) + fobj.seek(0) + + # Image loading... + img_tensor = pil_to_tensor(Image.open(fobj)) + imgs.append(self.transform(img_tensor)) + + # Stack... + assert len(imgs) == 2, "Something went awry with try/except in GenStateOK Dataset..." + imgs = torch.stack(imgs) + + # Return... + return imgs, lang_condition, lang_condition_mask, lang_gen, lang_gen_mask, lang_gen_weight + + except (NotFound, TransportError, UnidentifiedImageError, OSError) as e: + # At the minimum --> print the broken file (obnoxiously!) + print(f"=>> BROKEN FILE :: {current_frame}") + if not self.do_retry: + raise e + else: + continue + + # If we've exhausted our retries --> raise an informative ValueError + raise ValueError(f"Failed to fix states `{self.elements[index]['states']}` w/ {self.n_retries} retries...") + + def __len__(self) -> int: + return len(self.elements) + + +class QuintetDataset(PretrainDataset): + def __init__( + self, + epoch: int, + index_path: Path, + img_transform: Compose, + lang_dropout: Optional[float] = None, + stream: bool = False, + prefix: Optional[Path] = None, + is_val: bool = False, + ) -> None: + super().__init__() + self.index_path, self.stream, self.is_val, self.val_loaded = index_path, stream, is_val, False + self.epoch, self.transform, self.elements, self.prefix = epoch, img_transform, None, prefix + self.lang_dropout = 0.0 if (lang_dropout is None or lang_dropout == 0) else lang_dropout + self.dropout_indices = set() + self.r = N_CORES * get_rank() + + # Load Language Index & Retrieve Epoch 0 Batches + language_path = "val-language-index.json" if self.is_val else "train-language-index.json" + if not self.stream: + with open(self.index_path / language_path, "r") as f: + self.language_index = json.load(f) + else: + blob = BUCKETS[self.r].blob(str(self.prefix / "index" / language_path)) + self.language_index = json.loads(blob.download_as_string()) + + # === Retrieve Epoch Batches (only call `set_epoch` inside __init__() === + self.set_epoch(self.epoch) + + def set_epoch(self, epoch: int) -> None: + # Not Streaming --> Read from local disk... + if not self.stream: + if self.is_val and not self.val_loaded: + with open(self.index_path / "quintet+language" / "validation-batches.json", "r") as f: + self.elements = json.load(f) + + # Assemble the set of dropout indices for the given epoch... + n_drop = int(self.lang_dropout * len(self.elements)) + self.dropout_indices = set(np.random.choice(len(self.elements), n_drop, replace=False)) + + # Set `val_loaded` and move on... + self.val_loaded = True + + elif not self.is_val: + with open(self.index_path / "quintet+language" / f"train-epoch={epoch}-batches.json", "r") as f: + self.elements = json.load(f) + + # Assemble the set of dropout indices for the given epoch... + n_drop = int(self.lang_dropout * len(self.elements)) + self.dropout_indices = set(np.random.choice(len(self.elements), n_drop, replace=False)) + + # Streaming --> Beam directly from Bucket + else: + if self.is_val and not self.val_loaded: + blob = BUCKETS[self.r].blob(str(self.prefix / "index" / "quintet+language" / "validation-batches.json")) + self.elements = json.loads(blob.download_as_string()) + + # `elements[i]["states"]` currently maps to disk path.. remove all but `parent/child.jpg` + for element in self.elements: + element["states"] = ["/".join(x.split("/")[-2:]) for x in element["states"]] + + # Assemble the set of dropout indices for the given epoch... + n_drop = int(self.lang_dropout * len(self.elements)) + self.dropout_indices = set(np.random.choice(len(self.elements), n_drop, replace=False)) + + # Set `val_loaded` and move on... + self.val_loaded = True + + elif not self.is_val: + blob = BUCKETS[self.r].blob( + str(self.prefix / "index" / "quintet+language" / f"train-epoch={epoch}-batches.json") + ) + self.elements = json.loads(blob.download_as_string()) + + # `elements[i]["states"]` currently maps to disk path.. remove all but `parent/child.jpg` + for element in self.elements: + element["states"] = ["/".join(x.split("/")[-2:]) for x in element["states"]] + + # Assemble the set of dropout indices for the given epoch... + n_drop = int(self.lang_dropout * len(self.elements)) + self.dropout_indices = set(np.random.choice(len(self.elements), n_drop, replace=False)) + + def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Return all 5 states/frames, and language, decomposed as the input_ids and attention_mask.""" + vid = self.elements[index]["vid"] + lang = torch.tensor(self.language_index[vid]["input_ids"], dtype=torch.int64) + lang_mask = torch.tensor(self.language_index[vid]["attention_mask"], dtype=torch.int64) + + # Dropout Language Check --> (Naive Zeroing leads to NaN --> just want the "PAD" token) + if index in self.dropout_indices: + # Initial language token is *always* = `101` --> last token always = `102` + lang[1:] *= 0 + lang_mask[1:] *= 0 + + # Retrieve Frames + if not self.stream: + imgs = self.elements[index]["states"] + imgs = torch.stack([self.transform(read_image(s)) for s in imgs]) + else: + # Multiplex w/ num_worker idx... + worker_info, imgs = get_worker_info(), [] + r = (self.r + worker_info.id) if worker_info is not None else self.r + + # Stream JPEG contents into BytesIO (seek back to 0), then PIL Image.open() + for state in self.elements[index]["states"]: + if self.is_val: + blob, fobj = BUCKETS[r].blob(str(self.prefix / "val" / state)), BytesIO() + else: + blob, fobj = BUCKETS[r].blob(str(self.prefix / "train" / state)), BytesIO() + + # Download into FileObj & Rewind... + blob.download_to_file(fobj) + fobj.seek(0) + + # Add to imgs... + imgs.append(self.transform(pil_to_tensor(Image.open(fobj)))) + + # Stack... + imgs = torch.stack(imgs) + + return imgs, lang, lang_mask + + def __len__(self) -> int: + return len(self.elements) + + +def get_epoch_datasets( + epoch: int, + name: str, + normalization: Tuple[Any, Any], + model_arch: str, + stream: bool, + artifact_path: str, + stream_prefix: str, + data_modality: str, + lang_dropout: Optional[float] = None, + gen_ratio: Optional[float] = None, +) -> Tuple[PretrainDataset, PretrainDataset]: + """Retrieve the custom Dataset classes for the train & val set for the given dataset & data modality.""" + index, img_transform = Path(artifact_path) / name / "index", get_online_transform(name, model_arch, normalization) + prefix = Path(stream_prefix) / name if stream else stream_prefix + + # Switch on `data_modality` + if data_modality == "state": + train_ds = StateDataset(epoch, index, img_transform, stream, prefix) + val_ds = StateDataset(epoch, index, img_transform, stream, prefix, is_val=True) + + elif data_modality == "state+language": + train_ds = StateLanguageDataset(epoch, index, img_transform, lang_dropout, stream, prefix) + val_ds = StateLanguageDataset(epoch, index, img_transform, lang_dropout, stream, prefix, is_val=True) + + elif data_modality == "state+ok": + if gen_ratio is None: + nl = model_arch == "v-dual" + train_ds = StateOKDataset(epoch, index, img_transform, lang_dropout, stream, prefix, no_lang=nl) + val_ds = StateOKDataset(epoch, index, img_transform, lang_dropout, stream, prefix, no_lang=nl, is_val=True) + else: + # Special Generative Language Dataset... + train_ds = GenStateOKDataset(epoch, index, img_transform, gen_ratio, stream, prefix) + val_ds = GenStateOKDataset(epoch, index, img_transform, gen_ratio, stream, prefix, is_val=True) + + elif data_modality == "quintet+language": + train_ds = QuintetDataset(epoch, index, img_transform, lang_dropout, stream, prefix) + val_ds = QuintetDataset(epoch, index, img_transform, lang_dropout, stream, prefix, is_val=True) + + else: + raise NotImplementedError(f"Support for data modality `{data_modality}` not yet implemented!") + + return train_ds, val_ds diff --git a/voltron/preprocessing/transforms.py b/voltron/preprocessing/transforms.py index 1d3f4eb..90a3029 100644 --- a/voltron/preprocessing/transforms.py +++ b/voltron/preprocessing/transforms.py @@ -80,7 +80,7 @@ def get_online_transform(dataset: str, model_arch: str, normalization: Tuple[Any """Defines an `online` transform to be applied *when batching the images* (during training/validation).""" if dataset == "sth-sth-v2": # Note: R3M does *not* expect normalized 0-1 (then ImageNet normalized) images --> drop the identity. - if model_arch == "r3m": + if model_arch in {"v-r3m", "v-rn3m"}: return Compose([Lambda(identity)]) else: return Compose([ConvertImageDtype(torch.float), Normalize(mean=normalization[0], std=normalization[1])]) diff --git a/voltron/util/checkpointing.py b/voltron/util/checkpointing.py new file mode 100644 index 0000000..2e132f3 --- /dev/null +++ b/voltron/util/checkpointing.py @@ -0,0 +1,114 @@ +""" +checkpointing.py + +Core utility class for handling model/optimizer serialization & checkpointing... including (eventually) resume from +checkpoint logic. + +Support the following strategies: + - (k, -1, -1) --> Keep only the most recent "k" epoch checkpoints + - (k, m, -1) --> Keep the most recent "k" epoch checkpoints and *every* m epoch checkpoint + - (k, m, s = 2500) --> Keep "k" and "m" subject to above, but also keep *s* step checkpoints for current epoch +""" +import os +from collections import deque +from pathlib import Path +from typing import Any, Optional, Tuple + +import wandb + + +class FixedDeck(deque): + def __init__(self, maxlen): + super().__init__(maxlen=maxlen) + + def append(self, x: Any) -> Any: + pop_value = None + if self.__len__() == self.maxlen: + pop_value = self.__getitem__(0) + + # Perform parent append and return popped value, if any! + super().append(x) + return pop_value + + +class CheckpointSaver: + def __init__(self, strategy: Tuple[int, int, int], run_dir: str, accelerator: str): + """ + Create a checkpoint saver with the provided strategy that saves to the given path, with handling for the + specific hardware accelerator. + + :param strategy: Strategy, following the (k, -1, -1) -- (k, m, -1) -- (k, m, s) description above. + :param run_dir: Path to root of `run_dir` + :param accelerator: Hardware accelerator to run on -- TODO :: Only TPUs supported right now! + """ + import torch_xla.core.xla_model as xm + + (self.k, self.m, self.s), self.run_dir, self.accelerator = strategy, run_dir, accelerator + self.recents, self.intervals, self.step_checkpoints = FixedDeck(maxlen=self.k), set(), set() + + # If `self.s` is -1 --> disable step_checkpoints + self.enable_step = self.s != -1 + + # Create "checkpoints" subdirectory + self.path = Path(run_dir) / "checkpoints" + if xm.is_master_ordinal(local=False): + os.makedirs(self.path, exist_ok=True) + + # Populate `step_checkpoints` on __init__ (if resuming *within* an epoch...) + self.step_checkpoints.update([c for c in self.path.iterdir() if "local-epoch=" in str(c)]) + + # Create Saver + xm.master_print(f"Created Saver w/ `k` = {self.k}, `m` = {self.m}`, `s` = {self.s}!") + + def save( + self, + epoch: int, + is_local_step: bool, + model: Any, + optimizer: Any, + duration: int, + local_step: Optional[int] = None, + train_loss: Optional[float] = None, + val_loss: Optional[float] = None, + ) -> None: + """Performs the save operation, unlinking existing stale checkpoints, if necessary.""" + import torch_xla.core.xla_model as xm + + # Check if saving a `local_step` (within an epoch) or if saving an `epoch` + if self.enable_step and is_local_step and (local_step % self.s) == 0: + # Create filename + step_checkpoint = self.path / f"local-epoch={epoch}-step={local_step}-t={duration}.pt" + + # Perform actual save action... + # > IMPORTANT --> XLA/XM will throw an error if optimizer has "param_groups" so only save "state"... + xm.save([model.state_dict(), optimizer.state_dict()["state"]], step_checkpoint) + if xm.is_master_ordinal(local=False): + self.step_checkpoints.add(step_checkpoint) + + elif not is_local_step: + # Create filename + if train_loss is None and val_loss is None: + checkpoint = self.path / f"epoch={epoch}-train=inf-val=inf-t={duration}.pt" + else: + checkpoint = self.path / f"epoch={epoch}-train={train_loss:.4f}-val={val_loss:.4f}-t={duration}.pt" + + # Perform actual save action... + # > IMPORTANT --> XLA/XM will throw an error if optimizer has "param_groups" so only save "state"... + xm.save([model.state_dict(), optimizer.state_dict()["state"]], checkpoint) + + if xm.is_master_ordinal(local=False): + # Conditional Check for M -- Keep if modulated by interval + if epoch % self.m == 0: + self.intervals.add(checkpoint) + + # Remove all "step_checkpoints" now that we successfully made it to the end of the epoch! + while len(self.step_checkpoints) > 0: + os.remove(self.step_checkpoints.pop()) + + # Finally, recency add & unlink/delete if necessary + to_remove = self.recents.append(checkpoint) + if to_remove is not None and to_remove not in self.intervals: + os.remove(to_remove) + + # Extra Redundancy --> Save Model to W&B... + wandb.save(str(checkpoint)) diff --git a/voltron/util/xla_logger.py b/voltron/util/xla_logger.py new file mode 100644 index 0000000..e4d1c1f --- /dev/null +++ b/voltron/util/xla_logger.py @@ -0,0 +1,306 @@ +""" +xla_logger.py + +Utility class defining various XLA logging methods (called within marked closures), for logging metrics periodically +through training & validation. +""" +from typing import List + +import jsonlines +import numpy as np +import torch +import torch_xla.core.xla_model as xm +import wandb + + +# === Generic (Cross-Model) Epoch End Update === +def log_epoch_end_update( + arch: str, + epoch: int, + global_step: int, + run_id: str, + duration: int, + train_losses: List[torch.Tensor], + val_loss: float, + lr: float, + step_times: List[float], +) -> None: + train_loss = torch.stack(list(train_losses)).mean() + average_step_time = np.mean(list(step_times)) + + # Console Logging --> Unclear if it'll work? + xm.master_print( + f"Epoch {epoch:03d}, Global Step {global_step:06d} || LR :: {lr:.6f} -- Train Loss :: {train_loss:.4f} " + f"-- Val Loss :: {val_loss:.4f} -- Total Time (sec) :: {duration}" + ) + + # Get Log-Friendly Arch + p_arch = { + "v-mvp": "MVP", + "v-r3m": "R3M (ViT)", + "v-rn3m": "R3M (RN)", + "v-cond": "V-Cond", + "v-dual": "V-Dual", + "v-gen": "V-Gen", + }[arch] + + # Log to Weights & Biases & JSONL + blob = { + "Pretrain/Step": global_step, + "Pretrain/Epoch": epoch, + "Pretrain/Training Duration": duration, + "Pretrain/Step Time": average_step_time, + f"Pretrain/{p_arch} Train Epoch Loss": train_loss.item(), + f"Pretrain/{p_arch} Train Loss": train_loss.item(), + f"Pretrain/{p_arch} Validation Loss": val_loss, + "Pretrain/Learning Rate": lr, + } + + wandb.log(blob, step=global_step) + with jsonlines.open(f"{run_id}.jsonl", mode="a", sort_keys=True) as js_logger: + js_logger.write(blob) + + +# === Data-Locked Reproductions === + + +def log_vmvp_train_update( + epoch: int, + global_step: int, + run_id: str, + train_losses: List[torch.Tensor], + lr: float, + reconstruction_losses: List[torch.Tensor], + step_times: List[float], +) -> None: + train_loss = torch.stack(list(train_losses)).mean() + reconstruction_loss = torch.stack(list(reconstruction_losses)).mean() + average_step_time = np.mean(list(step_times)) + + # Console Logging --> Just log the aggregated train loss... + xm.master_print( + f"Epoch {epoch:03d}, Global Step {global_step:06d} || LR :: {lr:.6f} -- Train Loss :: {train_loss:.4f}" + ) + + # Log ot Weights & Biases + JSONL + blob = { + "Pretrain/Step": global_step, + "Pretrain/Epoch": epoch, + "Pretrain/V-MVP Train Loss": train_loss.item(), + "Pretrain/Reconstruction Loss": reconstruction_loss.item(), + "Pretrain/Learning Rate": lr, + "Pretrain/Step Time": average_step_time, + } + wandb.log(blob, step=global_step) + with jsonlines.open(f"{run_id}.jsonl", mode="a", sort_keys=True) as js_logger: + js_logger.write(blob) + + +def log_vr3m_train_update( + epoch: int, + global_step: int, + run_id: str, + train_losses: List[torch.Tensor], + lr: float, + tcn_losses: List[torch.Tensor], + reward_losses: List[torch.Tensor], + l1_losses: List[torch.Tensor], + l2_losses: List[torch.Tensor], + tcn_accuracies: List[torch.Tensor], + reward_accuracies: List[torch.Tensor], + step_times: List[float], +) -> None: + train_loss = torch.stack(list(train_losses)).mean() + tcn_loss = torch.stack(list(tcn_losses)).mean() + reward_loss = torch.stack(list(reward_losses)).mean() + l1_loss, l2_loss = torch.stack(list(l1_losses)).mean(), torch.stack(list(l2_losses)).mean() + tcn_accuracy = torch.stack(list(tcn_accuracies)).mean() + reward_accuracy = torch.stack(list(reward_accuracies)).mean() + average_step_time = np.mean(list(step_times)) + + # Console Logging --> Just log the aggregated train loss... + xm.master_print( + f"Epoch {epoch:03d}, Global Step {global_step:06d} || LR :: {lr:.6f} -- Train Loss :: {train_loss:.4f}" + ) + + # Log to Weights & Biases + JSONL + blob = { + "Pretrain/Step": global_step, + "Pretrain/Epoch": epoch, + "Pretrain/V-R3M Train Loss": train_loss.item(), + "Pretrain/TCN Loss": tcn_loss.item(), + "Pretrain/Reward Loss": reward_loss.item(), + "Pretrain/L1 Loss": l1_loss.item(), + "Pretrain/L2 Loss": l2_loss.item(), + "Pretrain/TCN Accuracy": tcn_accuracy.item(), + "Pretrain/Reward Accuracy": reward_accuracy.item(), + "Pretrain/Learning Rate": lr, + "Pretrain/Step Time": average_step_time, + } + wandb.log(blob, step=global_step) + with jsonlines.open(f"{run_id}.jsonl", mode="a", sort_keys=True) as js_logger: + js_logger.write(blob) + + +def log_vrn3m_train_update( + epoch: int, + global_step: int, + run_id: str, + train_losses: List[torch.Tensor], + lr: float, + tcn_losses: List[torch.Tensor], + reward_losses: List[torch.Tensor], + l1_losses: List[torch.Tensor], + l2_losses: List[torch.Tensor], + tcn_accuracies: List[torch.Tensor], + reward_accuracies: List[torch.Tensor], + step_times: List[float], +) -> None: + train_loss = torch.stack(list(train_losses)).mean() + tcn_loss = torch.stack(list(tcn_losses)).mean() + reward_loss = torch.stack(list(reward_losses)).mean() + l1_loss, l2_loss = torch.stack(list(l1_losses)).mean(), torch.stack(list(l2_losses)).mean() + tcn_accuracy = torch.stack(list(tcn_accuracies)).mean() + reward_accuracy = torch.stack(list(reward_accuracies)).mean() + average_step_time = np.mean(list(step_times)) + + # Console Logging --> Just log the aggregated train loss... + xm.master_print( + f"Epoch {epoch:03d}, Global Step {global_step:06d} || LR :: {lr:.6f} -- Train Loss :: {train_loss:.4f}" + ) + + # Log to Weights & Biases + JSONL + blob = { + "Pretrain/Step": global_step, + "Pretrain/Epoch": epoch, + "Pretrain/V-RN3M Train Loss": train_loss.item(), + "Pretrain/TCN Loss": tcn_loss.item(), + "Pretrain/Reward Loss": reward_loss.item(), + "Pretrain/L1 Loss": l1_loss.item(), + "Pretrain/L2 Loss": l2_loss.item(), + "Pretrain/TCN Accuracy": tcn_accuracy.item(), + "Pretrain/Reward Accuracy": reward_accuracy.item(), + "Pretrain/Learning Rate": lr, + "Pretrain/Step Time": average_step_time, + } + wandb.log(blob, step=global_step) + with jsonlines.open(f"{run_id}.jsonl", mode="a", sort_keys=True) as js_logger: + js_logger.write(blob) + + +# === Voltron Models === +def log_vcond_train_update( + epoch: int, + global_step: int, + run_id: str, + train_losses: List[torch.Tensor], + lr: float, + reconstruction_losses: List[torch.Tensor], + step_times: List[float], +) -> None: + train_loss = torch.stack(list(train_losses)).mean() + reconstruction_loss = torch.stack(list(reconstruction_losses)).mean() + average_step_time = np.mean(list(step_times)) + + # Console Logging --> Just log the aggregated train loss... + xm.master_print( + f"Epoch {epoch:03d}, Global Step {global_step:06d} || LR :: {lr:.6f} -- Train Loss :: {train_loss:.4f}" + ) + + # Log to Weights & Biases + JSONL + blob = { + "Pretrain/Step": global_step, + "Pretrain/Epoch": epoch, + "Pretrain/V-Cond Train Loss": train_loss.item(), + "Pretrain/Reconstruction Loss": reconstruction_loss.item(), + "Pretrain/Learning Rate": lr, + "Pretrain/Step Time": average_step_time, + } + wandb.log(blob, step=global_step) + with jsonlines.open(f"{run_id}.jsonl", mode="a", sort_keys=True) as js_logger: + js_logger.write(blob) + + +def log_vdual_train_update( + epoch: int, + global_step: int, + run_id: str, + train_losses: List[torch.Tensor], + lr: float, + reconstruction_losses: List[torch.Tensor], + zero_reconstruction_losses: List[torch.Tensor], + k_reconstruction_losses: List[torch.Tensor], + step_times: List[float], +) -> None: + train_loss = torch.stack(list(train_losses)).mean() + reconstruction_loss = torch.stack(list(reconstruction_losses)).mean() + zero_reconstruction_loss = torch.stack(list(zero_reconstruction_losses)).mean() + k_reconstruction_loss = torch.stack(list(k_reconstruction_losses)).mean() + average_step_time = np.mean(list(step_times)) + + # Console Logging --> Just log the aggregated train loss... + xm.master_print( + f"Epoch {epoch:03d}, Global Step {global_step:06d} || LR :: {lr:.6f} -- Train Loss :: {train_loss:.4f}" + ) + + # Log to Weights & Biases + JSONL + blob = { + "Pretrain/Step": global_step, + "Pretrain/Epoch": epoch, + "Pretrain/V-Dual Train Loss": train_loss.item(), + "Pretrain/Reconstruction Loss": reconstruction_loss.item(), + "Pretrain/Zero Reconstruction Loss": zero_reconstruction_loss.item(), + "Pretrain/K Reconstruction Loss": k_reconstruction_loss.item(), + "Pretrain/Learning Rate": lr, + "Pretrain/Step Time": average_step_time, + } + wandb.log(blob, step=global_step) + with jsonlines.open(f"{run_id}.jsonl", mode="a", sort_keys=True) as js_logger: + js_logger.write(blob) + + +def log_vgen_train_update( + epoch: int, + global_step: int, + run_id: str, + train_losses: List[torch.Tensor], + lr: float, + reconstruction_losses: List[torch.Tensor], + lm_losses: List[torch.Tensor], + lm_ppl: List[torch.Tensor], + zero_reconstruction_losses: List[torch.Tensor], + k_reconstruction_losses: List[torch.Tensor], + step_times: List[float], +) -> None: + train_loss = torch.stack(list(train_losses)).mean() + reconstruction_loss = torch.stack(list(reconstruction_losses)).mean() + lm_loss = torch.stack(list(lm_losses)).mean() + lm_perplexity = torch.stack(list(lm_ppl)).mean() + zero_reconstruction_loss = torch.stack(list(zero_reconstruction_losses)).mean() + k_reconstruction_loss = torch.stack(list(k_reconstruction_losses)).mean() + average_step_time = np.mean(list(step_times)) + + # Console Logging --> Just log the aggregated train loss... + xm.master_print( + f"Epoch {epoch:03d}, Global Step {global_step:06d} || LR :: {lr:.6f} -- Train Loss :: {train_loss:.4f} --" + f" Reconstruction Loss {reconstruction_loss:.4f} -- LM Loss {lm_loss:.4f}" + ) + + # Log to Weights & Biases + JSONL + blob = { + "Pretrain/Step": global_step, + "Pretrain/Epoch": epoch, + "Pretrain/V-Gen Train Loss": train_loss.item(), + "Pretrain/Reconstruction Loss": reconstruction_loss.item(), + "Pretrain/CLM Loss": lm_loss.item(), + "Pretrain/CLM Perplexity": lm_perplexity.item(), + "Pretrain/LM Loss": lm_loss.item(), + "Pretrain/LM Perplexity": lm_perplexity.item(), + "Pretrain/Zero Reconstruction Loss": zero_reconstruction_loss.item(), + "Pretrain/K Reconstruction Loss": k_reconstruction_loss.item(), + "Pretrain/Learning Rate": lr, + "Pretrain/Step Time": average_step_time, + } + wandb.log(blob, step=global_step) + with jsonlines.open(f"{run_id}.jsonl", mode="a", sort_keys=True) as js_logger: + js_logger.write(blob)