generated from siddk/kindling
-
Notifications
You must be signed in to change notification settings - Fork 23
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add Sth-Sth-v2 Preprocessing & XLA Pretraining Script (#10)
* Update xpretrain documentation * Update README * Add Sth-Sth-v2 Preprocessing Pipeline * Add model initialization stub * Add full XLA pretraining pipeline * Add v1.0.0 with full preprocessing/XLA pretraining pipeline
- Loading branch information
Showing
13 changed files
with
2,632 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,4 @@ | ||
from .accelerators import AcceleratorConfig | ||
from .datasets import DatasetConfig | ||
from .models import ModelConfig | ||
from .tracking import TrackingConfig |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
""" | ||
accelerator.py | ||
Base Hydra Structured Configs for defining various accelerator schemes. Uses a simple single inheritance structure. | ||
""" | ||
from dataclasses import dataclass | ||
|
||
from hydra.core.config_store import ConfigStore | ||
from omegaconf import MISSING | ||
|
||
|
||
@dataclass | ||
class AcceleratorConfig: | ||
accelerator: str = MISSING | ||
num_accelerators: int = MISSING | ||
num_workers: int = MISSING | ||
|
||
|
||
@dataclass | ||
class TPUv2OneConfig(AcceleratorConfig): | ||
accelerator = "tpu" | ||
num_accelerators = 1 | ||
num_workers = 4 | ||
|
||
|
||
@dataclass | ||
class TPUv2EightConfig(AcceleratorConfig): | ||
accelerator = "tpu" | ||
num_accelerators = 8 | ||
num_workers = 4 | ||
|
||
|
||
@dataclass | ||
class TPUv3OneConfig(AcceleratorConfig): | ||
accelerator = "tpu" | ||
num_accelerators = 1 | ||
num_workers = 8 | ||
|
||
|
||
@dataclass | ||
class TPUv3EightConfig(AcceleratorConfig): | ||
accelerator = "tpu" | ||
num_accelerators = 8 | ||
num_workers = 8 | ||
|
||
|
||
# Create a configuration group `accelerator` and populate with the above... | ||
cs = ConfigStore.instance() | ||
cs.store(group="accelerator", name="tpu-v2-1", node=TPUv2OneConfig) | ||
cs.store(group="accelerator", name="tpu-v2-8", node=TPUv2EightConfig) | ||
cs.store(group="accelerator", name="tpu-v3-1", node=TPUv3OneConfig) | ||
cs.store(group="accelerator", name="tpu-v3-8", node=TPUv3EightConfig) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.