Skip to content

feat(fast): add sliding-window SDPA kernel path#3552

Open
rabbitson87 wants to merge 2 commits into
ml-explore:mainfrom
rabbitson87:sliding-window-attention-kernel
Open

feat(fast): add sliding-window SDPA kernel path#3552
rabbitson87 wants to merge 2 commits into
ml-explore:mainfrom
rabbitson87:sliding-window-attention-kernel

Conversation

@rabbitson87
Copy link
Copy Markdown

feat(fast::sdpa): sliding-window attention via has_window function constant + kb_start truncation

Summary

Adds optional window_size parameter to mlx::fast::scaled_dot_product_attention enabling sliding-window attention (each Q position attends only to the last window_size K positions, combined with causal masking). The implementation mirrors the existing do_causal upper-bound (kb_lim) pattern with a symmetric lower-bound (kb_start) in the steel attention kernel, skipping K-blocks below the window's lower edge wholesale.

Motivation

Sliding-window attention is used by Gemma 2, Gemma 3, Gemma 4, Mistral 7B, and several long-context Qwen variants. Currently mlx callers must construct an explicit [L_q, L_kv] bool mask Array to express the window pattern, which:

  1. Materializes the mask tensor (8 KB at L=1024 → 64 MB at L=8192).
  2. Forces the matmul fallback for head_dim ∉ {64, 80, 128} since the steel kernel's sdpa_full_supported_head_dim doesn't include 256/512.
  3. Materializes the full [B, H, L, L] scores tensor in the fallback path (1 GB bf16 at L=8192, B=1, H=8) → significant DRAM pressure.

With explicit window_size, the steel kernel can:

  • Compute valid K-block range [kb_start, kb_lim) per Q-tile and skip the rest entirely.
  • Apply per-element left-edge mask only for the first ceil(BQ/BK) blocks past kb_start (symmetric with the existing causal upper-edge mask).
  • Avoid mask Array allocation + scan entirely.

Performance (Gemma 4 26B-A4B, M3 Max, macOS 26.3.1)

Methodology: 3 timed trials per cell, alternating ON/OFF within context to
control thermal drift, 1 untimed warmup per cell, 30s cooldown between
trials, 60s between configs, 90s between contexts. STEPS=32 decode,
WARMUP=4 (bench-internal). Window size W=1024.

Prefill (single-pass, sliding-window layers exercised):

Context Baseline (array-mask path) Windowed kernel Δ K-block skip ratio
2K 986.5 ± 8.3 tok/s 1039.4 ± 11.5 tok/s +5.4% 50.0%
4K 863.7 ± 7.5 tok/s 997.6 ± 9.0 tok/s +15.5% 75.0%
8K 678.8 ± 0.6 tok/s 914.7 ± 1.1 tok/s +34.8% 87.5%

Decode (32 steps each, post-prefill autoregressive):

Context Baseline Windowed Δ
2K 49.0 ± 15.5 tok/s 49.9 ± 19.6 tok/s neutral (small-ctx timing noise)
4K 47.1 ± 10.1 tok/s 46.8 ± 9.7 tok/s neutral
8K 50.9 ± 1.2 tok/s 50.4 ± 1.5 tok/s neutral (tight std)

The prefill speedup tracks the theoretical K-block skip ratio
floor((L − W) / L) — at 8K, 87.5% of K-blocks below the window's lower
edge are skipped wholesale, yielding a 34.8% end-to-end prefill speedup
once the unchanged full-attention layers, MoE experts, and dense layers
are amortized in. The 8K measurement is essentially noise-free (baseline
σ = 0.6 tok/s, windowed σ = 1.1 tok/s over 3 trials), making the +34.8%
margin statistically unambiguous.

Decode is neutral at all contexts. The 2K/4K decode std is high because
total decode time at small ctx is short and timing jitter dominates per
trial; the 8K decode measurement (σ ≈ 1.3 tok/s on both arms) is the
canonical reading for decode neutrality.

Raw measurements: see "Reproducibility" section below.

Implementation

Kernel (steel_attention.h / steel_attention_nax.h)

constant bool has_window [[function_constant(303)]];

// Lower-bound K-block truncation, mirrors do_causal upper-bound:
int kb_start = 0;
if (has_window) {
    int q_min = tid.x * BQ + params->qL_off;
    int k_min = q_min - params->window_size + 1;
    if (k_min > 0) kb_start = min(kb_lim, k_min / BK);
}
for (int kb = kb_start; kb < kb_lim; kb++) { ... }

// Per-element left-edge mask (only first ceil(BQ/BK) blocks past kb_start):
if (has_window && kb < kb_start + ((BQ + BK - 1) / BK)) {
    // row_pos - col_pos >= W → mask = -inf
}

The window check uses K-relative coordinates (qL_off = kL - qL equals kv_offset - cache_first_held_pos for both rotated and non-rotated caches), so chunked prefill with rotating sliding-window cache works automatically.

Public API

// mlx/fast.h
MLX_API array scaled_dot_product_attention(
    const array& queries,
    const array& keys,
    const array& values,
    const float scale,
    const std::string& mask_mode = "",
    std::optional<array> mask_arr = {},
    const std::optional<array>& sinks = {},
    int window_size = 0,    // NEW (default 0 = no window)
    StreamOrDevice s = {});

use_fallback

window_size > 0 allows BD=256 path even when LUMEN_GEMMA4_PREFILL_FAST_BD256 env gate is not set — the window-aware kernel is faster than fallback on non-NAX hardware (87.5% K-block skip dominates over the lack of NAX tensor unit fragments).

Backward compatibility

  • Default window_size = 0 preserves all existing behavior. Function constant has_window = false → kernel takes original code path with kb_start = 0.
  • Python: mx.fast.scaled_dot_product_attention(...) gets a new optional window_size kwarg. No breakage to existing callers.
  • C++: signature change is parameter addition with default, ABI-compatible at the C++ level.
  • C ABI: mlx_fast_scaled_dot_product_attention passes window_size=0 internally; no signature change required for C consumers.

Test coverage

The implementation has been validated against Gemma 4 26B-A4B on M3 Max:

  • Functional smoke (chat completions produce coherent output).
  • Multi-trial perf A/B confirming the speedup at 2K/4K/8K.
  • Chunked-prefill rotation case (chunk_size=1024, prompt_len=4096, 4 chunks) — kernel produces valid outputs via K-relative coordinate math.

Outstanding before merge (would appreciate maintainer feedback on test expectations):

  • Bit-identical output vs the manually-constructed sliding mask Array path. In our measurement bf16 accumulation order differs between kernel and array-mask path → greedy argmax can flip when top-K logits are close. Both are mathematically valid sliding attention. Maintainers may want to add tolerance-based tests; we did not include explicit unit tests for this.
  • D ∈ {64, 80, 128, 256} smoke (only D=256 directly tested; the other BD instantiations should behave the same — same kernel code, different specialization).
  • NAX path verification (kernel was updated symmetrically but tested only on non-NAX M3 Max).
  • VJP support — the windowed forward is supported; backward (ScaledDotProductAttentionVJP) was not extended. Probably needs same treatment but it's out of scope for inference-only use.

Notes for reviewers

  • The is_equivalent extension to compare window_size_ is required for correctness — primitives with different windows must not be deduplicated by mlx's graph optimizer. Independent clean A/B (3 trials, alternating, 8K) confirms decode throughput is unchanged (50.9 → 50.4 tok/s with σ ≈ 1.3 on both arms) with windowed kernel ON. An initial -22% decode reading turned out to be measurement variance at low trial count.
  • Two minor host-side helpers (hash_name extended with _has_window_, fallback lambda extended for sliding mask construction) — included for completeness.
  • BD=512 instantiation was independently attempted (BQ=16, BK=8, WM=2, padQ=0 conditional) to unlock head_dim=512 full-attention paths. It builds and runs functionally but A/B showed -25% prefill regression on M3 Max (non-NAX) — the WM=2 / padQ=0 trade-offs vs matmul fallback's tuned gemm dispatch are net negative. Reverted, not included in this PR. May be worth revisiting once NAX-capable M5+ hardware is available.

Files changed

mlx/backend/metal/kernels/steel/attn/params.h                    +2
mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h   +47
mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +46
mlx/backend/metal/scaled_dot_product_attention.cpp               +~30  (only sliding-window related)
mlx/backend/cuda/scaled_dot_product_attention.cpp                +4    (fallback for unsupported windowed path)
mlx/backend/no_gpu/primitives.cpp                                +1    (signature sync)
mlx/fast.h                                                       +9
mlx/fast.cpp                                                     +~65  (only sliding-window related)
mlx/fast_primitives.h                                            +15
python/src/fast.cpp                                              +~30  (Python kwarg binding/docs)

Total ~200 lines of sliding-window-only changes.

Reproducibility

The performance numbers above were collected with the following protocol on
a quiet M3 Max (no other GPU consumers, mains power):

hardware:    M3 Max (40-core GPU)
os:          macOS 26.3.1
model:       Gemma 4 26B-A4B Q4 (lumen-rs native backend, mlx ~main)
window_size: 1024 (Gemma-4 sliding layers)
bench:       lumen-rs `bench_gemma4_native_e2e`
decode:      STEPS=32, WARMUP=4
arms:        windowed ON (default)   vs   windowed OFF (LUMEN_GEMMA4_SDPA_WINDOWED=0)
trials/cell: 3 (timed) + 1 (untimed warmup)
order:       alternating ON / OFF within each context (thermal-drift control)
cooldowns:   30s between trials, 60s between configs, 90s between contexts

Raw per-trial readings (prefill tok/s | decode tok/s):

