Skip to content

Add Metal DLPack zero-copy sharing#3531

Open
XXXXRT666 wants to merge 13 commits into
ml-explore:mainfrom
XXXXRT666:metal-dlpack-zero-copy-draft
Open

Add Metal DLPack zero-copy sharing#3531
XXXXRT666 wants to merge 13 commits into
ml-explore:mainfrom
XXXXRT666:metal-dlpack-zero-copy-draft

Conversation

@XXXXRT666
Copy link
Copy Markdown
Contributor

@XXXXRT666 XXXXRT666 commented May 11, 2026

Proposed changes

This draft adds zero-copy Metal DLPack sharing for MLX arrays and PyTorch MPS tensors.

This PR builds on the merged DLPack import PR #3495 and requires nanobind support.

The main changes are:

  • Import Metal DLPack arrays by wrapping the underlying Metal buffer instead of copying through CPU.
  • Export MLX arrays to Metal DLPack using the MLX Metal buffer and DLPack byte_offset.
  • Add mx.from_dlpack(..., copy=...) controls for Metal DLPack inputs.
  • Keep mx.array(...) zero-copy for Metal DLPack inputs unless an explicit different dtype is requested.
  • Document the explicit synchronization requirements between PyTorch MPS and MLX.

The shared lifetime is tied to the exported or imported buffer. Synchronization remains explicit: PyTorch writes require torch.mps.synchronize() before MLX reads, and MLX writes require mx.eval(...) before PyTorch reads.

For MLX arrays exported to PyTorch, later MLX updates may rebind the MLX array to a new buffer while the PyTorch tensor continues to reference the exported buffer.

Checklist

Put an x in the boxes that apply.

  • I have read the CONTRIBUTING document
  • I have run pre-commit run --all-files to format my code / installed pre-commit prior to committing changes
  • I have added tests that prove my fix is effective or that my feature works
  • I have updated the necessary documentation (if needed)

@megacpp
Copy link
Copy Markdown

megacpp commented May 13, 2026

Hi @XXXXRT666 — read through this PR after @awni redirected us here from #3548. The nb::ndarray<nb::ro, nb::c_contig> approach over the in-flight nanobind PR (#1338) is materially cleaner than the manual capsule parsing we had in our downstream PoC, and lifting is_host_accessible() into mlx/allocator.h is the right level of abstraction. Closed the RFC; happy with this being the path forward for #2848.

Wanted to offer some testing help that complements the PyTorch MPS bring-up you have:

