diff --git a/examples/models/llama/static_attention.py b/examples/models/llama/static_attention.py index f97873ce646..ccb5bf3403b 100644 --- a/examples/models/llama/static_attention.py +++ b/examples/models/llama/static_attention.py @@ -1043,10 +1043,17 @@ def _forward_mha( mask = None masks = kwargs.get("masks") + cache_len = k.size(-2) - seq_len if masks: - cache_len = k.size(-2) - seq_len mask = masks[cache_len] + prev_tokens_masks = kwargs.get("prev_tokens_masks") + new_tokens_masks = kwargs.get("new_tokens_masks") + if prev_tokens_masks is not None and new_tokens_masks is not None: + prev_mask = prev_tokens_masks[cache_len] + new_mask = new_tokens_masks[cache_len] + mask = torch.cat([prev_mask, new_mask], dim=-1).contiguous() + if not self.decompose_sdpa_in_mha: if self.n_rep > 1: k = k.repeat_interleave(self.n_rep, dim=1)