Skip to content

(Non-constant) mass matrix example #710

@dannys4

Description

@dannys4

Thanks for such a great library! I need an implicit integrator that incorporates a mass matrix, so I thought I'd just create an example for people to work with. I don't have the bandwidth/time/knowledge to make this work for general cases (e.g., DAEs, IRK/DIRK/SDIRK/ESDIRK solvers, etc), but rather as a starting point to make it clear where/how one can adapt other methods (e.g. replacing _implicit_relation_f and/or _implicit_relation_k in _solver/runge_kutta.py). I feel like it could be easy to add this general functionality to the library, but it would likely either be a) a breaking change, as I'm inserting something into args, or b) more work than I'm willing to do, as you'd have to make sure the interface works and then implement it for all appropriate solvers.

import jax.numpy as jnp
import diffrax
from typing import ClassVar, TypeAlias, Callable
from equinox.internal import ω
from diffrax._custom_types import Args, BoolScalarLike, DenseInfo, RealScalarLike, VF, Y
from diffrax._heuristics import is_sde
from diffrax import LocalLinearInterpolation, with_stepsize_controller_tols, RESULTS, AbstractTerm, AbstractAdaptiveSolver, AbstractImplicitSolver
import optimistix as optx
_SolverState: TypeAlias = None

def _implicit_relation(z1, nonlinear_solve_args):
    vf_prod, t1, y0, args, control = nonlinear_solve_args
    mass_transform = args[0]
    y1_guess = (y0**ω + z1**ω).ω
    dy = mass_transform(z1, t1, y1_guess, args, control)
    diff = (vf_prod(t1, y1_guess, args, control) ** ω - dy ** ω).ω
    return diff

class ImplicitEulerMass(AbstractImplicitSolver, AbstractAdaptiveSolver):
    r"""Implicit Euler method with mass matrix.

    A-B-L stable 1st order SDIRK method. Has an embedded 2nd order Heun method for
    adaptive step sizing. Uses 1 stage. Uses a 1st order local linear interpolation for
    dense output.
    """

    term_structure: ClassVar = AbstractTerm
    # We actually have enough information to use 3rd order Hermite interpolation.
    #
    # We don't use it as this seems to be quite a bad choice for low-order solvers: it
    # produces very oscillatory interpolations.
    interpolation_cls: ClassVar[
        Callable[..., LocalLinearInterpolation]
    ] = LocalLinearInterpolation

    root_finder: optx.AbstractRootFinder = with_stepsize_controller_tols(optx.Chord)()
    root_find_max_steps: int = 10

    def order(self, terms):
        del terms
        return 1

    def error_order(self, terms):
        if is_sde(terms):
            return None
        else:
            return 2

    def init(
        self,
        terms: AbstractTerm,
        t0: RealScalarLike,
        t1: RealScalarLike,
        y0: Y,
        args: Args,
    ) -> _SolverState:
        del terms, t0, t1, y0, args
        return None

    def step(
        self,
        terms: AbstractTerm,
        t0: RealScalarLike,
        t1: RealScalarLike,
        y0: Y,
        args: Args,
        solver_state: _SolverState,
        made_jump: BoolScalarLike,
    ) -> tuple[Y, Y, DenseInfo, _SolverState, RESULTS]:
        del made_jump
        control = terms.contr(t0, t1)
        # Could use FSAL here but that would mean we'd need to switch to working with
        # `f0 = terms.vf(t0, y0, args)`, and that gets quite hairy quite quickly.
        # (C.f. `AbstractRungeKutta.step`.)
        # If we wanted FSAL then really the correct thing to do would just be to
        # write out a `ButcherTableau` and use `AbstractSDIRK`.
        k0 = terms.vf_prod(t0, y0, args, control)
        args = (terms.vf_prod, t1, y0, args, control)
        nonlinear_sol = optx.root_find(
            _implicit_relation,
            self.root_finder,
            k0,
            args,
            throw=False,
            max_steps=self.root_find_max_steps,
        )

        k1 = nonlinear_sol.value
        y1 = (y0**ω + k1**ω).ω
        # Use the trapezoidal rule for adaptive step sizing.
        y_error = (0.5 * (k1**ω - k0**ω)).ω
        dense_info = dict(y0=y0, y1=y1)
        solver_state = None
        result = RESULTS.promote(nonlinear_sol.result)
        return y1, y_error, dense_info, solver_state, result

    def func(
        self,
        terms: AbstractTerm,
        t0: RealScalarLike,
        y0: Y,
        args: Args,
    ) -> VF:
        return terms.vf(t0, y0, args)

Here's a small example of $y\dot{y} = -\lambda y^2$, which the astute reader will recognize as exponential decay.

import jax
import jax.numpy as jnp

def rhs(t, y, args):
    _, scale = args
    d_y = -scale * (y**2)
    return d_y

def mass_transform(dy, t1, y1, args, control):
    return y1 * dy

scale = 0.1
term = diffrax.ODETerm(rhs)
solver = ImplicitEulerMass()
t1 = 10
saveat_t = jnp.linspace(0, t1, num=101)
saveat = diffrax.SaveAt(ts=saveat_t)
stepsize_controller = diffrax.PIDController(rtol=1e-5, atol=1e-4)
true_traj = jnp.exp(-scale * saveat_t)
sol = diffrax.diffeqsolve(term, solver, t0=0, t1=t1, dt0=0.01, y0=1, args=(mass_transform, scale), saveat=saveat,
                  stepsize_controller=stepsize_controller)

Here's a more difficult example for the Roberson DAE (all values borrowed from Julia's SciML here). It takes a lot of steps because backward Euler, practically, is a poor choice of implicit solver for this problem; this is more of a general idea of how to formulate such problems.

def rober_rhs(t, y, args):
    y1, y2, y3 = y
    _, k1, k2, k3 = args
    du1 = -k1 * y1 + k3 * y2 * y3
    du2 = k1 * y1 - k3 * y2 * y3 - k2 * y2 * y2
    du3 = y1 + y2 + y3 - 1
    return jnp.array([du1, du2, du3])

def rober_mass(dy, t1, y1, args, control):
    return jnp.array([dy[0], dy[1], 0])

term = diffrax.ODETerm(rober_rhs)
solver = ImplicitEulerMass()
t1 = 1e1
saveat = diffrax.SaveAt(ts=saveat_t)
stepsize_controller = diffrax.PIDController(rtol=1e-8, atol=1e-8)
args = (rober_mass, 0.04, 3e7, 1e4)
y0 = jnp.array([1., 0., 0.])
saveat_t = jnp.logspace(-6, jnp.log10(t1), 1001)
saveat = diffrax.SaveAt(ts = saveat_t)
sol = diffrax.diffeqsolve(term, solver, t0=0, t1=t1, dt0=1e-5,
                          y0=y0, args=args, saveat=saveat,
                          stepsize_controller=stepsize_controller, max_steps=10000000)

where the trajectories should match up nicely to those provided by SciML.

Image

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions