Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

[ENH] Implements DoRA #790

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ Identifiers and configuration classes are explained in more detail in the [next
| `prefix_tuning_flat` | `PrefixTuningConfig(flat=True)` | [Prefix Tuning](methods.html#prefix-tuning) |
| `lora` | `LoRAConfig()` | [LoRA](methods.html#lora) |
| `ia3` | `IA3Config()` | [IA³](methods.html#ia-3) |
| `dora` | `DoRAConfig()` | [DoRA](methods.html#dora) |
| `mam` | `MAMConfig()` | [Mix-and-Match Adapters](method_combinations.html#mix-and-match-adapters) |
| `unipelt` | `UniPELTConfig()` | [UniPELT](method_combinations.html#unipelt) |
| `prompt_tuning` | `PromptTuningConfig()` | [Prompt Tuning](methods.html#prompt-tuning) |
Expand Down
2 changes: 2 additions & 0 deletions src/adapters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
"DynamicAdapterFusionConfig",
"IA3Config",
"LoRAConfig",
"DoRAConfig",
"LoReftConfig",
"MAMConfig",
"ModelAdaptersConfig",
Expand Down Expand Up @@ -169,6 +170,7 @@
DynamicAdapterFusionConfig,
IA3Config,
LoRAConfig,
DoRAConfig,
LoReftConfig,
MAMConfig,
ModelAdaptersConfig,
Expand Down
23 changes: 23 additions & 0 deletions src/adapters/configuration/adapter_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,6 +535,28 @@ class IA3Config(LoRAConfig):
dtype: Optional[str] = None


class DoRAConfig(LoRAConfig):
"""
The 'Weight-Decomposed Low-Rank Adaptation' DoRA method was preposed by Liu et al. (2022). See https://arxiv.org/pdf/2402.09353.
The DoRA method proposes that the weight matrix of a layer can be decomposed into magnitude and directional components, and finetunes both.
The directional component is then decomposed further via two LoRA low-rank matrices.
During training, the directional matrix is scaled to unit norm by the vector-wise columns of the matrix and multiplied
by the magnutude matrix to obtain the final weights.
"""

selfattn_lora: bool = True
intermediate_lora: bool = False
output_lora: bool = False

r: int = 8
alpha: int = 8
dropout: float = 0.0
composition_mode: str = "dora"
init_weights: str = "lora"
use_gating: bool = False
dtype: Optional[str] = None


@dataclass(eq=False)
class ReftConfig(AdapterConfig):
"""
Expand Down Expand Up @@ -770,6 +792,7 @@ def __init__(
"prompt_tuning": PromptTuningConfig(),
"lora": LoRAConfig(),
"ia3": IA3Config(),
"dora": DoRAConfig(),
"loreft": LoReftConfig(),
"noreft": NoReftConfig(),
"direft": DiReftConfig(),
Expand Down
75 changes: 75 additions & 0 deletions src/adapters/methods/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,79 @@ def forward(self, hidden_states: Optional[torch.Tensor], layer_input: torch.Tens
return hidden_states, gate


class DoRA(nn.Module):
def __init__(
self,
lora_A_shape,
lora_B_shape,
config: LoRAConfig,
gating_heads: int = 1,
):
super().__init__()
assert config.composition_mode == "dora", "DoRA module only supports composition_mode='dora'."
self.r = config.r
self.lora_alpha = config.alpha
self.composition_mode = config.composition_mode
self.attn_matrices = config.attn_matrices
self.use_gating = config.use_gating
# Optional dropout
if config.dropout > 0.0:
self.lora_dropout = nn.Dropout(p=config.dropout)
else:
self.lora_dropout = lambda x: x

dtype = getattr(torch, config.dtype) if config.dtype else None

# Actual trainable parameters
self.lora_A = nn.Parameter(torch.zeros(lora_A_shape, dtype=dtype))
self.lora_B = nn.Parameter(torch.zeros(lora_B_shape, dtype=dtype))
self.scaling = self.lora_alpha / self.r
self.m = nn.Parameter(torch.ones(lora_B_shape, dtype=dtype))

if config.init_weights == "lora":
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
nn.init.zeros_(self.lora_B)
elif config.init_weights == "bert":
nn.init.normal_(self.lora_A, std=0.02)
nn.init.normal_(self.lora_B, std=0.02)
elif config.init_weights == "ia3":
nn.init.ones_(self.lora_A)
nn.init.ones_(self.lora_B)
else:
raise ValueError("Unknown init_weights type: {}".format(config.init_weights))

if self.use_gating:
self.gate = nn.Linear(lora_A_shape[-1], gating_heads)
nn.init.normal_(self.gate.weight, std=0.02)

@property
def delta_w(self) -> torch.Tensor:
return self.lora_B @ self.lora_A

def com(self, weights: torch.Tensor, added: torch.Tensor, scaling=None) -> torch.Tensor:
"""Performs the composition operation between existing and injected weights."""
if scaling is None:
scaling = self.scaling
return weights * (added * scaling)

def com_inv(self, weights: torch.Tensor, added: torch.Tensor) -> torch.Tensor:
"""Inverts the composition operation between existing and injected weights."""
return weights / (added * self.scaling)

def forward(self, hidden_states: Optional[torch.Tensor], layer_input: torch.Tensor):
if hidden_states is None:
hidden_states = layer_input
hidden_states = self.lora_dropout(hidden_states) @ torch.t(self.lora_A) @ torch.t(self.lora_B)
if self.use_gating:
gate = torch.sigmoid(self.gate(layer_input))
gate = torch.mean(gate, dim=1).unsqueeze(-1)
hidden_states = hidden_states * gate
else:
gate = None

return hidden_states, gate


class LoRALayer(AdapterLayerBase):
adapter_modules_name = "loras"

Expand Down Expand Up @@ -213,6 +286,8 @@ def add_adapter(self, adapter_name: str, layer_idx: int) -> bool:
lora_cls = LoRA
elif lora_config.composition_mode == "scale":
lora_cls = IA3
elif lora_config.composition_mode == "dora":
lora_cls = DoRA
else:
raise ValueError(f"Unknown composition_mode: {lora_config.composition_mode}")
lora = lora_cls(
Expand Down
Loading