Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/run_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ jobs:
run-tests:
strategy:
matrix:
python-version: [ "3.10", "3.12" ]
python-version: [ "3.11", "3.13" ]
os: [ ubuntu-latest ]
fail-fast: false
runs-on: ${{ matrix.os }}
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/against_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def speedtest(fn, name):
# INTEGRATE WITH scan


@jax.checkpoint # pyright: ignore
@jax.checkpoint
def body(carry, t):
u, v, dt = carry
u = u + du(t, v, None) * dt
Expand Down
20 changes: 19 additions & 1 deletion diffrax/_step_size_controller/clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@ def callback(_keep_step, _t1):
step_info = None
else:
step_index, step_ts = controller_state.step_info
# We actaully bump `next_t0` past any `step_ts` whilst checking where to
# We actually bump `next_t0` past any `step_ts` whilst checking where to
# clip `next_t1`. This is in case we have a set up like the following:
# ```python
# ClipStepSizeController(
Expand All @@ -376,6 +376,24 @@ def callback(_keep_step, _t1):
else:
jump_index, jump_ts = controller_state.jump_info
next_t0, made_jump2 = _bump_next_t0(next_t0, jump_ts)
# This next line is to fix
# https://github.com/patrick-kidger/diffrax/issues/713
# TODO: should we add this to the `step_ts` branch as well?
#
# What's going on here is that we may have
# the `diffeqsolve(t0=...)` be prevbefore a jump time (for example due to a
# previous diffeqsolve targeting that time), in which case during `.init`
# we will obtain `t0 = t1 = prevbefore(jump_time)`.
# The `_bump_next_t0` will then move `next_t0` to after the `jump_time`...
# whilst leaving `next_t1` unchanged! We actually end up `next_t1 < next_t0`
# which is very not okay.
#
# The fix is to ensure that `next_t1` is itself bumped to at least this
# value. As a final detail, we need to make it `nextafter` so that we don't
# have a zero-length interval – in this case an underlying PID controller
# would just never change the interval size at all, since it acts
# multiplicatively. (And even just 1 ULP is enough to unstick it.)
next_t1 = jnp.maximum(eqxi.nextafter(next_t0), next_t1)
made_jump = made_jump | made_jump2
jump_index = _find_idx_with_hint(next_t0, jump_ts, jump_index)
next_t1 = _clip_t(next_t1, jump_index, jump_ts, True)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ keywords = ["jax", "dynamical-systems", "differential-equations", "deep-learning
license = {file = "LICENSE"}
name = "diffrax"
readme = "README.md"
requires-python = ">=3.10"
requires-python = ">=3.11"
urls = {repository = "https://github.com/patrick-kidger/diffrax"}
version = "0.7.0"

Expand Down
27 changes: 27 additions & 0 deletions test/test_adaptive_stepsize_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import jax.numpy as jnp
import jax.random as jr
import jax.tree_util as jtu
import optimistix as optx
import pytest
from diffrax._step_size_controller.clip import _find_idx_with_hint
from jaxtyping import Array
Expand Down Expand Up @@ -361,3 +362,29 @@ def test_jump_at_t1_with_large_t1_in_float32():
saveat=saveat,
)
assert sol.ts == jnp.array([t1])


# https://github.com/patrick-kidger/diffrax/issues/713
def test_t0_at_jump_time():
jump_time = 0.98
controller = diffrax.PIDController(rtol=1e-6, atol=1e-6)
controller = diffrax.ClipStepSizeController(controller, jump_ts=[jump_time])
sol = diffrax.diffeqsolve(
diffrax.ODETerm(lambda t, y, args: jnp.zeros_like(y)),
diffrax.Heun(),
t0=eqxi.prevbefore(jnp.asarray(jump_time)),
t1=1.2,
dt0=None,
y0=jnp.array([0, 0, 0, 0.0]),
stepsize_controller=controller,
event=diffrax.Event(
cond_fn=lambda t, y, args, **kw: jump_time - t,
root_finder=optx.Newton(atol=1e-4, rtol=1e-4),
direction=True,
),
max_steps=100,
)
# And in particular not an event.
# What used to happen was something very weird where we'd oscillate across the
# jump time.
assert sol.result == diffrax.RESULTS.successful
Loading