We maintain a downstream TileLang fork (https://github.com/DatasunriseOU/tilelang) whose TVM-FFI bridge exports kDLMetal DLPack capsules for tensors backed by id<MTLBuffer>. That gives a non-PyTorch Metal DLPack producer that exercises the same import path you're adding here. Specifically it covers:

  • kDLMetal producers that do not require the PyTorch-MPS workaround for __dlpack__ — exercises the import path directly.
  • Round-trip mx.array → DLPack → TVM-FFI Metal kernel → DLPack → mx.array zero-copy.
  • Custom Metal kernels (via mlx.fast.metal_kernel) consuming an imported mx.array whose underlying MTLBuffer was allocated outside MLX.
  • storageMode matrix (we hit Shared and Managed; Private is the obvious edge case the spec needs to nail down — your is_host_accessible() decision likely answers this implicitly but worth a sanity check from the producer side).
  • byte_offset != 0 cases that we'd previously rejected outright in our PoC — your PR seems to handle these via byte_offset-aware import; happy to write a TileLang-side test for it.

If useful, once the PR converges I can:

  1. Pull this branch into our TileLang test matrix and report back on any rough edges (CI on macOS Metal hardware).
  2. Send a minimal standalone repro (no TileLang dependency) for any of the above scenarios if you'd like them as additions to python/tests/test_array.py.
  3. Beta-test the mx.from_dlpack(..., copy=...) semantics against the dtype-mismatch-shares case (002360faa) once the API stabilizes.

Tag me here when you'd like input — no rush, just don't want this to slip past once it's review-ready.

(For the orthogonal mx.empty() piece that was also in our PoC, opened it as a separate issue per @awni's guidance.)

megacpp pushed a commit to DatasunriseOU/mlx that referenced this pull request May 13, 2026
Required by ml-explore/mlx PR ml-explore#3531 (Metal DLPack zero-copy sharing).
SHA: 33f52e635db5e6229060481d16a167230a1a474b
PR:   wjakob/nanobind#1338
Branch: metal-dlpack-cast
@XXXXRT666 XXXXRT666 force-pushed the metal-dlpack-zero-copy-draft branch from 002360f to 4e16f1d Compare May 14, 2026 04:39
@McPatate
Copy link
Copy Markdown

This would be super cool if it landed for end to end "0-copy" support in safetensors! I'm working (safetensors/safetensors#767) on adding reading bytes from disk in raw MTLBuffers, which can then be handed to the framework via dlpack with 0-copy. Works well with torch, would be happy to see that land in mlx!

Also, support for byte_offset !=0 would be nice (already in the PR but commenting to notify it's useful) since we can go one step further: currently the mps path is pread -> MTLBuffer, but that goes through kernel pages before hitting userspace buffer. Having byte_offset non zero support would enable mmap-ing the file and creating MTLBuffers that reference specific slices of the mmap, which would demand-fault pages from disk into the page cache on first access and give userspace access directly, leaving only the disk -> kernel-side copy.

Quick question on the dl_tensor.data convention, torch's mps treats it as id<MTLBuffer>, passing the contents segfaults. Curious to know which direction MLX will be taking, as it impacts us downstream!

megacpp pushed a commit to DatasunriseOU/mlx that referenced this pull request May 14, 2026
megacpp pushed a commit to DatasunriseOU/mlx that referenced this pull request May 14, 2026
Required by ml-explore/mlx PR ml-explore#3531 (Metal DLPack zero-copy sharing).
SHA: 33f52e635db5e6229060481d16a167230a1a474b
PR:   wjakob/nanobind#1338
Branch: metal-dlpack-cast
@XXXXRT666
Copy link
Copy Markdown
Contributor Author

Quick question on the dl_tensor.data convention, torch's mps treats it as id<MTLBuffer>, passing the contents segfaults. Curious to know which direction MLX will be taking, as it impacts us downstream!

https://dmlc.github.io/dlpack/latest/c_api.html#c.DLTensor.data

The data pointer points to the allocated data. This will be CUDA device pointer, cl_mem handle in OpenCL, or id<MTLBuffer> for Metal.

@XXXXRT666 XXXXRT666 force-pushed the metal-dlpack-zero-copy-draft branch from 4e16f1d to a17cd99 Compare May 19, 2026 07:44
@XXXXRT666 XXXXRT666 marked this pull request as ready for review May 19, 2026 08:40
Comment thread docs/src/usage/numpy.rst Outdated
Comment thread mlx/backend/cuda/allocator.cpp Outdated
Comment thread mlx/backend/metal/allocator.cpp
Comment thread CMakeLists.txt
Comment thread python/src/convert.cpp Outdated
@XXXXRT666
Copy link
Copy Markdown
Contributor Author

Metal DLPack benchmark

Mean over 50 measured iterations after 5 warmups on M4. Each timed iteration synchronizes the producer before timing and synchronizes/evaluates the result before stopping the timer.

  • Branch: mlx 0.32.0.dev20260521+04665e3bf, torch 2.12.0
  • Baseline: PyPI mlx 0.31.2, torch 2.12.0
  • Shapes: 1024x1024, 2048x2048, 4096x4096; dtypes: float32, float16.
  • Bandwidth is effective bandwidth computed as tensor bytes divided by mean time. For zero-copy paths it measures conversion overhead, not physical memory-copy bandwidth.

PyTorch MPS -> MLX Metal

Branch uses mx.array(torch_mps_tensor). Baseline uses the legacy path mx.array(torch_mps_tensor.cpu()).

dtype shape branch mean baseline mean comparison
float32 1024x1024 0.0017 ms 1.0455 ms, 3.7 GiB/s 618x lower latency, 618x bandwidth
float32 2048x2048 0.0015 ms 1.8094 ms, 8.6 GiB/s 1167x lower latency, 1167x bandwidth
float32 4096x4096 0.0015 ms 7.7090 ms, 8.1 GiB/s 5209x lower latency, 5209x bandwidth
float16 1024x1024 0.0016 ms 0.6499 ms, 3.0 GiB/s 399x lower latency, 399x bandwidth
float16 2048x2048 0.0015 ms 1.9847 ms, 3.9 GiB/s 1305x lower latency, 1305x bandwidth
float16 4096x4096 0.0018 ms 4.1107 ms, 7.6 GiB/s 2306x lower latency, 2306x bandwidth

MLX Metal -> PyTorch MPS

Both variants call torch.utils.dlpack.from_dlpack(mx_array) and then ensure the result is on MPS with to("mps") if needed.

dtype shape branch mean baseline mean comparison
float32 1024x1024 0.0040 ms 0.5166 ms, 7.6 GiB/s 128x lower latency, 128x bandwidth
float32 2048x2048 0.0017 ms 1.5256 ms, 10.2 GiB/s 923x lower latency, 923x bandwidth
float32 4096x4096 0.0018 ms 2.3682 ms, 26.4 GiB/s 1353x lower latency, 1353x bandwidth
float16 1024x1024 0.0014 ms 0.4257 ms, 4.6 GiB/s 295x lower latency, 295x bandwidth
float16 2048x2048 0.0018 ms 0.7522 ms, 10.4 GiB/s 428x lower latency, 428x bandwidth
float16 4096x4096 0.0014 ms 1.4215 ms, 22.0 GiB/s 1036x lower latency, 1036x bandwidth

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants