Skip to content

Commit

Permalink
Bump to v0.2.6
Browse files Browse the repository at this point in the history
  • Loading branch information
tridao committed Dec 28, 2022
1 parent 63670fd commit a6ec178
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion flash_attn/modules/mha.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,7 @@ def forward(self, x, x_kv=None, key_padding_mask=None, cu_seqlens=None, max_seql
kv = self._update_kv_cache(qkv[:, :, 1:], inference_params)
# If we're processing the prompt, causal=None (use self.causal).
# If we're decoding, then causal=False.
causal = False if inference_params.sequence_len_offset == 0 else None
causal = None if inference_params.sequence_len_offset == 0 else False
context = self.inner_cross_attn(q, kv, causal=causal)
else:
if not self.return_residual:
Expand Down
2 changes: 1 addition & 1 deletion flash_attn/utils/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def greedy_decode(input_ids, model, max_length):
inference_params.sequence_len_offset = seqlen_og
while True:
position_ids = torch.full((batch_size, 1), inference_params.sequence_len_offset,
dtype=torch.long, device=input_ids.device)
dtype=torch.long, device=input_ids.device)
logits = model(rearrange(next_token, 'b -> b 1'), position_ids=position_ids,
inference_params=inference_params).logits[:, -1]
scores.append(logits)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def append_nvcc_threads(nvcc_extra_args):

setup(
name="flash_attn",
version="0.2.5",
version="0.2.6-1",
packages=find_packages(
exclude=("build", "csrc", "include", "tests", "dist", "docs", "benchmarks", "flash_attn.egg-info",)
),
Expand Down

0 comments on commit a6ec178

Please # to comment.