diff --git a/src/adapters/models/llama/modeling_llama.py b/src/adapters/models/llama/modeling_llama.py index 95d114224..9ff83f5c2 100644 --- a/src/adapters/models/llama/modeling_llama.py +++ b/src/adapters/models/llama/modeling_llama.py @@ -18,7 +18,6 @@ # See the License for the specific language governing permissions and # limitations under the License. """PyTorch LLaMA model.""" -import warnings from typing import Callable, Optional, Tuple import torch @@ -55,6 +54,7 @@ def forward( cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # >>> START AH Changes <<< bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) @@ -66,7 +66,6 @@ def forward( key_states = key_states.view(-1, q_len, self.config.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(-1, q_len, self.config.num_key_value_heads, self.head_dim).transpose(1, 2) - # >>> START AH Changes <<< query_states, key_states, value_states = match_attn_matrices_for_parallel( query_states, key_states, value_states ) @@ -127,28 +126,9 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45 - **kwargs, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): - attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, - query_sequence_length, key_sequence_length)` if default attention is used. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states - """ - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use" - " `attention_mask` instead.`" - ) residual = hidden_states