-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathsetup.sh
42 lines (35 loc) · 1.4 KB
/
setup.sh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
#!/bin/bash
set -e
echo "=== Cloning simple_grpo repository ==="
# Clone the repo into a temporary directory
git clone https://github.com/minosvasilias/simple_grpo.git
cd simple_grpo
echo "=== Setting up the Python environment ==="
# Create and activate a virtual environment
python -m venv venv
source venv/bin/activate
# Upgrade pip and install requirements
echo "=== Installing requirements ==="
pip install --upgrade pip
pip install -r requirements.txt
# Install custom TRL branch for memory optimization
echo "=== Installing custom TRL branch ==="
git clone --branch grpo-vram-optimization https://github.com/andyl98/trl.git trl_custom
cd trl_custom
# Tested, stable commit
git checkout ccc95472f6245f2db00986a08ca16da68bf32c14
pip install .
cd ..
# Ask user for Huggingface API key to set HF_TOKEN env var
echo "=== Setting up Huggingface API key ==="
read -rp "Please enter your Huggingface API key to access gated models such as Llama-3.1-8B-Instruct: " HF_TOKEN
export HF_TOKEN
echo "export HF_TOKEN=$HF_TOKEN" >> venv/bin/activate
echo "HF_TOKEN has been set to: $HF_TOKEN"
# Also ask for Wandb API key to log training runs
echo "=== Setting up Wandb API key ==="
read -rp "Please enter your Wandb API key to log training runs: " WANDB_API_KEY
export WANDB_API_KEY
echo "export WANDB_API_KEY=$WANDB_API_KEY" >> venv/bin/activate
echo "Wandb API key has been set to: $WANDB_API_KEY"
echo "=== Setup complete! ==="