@@ -356,7 +356,7 @@ def callback(_keep_step, _t1):
356356 step_info = None
357357 else :
358358 step_index , step_ts = controller_state .step_info
359- # We actaully bump `next_t0` past any `step_ts` whilst checking where to
359+ # We actually bump `next_t0` past any `step_ts` whilst checking where to
360360 # clip `next_t1`. This is in case we have a set up like the following:
361361 # ```python
362362 # ClipStepSizeController(
@@ -376,6 +376,24 @@ def callback(_keep_step, _t1):
376376 else :
377377 jump_index , jump_ts = controller_state .jump_info
378378 next_t0 , made_jump2 = _bump_next_t0 (next_t0 , jump_ts )
379+ # This next line is to fix
380+ # https://github.com/patrick-kidger/diffrax/issues/713
381+ # TODO: should we add this to the `step_ts` branch as well?
382+ #
383+ # What's going on here is that we may have
384+ # the `diffeqsolve(t0=...)` be prevbefore a jump time (for example due to a
385+ # previous diffeqsolve targeting that time), in which case during `.init`
386+ # we will obtain `t0 = t1 = prevbefore(jump_time)`.
387+ # The `_bump_next_t0` will then move `next_t0` to after the `jump_time`...
388+ # whilst leaving `next_t1` unchanged! We actually end up `next_t1 < next_t0`
389+ # which is very not okay.
390+ #
391+ # The fix is to ensure that `next_t1` is itself bumped to at least this
392+ # value. As a final detail, we need to make it `nextafter` so that we don't
393+ # have a zero-length interval – in this case an underlying PID controller
394+ # would just never change the interval size at all, since it acts
395+ # multiplicatively. (And even just 1 ULP is enough to unstick it.)
396+ next_t1 = jnp .maximum (eqxi .nextafter (next_t0 ), next_t1 )
379397 made_jump = made_jump | made_jump2
380398 jump_index = _find_idx_with_hint (next_t0 , jump_ts , jump_index )
381399 next_t1 = _clip_t (next_t1 , jump_index , jump_ts , True )
0 commit comments