Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

[BUG] Crash with a minimal ZeRO stage 3 NVMe checkpointing example #4565

Closed
eisene opened this issue Oct 25, 2023 · 3 comments · Fixed by #4702
Closed

[BUG] Crash with a minimal ZeRO stage 3 NVMe checkpointing example #4565

eisene opened this issue Oct 25, 2023 · 3 comments · Fixed by #4702
Labels
bug Something isn't working training

Comments

@eisene
Copy link
Contributor

eisene commented Oct 25, 2023

Describe the bug

Simplest possible training code with ZeRO stage 3 with NVMe offload for the optimizer crashes on model.step() with the error

  File "/home/eeisenst/workspace/DeepSpeed/deepspeed/runtime/zero/stage3.py", line 2002, in unscale_and_clip_grads
    self.fp32_partitioned_groups_flat[sub_group_id].grad.mul_(1. / combined_scale)

To Reproduce

import os
import deepspeed
import deepspeed.comm as dist
from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum
import torch


class SimpleModel(torch.nn.Module):

    def __init__(self, hidden_dim, empty_grad=False, nlayers=1):
        super(SimpleModel, self).__init__()
        self.linears = torch.nn.ModuleList([torch.nn.Linear(hidden_dim, hidden_dim) for i in range(nlayers)])
        if empty_grad:
            self.linear2 = torch.nn.Linear(hidden_dim, hidden_dim)
        self.cross_entropy_loss = torch.nn.CrossEntropyLoss()
        self.empty_grad = empty_grad

    def forward(self, x, y):
        if len(self.linears) == 1:
            x = self.linears[0](x)
        else:
            for i, l in enumerate(self.linears):
                x = self.linears[i // 2](x) + l(x)
        return self.cross_entropy_loss(x, y)


def random_dataset(total_samples, hidden_dim, device, dtype=torch.half):
    train_data = torch.randn(total_samples, hidden_dim, device=device, dtype=dtype)
    train_label = torch.empty(total_samples, dtype=torch.long, device=device).random_(hidden_dim)
    train_dataset = torch.utils.data.TensorDataset(train_data, train_label)
    return train_dataset


def random_dataloader(model, total_samples, hidden_dim, device, dtype=torch.half):
    batch_size = model.train_micro_batch_size_per_gpu()
    train_dataset = random_dataset(total_samples, hidden_dim, device, dtype=dtype)
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size)
    return train_loader

tmpdir = "/home/eeisenst/workspace/temp/temp"    # CHANGE THIS TO SOMETHING CONVENIENT
zero_dir, ckpt_dir = os.path.join(tmpdir, "zero"), os.path.join(tmpdir, "checkpoint")

torch.manual_seed(12345)

config_dict = {
    "train_micro_batch_size_per_gpu": 1,
    "steps_per_print": 1,
    "optimizer": {
        "type": "Adam",
        "params": {
            "lr": 1e-6
        }
    },
    "fp16": {
        "enabled": True,
        "initial_scale_power": 2
    },
    "zero_optimization": {
        "stage": 3,
        "offload_param": {
            "device": OffloadDeviceEnum.cpu
            # "device": OffloadDeviceEnum.nvme,
            # "nvme_path": str(zero_dir)
        },
        "offload_optimizer": {
            # "device": OffloadDeviceEnum.cpu
            "device": OffloadDeviceEnum.nvme,
            "nvme_path": str(zero_dir)
        },
        "sub_group_size": 100,
        "stage3_max_live_parameters": 100,
        "stage3_param_persistence_threshold": 0,
    },
    "aio": {
        "block_size": 1048576       # Minimum AIO bytes, anything smaller than this will not be offloaded
    }
}

hidden_dim, nlayers = 2048, 5
with deepspeed.zero.Init(config_dict_or_path=config_dict):
    model = SimpleModel(hidden_dim, nlayers=nlayers, empty_grad=False)

model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict)
data_loader = random_dataloader(model=model,
                                total_samples=10,
                                hidden_dim=hidden_dim,
                                device=model.device,
                                dtype=torch.float16)
dist.barrier()
for n, batch in enumerate(data_loader):
    loss = model(batch[0], batch[1])
    model.backward(loss)
    model.step()

Expected behavior
This script should exit with no error.

ds_report output

--------------------------------------------------
DeepSpeed C++/CUDA extension op report
--------------------------------------------------
NOTE: Ops not installed will be just-in-time (JIT) compiled at
      runtime if needed. Op compatibility means that your system
      meet the required dependencies to JIT install the op.
