diff --git a/devito/symbolics/manipulation.py b/devito/symbolics/manipulation.py index e43494974d..ec241fc967 100644 --- a/devito/symbolics/manipulation.py +++ b/devito/symbolics/manipulation.py @@ -6,6 +6,7 @@ from sympy import Pow, Add, Mul, Min, Max, S, SympifyError, Tuple, sympify from sympy.core.add import _addsort from sympy.core.mul import _mulsort +from sympy.functions.elementary.piecewise import ExprCondPair from devito.finite_differences.differentiable import ( EvalDerivative, IndexDerivative @@ -356,6 +357,8 @@ def pow_to_mul(expr): else: # Default. We should not end up here as all cases are handled return expr + elif expr.func is ExprCondPair: + return expr.func(*[pow_to_mul(i) for i in expr.args]) else: args = [pow_to_mul(i) for i in expr.args] diff --git a/tests/test_symbolics.py b/tests/test_symbolics.py index 367090d0d6..8daf598249 100644 --- a/tests/test_symbolics.py +++ b/tests/test_symbolics.py @@ -7,7 +7,7 @@ from sympy import And, Expr, Number, Symbol from devito import (Constant, Dimension, Grid, Function, solve, TimeFunction, Eq, # noqa Operator, SubDimension, norm, Le, Ge, Gt, Lt, Abs, sin, cos, - Min, Max, Real, Imag, Conj, SubDomain, configuration) + Piecewise, Min, Max, Real, Imag, Conj, SubDomain, configuration) from devito.finite_differences.differentiable import SafeInv, Weights, Mul from devito.ir import Expression, FindNodes, ccode from devito.ir.support.guards import GuardExpr, simplify_and @@ -1104,6 +1104,15 @@ def test_print_div(): assert cstr == 'sizeof(int)/sizeof(long)' +def test_piecewise(): + grid = Grid(shape=(11,)) + u = Function(name='u', grid=grid, space_order=2) + v = Function(name='v', grid=grid, space_order=2) + eq_u = Eq(u, Piecewise((1, v < 10), (2, True))) + op = Operator(eq_u) + # check that the code generated a condition + assert "v[x + 2] < 10" in str(op.ccode) + def test_customdtype_complex(): """ Test that `CustomDtype` doesn't brak is_imag