Skip to content

Commit 8857ef1

Browse files
Fixed case in which t0 is prevbefore a jump time
1 parent b91138f commit 8857ef1

File tree

3 files changed

+47
-2
lines changed

3 files changed

+47
-2
lines changed

benchmarks/against_scan.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def speedtest(fn, name):
3636
# INTEGRATE WITH scan
3737

3838

39-
@jax.checkpoint # pyright: ignore
39+
@jax.checkpoint
4040
def body(carry, t):
4141
u, v, dt = carry
4242
u = u + du(t, v, None) * dt

diffrax/_step_size_controller/clip.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -356,7 +356,7 @@ def callback(_keep_step, _t1):
356356
step_info = None
357357
else:
358358
step_index, step_ts = controller_state.step_info
359-
# We actaully bump `next_t0` past any `step_ts` whilst checking where to
359+
# We actually bump `next_t0` past any `step_ts` whilst checking where to
360360
# clip `next_t1`. This is in case we have a set up like the following:
361361
# ```python
362362
# ClipStepSizeController(
@@ -376,6 +376,24 @@ def callback(_keep_step, _t1):
376376
else:
377377
jump_index, jump_ts = controller_state.jump_info
378378
next_t0, made_jump2 = _bump_next_t0(next_t0, jump_ts)
379+
# This next line is to fix
380+
# https://github.com/patrick-kidger/diffrax/issues/713
381+
# TODO: should we add this to the `step_ts` branch as well?
382+
#
383+
# What's going on here is that we may have
384+
# the `diffeqsolve(t0=...)` be prevbefore a jump time (for example due to a
385+
# previous diffeqsolve targeting that time), in which case during `.init`
386+
# we will obtain `t0 = t1 = prevbefore(jump_time)`.
387+
# The `_bump_next_t0` will then move `next_t0` to after the `jump_time`...
388+
# whilst leaving `next_t1` unchanged! We actually end up `next_t1 < next_t0`
389+
# which is very not okay.
390+
#
391+
# The fix is to ensure that `next_t1` is itself bumped to at least this
392+
# value. As a final detail, we need to make it `nextafter` so that we don't
393+
# have a zero-length interval – in this case an underlying PID controller
394+
# would just never change the interval size at all, since it acts
395+
# multiplicatively. (And even just 1 ULP is enough to unstick it.)
396+
next_t1 = jnp.maximum(eqxi.nextafter(next_t0), next_t1)
379397
made_jump = made_jump | made_jump2
380398
jump_index = _find_idx_with_hint(next_t0, jump_ts, jump_index)
381399
next_t1 = _clip_t(next_t1, jump_index, jump_ts, True)

test/test_adaptive_stepsize_controller.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import jax.numpy as jnp
88
import jax.random as jr
99
import jax.tree_util as jtu
10+
import optimistix as optx
1011
import pytest
1112
from diffrax._step_size_controller.clip import _find_idx_with_hint
1213
from jaxtyping import Array
@@ -361,3 +362,29 @@ def test_jump_at_t1_with_large_t1_in_float32():
361362
saveat=saveat,
362363
)
363364
assert sol.ts == jnp.array([t1])
365+
366+
367+
# https://github.com/patrick-kidger/diffrax/issues/713
368+
def test_t0_at_jump_time():
369+
jump_time = 0.98
370+
controller = diffrax.PIDController(rtol=1e-6, atol=1e-6)
371+
controller = diffrax.ClipStepSizeController(controller, jump_ts=[jump_time])
372+
sol = diffrax.diffeqsolve(
373+
diffrax.ODETerm(lambda t, y, args: jnp.zeros_like(y)),
374+
diffrax.Heun(),
375+
t0=eqxi.prevbefore(jnp.asarray(jump_time)),
376+
t1=1.2,
377+
dt0=None,
378+
y0=jnp.array([0, 0, 0, 0.0]),
379+
stepsize_controller=controller,
380+
event=diffrax.Event(
381+
cond_fn=lambda t, y, args, **kw: jump_time - t,
382+
root_finder=optx.Newton(atol=1e-4, rtol=1e-4),
383+
direction=True,
384+
),
385+
max_steps=100,
386+
)
387+
# And in particular not an event.
388+
# What used to happen was something very weird where we'd oscillate across the
389+
# jump time.
390+
assert sol.result == diffrax.RESULTS.successful

0 commit comments

Comments
 (0)