ctx=2K  ON  T1: 1032.8 | 64.6     OFF T1:  998.2 | 27.2
ctx=2K  ON  T2: 1055.6 | 22.1     OFF T2:  981.6 | 58.1
ctx=2K  ON  T3: 1029.9 | 62.9     OFF T3:  979.6 | 61.7
ctx=4K  ON  T1:  991.1 | 53.4     OFF T1:  857.4 | 53.8
ctx=4K  ON  T2:  991.4 | 53.8     OFF T2:  859.6 | 54.7
ctx=4K  ON  T3: 1010.4 | 33.1     OFF T3:  874.2 | 32.8
ctx=8K  ON  T1:  915.9 | 48.2     OFF T1:  678.6 | 49.2
ctx=8K  ON  T2:  915.0 | 51.5     OFF T2:  678.1 | 51.8
ctx=8K  ON  T3:  913.2 | 51.4     OFF T3:  679.6 | 51.6

The 8K prefill measurement is the tightest signal in the dataset
(σ = 0.6 / 1.1 tok/s on the two arms) and is the load-bearing evidence
for the kernel's prefill speedup. The smaller-context numbers carry more
timing noise but show consistent directionality and align with the
predicted K-block skip ratio (50% / 75% / 87.5% at 2K / 4K / 8K).

@angeloskath angeloskath requested a review from jagrit06 May 16, 2026 00:27

if (has_window) {
int q_min = tid.x * BQ + params->qL_off;
int k_min = q_min - params->window_size + 1;
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the + 1 here needed ?
For given (adjusted) query index q, the window should attend to keys [q - window_size, q] inclusive - meaning the first k would be just q_min - params->window_size to match the logic

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the + 1 is intentional with the current definition of window_size.

I interpreted window_size = W as the number of visible positions in the left window, including the current query position. So the valid key range is:

[q - W + 1, q]
For example, with window_size = 2, query q should attend to {q - 1, q}, not {q - 2, q}. Without the + 1, the window would include W + 1 positions.

That said, if we want window_size to mean “number of previous keys in addition to the current key”, then your suggested formula would be correct. I’ll make sure the docs and implementation use the same convention.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, there seems to be a gap between between the lower level implementation providers (flash attention and Jax sdpa) vs the upper level implementation use (hugging face)
See below: https://github.com/huggingface/transformers/blob/52b82b299171721fbe7b04fe056187f7aed2e2cc/src/transformers/modeling_flash_attention_utils.py#L627

Screenshot 2026-05-20 at 1 46 57 PM

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This decision comes down to what convention we wish to follow, give me a bit to get back to you on that front

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good, thanks for checking.

Just to make the current PR intent explicit: the implementation is currently following the HF/model-level convention where window_size=W means a total causal window of W visible positions including the current token, so the valid range is [q - W + 1, q].

I agree the lower-level provider convention is different: FlashAttention/JAX-style APIs expose left/right window sizes with inclusive bounds, where a left window of W plus right 0 would cover W previous keys plus the current key.

I’ll hold off on further changes until MLX decides which convention the public API should follow. If MLX wants provider-style semantics, I can update the kernel, fallback mask, docs, and tests to use [q - W, q]; if MLX wants HF/model-style semantics, the current + 1 should stay and I’ll make that convention explicit in the docs/tests.

@jagrit06
Copy link
Copy Markdown
Member

From what I have seen, FlashAttention/pytorch implement the default window size as -1 which means an infinite window length
Granted, they do so because they allow for left and right side windows, in which case (ws, 0) or (0, ws) are valid inputs
As implemented, 0 makes sense to mean no window, but I think we should stick to convention and let -1 mean invite window i.e: No local masking
@angeloskath any opinions ?

Comment thread mlx/fast.cpp
Comment on lines +763 to +768
if (window_size > 0) {
auto window_mask =
less(q_idx, add(k_idx, array(window_size, k_idx.dtype()), s), s);
if (!has_arr_mask) {
return window_mask;
}
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This window mask (in the case of non-causal attention) doesn't stop at the right end of the window
For example, if we have a window size of 2, and seq length of 5, then
At q_idx = 3, k_idx = 4, the mask above would be calculated as (3 < (4 + 2)) = true, but the left window stops at 3
Is this intended ?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You’re right, that is not intended.

The current non-causal fallback only applies the left-window condition:

q_idx < k_idx + window_size
but it does not apply the right bound:

k_idx <= q_idx
So in your example, q_idx = 3, k_idx = 4, and window_size = 2 would incorrectly pass the mask. The intended left-window range is [q - window_size + 1, q], so k_idx = 4 should be masked out.

I’ll fix this by adding the missing right-bound check, or alternatively restrict window_size to causal attention only if that better matches the intended API.

@rabbitson87
Copy link
Copy Markdown
Author

From what I have seen, FlashAttention/pytorch implement the default window size as -1 which means an infinite window length Granted, they do so because they allow for left and right side windows, in which case (ws, 0) or (0, ws) are valid inputs As implemented, 0 makes sense to mean no window, but I think we should stick to convention and let -1 mean invite window i.e: No local masking @angeloskath any opinions ?

Agreed, that convention makes sense.

I updated the PR so window_size = -1 is now the default and means no local masking / infinite window. window_size = 0 is still accepted as no-window as well, to keep the implementation behavior permissive, while values < -1 are rejected.

This is in the follow-up commit: 77eb21d.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants