-
Notifications
You must be signed in to change notification settings - Fork 245
dsl: Introduce abstractions for multi-stage time integrators #2599
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
1d830b8
7f087b3
214d882
d6c4d4a
78f8a0b
1c9d517
11db48b
83dfb04
d47a106
1f93a45
eea3a52
11d1429
4637ac2
ac1da7e
e9b3533
dc3dd77
1fd4a02
a0c45c1
93c6e3f
ef8d1ac
fa5acac
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -63,11 +63,7 @@ def solve(eq, target, **kwargs): | |
| sols_temp = sols[0] | ||
|
|
||
| method = kwargs.get("method", None) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is method a string here, or is it the class for the method? In the latter case, it would remove the need to have the
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's a string. The idea is that the user provides a string to identify which time integrator to apply. |
||
| if method is not None: | ||
| method_cls = resolve_method(method) | ||
| return method_cls(target, sols_temp)._evaluate(**kwargs) | ||
| else: | ||
| return sols_temp | ||
| return sols_temp if method is None else resolve_method(method)(target, sols_temp) | ||
|
|
||
|
|
||
| def linsolve(expr, target, **kwargs): | ||
|
|
||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this file should be moved to somewhere like
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I renamed the class to Regarding the file location, it’s currently in |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -127,16 +127,17 @@ def _evaluate(self, **kwargs): | |
| - `s` stage equations of the form `k_i = rhs evaluated at intermediate state` | ||
| - 1 final update equation of the form `u.forward = u + dt * sum(b_i * k_i)` | ||
| """ | ||
| eq_num = kwargs['eq_num'] | ||
| stage_id = kwargs.get('sregistry').make_name(prefix='k') | ||
|
|
||
| u = self.lhs.function | ||
| rhs = self.rhs | ||
| grid = u.grid | ||
| t = grid.time_dim | ||
| dt = t.spacing | ||
|
|
||
| # Create temporary Functions to hold each stage | ||
|
||
| # k = [Array(name=f'k{eq_num}{i}', dimensions=grid.shape, grid=grid, dtype=u.dtype) for i in range(self.s)] # Trying Array | ||
| k = [Function(name=f'k{eq_num}{i}', grid=grid, space_order=u.space_order, dtype=u.dtype) | ||
| # k = [Array(name=f'{stage_id}{i}', dimensions=grid.shape, grid=grid, dtype=u.dtype) for i in range(self.s)] # Trying Array | ||
| k = [Function(name=f'{stage_id}{i}', grid=grid, space_order=u.space_order, dtype=u.dtype) | ||
|
||
| for i in range(self.s)] | ||
|
|
||
| stage_eqs = [] | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,9 +4,11 @@ | |
| from devito import (Grid, Function, TimeFunction, | ||
| Derivative, Operator, solve, Eq) | ||
| from devito.types.multistage import resolve_method | ||
| from devito.ir.support import SymbolRegistry | ||
| from devito.ir.equations import lower_multistage | ||
|
|
||
|
|
||
| def test_multistage_solve(time_int='RK44'): | ||
| def test_multistage_object(time_int='RK44'): | ||
|
||
| extent = (1, 1) | ||
| shape = (3, 3) | ||
| origin = (0, 0) | ||
|
|
@@ -25,20 +27,19 @@ def test_multistage_solve(time_int='RK44'): | |
| # Source definition | ||
| src_spatial = Function(name="src_spat", grid=grid, space_order=2, dtype=float64) | ||
| src_spatial.data[1, 1] = 1 | ||
| f0 = 0.01 | ||
| src_temporal = (1 - 2 * (pi * f0 * (t * dt - 1 / f0)) ** 2) * exp(-(pi * f0 * (t * dt - 1 / f0)) ** 2) | ||
| src_temporal = (1 - 2 * (t*dt - 1)**2) | ||
|
||
|
|
||
| # PDE system (2D acoustic) | ||
| system_eqs_rhs = [U[1] + src_spatial * src_temporal, | ||
| Derivative(U[0], (x, 2), fd_order=2) + | ||
| Derivative(U[0], (y, 2), fd_order=2) + | ||
| src_spatial * src_temporal] | ||
| Derivative(U[0], (x, 2), fd_order=2) + | ||
| Derivative(U[0], (y, 2), fd_order=2) + | ||
| src_spatial * src_temporal] | ||
|
|
||
| # Time integration scheme | ||
| return [solve(system_eqs_rhs[i] - U[i], U[i], method=time_int, eq_num=i) for i in range(2)] | ||
| # Class of the time integration scheme | ||
| return [resolve_method(time_int)(U[i], system_eqs_rhs[i]) for i in range(2)] | ||
|
|
||
|
|
||
| def test_multistage_op_constructing_directly(time_int='RK44'): | ||
| def test_multistage_lower_multistage(time_int='RK44'): | ||
| extent = (1, 1) | ||
|
||
| shape = (3, 3) | ||
| origin = (0, 0) | ||
|
|
@@ -57,23 +58,55 @@ def test_multistage_op_constructing_directly(time_int='RK44'): | |
| # Source definition | ||
| src_spatial = Function(name="src_spat", grid=grid, space_order=2, dtype=float64) | ||
| src_spatial.data[1, 1] = 1 | ||
| f0 = 0.01 | ||
| src_temporal = (1 - 2 * (pi * f0 * (t * dt - 1 / f0)) ** 2) * exp(-(pi * f0 * (t * dt - 1 / f0)) ** 2) | ||
| src_temporal = (1 - 2 * (t*dt - 1)**2) | ||
|
|
||
| # PDE system (2D acoustic) | ||
| system_eqs_rhs = [U[1] + src_spatial * src_temporal, | ||
| Derivative(U[0], (x, 2), fd_order=2) + | ||
| Derivative(U[0], (y, 2), fd_order=2) + | ||
| src_spatial * src_temporal] | ||
|
|
||
| # Time integration scheme | ||
|
|
||
| # Class of the time integration scheme | ||
| pdes = [resolve_method(time_int)(U[i], system_eqs_rhs[i]) for i in range(2)] | ||
| op = Operator(pdes, subs=grid.spacing_map) | ||
| op(dt=0.01, time=1) | ||
|
|
||
| sregistry=SymbolRegistry() | ||
|
|
||
| def test_multistage_op_computing_directly(time_int='RK44'): | ||
| return lower_multistage(pdes, sregistry=sregistry) | ||
|
|
||
|
|
||
|
|
||
| def test_multistage_solve(time_int='RK44'): | ||
| extent = (1, 1) | ||
| shape = (3, 3) | ||
| origin = (0, 0) | ||
|
|
||
| # Grid setup | ||
| grid = Grid(origin=origin, extent=extent, shape=shape, dtype=float64) | ||
| x, y = grid.dimensions | ||
| dt = grid.stepping_dim.spacing | ||
| t = grid.time_dim | ||
|
|
||
| # Define wavefield unknowns: u (displacement) and v (velocity) | ||
| fun_labels = ['u', 'v'] | ||
| U = [TimeFunction(name=name, grid=grid, space_order=2, | ||
| time_order=1, dtype=float64) for name in fun_labels] | ||
|
|
||
| # Source definition | ||
| src_spatial = Function(name="src_spat", grid=grid, space_order=2, dtype=float64) | ||
| src_spatial.data[1, 1] = 1 | ||
| src_temporal = (1 - 2 * (t*dt - 1)**2) | ||
|
|
||
| # PDE system (2D acoustic) | ||
| system_eqs_rhs = [U[1] + src_spatial * src_temporal, | ||
| Derivative(U[0], (x, 2), fd_order=2) + | ||
| Derivative(U[0], (y, 2), fd_order=2) + | ||
| src_spatial * src_temporal] | ||
|
|
||
| # Time integration scheme | ||
| return [solve(system_eqs_rhs[i] - U[i], U[i], method=time_int) for i in range(2)] | ||
|
|
||
|
|
||
| def test_multistage_op_computing_1eq(time_int='RK44'): | ||
|
||
| extent = (1, 1) | ||
| shape = (200, 200) | ||
| origin = (0, 0) | ||
|
|
@@ -85,40 +118,44 @@ def test_multistage_op_computing_directly(time_int='RK44'): | |
| t = grid.time_dim | ||
|
|
||
| # Define wavefield unknowns: u (displacement) and v (velocity) | ||
| u_multi_stage = TimeFunction(name='u_multi_stage', grid=grid, space_order=2, time_order=1, dtype=float64) | ||
| fun_labels = ['u_multi_stage', 'v_multi_stage'] | ||
| U_multi_stage = [TimeFunction(name=name, grid=grid, space_order=2, | ||
|
||
| time_order=1, dtype=float64) for name in fun_labels] | ||
|
|
||
| # Source definition | ||
| src_spatial = Function(name="src_spat", grid=grid, space_order=2, dtype=float64) | ||
| src_spatial.data[1, 1] = 1 | ||
| f0 = 0.01 | ||
| src_temporal = (1 - 2 * (pi * f0 * (t * dt - 1 / f0)) ** 2) * exp(-(pi * f0 * (t * dt - 1 / f0)) ** 2) | ||
| src_temporal = (1 - 2 * (t*dt - 1)**2) | ||
|
|
||
| # PDE (2D heat eq.) | ||
| eq_rhs = (Derivative(u_multi_stage, (x, 2), fd_order=2) + Derivative(u_multi_stage, (y, 2), fd_order=2) + | ||
| src_spatial * src_temporal) | ||
| # PDE system | ||
| system_eqs_rhs = [U_multi_stage[1] + src_spatial * src_temporal, | ||
| Derivative(U_multi_stage[0], (x, 2), fd_order=2) + | ||
| Derivative(U_multi_stage[0], (y, 2), fd_order=2) + | ||
| src_spatial * src_temporal] | ||
|
|
||
| # Time integration scheme | ||
| pde = [MultiStage(eq_rhs, u_multi_stage, method=time_int)] | ||
| op = Operator(pde, subs=grid.spacing_map) | ||
| pdes = [resolve_method(time_int)(U_multi_stage[i], system_eqs_rhs[i]) for i in range(2)] | ||
|
||
| op = Operator(pdes, subs=grid.spacing_map) | ||
| op(dt=0.01, time=1) | ||
|
||
|
|
||
| # Solving now using Devito's standard time solver | ||
| u = TimeFunction(name='u', grid=grid, space_order=2, time_order=1, dtype=float64) | ||
| eq_rhs = (Derivative(u, (x, 2), fd_order=2) + Derivative(u, (y, 2), fd_order=2) + | ||
| src_spatial * src_temporal) | ||
| # Define wavefield unknowns: u (displacement) and v (velocity) | ||
| fun_labels = ['u', 'v'] | ||
| U = [TimeFunction(name=name, grid=grid, space_order=2, | ||
| time_order=1, dtype=float64) for name in fun_labels] | ||
| system_eqs_rhs = [U[1] + src_spatial * src_temporal, | ||
| Derivative(U[0], (x, 2), fd_order=2) + | ||
| Derivative(U[0], (y, 2), fd_order=2) + | ||
| src_spatial * src_temporal] | ||
|
|
||
| # Time integration scheme | ||
| pde = Eq(u, solve(eq_rhs - u, u)) | ||
| op = Operator(pde, subs=grid.spacing_map) | ||
| pdes = [Eq(U[i], system_eqs_rhs[i]) for i in range(2)] | ||
| op = Operator(pdes, subs=grid.spacing_map) | ||
| op(dt=0.01, time=1) | ||
|
||
|
|
||
| return max(abs(u_multi_stage.data[0, :] - u.data[0, :])) | ||
| return max(abs(U_multi_stage[0].data[0, :] - U[0].data[0, :])) | ||
|
|
||
| # test_multistage_op_constructing_directly() | ||
|
|
||
| # test_multistage_op_computing_directly() | ||
|
|
||
| def test_multistage_op_solve_computing(time_int='RK44'): | ||
| def test_multistage_op_computing_directly(time_int='RK44'): | ||
| extent = (1, 1) | ||
| shape = (200, 200) | ||
| origin = (0, 0) | ||
|
|
@@ -129,22 +166,21 @@ def test_multistage_op_solve_computing(time_int='RK44'): | |
| dt = grid.stepping_dim.spacing | ||
| t = grid.time_dim | ||
|
|
||
| # Define unknown for the 'time_int' method: u (heat) | ||
| u_time_int = TimeFunction(name='u', grid=grid, space_order=2, time_order=1, dtype=float64) | ||
| # Define wavefield unknowns: u (displacement) and v (velocity) | ||
| u_multi_stage = TimeFunction(name='u_multi_stage', grid=grid, space_order=2, time_order=1, dtype=float64) | ||
|
|
||
| # Source definition | ||
| src_spatial = Function(name="src_spat", grid=grid, space_order=2, dtype=float64) | ||
| src_spatial.data[1, 1] = 1 | ||
| f0 = 0.01 | ||
| src_temporal = (1 - 2 * (pi * f0 * (t * dt - 1 / f0)) ** 2) * exp(-(pi * f0 * (t * dt - 1 / f0)) ** 2) | ||
| src_temporal = (1 - 2 * (t*dt - 1)**2) | ||
|
|
||
| # PDE (2D heat eq.) | ||
| eq_rhs = (Derivative(u_time_int, (x, 2), fd_order=2) + Derivative(u_time_int, (y, 2), fd_order=2) + | ||
| src_spatial * src_temporal) | ||
| eq_rhs = (Derivative(u_multi_stage, (x, 2), fd_order=2) + Derivative(u_multi_stage, (y, 2), fd_order=2) + | ||
| src_spatial * src_temporal) | ||
|
|
||
| # Time integration scheme | ||
| pde = solve(eq_rhs - u_time_int, u_time_int, method=time_int) | ||
| op=Operator(pde, subs=grid.spacing_map) | ||
| pde = [resolve_method(time_int)(eq_rhs, u_multi_stage)] | ||
| op = Operator(pde, subs=grid.spacing_map) | ||
| op(dt=0.01, time=1) | ||
|
||
|
|
||
| # Solving now using Devito's standard time solver | ||
|
|
@@ -157,6 +193,4 @@ def test_multistage_op_solve_computing(time_int='RK44'): | |
| op = Operator(pde, subs=grid.spacing_map) | ||
| op(dt=0.01, time=1) | ||
|
|
||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nitpick: Too many blank lines at end of file
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
| return max(abs(u_time_int.data[0,:]-u.data[0,:])) | ||
|
|
||
| # test_multistage_op_solve_computing() | ||
| return max(abs(u_multi_stage.data[0, :] - u.data[0, :])) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you do
return [_lower_multistage(expr, **kwargs) for i in exprs for expr in i]?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did something like that...