Summary
deepdiff.DeepHash (used by DataJoint's tripartite make for referential integrity checks) fails when the fetched data contains JAX arrays, as is common when using equinox Modules as data containers.
Minimal Reproduction
import jax.numpy as jnp
import deepdiff
# JAX 0-d scalar
deepdiff.DeepHash(jnp.float32(1.0))
# TypeError: iteration over a 0-d array
# JAX 1-d array
deepdiff.DeepHash(jnp.ones(3))
# TypeError: iteration over a 0-d array
# equinox Module with JAX fields (common pattern for datasets)
import equinox as eqx
class MyDataset(eqx.Module):
data: jnp.ndarray
scalar: jnp.ndarray
def __init__(self):
self.data = jnp.ones((3, 4))
self.scalar = jnp.float32(1.0)
deepdiff.DeepHash(MyDataset())
# TypeError: iteration over a 0-d array
Root Cause
DeepHash._hash checks isinstance(obj, Iterable) (line ~79 of deephash.py). JAX arrays implement __iter__, so they match. But JAX 0-d arrays (scalars) raise TypeError: iteration over a 0-d array when iterated — unlike numpy 0-d arrays which deepdiff handles via the numbers type check (numpy scalars are registered as numbers).
This affects the tripartite make pattern because _populate1 calls DeepHash(fetched_data) on whatever make_fetch returns. If the fetched data includes JAX arrays (e.g., a dataset loaded from the database), the integrity check fails before computation even starts.
Workaround
We monkey-patch DeepHash._hash to convert JAX arrays to numpy before hashing:
import numpy as np
import deepdiff.deephash as _dh
from deepdiff.helper import get_id
from jaxlib._jax import ArrayImpl
_orig_hash = _dh.DeepHash._hash
def _patched_hash(self, obj, parent, parents_ids=frozenset()):
if isinstance(obj, ArrayImpl):
jax_key = get_id(obj)
obj = np.asarray(obj)
if obj.ndim == 0:
obj = obj.item()
result = _orig_hash(self, obj, parent, parents_ids)
self.hashes[jax_key] = result
return result
return _orig_hash(self, obj, parent, parents_ids)
_dh.DeepHash._hash = _patched_hash
Suggested Fix
DeepHash could handle array-like objects more robustly by:
- Checking for numpy/JAX array types before the
Iterable check
- Converting array-likes to numpy via
np.asarray() (which both numpy and JAX support)
- Or registering JAX scalar types alongside numpy scalar types in the
numbers tuple
Versions
- deepdiff: 9.0.0
- jax: 0.9.2
- datajoint: 0.14.9
- equinox: 0.12.5
Summary
deepdiff.DeepHash(used by DataJoint's tripartitemakefor referential integrity checks) fails when the fetched data contains JAX arrays, as is common when using equinox Modules as data containers.Minimal Reproduction
Root Cause
DeepHash._hashchecksisinstance(obj, Iterable)(line ~79 ofdeephash.py). JAX arrays implement__iter__, so they match. But JAX 0-d arrays (scalars) raiseTypeError: iteration over a 0-d arraywhen iterated — unlike numpy 0-d arrays which deepdiff handles via thenumberstype check (numpy scalars are registered as numbers).This affects the tripartite
makepattern because_populate1callsDeepHash(fetched_data)on whatevermake_fetchreturns. If the fetched data includes JAX arrays (e.g., a dataset loaded from the database), the integrity check fails before computation even starts.Workaround
We monkey-patch
DeepHash._hashto convert JAX arrays to numpy before hashing:Suggested Fix
DeepHash could handle array-like objects more robustly by:
Iterablechecknp.asarray()(which both numpy and JAX support)numberstupleVersions