Skip to content

Commit

Permalink
mergekit gpu 1226 (#9702)
Browse files Browse the repository at this point in the history
* mergekit gpu 1226

* merge model gpu

* merge gpu

* add lora model

* change valueerror

* add lora

* gpu test
  • Loading branch information
Mangodadada authored Jan 21, 2025
1 parent 7c1c9ba commit ac095f5
Show file tree
Hide file tree
Showing 8 changed files with 531 additions and 76 deletions.
11 changes: 5 additions & 6 deletions paddlenlp/mergekit/merge_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class MergeConfig:
default="np", metadata={"help": "Tensor type to use for the merge. Choose np(CPU Only) or pd (CPU/GPU)"}
)
n_process: int = field(default=1, metadata={"help": "Number of processes to use for the merge."})
merge_preifx: str = field(default="model", metadata={"help": "Prefix name: model or master_weights"})
merge_prefix: str = field(default="model", metadata={"help": "Prefix name: model or master_weights"})
merge_method: str = field(default="linear", metadata={"help": "The merge strategy."})
merge_type: str = field(default="linear", metadata={"help": "The type of merge process."})
sparsify_type: str = field(default=None, metadata={"help": "The type of sparsify process."})
Expand Down Expand Up @@ -73,12 +73,11 @@ def __post_init__(self):
def config_check(self):
if self.output_path is not None:
os.makedirs(self.output_path, exist_ok=True)
if self.tensor_type not in ["np"]:
raise ValueError(f"Unsupported tensor type: {self.tensor_type}. Support 'np' only.")
if self.device != "cpu":
logger.warning(f"Currently only support cpu device, but got {self.device}. Setting `device` to `cpu`.")
if self.tensor_type not in ["np", "pd"]:
raise ValueError(f"Unsupported tensor type: {self.tensor_type}. Support 'np' and 'pd' only.")
if self.device == "gpu" and self.tensor_type == "np":
logger.warning("np only support cpu device, but got gpu. Setting `device` to `cpu`.")
self.device = "cpu"
self.tensor_type = "np"

