Skip to content

Commit

Permalink
Compute-graph based mergekit-extract-lora (#505)
Browse files Browse the repository at this point in the history
Now with better embedding handling, multi-gpu execution, and lazy
loading/saving of tensors.

When extracting a LoRA from an 8B model, execution time goes from ~6
minutes down to 40 seconds with `--cuda --multi-gpu` on an 8-GPU
machine.

Additionally, the `--sv-epsilon` flag can be used to set a tolerance for
singular values to opportunistically reduce rank when the fine tuned
difference is inherently lower rank.

Also reimplement a couple of merge methods using the `@easy_define`
decorator and add some missing tests.
  • Loading branch information
cg123 authored Feb 7, 2025
1 parent 86c30b6 commit a2dda31
Show file tree
Hide file tree
Showing 14 changed files with 793 additions and 1,148 deletions.
45 changes: 22 additions & 23 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
`mergekit` is a toolkit for merging pre-trained language models. `mergekit` uses an out-of-core approach to perform unreasonably elaborate merges in resource-constrained situations. Merges can be run entirely on CPU or accelerated with as little as 8 GB of VRAM. Many merging algorithms are supported, with more coming as they catch my attention.

## Contents

- [Why Merge Models?](#why-merge-models)
- [Features](#features)
- [Installation](#installation)
Expand Down Expand Up @@ -240,21 +239,21 @@ A quick overview of the currently supported merge methods:

| Method | `merge_method` value | Multi-Model | Uses base model |
| ------------------------------------------------------------------------------------------------ | -------------------- | ----------- | --------------- |
| Linear ([Model Soups](https://arxiv.org/abs/2203.05482)) | `linear` | ✅ | ❌ |
| SLERP | `slerp` | ❌ | ✅ |
| Nearswap | `nearswap` | ❌ | ✅ |
| [Task Arithmetic](https://arxiv.org/abs/2212.04089) | `task_arithmetic` | ✅ | ✅ |
| [TIES](https://arxiv.org/abs/2306.01708) | `ties` | ✅ | ✅ |
| [DARE](https://arxiv.org/abs/2311.03099) [TIES](https://arxiv.org/abs/2306.01708) | `dare_ties` | ✅ | ✅ |
| [DARE](https://arxiv.org/abs/2311.03099) [Task Arithmetic](https://arxiv.org/abs/2212.04089) | `dare_linear` | ✅ | ✅ |
| Passthrough | `passthrough` | ❌ | ❌ |
| [Model Breadcrumbs](https://arxiv.org/abs/2312.06795) | `breadcrumbs` | ✅ | ✅ |
| [Model Breadcrumbs](https://arxiv.org/abs/2312.06795) + [TIES](https://arxiv.org/abs/2306.01708) | `breadcrumbs_ties` | ✅ | ✅ |
| [Model Stock](https://arxiv.org/abs/2403.19522) | `model_stock` | ✅ | ✅ |
| NuSLERP | `nuslerp` | ❌ | ✅ |
| [DELLA](https://arxiv.org/abs/2406.11617) | `della` | ✅ | ✅ |
| [DELLA](https://arxiv.org/abs/2406.11617) [Task Arithmetic](https://arxiv.org/abs/2212.04089) | `della_linear` | ✅ | ✅ |
| [SCE](https://arxiv.org/abs/2408.07990) | `sce` | ✅ | ✅ |
| Linear ([Model Soups](https://arxiv.org/abs/2203.05482)) | `linear` | ✅ | ❌ |
| SLERP | `slerp` | ❌ | ✅ |
| Nearswap | `nearswap` | ❌ | ✅ |
| [Task Arithmetic](https://arxiv.org/abs/2212.04089) | `task_arithmetic` | ✅ | ✅ |
| [TIES](https://arxiv.org/abs/2306.01708) | `ties` | ✅ | ✅ |
| [DARE](https://arxiv.org/abs/2311.03099) [TIES](https://arxiv.org/abs/2306.01708) | `dare_ties` | ✅ | ✅ |
| [DARE](https://arxiv.org/abs/2311.03099) [Task Arithmetic](https://arxiv.org/abs/2212.04089) | `dare_linear` | ✅ | ✅ |
| Passthrough | `passthrough` | ❌ | ❌ |
| [Model Breadcrumbs](https://arxiv.org/abs/2312.06795) | `breadcrumbs` | ✅ | ✅ |
| [Model Breadcrumbs](https://arxiv.org/abs/2312.06795) + [TIES](https://arxiv.org/abs/2306.01708) | `breadcrumbs_ties` | ✅ | ✅ |
| [Model Stock](https://arxiv.org/abs/2403.19522) | `model_stock` | ✅ | ✅ |
| NuSLERP | `nuslerp` | ❌ | ✅ |
| [DELLA](https://arxiv.org/abs/2406.11617) | `della` | ✅ | ✅ |
| [DELLA](https://arxiv.org/abs/2406.11617) [Task Arithmetic](https://arxiv.org/abs/2212.04089) | `della_linear` | ✅ | ✅ |
| [SCE](https://arxiv.org/abs/2408.07990) | `sce` | ✅ | ✅ |

### Linear

Expand Down Expand Up @@ -285,13 +284,14 @@ Parameters:

Computes "task vectors" for each model by subtracting a base model. Merges the task vectors linearly and adds back the base. Works great for models that were fine tuned from a common ancestor. Also a super useful mental framework for several of the more involved merge methods.

Parameters: same as [Linear](#linear)
Parameters: same as [Linear](#linear), plus:
- `lambda` - scaling factor applied after weighted sum of task vectors

### [TIES](https://arxiv.org/abs/2306.01708)

Builds on the task arithmetic framework. Resolves interference between models by sparsifying the task vectors and applying a sign consensus algorithm. Allows you to merge a larger number of models and retain more of their strengths.

Parameters: same as [Linear](#linear), plus:
Parameters: same as [Task Arithmetic](#task-arithmetic), plus:

- `density` - fraction of weights in differences from the base model to retain

Expand All @@ -309,7 +309,7 @@ Parameters: same as [TIES](#ties) for `dare_ties`, or [Linear](#linear) for `dar

An extension of task arithmetic that discards both small and extremely large differences from the base model. As with DARE, the Model Breadcrumbs algorithm can be used with (`breadcrumbs_ties`) or without (`breadcrumbs`) the sign consensus algorithm of TIES.

Parameters: same as [Linear](#linear), plus:
Parameters: same as [Task Arithmetic](#task-arithmetic), plus:

- `density` - fraction of weights in differences from the base model to retain
- `gamma` - fraction of largest magnitude differences to remove
Expand Down Expand Up @@ -340,17 +340,16 @@ To replicate the behavior of the original `slerp` method, set `weight` to `1-t`

Building upon DARE, DELLA uses adaptive pruning based on parameter magnitudes. DELLA first ranks parameters in each row of delta parameters and assigns drop probabilities inversely proportional to their magnitudes. This allows it to retain more important changes while reducing interference. After pruning, it rescales the remaining parameters similar to [DARE](#dare). DELLA can be used with (`della`) or without (`della_linear`) the sign elect step of TIES

Parameters: same as [Linear](#linear), plus:
Parameters: same as [Task Arithmetic](#task-arithmetic), plus:

- `density` - fraction of weights in differences from the base model to retain
- `epsilon` - maximum change in drop probability based on magnitude. Drop probabilities assigned will range from `density - epsilon` to `density + epsilon`. (When selecting values for `density` and `epsilon`, ensure that the range of probabilities falls within 0 to 1)
- `lambda` - scaling factor for the final merged delta parameters before merging with the base parameters.

### [SCE](https://arxiv.org/abs/2408.07990)

SCE introduces adaptive matrix-level merging weights based on parameter variances. SCE first selects the top-k% elements from each parameter matrix that exhibit high variance across all delta parameters. Following this selection, SCE calculates matrix-level merging weights based on the sum of squares of elements in the delta parameters. Finally, it erases minority elements, a step similar to the sign election process in TIES.

Parameters:
Parameters: same as [TIES](#ties), plus:

- `select_topk` - fraction of elements with the highest variance in the delta parameters to retain.

Expand All @@ -361,7 +360,7 @@ Mergekit allows extracting PEFT-compatible low-rank approximations of finetuned
### Usage

```sh
mergekit-extract-lora finetuned_model_id_or_path base_model_id_or_path output_path [--no-lazy-unpickle] --rank=desired_rank
mergekit-extract-lora --model finetuned_model_id_or_path --base-model base_model_id_or_path --out-path output_path [--no-lazy-unpickle] [--cuda] [--max-rank=desired_rank] [--sv-epsilon=tol]
```

## Mixture of Experts merging
Expand Down
45 changes: 16 additions & 29 deletions mergekit/card.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# Copyright (C) 2025 Arcee AI
# SPDX-License-Identifier: BUSL-1.1

import logging
import os
from typing import Generator, List, Optional

Expand Down Expand Up @@ -187,49 +186,37 @@ def generate_card(


def generate_card_lora(
base_model_ref: ModelReference,
finetuned_model_ref: ModelReference,
base_ref: ModelReference,
finetuned_ref: ModelReference,
invocation: str,
extended: bool,
vocab_size: int,
name: str,
base_vocab_size: Optional[int] = None,
final_vocab_size: Optional[int] = None,
) -> str:
"""
Generates a markdown card for a merged model configuration.
Args:
config: A MergeConfiguration object.
config_yaml: YAML source text of the config.
name: An optional name for the model.
"""
if not name:
name = "Untitled LoRA Model (1)"

hf_bases = list(extract_hf_paths([base_model_ref, finetuned_model_ref]))
hf_bases = list(extract_hf_paths([base_ref, finetuned_ref]))
tags = ["mergekit", "peft"]

finetuned_ref_md = modelref_md(finetuned_model_ref)
basemodel_ref_md = modelref_md(base_model_ref)

details = f"This LoRA adapter was extracted from {finetuned_ref_md} and uses {basemodel_ref_md} as a base."

if extended:
details += f"\n\n> [!WARNING]\n> This LoRA adapter has an extended vocabulary. Make sure to call `model.resize_token_embeddings({vocab_size})` before applying the adapter to {basemodel_ref_md}"
details = (
f"This LoRA adapter was extracted from {modelref_md(finetuned_ref)} "
f"and uses {modelref_md(base_ref)} as a base."
)

if os.path.isdir(base_model_ref.model.path) or os.path.isdir(
finetuned_model_ref.model.path
):
logging.warning(
"Some model identifiers you provided are directory paths and will appear as such in the model card, you may want to edit it."
if base_vocab_size and final_vocab_size and base_vocab_size != final_vocab_size:
verb = "extended" if final_vocab_size > base_vocab_size else "reduced"
details += (
f"\n\n [!WARNING]\n> The vocabulary size has been {verb} from the base "
f"model's {base_vocab_size} to {final_vocab_size}. To load this adapter, "
f"you must first call `model.resize_token_embeddings({final_vocab_size})`."
)

return CARD_TEMPLATE_LORA.format(
metadata=yaml.dump(
{"base_model": hf_bases, "tags": tags, "library_name": "transformers"}
{"base_model": hf_bases, "tags": tags, "library_name": "peft"}
),
name=name,
details=details,
base_model=base_model_ref.model.path,
finetuned_model=finetuned_model_ref.model.path,
invocation=invocation,
)
41 changes: 23 additions & 18 deletions mergekit/io/lazy_tensor_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import logging
import os
import os.path
import threading
from dataclasses import dataclass
from typing import Dict, List, Optional

Expand Down Expand Up @@ -106,11 +107,13 @@ class LazyTensorLoader:
index: ShardedTensorIndex
current_shard: Optional[TensorLoader]
lazy_unpickle: bool
lock: threading.Lock

def __init__(self, index: ShardedTensorIndex, lazy_unpickle: bool = True):
self.index = index
self.current_shard = None
self.lazy_unpickle = lazy_unpickle
self.lock = threading.Lock()

def get_tensor(
self,
Expand All @@ -125,27 +128,29 @@ def get_tensor(
key = alias
break

if self.current_shard is None or key not in self.current_shard.keys():
if key not in self.index.tensor_paths:
if raise_on_missing:
raise KeyError(key)
return None

self.current_shard = None
self.current_keys = None

shard_file = self.index.tensor_paths[key]
shard_full_path = os.path.join(self.index.base_path, shard_file)
logging.debug(f"Opening shard {shard_full_path}")
self.current_shard = TensorLoader.get(
shard_full_path, use_lazy_unpickle=self.lazy_unpickle, device=device
)
with self.lock:
if self.current_shard is None or key not in self.current_shard.keys():
if key not in self.index.tensor_paths:
if raise_on_missing:
raise KeyError(key)
return None

self.current_shard = None
self.current_keys = None

shard_file = self.index.tensor_paths[key]
shard_full_path = os.path.join(self.index.base_path, shard_file)
logging.debug(f"Opening shard {shard_full_path}")
self.current_shard = TensorLoader.get(
shard_full_path, use_lazy_unpickle=self.lazy_unpickle, device=device
)

return self.current_shard.get_tensor(key).to(device)
return self.current_shard.get_tensor(key).to(device)

def flush(self):
self.current_shard = None
self.current_keys = None
with self.lock:
self.current_shard = None
self.current_keys = None

@classmethod
def from_disk(
Expand Down
2 changes: 2 additions & 0 deletions mergekit/io/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ class TensorWriterTask(Task[TensorWriter]):
out_path: str
max_shard_size: int
safe_serialization: bool = True
override_basename: Optional[str] = None

def arguments(self) -> Dict[str, Task]:
return {}
Expand All @@ -160,6 +161,7 @@ def execute(self, **_kwargs) -> TensorWriter:
self.out_path,
max_shard_size=self.max_shard_size,
safe_serialization=self.safe_serialization,
override_basename=self.override_basename,
)

def main_thread_only(self):
Expand Down
10 changes: 9 additions & 1 deletion mergekit/io/tensor_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import logging
import os
import threading
from typing import Dict
from typing import Dict, Optional

import safetensors
import torch
Expand All @@ -15,6 +15,7 @@

class TensorWriter:
out_path: str
override_basename: Optional[str]
max_shard_size: int
shards_written: int
weight_map = Dict[str, str]
Expand All @@ -28,10 +29,12 @@ def __init__(
out_path: str,
max_shard_size: int = 1000 * 1000 * 1000 * 5,
safe_serialization: bool = True,
override_basename: Optional[str] = None,
) -> None:
os.makedirs(out_path, exist_ok=True)

self.out_path = out_path
self.override_basename = override_basename
self.max_shard_size = max_shard_size
self.safe_serialization = safe_serialization
self.shards_written = 0
Expand All @@ -50,6 +53,7 @@ def save_tensor(self, name: str, tensor: torch.Tensor, clone: bool = False):
with self.lock:
if (
self.current_shard
and self.max_shard_size >= 0
and self.current_shard_size + tensor_size > self.max_shard_size
):
self._flush_current_shard()
Expand Down Expand Up @@ -126,6 +130,10 @@ def finalize(self):
)

def _get_name_components(self):
if self.override_basename:
return self.override_basename, (
"safetensors" if self.safe_serialization else "bin"
)
if self.safe_serialization:
return "model", "safetensors"
return "pytorch_model", "bin"
Expand Down
2 changes: 2 additions & 0 deletions mergekit/merge_methods/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
# SPDX-License-Identifier: BUSL-1.1

import mergekit.merge_methods.multislerp
import mergekit.merge_methods.nearswap
import mergekit.merge_methods.sce
from mergekit.merge_methods.base import MergeMethod
from mergekit.merge_methods.generalized_task_arithmetic import (
GeneralizedTaskArithmeticMerge,
Expand Down
Loading

0 comments on commit a2dda31

Please # to comment.