Skip to content

Commit

Permalink
Update comments in Llama
Browse files Browse the repository at this point in the history
  • Loading branch information
TimoImhof committed Feb 24, 2025
1 parent ff3ea80 commit 2394479
Showing 1 changed file with 3 additions and 23 deletions.
26 changes: 3 additions & 23 deletions src/adapters/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch LLaMA model."""
import warnings
from typing import Callable, Optional, Tuple

import torch
Expand Down Expand Up @@ -55,6 +54,7 @@ def forward(
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
# >>> START AH Changes <<<
bsz, q_len, _ = hidden_states.size()

query_states = self.q_proj(hidden_states)
Expand All @@ -66,7 +66,6 @@ def forward(
key_states = key_states.view(-1, q_len, self.config.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(-1, q_len, self.config.num_key_value_heads, self.head_dim).transpose(1, 2)

# >>> START AH Changes <<<
query_states, key_states, value_states = match_attn_matrices_for_parallel(
query_states, key_states, value_states
)
Expand Down Expand Up @@ -127,28 +126,9 @@ def forward(
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
**kwargs,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
**kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*):
attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
query_sequence_length, key_sequence_length)` if default attention is used.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
"""
if "padding_mask" in kwargs:
warnings.warn(
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use"
" `attention_mask` instead.`"
)

residual = hidden_states

Expand Down

0 comments on commit 2394479

Please # to comment.