elif self.merge_method not in ["linear", "ties", "slerp", "della_linear", "della", "dare_linear", "dare_ties"]:
raise ValueError(
Expand Down
88 changes: 80 additions & 8 deletions paddlenlp/mergekit/merge_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import numpy as np
import paddle


class MergeMethod:
Expand Down Expand Up @@ -46,8 +47,14 @@ def linear(self, tensor_list):
if self.merge_config.tensor_type == "np":
tensor_output = sum(weight * tensor for weight, tensor in zip(weight_list, tensor_list))
return tensor_output
elif self.merge_config.tensor_type == "pd":
stacked_tensors = paddle.stack(tensor_list, axis=0)
weights = paddle.to_tensor(weight_list, dtype=stacked_tensors.dtype)
weights = weights.reshape([-1] + [1] * (len(stacked_tensors.shape) - 1))
weighted_sum = paddle.sum(stacked_tensors * weights, axis=0)
return weighted_sum
else:
raise NotImplementedError("Paddle Tensor is not supported yet.")
raise ValueError(f"Unkonwn tensor type {self.merge_config.tensor_type}")

def slerp(self, tensor_list):
"""
Expand Down Expand Up @@ -85,17 +92,45 @@ def slerp(self, tensor_list):
s0 = np.sin(theta_0 - theta_t) / sin_theta_0
s1 = sin_theta_t / sin_theta_0

return s0 * t0_copy + s1 * t1_copy
elif self.merge_config.tensor_type == "pd":
t0, t1 = tensor_list
# Copy the tensors to reuse them later
t0_copy = t0.clone()
t1_copy = t1.clone()

# Normalize the tensors to get the directions and angles
t0 = self.normalize(t0)
t1 = self.normalize(t1)

# Dot product with the normalized tensors
dot = paddle.sum(t0 * t1)
# If absolute value of dot product is almost 1, vectors are ~colinear, so use lerp
if paddle.abs(dot) > self.merge_config.slerp_dot_threshold:
return (1 - self.merge_config.slerp_alpha) * t0_copy + self.merge_config.slerp_alpha * t1_copy

# Calculate initial angle between t0 and t1
theta_0 = paddle.acos(dot)
sin_theta_0 = paddle.sin(theta_0)

# Angle at timestep t
theta_t = theta_0 * self.merge_config.slerp_alpha
sin_theta_t = paddle.sin(theta_t)

# Finish the slerp algorithm
s0 = paddle.sin(theta_0 - theta_t) / sin_theta_0
s1 = sin_theta_t / sin_theta_0

return s0 * t0_copy + s1 * t1_copy
else:
raise NotImplementedError("Paddle Tensor is not supported yet.")
raise ValueError(f"Unkonwn tensor type {self.merge_config.tensor_type}")

def ties(self, tensor_list):
if self.merge_config.tensor_type == "np":
# Get weight tensor
mask_dtype = tensor_list[0].dtype
weight_list = self.merge_config.weight_list
tensor_list = [weight * tensor for (weight, tensor) in zip(weight_list, tensor_list)]

# Elect majority sign
sign_tensor_list = [np.sign(tensor).astype(mask_dtype) for tensor in tensor_list]
if self.merge_config.ties_elect_type == "sum":
Expand All @@ -117,14 +152,51 @@ def ties(self, tensor_list):
divisor[np.abs(divisor) < 1e-8] = 1
merge_tensor /= divisor
return merge_tensor

elif self.merge_config.tensor_type == "pd":
mask_dtype = tensor_list[0].dtype
weight_list = self.merge_config.weight_list
stacked_tensors = paddle.stack(tensor_list, axis=0)
weights = paddle.to_tensor(weight_list, dtype=stacked_tensors.dtype)
weights = weights.reshape([-1] + [1] * (len(stacked_tensors.shape) - 1))
weighted_tensors = stacked_tensors * weights
# Elect majority sign
if self.merge_config.ties_elect_type == "sum":
majority_sign = (paddle.sum(weighted_tensors, axis=0) >= 0).astype(mask_dtype) * 2 - 1
elif self.merge_config.ties_elect_type == "count":
stacked_signs = paddle.sign(stacked_tensors).astype(mask_dtype)
majority_sign = (paddle.sum(stacked_signs, axis=0) >= 0).astype(mask_dtype) * 2 - 1
else:
raise NotImplementedError(f"ties_elect_type: {self.merge_config.ties_elect_type} is unknown.")

# Merge
stacked_masks = (paddle.sign(weighted_tensors) == majority_sign).astype(mask_dtype)
masked_tensors = stacked_masks * weighted_tensors
merge_tensor = paddle.sum(masked_tensors, axis=0)
# Normalize
if self.merge_config.normalize:
weight_masks = stacked_masks * weights
divisor = paddle.sum(weight_masks, axis=0)
divisor = paddle.where(paddle.abs(divisor) < 1e-8, paddle.ones_like(divisor), divisor)
merge_tensor /= divisor

return merge_tensor
else:
raise NotImplementedError("Paddle Tensor is not supported yet.")
raise ValueError(f"Unkonwn tensor type {self.merge_config.tensor_type}")

def normalize(self, t):
"""
Normalize a vector by its L2 norm.
"""
norm_t = np.linalg.norm(t)
if norm_t > self.merge_config.slerp_normalize_eps:
t = t / norm_t
return t
if self.merge_config.tensor_type == "np":
norm_t = np.linalg.norm(t)
if norm_t > self.merge_config.slerp_normalize_eps:
t = t / norm_t
return t
elif self.merge_config.tensor_type == "pd":
norm_t = paddle.norm(t, p=2)
if norm_t > self.merge_config.slerp_normalize_eps:
t = t / norm_t
return t
else:
raise ValueError(f"Unkonwn tensor type {self.merge_config.tensor_type}")
Loading

0 comments on commit ac095f5

Please # to comment.