-
Notifications
You must be signed in to change notification settings - Fork 4.2k
New issue
Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? # to your account
[inf] Add config var to enable keeping module on host #6846
base: master
Are you sure you want to change the base?
Conversation
Using keep_module_on_host config var will let us control if the loaded checkpoints to model parameters will be moved to the device or stay on host
deepspeed/module_inject/auto_tp.py
Outdated
@@ -17,9 +17,11 @@ | |||
from deepspeed.module_inject.tp_shard import get_shard_size, get_shard_size_list | |||
|
|||
|
|||
def move(tensor, device): | |||
def move(tensor, device, keep_module_on_host=False): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Won't it be simpler to modify callers to pass device='cpu' when keep_module_on_host=True
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
@tjruwase Can you please rereview and retrigger the CI? |
@@ -554,6 +554,7 @@ def test(self, model_w_task, injection_policy, query, inf_kwargs, assert_fn, dty | |||
|
|||
|
|||
@pytest.mark.seq_inference | |||
@pytest.mark.parametrize('keep_module_on_host', [True, False]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it useful to validate that tensors are on cpu
when keep_module_on_host=True
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tjruwase
You're right.
A check was added.
Using keep_module_on_host config var will let us control if the loaded checkpoints to model parameters will be moved to the device or stay on host