Skip to content

FFT: fix grid and mini-batch#87

Open
AntonOresten wants to merge 1 commit into
NVIDIA:mainfrom
AntonOresten:fft-batch
Open

FFT: fix grid and mini-batch#87
AntonOresten wants to merge 1 commit into
NVIDIA:mainfrom
AntonOresten:fft-batch

Conversation

@AntonOresten
Copy link
Copy Markdown

Description

The FFT sample currently passes the full batch as the kernel's BS constant while also launching a grid of (BS, 1, 1) blocks. Each block then loads a (BS, …) tile, so every block redundantly processes the whole batch. The cost scales with batch and spills hard at modest sizes.

This PR re-interprets the kernel's BS constant as a per-block minibatch size and sizes the grid as Batch // BS accordingly. The wrapper exposes it as a minibatch: int = 1 parameter; the kernel and its internal tile shapes are unchanged.

Measured on a DGX Spark, N=512, batch=64, factors=(8,8,8), minibatch=1: kernel launch time goes from 2376 μs -> 12 μs (~200x), with twiddle factors precomputed (as they would be in any real use).

Sweep of minibatch ∈ {1, 2, 4} are all correct against torch.fft.fft, with no consistent win for minibatch > 1 at these problem sizes (registers/shared mem fill quickly). Thus, BS / minibatch could, and maybe should be dropped entirely.

x-ref JuliaGPU/cuTile.jl#232

Checklist

  • I am familiar with the Contributing Guidelines.
  • New or existing tests cover these changes.
  • The documentation is up to date with these changes.

Signed-off-by: AntonOresten <antonoresten@proton.me>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant