Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
1d830b8
implementation of multi-stage time integrators
fernanvr May 5, 2025
7f087b3
Merge remote-tracking branch 'upstream/main' into multi-stage-time-in…
fernanvr Jun 13, 2025
214d882
Return of first PR comments
fernanvr Jun 13, 2025
d6c4d4a
Return of first PR comments
fernanvr Jun 13, 2025
78f8a0b
2nd PR revision
fernanvr Jun 23, 2025
1c9d517
2nd PR revision
fernanvr Jun 23, 2025
11db48b
2rd PR, updating tests and suggestions of 2nd PR revision
fernanvr Jun 25, 2025
83dfb04
3rd PR, updating tests and suggestions of 2nd PR revision
fernanvr Jun 25, 2025
d47a106
4th PR revision, code refining and improving tests
fernanvr Jun 25, 2025
1f93a45
4th PR revision, code refining and improving tests
fernanvr Jun 25, 2025
eea3a52
5th PR revision, one suggestion from EdC and improving tests
fernanvr Jun 26, 2025
11d1429
including two more Runge-Kutta methods and improving tests: checking …
fernanvr Jul 1, 2025
4637ac2
changes to consider coupled Multistage equations
fernanvr Jul 16, 2025
ac1da7e
Improvements of the HORK_EXP
fernanvr Aug 15, 2025
e9b3533
Merge branch 'main' into multi-stage-time-integrator
fernanvr Aug 15, 2025
dc3dd77
Merge remote-tracking branch 'upstream/main' into multi-stage-time-in…
fernanvr Oct 8, 2025
1fd4a02
tuples, improved class names, extensive tests
fernanvr Oct 8, 2025
a0c45c1
improving spacing in some tests
fernanvr Oct 8, 2025
93c6e3f
Add MFE time stepping Jupyter notebook
fernanvr Oct 23, 2025
ef8d1ac
Remove MFE_time_size.ipynb notebook
fernanvr Oct 23, 2025
fa5acac
Update multistage implementation and tests
fernanvr Oct 29, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions devito/ir/equations/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,8 @@ def _(exprs, **kwargs):
Handle iterables of expressions.
"""
lowered = []
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you do return [_lower_multistage(expr, **kwargs) for i in exprs for expr in i]?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did something like that...

for i, expr in enumerate(exprs):
lowered.extend(_lower_multistage(expr, eq_num=i))
for expr in exprs:
lowered.extend(_lower_multistage(expr, **kwargs))
return lowered


Expand Down
6 changes: 1 addition & 5 deletions devito/operations/solve.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,7 @@ def solve(eq, target, **kwargs):
sols_temp = sols[0]

method = kwargs.get("method", None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is method a string here, or is it the class for the method? In the latter case, it would remove the need to have the method_registry mapper. Furthermore, it would allow you to have method.resolve(target, sols_temp) here, which is tidier

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a string. The idea is that the user provides a string to identify which time integrator to apply.

if method is not None:
method_cls = resolve_method(method)
return method_cls(target, sols_temp)._evaluate(**kwargs)
else:
return sols_temp
return sols_temp if method is None else resolve_method(method)(target, sols_temp)


def linsolve(expr, target, **kwargs):
Expand Down
2 changes: 1 addition & 1 deletion devito/operator/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ def _lower_exprs(cls, expressions, **kwargs):
* Apply substitution rules;
* Shift indices for domain alignment.
"""
expressions = lower_multistage(expressions)
expressions = lower_multistage(expressions, **kwargs)

expand = kwargs['options'].get('expand', True)

