diff --git a/README.md b/README.md
index 672b1880..eef777a6 100644
--- a/README.md
+++ b/README.md
@@ -1,4 +1,34 @@
-# ⚡️ Nanotron
+
⚡️ Nanotron
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ Philosophy •
+ Core Features •
+ Installation •
+ Usage •
+ Contributions
+
+
+
+
+
+
+
+
+
+#
The objective of this library is to provide easy distributed primitives in order to train a variety of models efficiently using 3D parallelism. For more information about the internal design of the library or 3D parallelism in general, please check out [[docs.md]](./docs/docs.md) and [[3d_parallelism.md]](./docs/3d_parallelism.md).
@@ -28,12 +58,10 @@ To install (in a new env):
```bash
pip install torch
pip install packaging; pip install "flash-attn>=2.5.0" --no-build-isolation
-git clone git@github.com:huggingface/nanotron.git
-cd nanotron
-pip install -e .
+pip install nanotron
```
-Also nice to have `transformers` `datasets` `python-etcd` `tensorboardX`: `pip install transformers datasets python-etcd tensorboardX`
+Also nice to have: `pip install transformers datasets python-etcd tensorboardX`
We also support a set of flavors that you can install using `pip install -e [$FLAVOR]`:
- `dev`: Used is you are developping in `nanotron`. It installs in particular our linter mechanism. On top of that you have to run `pre-commit install` afterwards.
@@ -68,7 +96,6 @@ pre-commit run --config .pre-commit-config.yaml --all-files
Features we would like to add:
- [ ] Support `torch.compile`
-- [ ] Support `torch.distributed.rpc`
- [ ] More optimized kernels
- [ ] Support Zero3
- [ ] Other PP schedules (such as Interleaved 1f1b...)
diff --git a/pyproject.toml b/pyproject.toml
index f3372750..ebb81b8f 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,6 +1,6 @@
[project]
name = "nanotron"
-version = "0.2"
+version = "0.4"
description = "Minimalistic Large Language Model Training and Finetuning"
authors = [
{name = "Nouamane Tazi", email="nouamane@huggingface.co"},
diff --git a/run_generate.py b/run_generate.py
index e280a5fd..0f52b8ed 100644
--- a/run_generate.py
+++ b/run_generate.py
@@ -15,9 +15,19 @@
import torch
from nanotron import distributed as dist
from nanotron import logging
-from nanotron.config import GenerationArgs, LoggingArgs, ParallelismArgs, get_config_from_file
-from nanotron.generation.decode import GenerationInput, TokenizerConfig, decode_text, decode_tokenized
-from nanotron.logging import log_rank, set_logger_verbosity_format
+from nanotron.config import (
+ GenerationArgs,
+ LoggingArgs,
+ ParallelismArgs,
+ get_config_from_file,
+)
+from nanotron.generation.decode import (
+ GenerationInput,
+ TokenizerConfig,
+ decode_text,
+ decode_tokenized,
+)
+from nanotron.logging import log_rank, set_ranks_logging_level
from nanotron.models import build_model
from nanotron.parallel import ParallelContext
from nanotron.parallel.parameters import sanity_check
@@ -32,9 +42,7 @@
get_synced_random_state,
set_random_seed,
)
-from nanotron.serialize import (
- load_weights,
-)
+from nanotron.serialize import load_weights
from nanotron.trainer import CONFIG_TO_MODEL_CLASS, mark_tied_parameters
try:
@@ -86,12 +94,8 @@ def main():
log_level_replica="info",
)
- if dist.get_rank(parallel_context.world_pg) == 0:
- if logging_config.log_level is not None:
- set_logger_verbosity_format(logging_config.log_level, parallel_context=parallel_context)
- else:
- if logging_config.log_level_replica is not None:
- set_logger_verbosity_format(logging_config.log_level_replica, parallel_context=parallel_context)
+ # Set log levels
+ set_ranks_logging_level(parallel_context=parallel_context, logging_config=logging_config)
log_rank(f"model_config: {model_config}", logger=logger, level=logging.INFO, rank=0)
log_rank(f"tokenizer_path: {tokenizer_path}", logger=logger, level=logging.INFO, rank=0)
diff --git a/src/nanotron/__init__.py b/src/nanotron/__init__.py
index 09888577..896a370c 100644
--- a/src/nanotron/__init__.py
+++ b/src/nanotron/__init__.py
@@ -1 +1 @@
-__version__ = "0.2"
+__version__ = "0.4"
diff --git a/src/nanotron/logging.py b/src/nanotron/logging.py
index 0e8441dd..72d68efc 100644
--- a/src/nanotron/logging.py
+++ b/src/nanotron/logging.py
@@ -18,13 +18,24 @@
import sys
from dataclasses import dataclass
from functools import lru_cache
-from logging import CRITICAL, DEBUG, ERROR, FATAL, INFO, NOTSET, WARNING, Formatter, Logger
+from logging import (
+ CRITICAL,
+ DEBUG,
+ ERROR,
+ FATAL,
+ INFO,
+ NOTSET,
+ WARNING,
+ Formatter,
+ Logger,
+)
from typing import List, Optional, Union
import torch
from torch import distributed as torch_dist
from nanotron import distributed as dist
+from nanotron.config.config import LoggingArgs
from nanotron.parallel import ParallelContext
log_levels = {
@@ -283,7 +294,6 @@ def set_logger_verbosity_format(logging_level: str, parallel_context: ParallelCo
f"TP={dist.get_rank(parallel_context.tp_pg)}{expert_parallel_log}{'|' + node_name if node_name else ''}]: %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
)
- # TODO @thomasw21: `logging.log_levels` returns valid lg log levels
log_level = log_levels[logging_level]
# main root logger
@@ -299,4 +309,13 @@ def set_logger_verbosity_format(logging_level: str, parallel_context: ParallelCo
set_formatter(formatter=formatter)
+def set_ranks_logging_level(parallel_context: ParallelContext, logging_config: LoggingArgs):
+ if dist.get_rank(parallel_context.world_pg) == 0:
+ if logging_config.log_level is not None:
+ set_logger_verbosity_format(logging_config.log_level, parallel_context=parallel_context)
+ else:
+ if logging_config.log_level_replica is not None:
+ set_logger_verbosity_format(logging_config.log_level_replica, parallel_context=parallel_context)
+
+
_configure_library_root_logger()
diff --git a/src/nanotron/parallel/pipeline_parallel/__init__.py b/src/nanotron/parallel/pipeline_parallel/__init__.py
new file mode 100644
index 00000000..a4d66e50
--- /dev/null
+++ b/src/nanotron/parallel/pipeline_parallel/__init__.py
@@ -0,0 +1,5 @@
+from nanotron.parallel.pipeline_parallel.engine import PipelineEngine
+from nanotron.parallel.pipeline_parallel.tensor_pointer import TensorPointer
+from nanotron.parallel.pipeline_parallel.utils import get_pp_rank_of
+
+__all__ = ["PipelineEngine", "TensorPointer", "get_pp_rank_of"]
diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py
index dc4a44dd..0d6c7994 100644
--- a/src/nanotron/trainer.py
+++ b/src/nanotron/trainer.py
@@ -45,7 +45,7 @@
human_format,
log_memory,
log_rank,
- set_logger_verbosity_format,
+ set_ranks_logging_level,
)
from nanotron.models import NanotronModel, build_model
from nanotron.models.base import check_model_has_grad
@@ -55,9 +55,11 @@
from nanotron.parallel import ParallelContext
from nanotron.parallel.data_parallel.utils import sync_gradients_across_dp
from nanotron.parallel.parameters import NanotronParameter, sanity_check
-from nanotron.parallel.pipeline_parallel.engine import PipelineEngine
-from nanotron.parallel.pipeline_parallel.tensor_pointer import TensorPointer
-from nanotron.parallel.pipeline_parallel.utils import get_pp_rank_of
+from nanotron.parallel.pipeline_parallel import (
+ PipelineEngine,
+ TensorPointer,
+ get_pp_rank_of,
+)
from nanotron.parallel.tensor_parallel.nn import (
TensorParallelLinearMode,
TensorParallelRowLinear,
@@ -143,14 +145,7 @@ def __init__(
self.pre_init()
# Set log levels
- if dist.get_rank(self.parallel_context.world_pg) == 0:
- if self.config.logging.log_level is not None:
- set_logger_verbosity_format(self.config.logging.log_level, parallel_context=self.parallel_context)
- else:
- if self.config.logging.log_level_replica is not None:
- set_logger_verbosity_format(
- self.config.logging.log_level_replica, parallel_context=self.parallel_context
- )
+ set_ranks_logging_level(parallel_context=self.parallel_context, logging_config=self.config.logging)
# Log benchmark info
if os.environ.get("NANOTRON_BENCHMARK", "0") == "1":
@@ -198,8 +193,6 @@ def __init__(
)
# Define iteration start state
- self.start_iteration_step: int
- self.consumed_train_samples: int
if self.init_checkpoint_path is not None:
checkpoint_metadata = load_meta(
parallel_context=self.parallel_context, root_folder=self.init_checkpoint_path
@@ -266,8 +259,6 @@ def train(
self.save_checkpoint()
if isinstance(dataloader_or_dls, tuple):
- dataloader_or_dls[1] if len(dataloader_or_dls) > 1 else None
- dataloader_or_dls[2] if len(dataloader_or_dls) > 2 else None
dataloader = dataloader_or_dls[0]
else:
dataloader = dataloader_or_dls