-
-
Notifications
You must be signed in to change notification settings - Fork 166
Description
Thanks for such a great library! I need an implicit integrator that incorporates a mass matrix, so I thought I'd just create an example for people to work with. I don't have the bandwidth/time/knowledge to make this work for general cases (e.g., DAEs, IRK/DIRK/SDIRK/ESDIRK solvers, etc), but rather as a starting point to make it clear where/how one can adapt other methods (e.g. replacing _implicit_relation_f and/or _implicit_relation_k in _solver/runge_kutta.py). I feel like it could be easy to add this general functionality to the library, but it would likely either be a) a breaking change, as I'm inserting something into args, or b) more work than I'm willing to do, as you'd have to make sure the interface works and then implement it for all appropriate solvers.
import jax.numpy as jnp
import diffrax
from typing import ClassVar, TypeAlias, Callable
from equinox.internal import ω
from diffrax._custom_types import Args, BoolScalarLike, DenseInfo, RealScalarLike, VF, Y
from diffrax._heuristics import is_sde
from diffrax import LocalLinearInterpolation, with_stepsize_controller_tols, RESULTS, AbstractTerm, AbstractAdaptiveSolver, AbstractImplicitSolver
import optimistix as optx
_SolverState: TypeAlias = None
def _implicit_relation(z1, nonlinear_solve_args):
vf_prod, t1, y0, args, control = nonlinear_solve_args
mass_transform = args[0]
y1_guess = (y0**ω + z1**ω).ω
dy = mass_transform(z1, t1, y1_guess, args, control)
diff = (vf_prod(t1, y1_guess, args, control) ** ω - dy ** ω).ω
return diff
class ImplicitEulerMass(AbstractImplicitSolver, AbstractAdaptiveSolver):
r"""Implicit Euler method with mass matrix.
A-B-L stable 1st order SDIRK method. Has an embedded 2nd order Heun method for
adaptive step sizing. Uses 1 stage. Uses a 1st order local linear interpolation for
dense output.
"""
term_structure: ClassVar = AbstractTerm
# We actually have enough information to use 3rd order Hermite interpolation.
#
# We don't use it as this seems to be quite a bad choice for low-order solvers: it
# produces very oscillatory interpolations.
interpolation_cls: ClassVar[
Callable[..., LocalLinearInterpolation]
] = LocalLinearInterpolation
root_finder: optx.AbstractRootFinder = with_stepsize_controller_tols(optx.Chord)()
root_find_max_steps: int = 10
def order(self, terms):
del terms
return 1
def error_order(self, terms):
if is_sde(terms):
return None
else:
return 2
def init(
self,
terms: AbstractTerm,
t0: RealScalarLike,
t1: RealScalarLike,
y0: Y,
args: Args,
) -> _SolverState:
del terms, t0, t1, y0, args
return None
def step(
self,
terms: AbstractTerm,
t0: RealScalarLike,
t1: RealScalarLike,
y0: Y,
args: Args,
solver_state: _SolverState,
made_jump: BoolScalarLike,
) -> tuple[Y, Y, DenseInfo, _SolverState, RESULTS]:
del made_jump
control = terms.contr(t0, t1)
# Could use FSAL here but that would mean we'd need to switch to working with
# `f0 = terms.vf(t0, y0, args)`, and that gets quite hairy quite quickly.
# (C.f. `AbstractRungeKutta.step`.)
# If we wanted FSAL then really the correct thing to do would just be to
# write out a `ButcherTableau` and use `AbstractSDIRK`.
k0 = terms.vf_prod(t0, y0, args, control)
args = (terms.vf_prod, t1, y0, args, control)
nonlinear_sol = optx.root_find(
_implicit_relation,
self.root_finder,
k0,
args,
throw=False,
max_steps=self.root_find_max_steps,
)
k1 = nonlinear_sol.value
y1 = (y0**ω + k1**ω).ω
# Use the trapezoidal rule for adaptive step sizing.
y_error = (0.5 * (k1**ω - k0**ω)).ω
dense_info = dict(y0=y0, y1=y1)
solver_state = None
result = RESULTS.promote(nonlinear_sol.result)
return y1, y_error, dense_info, solver_state, result
def func(
self,
terms: AbstractTerm,
t0: RealScalarLike,
y0: Y,
args: Args,
) -> VF:
return terms.vf(t0, y0, args)Here's a small example of
import jax
import jax.numpy as jnp
def rhs(t, y, args):
_, scale = args
d_y = -scale * (y**2)
return d_y
def mass_transform(dy, t1, y1, args, control):
return y1 * dy
scale = 0.1
term = diffrax.ODETerm(rhs)
solver = ImplicitEulerMass()
t1 = 10
saveat_t = jnp.linspace(0, t1, num=101)
saveat = diffrax.SaveAt(ts=saveat_t)
stepsize_controller = diffrax.PIDController(rtol=1e-5, atol=1e-4)
true_traj = jnp.exp(-scale * saveat_t)
sol = diffrax.diffeqsolve(term, solver, t0=0, t1=t1, dt0=0.01, y0=1, args=(mass_transform, scale), saveat=saveat,
stepsize_controller=stepsize_controller)Here's a more difficult example for the Roberson DAE (all values borrowed from Julia's SciML here). It takes a lot of steps because backward Euler, practically, is a poor choice of implicit solver for this problem; this is more of a general idea of how to formulate such problems.
def rober_rhs(t, y, args):
y1, y2, y3 = y
_, k1, k2, k3 = args
du1 = -k1 * y1 + k3 * y2 * y3
du2 = k1 * y1 - k3 * y2 * y3 - k2 * y2 * y2
du3 = y1 + y2 + y3 - 1
return jnp.array([du1, du2, du3])
def rober_mass(dy, t1, y1, args, control):
return jnp.array([dy[0], dy[1], 0])
term = diffrax.ODETerm(rober_rhs)
solver = ImplicitEulerMass()
t1 = 1e1
saveat = diffrax.SaveAt(ts=saveat_t)
stepsize_controller = diffrax.PIDController(rtol=1e-8, atol=1e-8)
args = (rober_mass, 0.04, 3e7, 1e4)
y0 = jnp.array([1., 0., 0.])
saveat_t = jnp.logspace(-6, jnp.log10(t1), 1001)
saveat = diffrax.SaveAt(ts = saveat_t)
sol = diffrax.diffeqsolve(term, solver, t0=0, t1=t1, dt0=1e-5,
y0=y0, args=args, saveat=saveat,
stepsize_controller=stepsize_controller, max_steps=10000000)where the trajectories should match up nicely to those provided by SciML.
