Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

docs: Add ADR for Python fm-training-estimator #10

Merged
merged 7 commits into from
Oct 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 161 additions & 0 deletions adrs/001-resource-estimator-library.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
---
title: Resource Estimator Library
---

- **Author(s)**: Angel Luu (@aluu317)
- **Signer(s)**: Praveen Jayachandran, Ashok Pon Kumar Sree Prakash @ashokponkumar, Chander Govindarajan @ChanderG
- **Date (YYYY-MM-DD)**: 2024-10-31
- **Obsoletes ADRs**: N/A
- **Modified By ADRs**: N/A
- **Relevant Issues**: N/A

## Problem Context

Users of tuning/training stack currently have no way of estimating how much memory, time or cost it takes to run a training. They often hit OOM errors due to lack of memory. Users don't have enough information to make trade-off decisions on time vs. cost. Platform admins do not have any info to better schedule/pack jobs onto GPUs.

In order to be useful, the capability of estimating resources must be exposed to tuning/training users. The primary user personas of this service include training users and platform admins.

This ADR defines a Resource Estimator Python Library that provides an estimate of resource requirements for training runs.

## Impact Table

| AI Functionality | Operational Functionality |
| ----------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------- |
| Tuning Stack | APIs |

## Decision

- We will expose the resource estimator service as a Python library `fm_training_estimator`, hosted as Open Source at the repo [fm-training-estimator](https://github.com/foundation-model-stack/fm-training-estimator) and published to [PyPI](https://pypi.org/).
- This Python library can be installed and plugged into any UI backend or a docker image by a product team.
- The `fm_training_estimator` exposes 4 methods to calculate memory, time, tokens and cost. The method calls allows for user to pass training data as input for "learned" or "hybrid" model. If training data is missing, the "theory" is used.

### Alternatives to Python library deliverable
We have considered choices of:
- Alternative 1: A new docker image which has a FastAPI Server with a REST interface defined. When a product team integrates as a service, they can run this docker image, a server will run on localhost which can then be queried by GET/POST calls to do the estimates.

- Alternative 2: A new docker image with a python script similar to fms-hf-tuning, which accepts a JSON config and calls the necessary python scripts to get estimate and save results in a file.

Both alternatives provide more value to consumers. However does not provide the flexibility of how the library can be integrated and consumed.

## Consequences

- By using this library, users need to supply their own dataset for the estimator to generate a learned model, and assume the security and privacy of that data. They can use flight service plugin should that be applicable.
- The library can be used as backend component of a larger UI effort, or as part of a Docker image. The product teams can consume the library however they see fit and create their own build/update process.

## High Level Design

- The `EstimateInput` data class (not all fields are required) defines the set of configs the library will use to calculate the results. This includes a list of instances of `Config` data class which in turns includes different types of configs (hf training args `HFArguments`, fms-hf-tuning additional args `FMArguments`, data args `DataArguments`, infrastructure args `InfraArguments` and peft lora args `PeftLoraConfig`), and `EstimatorConfig` with metadata parameters. The input can be read from a json file using `--input_file_path` or `-f`.

Example of an `EstimateInput` with all fields defined:
```json
{
"estimator": { // EstimatorMetadata
"base_data_path": "data.csv",
"method": "theory", // theory, learned, hybrid
"token_estimation_version": 0
},
"job_configs": [{ // list of [JobConfig]
"hf_training": { // HFArguments
"output_dir": "./output"
},
"fm": { // FMArguments
"base_model_path": "ibm-granite/granite-3b-code-base",
"flash_attention_v2": "false",
"lora_config": null,
"max_seq_length": 2048,
"block_size": 2048,
"data_config_file": "data_config.json",
"prompt_tuning_config": null,
"torch_dtype": "float32",
"technique": "full"
},
"data": { // DataArguments
"te_approach": 0,
"dataset": null,
"dataset_text_field": "text",
"dataset_split": "test",
"dataset_config_name": null
},
"infra": { // InfraArguments
"numGpusPerPod": 1,
"numPods": 1,
"gpu_memory_in_gb": 80,
"gpuModel": "A100"
},
"peft_lora": { // PeftLoraConfig
"r": 4,
"lora_alpha": 8,
"lora_dropout": 0.1,
"target_modules": "[q_proj, v_proj]"
}
}]
}
```

- The API exposes 4 functions:

Function `estimate_memory` returns a `MemoryEstimate`:
```python
{
"memory": { # MemoryEstimate
"total_mem_estimate": "44.6 GiB",
"activation_memory": "34.7 GiB",
"gradient_memory": "2.5 GiB",
"model_memory": "2.5 GiB",
"optimizer_memory": "4.9 GiB",
"num_gpus": 2
}
}
```

Function `estimate_time` returns a `TimeEstimate`:
```python
{
"time": { # TimeEstimate
"time": "40s"
}
}
```

Function `estimate_tokens` returns a `TokensEstimate`:
```python
{
"tokens": { # TokensEstimate
"tps": "5259.07373046875"
}
}
```

Function `estimate_cost` returns a `CostEstimate`:
```python
{
"cost": { # CostEstimate
"usd": "0.0"
}
}
```

Function `estimate` returns a `Estimate` that include all 4 types of estimates above:
```python
{
"estimate": { # Estimate
"memory": { # MemoryEstimate
"total_mem_estimate": "44.6 GiB",
"activation_memory": "34.7 GiB",
"gradient_memory": "2.5 GiB",
"model_memory": "2.5 GiB",
"optimizer_memory": "4.9 GiB",
"num_gpus": 2
},
"time": { # TimeEstimate
"time": "40s"
},
"tokens": { # TokensEstimate
"tps": "5259.07373046875"
},
"cost": { # CostEstimate
"usd": "0.0"
}
}
}
```
110 changes: 106 additions & 4 deletions fm_training_estimator/config/arguments.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# Standard
from dataclasses import dataclass, field
from enum import Enum
from typing import List, Optional

# Third Party
from peft.tuners.lora import LoraConfig
Expand All @@ -9,7 +11,7 @@

@dataclass
class PeftPromptTuningConfig(PromptTuningConfig):
"""dataclass for promptuning config
"""dataclass for prompt tuning config

Args:
PromptTuningConfig (_type_): imported directly from peft library
Expand All @@ -18,7 +20,7 @@ class PeftPromptTuningConfig(PromptTuningConfig):

@dataclass
class PeftLoraConfig:
"""Dataclass for lora config
"""Dataclass for LoRA tuning config

Not directly imported from peft LoraConfig due to complexity.
"""
Expand Down Expand Up @@ -65,6 +67,16 @@ class InfraArguments:
)


class TuningTechnique(Enum):
"""Enumerate different tuning techniques the FM Training Estimator can perform estimation on."""

LORA = "lora"
"""LoRA tuning technique."""

FULL = "full"
"""Full fine-tuning technique."""


@dataclass
class FMArguments:
"""dataclass to store additional args not covered by standard HF argument dataclasses"""
Expand Down Expand Up @@ -116,13 +128,16 @@ class FMArguments:
},
)

technique: str = field(
default="full", metadata={"help": ("Fine-tuning technique being used")}
technique: TuningTechnique = field(
default=TuningTechnique.FULL,
metadata={"help": ("Fine-tuning technique being used")},
)


@dataclass
class DataArguments:
"""dataclass to define args handling training data as input for estimation."""

te_approach: int = field(
default=0, metadata={"help": ("Approach to use for Token Estimation")}
)
Expand All @@ -144,3 +159,90 @@ class DataArguments:
default=None,
metadata={"help": ("dataset configuration to use, in case of HF dataset")},
)


class EstimatorMethod(Enum):
"""Enumerate different estimation models the FM Training Estimator is to use to make an estimation."""

THEORY = "theory"
"""Theory model for estimation."""

LEARNED = "learned"
"""Learned model for estimation, based on user provided training data."""

HYBRID = "hybrid"
"""Hybrid model for estimation, a combination of theory and learned models."""


@dataclass
class EstimatorMetadata:
"""Metadata for the FM Training Estimator."""

base_data_path: str
method: List[EstimatorMethod]
token_estimation_version: str


@dataclass
class JobConfig:
"""Dataclass that represents a set of different configs for a tuning job to make estimate on."""

hf_training: HFTrainingArguments = field(default_factory=HFTrainingArguments)
fm: FMArguments = field(default_factory=FMArguments)
data: DataArguments = field(default_factory=DataArguments)
infra: InfraArguments = field(default_factory=InfraArguments)
peft_lora: PeftLoraConfig = field(default_factory=PeftLoraConfig)


@dataclass
class EstimateInput:
"""
The dataclass that is an input to a estimate function.
It includes a list of different training job configs and metadata about the estimator.
"""

job_configs: List[JobConfig]
estimator_metadata: Optional[EstimatorMetadata] = None


@dataclass
class TimeEstimate:
"""The estimated time response to estimate_time function."""

time: str


@dataclass
class MemoryEstimate:
"""The estimated memory response to estimate_memory function."""

total_mem_estimate: str
activation_memory: str
gradient_memory: str
model_memory: str
optimizer_memory: str
num_gpus: int


@dataclass
class TokenEstimate:
"""The estimated token response to estimate_token function."""

tps: float


@dataclass
class CostEstimate:
"""The estimated cost response to estimate_cost function."""

usd: float


@dataclass
class Estimate:
"""The estimate response to estimate function, including time, memory, tokens and cost."""

memory: MemoryEstimate
time: TimeEstimate
tokens: TokenEstimate
cost: CostEstimate