--------------------------------------------------
JIT compiled ops requires ninja
ninja .................. [OKAY]
--------------------------------------------------
op name ................ installed .. compatible
--------------------------------------------------
 [WARNING]  async_io requires the dev libaio .so object and headers but these were not found.
 [WARNING]  If libaio is already installed (perhaps from source), try setting the CFLAGS and LDFLAGS environment variables to where it can be found.
async_io ............... [YES] ...... [NO]
fused_adam ............. [NO] ....... [OKAY]
cpu_adam ............... [NO] ....... [OKAY]
cpu_adagrad ............ [NO] ....... [OKAY]
 [WARNING]  Please specify the CUTLASS repo directory as environment variable $CUTLASS_PATH
evoformer_attn ......... [NO] ....... [NO]
fused_lamb ............. [NO] ....... [OKAY]
quantizer .............. [NO] ....... [OKAY]
random_ltd ............. [NO] ....... [OKAY]
 [WARNING]  sparse_attn requires a torch version >= 1.5 and < 2.0 but detected 2.0
 [WARNING]  using untested triton version (2.0.0), only 1.0.0 is known to be compatible
sparse_attn ............ [NO] ....... [NO]
spatial_inference ...... [NO] ....... [OKAY]
transformer ............ [NO] ....... [OKAY]
stochastic_transformer . [NO] ....... [OKAY]
transformer_inference .. [NO] ....... [OKAY]
--------------------------------------------------
DeepSpeed general environment info:
torch install path ............... ['/home/eeisenst/miniconda3/envs/deepspeed-test/lib/python3.11/site-packages/torch']
torch version .................... 2.0.1
deepspeed install path ........... ['/home/eeisenst/workspace/DeepSpeed/deepspeed']
deepspeed info ................... 0.10.4+6c6a1ec0, 6c6a1ec0, nvme_ckpt
torch cuda version ............... 11.8
torch hip version ................ None
nvcc version ..................... 11.7
deepspeed wheel compiled w. ...... torch 2.0, cuda 11.8
shared memory (/dev/shm) size .... 15.57 GB

System info (please complete the following information):

  • OS: Fedora 37
  • GPU count and types: 1x 3080 Ti on test machine, but it doesn't seem to matter
  • Python version: 3.11.5

Environment:

mamba install -c conda-forge pip python pytest pytorch gcc=11 libaio rust cmake

Build command:

CFLAGS="-I$CONDA_PREFIX/include/" LDFLAGS="-L$CONDA_PREFIX/lib/" DS_BUILD_CPU_ADAM=1 DS_BUILD_AIO=1  DS_BUILD_UTILS=1 pip install -e . --global-option="build_ext" --global-option="-j8" --no-cache -v

Launcher context
deepspeed launcher

Docker context
No docker.

Additional context

It seems that the problem is being caused by the following two lines 1334-1335 in deepspeed/runtime/zero/stage3.py in DeepSpeedZeroOptimizer_Stage3.partition_grads:

            # offload the gradient partition if applicable
            if self.offload_optimizer:
                i, dest_offset, _ = self.grad_position[self.get_param_id(param)]
                offload_fp32_gradients = {}                   # THIS IS THE BUG???
                offload_fp32_offsets = {}                     # THIS IS THE BUG???

This resets the dictionary of offloaded gradients so that, later in the same function, lines 1357-1361 do nothing:

        if self.offload_optimizer and self.swap_optimizer:
            for i in offload_fp32_gradients.keys():
                self.optimizer_swapper.swap_out_gradients(parameter=self.fp32_partitioned_groups_flat[i],
                                                          gradient_offsets=offload_fp32_offsets[i],
                                                          gradient_tensors=offload_fp32_gradients[i])

Commenting the lines marked BUG causes the script to work as expected.

@eisene eisene added bug Something isn't working training labels Oct 25, 2023
@eisene
Copy link
Contributor Author

eisene commented Oct 25, 2023

Forgot to comment, the minimal example causes the engine to go through this code path (deepspeed/runtime/engine.py lines 1953-1955):

        if allreduce_gradients and self.enable_backward_allreduce:
            # Traditional code path that allreduces the module parameter grads
            self.allreduce_gradients()

Here it erases the gradients without offloading them, as described above.

@tjruwase
Copy link
Contributor

@eisene, thanks for reporting this issue. I believe you have correctly identified the buggy re-initialization of the two dicts:

offload_fp32_gradients = {}
offload_fp32_offsets = {}

These dicts are correctly initialized at the beginning of the function:

offload_fp32_gradients = {}
offload_fp32_offsets = {}

Are you able to submit a PR deleting the re-initializations? Thanks!

@eisene
Copy link
Contributor Author

eisene commented Nov 17, 2023

Done, see #4702

# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
bug Something isn't working training
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants