Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions diffrax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
AbstractLocalInterpolation as AbstractLocalInterpolation,
FourthOrderPolynomialInterpolation as FourthOrderPolynomialInterpolation,
LocalLinearInterpolation as LocalLinearInterpolation,
RodasInterpolation as RodasInterpolation,
ThirdOrderHermitePolynomialInterpolation as ThirdOrderHermitePolynomialInterpolation, # noqa: E501
)
from ._misc import adjoint_rms_seminorm as adjoint_rms_seminorm
Expand Down Expand Up @@ -106,6 +107,8 @@
QUICSORT as QUICSORT,
Ralston as Ralston,
ReversibleHeun as ReversibleHeun,
Rodas5p as Rodas5p,
Ros3p as Ros3p,
SEA as SEA,
SemiImplicitEuler as SemiImplicitEuler,
ShARK as ShARK,
Expand Down
1 change: 1 addition & 0 deletions diffrax/_integrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
Euler,
EulerHeun,
ItoMilstein,
Ros3p,
StratonovichMilstein,
)
from ._step_size_controller import (
Expand Down
92 changes: 92 additions & 0 deletions diffrax/_local_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@
import jax.numpy as jnp
import jax.tree_util as jtu
import numpy as np
from jaxtyping import Float64


if TYPE_CHECKING:
from typing import ClassVar as AbstractVar
else:
from equinox import AbstractVar
import jax.flatten_util as fu
from equinox.internal import ω
from jaxtyping import Array, ArrayLike, PyTree, Shaped

Expand Down Expand Up @@ -137,3 +139,93 @@ def _eval(_coeffs):
return jnp.polyval(_coeffs, t)

return jtu.tree_map(_eval, self.coeffs)


class RodasInterpolation(AbstractLocalInterpolation):
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

new interpolation method added for rodas class.

r"""Interpolation method for Rodas type solver.

??? cite "Reference"
```bibtex
@book{book,
author = {Hairer, Ernst and Wanner, Gerhard},
year = {1996},
month = {01},
pages = {},
title = {Solving Ordinary Differential Equations II. Stiff and
Differential-Algebraic Problems},
volume = {14},
journal = {Springer Verlag Series in Comput. Math.},
doi = {10.1007/978-3-662-09947-6}
}
```
"""

coeff: AbstractVar[np.ndarray]

stage_poly_coeffs: Float64[Array, "order stage"]
t0: RealScalarLike
t1: RealScalarLike
y0: PyTree[Shaped[ArrayLike, " ?*dims"], "Y"]
k: Float64[Array, "stage dims"]

def __init__(
self,
*,
t0: RealScalarLike,
t1: RealScalarLike,
y0: PyTree[Shaped[ArrayLike, " ?*dims"], "Y"],
k: Float64[Array, "stage dims"],
):
stage_poly_coeffs = []
for i in range(len(self.coeff)):
if i == len(self.coeff) - 1:
stage_poly_coeffs.append(self.coeff[i])
continue
stage_poly_coeffs.append(self.coeff[i] - self.coeff[i + 1])

self.stage_poly_coeffs = jnp.array(
np.transpose(stage_poly_coeffs), dtype=jnp.float64
)
self.y0 = y0
self.k = k
self.t0 = t0
self.t1 = t1

def evaluate(
self, t0: RealScalarLike, t1: RealScalarLike | None = None, left: bool = True
) -> PyTree[Array]:
del left
if t1 is not None:
return self.evaluate(t1) - self.evaluate(t0)

t = linear_rescale(self.t0, t0, self.t1)

def eval_increment():
with jax.numpy_dtype_promotion("standard"):
weighted_increment = jax.vmap(
lambda coeff, stage_k: (t * jnp.polyval(jnp.flip(coeff), t))
* stage_k
)(self.stage_poly_coeffs, self.k)
return jnp.sum(weighted_increment, axis=0).astype(self.k.dtype)

y0, unravel = fu.ravel_pytree(self.y0)
y1 = y0 + eval_increment()
return unravel(y1)

@classmethod
def from_k(
cls,
*,
t0: RealScalarLike,
t1: RealScalarLike,
y0: PyTree[Shaped[ArrayLike, " ?*dims"], "Y"],
k: Float64[Array, "stage dims"],
):
return cls(t0=t0, t1=t1, y0=y0, k=k)


RodasInterpolation.__init__.__doc__ = """**Arguments:**
Let `k` and `order` denote the stages and order of the solver.
- `coeff`: The coefficients of the Rodas interpolation. They represent the coefficients
of b(τ). Should be a numpy array of shape `(order - 1, k)`.
"""
2 changes: 2 additions & 0 deletions diffrax/_solver/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
from .quicsort import QUICSORT as QUICSORT
from .ralston import Ralston as Ralston
from .reversible_heun import ReversibleHeun as ReversibleHeun
from .rodas5p import Rodas5p as Rodas5p
from .ros3p import Ros3p as Ros3p
from .runge_kutta import (
AbstractDIRK as AbstractDIRK,
AbstractERK as AbstractERK,
Expand Down
238 changes: 238 additions & 0 deletions diffrax/_solver/rodas5p.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,238 @@
from collections.abc import Callable
from typing import ClassVar

import numpy as np

from diffrax._local_interpolation import RodasInterpolation

from .rosenbrock import AbstractRosenbrock, RosenbrockTableau


_tableau = RosenbrockTableau(
a_lower=(
np.array(
[
0.6358126895828704,
]
),
np.array([0.31242290829798824, 0.0971569310417652]),
np.array([1.3140825753299277, 1.8583084874257945, -2.1954603902496506]),
np.array(
[
0.42153145792835994,
0.25386966273009,
-0.2365547905326239,
-0.010005969169959593,
]
),
np.array(
[
1.712028062121536,
2.4456320333807953,
-3.117254839827603,
-0.04680538266310614,
0.006400126988377645,
]
),
np.array(
[
-0.9993030215739269,
-1.5559156221686088,
3.1251564324842267,
0.24141811637172583,
-0.023293468307707062,
0.21193756319429014,
],
),
np.array(
[
-0.003487250199264519,
-0.1299669712056423,
1.525941760806273,
1.1496140949123888,
-0.7043357115882416,
-1.0497034859198033,
0.21193756319429014,
]
),
),
c_lower=(
np.array([-0.6358126895828704]),
np.array([-0.4219499144476441, -0.12845036137023838]),
np.array([0.38766328985840337, -2.0150665034868993, 3.2201109377224792]),
np.array(
[
3.165730533008969,
1.3574038770338352,
-2.1414486119160854,
-0.2677977215559399,
]
),
np.array(
[
-2.711331083695463,
-4.001547655549404,
6.24241127231183,
0.28822349903483196,
-0.02969359529608471,
]
),
np.array(
[
0.9958157713746624,
1.4259486509629664,
-1.5992146716779536,
0.9081959785406629,
-0.6810422432805345,
-1.2616410491140935,
]
),
np.array(
[
0.12584733011227164,
0.1802058530898342,
-0.20210253993991456,
0.11477428094984177,
-0.08606747399894099,
0.08161021050037465,
-0.42620522390775717,
]
),
),
α=np.array(
[
0.0,
0.6358126895828704,
0.4095798393397535,
0.9769306725060716,
0.4288403609558664,
0.9999999999999998,
0.9999999999999999,
1.0000000000000002,
]
),
γ=np.array(
[
0.21193756319429014,
-0.42387512638858027,
-0.3384627126235924,
1.8046452872882734,
2.325825639765069,
9.71445146547012e-16,
2.220446049250313e-16,
-3.3306690738754696e-16,
]
),
m_sol=np.array(
[
0.12236007991300712,
0.050238881884191906,
1.3238392208663585,
1.2643883758622305,
-0.7904031855871826,
-0.9680932754194287,
-0.214267660713467,
0.21193756319429014,
]
),
m_error=np.array(
[
-0.003487250199264519,
-0.1299669712056423,
1.525941760806273,
1.1496140949123888,
-0.7043357115882416,
-1.0497034859198033,
0.21193756319429014,
0.0,
]
),
)


class _Rodas5pInterpolation(RodasInterpolation):
coeff: ClassVar[np.ndarray] = np.array(
[
[
0.12236007991300712,
0.050238881884191906,
1.3238392208663585,
1.2643883758622305,
-0.7904031855871826,
-0.9680932754194287,
-0.214267660713467,
0.21193756319429014,
],
[
-0.8232744916805133,
0.3181483349120214,
0.16922330104086836,
-0.049879453396320994,
0.19831791977261218,
0.31488148287699225,
-0.16387506167704194,
0.036457968151382296,
],
[
-0.6726085201965635,
-1.3128972079520966,
9.467244336394248,
12.924520918142036,
-9.002714541842755,
-11.404611057341922,
-1.4210850083209667,
1.4221510811179898,
],
[
1.4025185206933914,
0.9860299407499886,
-11.006871867857507,
-14.112585514422294,
9.574969612795117,
12.076626078349426,
2.114222828697341,
-1.0349095990054304,
],
],
dtype=np.float64,
)


class Rodas5p(AbstractRosenbrock):
r"""Rodas5p method.

5th order Rosenbrock method for solving stiff equations.

??? cite "Reference"

@article{Steinebach2023,
author = {Steinebach, Gerd},
title = {Construction of Rosenbrock--Wanner method Rodas5P and numerical
benchmarks within the Julia Differential Equations package},
journal = {BIT Numerical Mathematics},
year = {2023},
volume = {63},
number = {2},
pages = {27},
doi = {10.1007/s10543-023-00967-x},
url = {https://doi.org/10.1007/s10543-023-00967-x},
issn = {1572-9125},
date = {2023-04-17}
}

"""

tableau: ClassVar[RosenbrockTableau] = _tableau

interpolation_cls: ClassVar[Callable[..., _Rodas5pInterpolation]] = (
_Rodas5pInterpolation.from_k
)

rodas: ClassVar[bool] = True

def order(self, terms):
del terms
return 5


Rodas5p.__init__.__doc__ = """**Arguments:** None"""
Loading