Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Minimum VRAM requirements #62

Open
soulde opened this issue Feb 11, 2025 · 5 comments
Open

Minimum VRAM requirements #62

soulde opened this issue Feb 11, 2025 · 5 comments

Comments

@soulde
Copy link

soulde commented Feb 11, 2025

Can I train the model with a single RTX4090? Is the single GPU mentioned in the readme an H100 or A100 for training, which means I need at least 80G VRAM?

@bruno686
Copy link

I also only have 4090, waiting for an answer

@patrickstar-sjh
Copy link

When 4090 graphics card is running, 0.5B model can run, 1.5B, 3B model,GPU= 4, CUDA OMM,GPU=8, memory overflow, my 4090 8 card is like this

@prvnsmpth
Copy link

I was able to run the 0.5B model on an RTX 4090 for a few mins before Ray errored out with OOM. (RAM, not VRAM. I'm running on a machine with 32G RAM).

For the 1.5B model, I don't think you can get it to run on the 4090 (unless you can figure out how to quantize to 8-bit or lower). Tried to reproduce it on an A100 40G and failed with GPU OOM. So now I'm trying A100 80G, and training seems to be progressing ok. Peak GPU mem usage is ~64G:

Image

@MonkeyNi
Copy link

MonkeyNi commented Mar 8, 2025

Can I train the model with a single RTX4090? Is the single GPU mentioned in the readme an H100 or A100 for training, which means I need at least 80G VRAM?

Yes, you can use GRPO instead of PPO. Here are my results. (#5 (comment))

@CiTY-GO
Copy link

CiTY-GO commented Mar 9, 2025

I found that my Flash atteh 2 didn't work how can i solve it
Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes, but the current dype in Qwen2ForTokenClassification is torch.float32. You should run training or inference using Automatic Mixed-Precision via the with torch.autocast(device_type='torch_device'):decorator, or load the model with thetorch_dtypeargument. Example:model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="flash_attention_2", torch_dtype=torch.float16)(WorkerDict pid=61308) You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU withmodel.to('cuda').

# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

6 participants