Expand Down
7 changes: 4 additions & 3 deletions devito/types/multistage.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this file should be moved to somewhere like devito/timestepping/rungekutta.py or devito/timestepping/explicitmultistage.py that way additional timesteppers can be contributed as new files. (I'm thinking about implicit multistage, backward difference formulae etc...)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I renamed the class to HighOrderRungeKuttaExponential. I realize the name might be confusing since this particular Runge-Kutta is explicit, but “EXP” was intended to highlight the exponential aspect. I’ve also updated the other class names based on your suggestions.

Regarding the file location, it’s currently in /types as recommended by @mloubout (see suggestion). Personally, I think both /timestepping and /types are reasonable options. Perhaps we can discuss this with @EdCaunt and @FabioLuporini to reach a consensus.

Original file line number Diff line number Diff line change
Expand Up @@ -127,16 +127,17 @@ def _evaluate(self, **kwargs):
- `s` stage equations of the form `k_i = rhs evaluated at intermediate state`
- 1 final update equation of the form `u.forward = u + dt * sum(b_i * k_i)`
"""
eq_num = kwargs['eq_num']
stage_id = kwargs.get('sregistry').make_name(prefix='k')

u = self.lhs.function
rhs = self.rhs
grid = u.grid
t = grid.time_dim
dt = t.spacing

# Create temporary Functions to hold each stage
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpick: these are Array now

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

right!

# k = [Array(name=f'k{eq_num}{i}', dimensions=grid.shape, grid=grid, dtype=u.dtype) for i in range(self.s)] # Trying Array
k = [Function(name=f'k{eq_num}{i}', grid=grid, space_order=u.space_order, dtype=u.dtype)
# k = [Array(name=f'{stage_id}{i}', dimensions=grid.shape, grid=grid, dtype=u.dtype) for i in range(self.s)] # Trying Array
k = [Function(name=f'{stage_id}{i}', grid=grid, space_order=u.space_order, dtype=u.dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

kwargs.get('sregistry').make_name(prefix='k') wants to be inside this loop to ensure that all names are unique

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

for i in range(self.s)]

stage_eqs = []
Expand Down
126 changes: 80 additions & 46 deletions tests/test_multistage.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
from devito import (Grid, Function, TimeFunction,
Derivative, Operator, solve, Eq)
from devito.types.multistage import resolve_method
from devito.ir.support import SymbolRegistry
from devito.ir.equations import lower_multistage


def test_multistage_solve(time_int='RK44'):
def test_multistage_object(time_int='RK44'):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider using pytest.mark.parametrize here. To add, some test classes like TestLowering, TestAPI, TestRK, etc would help with organisation of this file and running specific batches of tests

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

extent = (1, 1)
shape = (3, 3)
origin = (0, 0)
Expand All @@ -25,20 +27,19 @@ def test_multistage_solve(time_int='RK44'):
# Source definition
src_spatial = Function(name="src_spat", grid=grid, space_order=2, dtype=float64)
src_spatial.data[1, 1] = 1
f0 = 0.01
src_temporal = (1 - 2 * (pi * f0 * (t * dt - 1 / f0)) ** 2) * exp(-(pi * f0 * (t * dt - 1 / f0)) ** 2)
src_temporal = (1 - 2 * (t*dt - 1)**2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For what its worth, tests of this kind don't need to have any physical significance, so long as they produce the desired behaviour in the compiler that you are testing for. For example, you could probably omit the source terms entirely and probably the derivatives too, simply creating a multistage timestepper out of a trivial equation that adds one to the solution at each timestep or similar.

However this is still a well-made and focussed test as-is

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I think it’s important to have a test that includes derivatives and source terms. However, I agree that simpler examples should also be included. I’ve added one without those elements.


# PDE system (2D acoustic)
system_eqs_rhs = [U[1] + src_spatial * src_temporal,
Derivative(U[0], (x, 2), fd_order=2) +
Derivative(U[0], (y, 2), fd_order=2) +
src_spatial * src_temporal]
Derivative(U[0], (x, 2), fd_order=2) +
Derivative(U[0], (y, 2), fd_order=2) +
src_spatial * src_temporal]

# Time integration scheme
return [solve(system_eqs_rhs[i] - U[i], U[i], method=time_int, eq_num=i) for i in range(2)]
# Class of the time integration scheme
return [resolve_method(time_int)(U[i], system_eqs_rhs[i]) for i in range(2)]


def test_multistage_op_constructing_directly(time_int='RK44'):
def test_multistage_lower_multistage(time_int='RK44'):
extent = (1, 1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lots of boilerplate is repeated in these tests. Consider a convenience function for setting up the grid etc

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

were created two functions to reduce code repetition

shape = (3, 3)
origin = (0, 0)
Expand All @@ -57,23 +58,55 @@ def test_multistage_op_constructing_directly(time_int='RK44'):
# Source definition
src_spatial = Function(name="src_spat", grid=grid, space_order=2, dtype=float64)
src_spatial.data[1, 1] = 1
f0 = 0.01
src_temporal = (1 - 2 * (pi * f0 * (t * dt - 1 / f0)) ** 2) * exp(-(pi * f0 * (t * dt - 1 / f0)) ** 2)
src_temporal = (1 - 2 * (t*dt - 1)**2)

# PDE system (2D acoustic)
system_eqs_rhs = [U[1] + src_spatial * src_temporal,
Derivative(U[0], (x, 2), fd_order=2) +
Derivative(U[0], (y, 2), fd_order=2) +
src_spatial * src_temporal]

# Time integration scheme

# Class of the time integration scheme
pdes = [resolve_method(time_int)(U[i], system_eqs_rhs[i]) for i in range(2)]
op = Operator(pdes, subs=grid.spacing_map)
op(dt=0.01, time=1)

sregistry=SymbolRegistry()

def test_multistage_op_computing_directly(time_int='RK44'):
return lower_multistage(pdes, sregistry=sregistry)



def test_multistage_solve(time_int='RK44'):
extent = (1, 1)
shape = (3, 3)
origin = (0, 0)

# Grid setup
grid = Grid(origin=origin, extent=extent, shape=shape, dtype=float64)
x, y = grid.dimensions
dt = grid.stepping_dim.spacing
t = grid.time_dim

# Define wavefield unknowns: u (displacement) and v (velocity)
fun_labels = ['u', 'v']
U = [TimeFunction(name=name, grid=grid, space_order=2,
time_order=1, dtype=float64) for name in fun_labels]

# Source definition
src_spatial = Function(name="src_spat", grid=grid, space_order=2, dtype=float64)
src_spatial.data[1, 1] = 1
src_temporal = (1 - 2 * (t*dt - 1)**2)

# PDE system (2D acoustic)
system_eqs_rhs = [U[1] + src_spatial * src_temporal,
Derivative(U[0], (x, 2), fd_order=2) +
Derivative(U[0], (y, 2), fd_order=2) +
src_spatial * src_temporal]

# Time integration scheme
return [solve(system_eqs_rhs[i] - U[i], U[i], method=time_int) for i in range(2)]


def test_multistage_op_computing_1eq(time_int='RK44'):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since all of the tests in this file pertain to MultiStage, you can drop multistage from all function names within the file for concision

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you'r right, dropped

extent = (1, 1)
shape = (200, 200)
origin = (0, 0)
Expand All @@ -85,40 +118,44 @@ def test_multistage_op_computing_directly(time_int='RK44'):
t = grid.time_dim

# Define wavefield unknowns: u (displacement) and v (velocity)
u_multi_stage = TimeFunction(name='u_multi_stage', grid=grid, space_order=2, time_order=1, dtype=float64)
fun_labels = ['u_multi_stage', 'v_multi_stage']
U_multi_stage = [TimeFunction(name=name, grid=grid, space_order=2,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use of a capital U here makes this look like a class, consider renaming

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed for U_multi_stage and for U.

time_order=1, dtype=float64) for name in fun_labels]

# Source definition
src_spatial = Function(name="src_spat", grid=grid, space_order=2, dtype=float64)
src_spatial.data[1, 1] = 1
f0 = 0.01
src_temporal = (1 - 2 * (pi * f0 * (t * dt - 1 / f0)) ** 2) * exp(-(pi * f0 * (t * dt - 1 / f0)) ** 2)
src_temporal = (1 - 2 * (t*dt - 1)**2)

# PDE (2D heat eq.)
eq_rhs = (Derivative(u_multi_stage, (x, 2), fd_order=2) + Derivative(u_multi_stage, (y, 2), fd_order=2) +
src_spatial * src_temporal)
# PDE system
system_eqs_rhs = [U_multi_stage[1] + src_spatial * src_temporal,
Derivative(U_multi_stage[0], (x, 2), fd_order=2) +
Derivative(U_multi_stage[0], (y, 2), fd_order=2) +
src_spatial * src_temporal]

# Time integration scheme
pde = [MultiStage(eq_rhs, u_multi_stage, method=time_int)]
op = Operator(pde, subs=grid.spacing_map)
pdes = [resolve_method(time_int)(U_multi_stage[i], system_eqs_rhs[i]) for i in range(2)]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm still not a fan of this resolve_method("method_name") API. I think MethodClass(lhs, rhs) is far less ambiguous

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did a change on this, not quite as your suggestion because I think it is not friendly asking to the user to import the specific class of the time integration.

op = Operator(pdes, subs=grid.spacing_map)
op(dt=0.01, time=1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should assert a norm or similar

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand your point. But the idea of this test if only to check if op.apply() executes for multistage objects. Is not about the convergence. Do you think it is unnecessary do that?


# Solving now using Devito's standard time solver
u = TimeFunction(name='u', grid=grid, space_order=2, time_order=1, dtype=float64)
eq_rhs = (Derivative(u, (x, 2), fd_order=2) + Derivative(u, (y, 2), fd_order=2) +
src_spatial * src_temporal)
# Define wavefield unknowns: u (displacement) and v (velocity)
fun_labels = ['u', 'v']
U = [TimeFunction(name=name, grid=grid, space_order=2,
time_order=1, dtype=float64) for name in fun_labels]
system_eqs_rhs = [U[1] + src_spatial * src_temporal,
Derivative(U[0], (x, 2), fd_order=2) +
Derivative(U[0], (y, 2), fd_order=2) +
src_spatial * src_temporal]

# Time integration scheme
pde = Eq(u, solve(eq_rhs - u, u))
op = Operator(pde, subs=grid.spacing_map)
pdes = [Eq(U[i], system_eqs_rhs[i]) for i in range(2)]
op = Operator(pdes, subs=grid.spacing_map)
op(dt=0.01, time=1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should also assert something

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


return max(abs(u_multi_stage.data[0, :] - u.data[0, :]))
return max(abs(U_multi_stage[0].data[0, :] - U[0].data[0, :]))

# test_multistage_op_constructing_directly()

# test_multistage_op_computing_directly()

def test_multistage_op_solve_computing(time_int='RK44'):
def test_multistage_op_computing_directly(time_int='RK44'):
extent = (1, 1)
shape = (200, 200)
origin = (0, 0)
Expand All @@ -129,22 +166,21 @@ def test_multistage_op_solve_computing(time_int='RK44'):
dt = grid.stepping_dim.spacing
t = grid.time_dim

# Define unknown for the 'time_int' method: u (heat)
u_time_int = TimeFunction(name='u', grid=grid, space_order=2, time_order=1, dtype=float64)
# Define wavefield unknowns: u (displacement) and v (velocity)
u_multi_stage = TimeFunction(name='u_multi_stage', grid=grid, space_order=2, time_order=1, dtype=float64)

# Source definition
src_spatial = Function(name="src_spat", grid=grid, space_order=2, dtype=float64)
src_spatial.data[1, 1] = 1
f0 = 0.01
src_temporal = (1 - 2 * (pi * f0 * (t * dt - 1 / f0)) ** 2) * exp(-(pi * f0 * (t * dt - 1 / f0)) ** 2)
src_temporal = (1 - 2 * (t*dt - 1)**2)

# PDE (2D heat eq.)
eq_rhs = (Derivative(u_time_int, (x, 2), fd_order=2) + Derivative(u_time_int, (y, 2), fd_order=2) +
src_spatial * src_temporal)
eq_rhs = (Derivative(u_multi_stage, (x, 2), fd_order=2) + Derivative(u_multi_stage, (y, 2), fd_order=2) +
src_spatial * src_temporal)

# Time integration scheme
pde = solve(eq_rhs - u_time_int, u_time_int, method=time_int)
op=Operator(pde, subs=grid.spacing_map)
pde = [resolve_method(time_int)(eq_rhs, u_multi_stage)]
op = Operator(pde, subs=grid.spacing_map)
op(dt=0.01, time=1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again should assert a norm. Can also be consolidated with the previous test via parameterisation

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The same, the idea is to check if op.apply() executes without an error...


# Solving now using Devito's standard time solver
Expand All @@ -157,6 +193,4 @@ def test_multistage_op_solve_computing(time_int='RK44'):
op = Operator(pde, subs=grid.spacing_map)
op(dt=0.01, time=1)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpick: Too many blank lines at end of file

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

return max(abs(u_time_int.data[0,:]-u.data[0,:]))

# test_multistage_op_solve_computing()
return max(abs(u_multi_stage.data[0, :] - u.data[0, :]))