From 8ade23cc6aec7c3bd3d80fef6378cafaade75bbe Mon Sep 17 00:00:00 2001 From: Marc Sun <57196510+SunMarc@users.noreply.github.com> Date: Fri, 15 Nov 2024 17:29:41 +0100 Subject: [PATCH] remove hook for bnb 4-bit (#3223) * relax dispatch for bnb * style --- src/accelerate/big_modeling.py | 14 +++++++++----- src/accelerate/utils/imports.py | 9 +++++++-- 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/src/accelerate/big_modeling.py b/src/accelerate/big_modeling.py index 9101153077c..c0708373730 100644 --- a/src/accelerate/big_modeling.py +++ b/src/accelerate/big_modeling.py @@ -37,6 +37,7 @@ find_tied_parameters, get_balanced_memory, infer_auto_device_map, + is_bnb_available, is_mlu_available, is_musa_available, is_npu_available, @@ -351,16 +352,19 @@ def dispatch_model( # Error early if the device map is incomplete. check_device_map(model, device_map) - # for backward compatibility - is_bnb_quantized = ( - getattr(model, "is_quantized", False) or getattr(model, "is_loaded_in_8bit", False) - ) and getattr(model, "quantization_method", "bitsandbytes") == "bitsandbytes" + # We need to force hook for quantized model that can't be moved with to() + if getattr(model, "quantization_method", "bitsandbytes") == "bitsandbytes": + # since bnb 0.43.2, we can move 4-bit model + if getattr(model, "is_loaded_in_8bit", False) or ( + getattr(model, "is_loaded_in_4bit", False) and not is_bnb_available(min_version="0.43.2") + ): + force_hooks = True # We attach hooks if the device_map has at least 2 different devices or if # force_hooks is set to `True`. Otherwise, the model in already loaded # in the unique device and the user can decide where to dispatch the model. # If the model is quantized, we always force-dispatch the model - if (len(set(device_map.values())) > 1) or is_bnb_quantized or force_hooks: + if (len(set(device_map.values())) > 1) or force_hooks: if main_device is None: if set(device_map.values()) == {"cpu"} or set(device_map.values()) == {"cpu", "disk"}: main_device = "cpu" diff --git a/src/accelerate/utils/imports.py b/src/accelerate/utils/imports.py index f408e60d9d1..aeafe91cf3c 100644 --- a/src/accelerate/utils/imports.py +++ b/src/accelerate/utils/imports.py @@ -179,8 +179,13 @@ def is_8bit_bnb_available(): return False -def is_bnb_available(): - return _is_package_available("bitsandbytes") +def is_bnb_available(min_version=None): + package_exists = _is_package_available("bitsandbytes") + if package_exists and min_version is not None: + bnb_version = version.parse(importlib.metadata.version("bitsandbytes")) + return compare_versions(bnb_version, ">=", min_version) + else: + return package_exists def is_bitsandbytes_multi_backend_available():