feat(fast): add sliding-window SDPA kernel path#3552
Conversation
|
|
||
| if (has_window) { | ||
| int q_min = tid.x * BQ + params->qL_off; | ||
| int k_min = q_min - params->window_size + 1; |
There was a problem hiding this comment.
Is the + 1 here needed ?
For given (adjusted) query index q, the window should attend to keys [q - window_size, q] inclusive - meaning the first k would be just q_min - params->window_size to match the logic
There was a problem hiding this comment.
Yes, the + 1 is intentional with the current definition of window_size.
I interpreted window_size = W as the number of visible positions in the left window, including the current query position. So the valid key range is:
[q - W + 1, q]
For example, with window_size = 2, query q should attend to {q - 1, q}, not {q - 2, q}. Without the + 1, the window would include W + 1 positions.
That said, if we want window_size to mean “number of previous keys in addition to the current key”, then your suggested formula would be correct. I’ll make sure the docs and implementation use the same convention.
There was a problem hiding this comment.
Good point, there seems to be a gap between between the lower level implementation providers (flash attention and Jax sdpa) vs the upper level implementation use (hugging face)
See below: https://github.com/huggingface/transformers/blob/52b82b299171721fbe7b04fe056187f7aed2e2cc/src/transformers/modeling_flash_attention_utils.py#L627
There was a problem hiding this comment.
This decision comes down to what convention we wish to follow, give me a bit to get back to you on that front
There was a problem hiding this comment.
Sounds good, thanks for checking.
Just to make the current PR intent explicit: the implementation is currently following the HF/model-level convention where window_size=W means a total causal window of W visible positions including the current token, so the valid range is [q - W + 1, q].
I agree the lower-level provider convention is different: FlashAttention/JAX-style APIs expose left/right window sizes with inclusive bounds, where a left window of W plus right 0 would cover W previous keys plus the current key.
I’ll hold off on further changes until MLX decides which convention the public API should follow. If MLX wants provider-style semantics, I can update the kernel, fallback mask, docs, and tests to use [q - W, q]; if MLX wants HF/model-style semantics, the current + 1 should stay and I’ll make that convention explicit in the docs/tests.
|
From what I have seen, FlashAttention/pytorch implement the default window size as |
| if (window_size > 0) { | ||
| auto window_mask = | ||
| less(q_idx, add(k_idx, array(window_size, k_idx.dtype()), s), s); | ||
| if (!has_arr_mask) { | ||
| return window_mask; | ||
| } |
There was a problem hiding this comment.
This window mask (in the case of non-causal attention) doesn't stop at the right end of the window
For example, if we have a window size of 2, and seq length of 5, then
At q_idx = 3, k_idx = 4, the mask above would be calculated as (3 < (4 + 2)) = true, but the left window stops at 3
Is this intended ?
There was a problem hiding this comment.
You’re right, that is not intended.
The current non-causal fallback only applies the left-window condition:
q_idx < k_idx + window_size
but it does not apply the right bound:
k_idx <= q_idx
So in your example, q_idx = 3, k_idx = 4, and window_size = 2 would incorrectly pass the mask. The intended left-window range is [q - window_size + 1, q], so k_idx = 4 should be masked out.
I’ll fix this by adding the missing right-bound check, or alternatively restrict window_size to causal attention only if that better matches the intended API.
Agreed, that convention makes sense. I updated the PR so This is in the follow-up commit: 77eb21d. |
feat(fast::sdpa): sliding-window attention via has_window function constant + kb_start truncation
Summary
Adds optional
window_sizeparameter tomlx::fast::scaled_dot_product_attentionenabling sliding-window attention (each Q position attends only to the lastwindow_sizeK positions, combined with causal masking). The implementation mirrors the existingdo_causalupper-bound (kb_lim) pattern with a symmetric lower-bound (kb_start) in the steel attention kernel, skipping K-blocks below the window's lower edge wholesale.Motivation
Sliding-window attention is used by Gemma 2, Gemma 3, Gemma 4, Mistral 7B, and several long-context Qwen variants. Currently mlx callers must construct an explicit
[L_q, L_kv]bool mask Array to express the window pattern, which:head_dim ∉ {64, 80, 128}since the steel kernel'ssdpa_full_supported_head_dimdoesn't include 256/512.[B, H, L, L]scores tensor in the fallback path (1 GB bf16 at L=8192, B=1, H=8) → significant DRAM pressure.With explicit
window_size, the steel kernel can:[kb_start, kb_lim)per Q-tile and skip the rest entirely.ceil(BQ/BK)blocks pastkb_start(symmetric with the existing causal upper-edge mask).Performance (Gemma 4 26B-A4B, M3 Max, macOS 26.3.1)
Methodology: 3 timed trials per cell, alternating ON/OFF within context to
control thermal drift, 1 untimed warmup per cell, 30s cooldown between
trials, 60s between configs, 90s between contexts. STEPS=32 decode,
WARMUP=4 (bench-internal). Window size W=1024.
Prefill (single-pass, sliding-window layers exercised):
Decode (32 steps each, post-prefill autoregressive):
The prefill speedup tracks the theoretical K-block skip ratio
floor((L − W) / L)— at 8K, 87.5% of K-blocks below the window's loweredge are skipped wholesale, yielding a 34.8% end-to-end prefill speedup
once the unchanged full-attention layers, MoE experts, and dense layers
are amortized in. The 8K measurement is essentially noise-free (baseline
σ = 0.6 tok/s, windowed σ = 1.1 tok/s over 3 trials), making the +34.8%
margin statistically unambiguous.
Decode is neutral at all contexts. The 2K/4K decode std is high because
total decode time at small ctx is short and timing jitter dominates per
trial; the 8K decode measurement (σ ≈ 1.3 tok/s on both arms) is the
canonical reading for decode neutrality.
Raw measurements: see "Reproducibility" section below.
Implementation
Kernel (
steel_attention.h/steel_attention_nax.h)The window check uses K-relative coordinates (
qL_off = kL - qLequalskv_offset - cache_first_held_posfor both rotated and non-rotated caches), so chunked prefill with rotating sliding-window cache works automatically.Public API
use_fallback
window_size > 0allows BD=256 path even whenLUMEN_GEMMA4_PREFILL_FAST_BD256env gate is not set — the window-aware kernel is faster than fallback on non-NAX hardware (87.5% K-block skip dominates over the lack of NAX tensor unit fragments).Backward compatibility
window_size = 0preserves all existing behavior. Function constanthas_window = false→ kernel takes original code path withkb_start = 0.mx.fast.scaled_dot_product_attention(...)gets a new optionalwindow_sizekwarg. No breakage to existing callers.mlx_fast_scaled_dot_product_attentionpasseswindow_size=0internally; no signature change required for C consumers.Test coverage
The implementation has been validated against Gemma 4 26B-A4B on M3 Max:
Outstanding before merge (would appreciate maintainer feedback on test expectations):
ScaledDotProductAttentionVJP) was not extended. Probably needs same treatment but it's out of scope for inference-only use.Notes for reviewers
is_equivalentextension to comparewindow_size_is required for correctness — primitives with different windows must not be deduplicated by mlx's graph optimizer. Independent clean A/B (3 trials, alternating, 8K) confirms decode throughput is unchanged (50.9 → 50.4 tok/s with σ ≈ 1.3 on both arms) with windowed kernel ON. An initial -22% decode reading turned out to be measurement variance at low trial count.hash_nameextended with_has_window_, fallback lambda extended for sliding mask construction) — included for completeness.Files changed
Total ~200 lines of sliding-window-only changes.
Reproducibility
The performance numbers above were collected with the following protocol on
a quiet M3 Max (no other GPU consumers, mains power):
Raw per-trial readings (prefill tok/s | decode tok/s):
The 8K prefill measurement is the tightest signal in the dataset
(σ = 0.6 / 1.1 tok/s on the two arms) and is the load-bearing evidence
for the kernel's prefill speedup. The smaller-context numbers carry more
timing noise but show consistent directionality and align with the
predicted K-block skip ratio (50% / 75% / 87.5% at 2K / 4K / 8K).