[CPU - Linux] AVX SIMD backend for fp16 and bf16 matmul#3502
Conversation
|
@dhiltgen I remember ollama was doing something similar? Can you please check if this would live together with your work? |
|
I have an older PR #3019 which I've been meaning to break up into smaller chunks. I'll add inline comments on this PR with some suggestions on how this could become a partial precursor to that broader implementation. |
| @@ -0,0 +1,432 @@ | |||
| // Copyright © 2025 Apple Inc. | |||
| #pragma once | |||
There was a problem hiding this comment.
Since this is largely GEMM-oriented helpers this could move to mlx/backend/cpu/gemms/avx2_gemm_simd.h and use a GEMM-private namespace rather than mlx::core::simd. That lets a future broad AVX2 SIMD layer land without colliding with this PR.
There was a problem hiding this comment.
Thank you, I have been working on implementing a more full set of AVX2 instructions to follow this PR. Does it still make sense for me to submit that?
There was a problem hiding this comment.
I would be happy to review if it is small scale pull request like this one.
|
Thank you for the callout @zcbenz and for the feedback @dhiltgen. I made some changes and appreciate any feedback. As mentioned in the PR description, and a reply above, I'd been working on a fuller set of AVX2 instructions to follow this PR. I'd like to submit that (assuming this PR is good!) if it still makes sense. Let me know. |
| inline Simd<T, N> fma(Simd<T, N> a, Simd<T, N> b, Simd<T, N> c); | ||
|
|
||
| // Simd<float, 8> — wraps __m256 for AVX operations. | ||
| using float8 = Simd<float, 8>; |
There was a problem hiding this comment.
float8 is an ambiguous name as it also means 8-bit float number. I think just using Simd<float, 8> would not be so bad?
There was a problem hiding this comment.
Thank you, that makes sense! I think at some point I thought the alias made the code cleaner
| inline float8 fma<float, 8>(float8 a, float8 b, float8 c) { | ||
| #ifdef __AVX2__ | ||
| return float8(_mm256_fmadd_ps(a, b, c)); | ||
| #else |
There was a problem hiding this comment.
Is it necessary to provide a fallback since this file is ensured to be compiled with -mavx2 -mfma -mf16c?
There was a problem hiding this comment.
This makes sense too! I think I initially wasn't sure if it was realistic to expect all three sets of instructions to be available
| message( | ||
| STATUS "Compiler supports AVX2/FMA/F16C - enabling AVX2 SIMD backend") | ||
| target_compile_options(mlx PRIVATE -mavx2 -mfma -mf16c) | ||
| target_compile_definitions(mlx PRIVATE MLX_USE_AVX2) |
There was a problem hiding this comment.
Rather than defining it in cmake, it would be more robust defining it in source code:
#if defined(__AVX2__) && defined(__F16C__)
#define MLX_USE_AVX2
#endifThere was a problem hiding this comment.
I'm less sure of how I implemented this piece, but will have the commit with a change pushed here "soon"
|
|
||
| A_packed_buf.reset(MC_BLOCK * KC_BLOCK); | ||
| B_packed_buf.reset(KC_BLOCK * NC_BLOCK); | ||
| C_acc_buf.reset(M * NC_BLOCK); |
There was a problem hiding this comment.
We don't do allocations inside the implementations, because it would escape buffer cache and it is expensive.
The proper way to do that, is to create an mx::array, allocate memory for it, and then call encoder.add_temporary and encoder.set_input_array for it. For example:
mlx/mlx/backend/cpu/quantized.cpp
Lines 1336 to 1341 in 7b7c124
| aligned_unique_ptr() : ptr_(nullptr), size_(0) {} | ||
|
|
||
| explicit aligned_unique_ptr(size_t size) : size_(size) { | ||
| ptr_ = static_cast<T*>(aligned_alloc(32, size * sizeof(T))); |
There was a problem hiding this comment.
I'm not sure if aligned_alloc is necessary here, malloc(size) should be able to ensure 256-bit alignment if size is a multiple of 32-byte.
Proposed changes
This PR adds an AVX SIMD backend for fp16 and bf16 matmul (GEMM and GEMV) on CPU for Linux. Follows from the discussion in #2037, and is a precursor to adding the full set of AVX SIMD instructions in a follow-up PR. Let me know what you think, I'd appreciate any feedback (including adjustments to benchmarking methodology).
I modified the
bench_gemm.pyandbench_gemv.pyinbenchmarks/python/blasso they'd complete in a reasonable amount of time. I ran them with a build of mlx from this PR and against the official mlx release for comparison. Note I left out the other dtypes from the benchmarked results printed below due to potential build differences (could be an error on my part). I built mlx with:Bench setup
6.18.9-arch1-2 x86_64mlx-cpu==0.31.2torch==2.5.1+cpuBench results
GEMM - branch (this PR)
GEMM -
mlx-cpu==0.31.2GEMV - branch (this PR)
GEMV -
mlx-cpu==0.31.2Checklist
Put an
xin the boxes that apply.pre-commit run --all-filesto format my code / installed pre-commit prior to committing changes