Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions diffrax/_integrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -1084,6 +1084,22 @@ def _promote(yi):
y0 = jtu.tree_map(_promote, y0)
del timelikes

# Check if the solver is an instance of AbstractSolver and provide an informative
# error if it is not. Addresses https://github.com/patrick-kidger/diffrax/issues/705
if not isinstance(solver, AbstractSolver):
if issubclass(solver, AbstractSolver):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will error out if solver is not a class. We need isinstance(solver, type) and issubclass(solver, AbstractSolver).

msg = (
"It looks like you forgot to instantiate your solver, e.g. by passing "
"`dfx.Euler` instead of `dfx.Euler()`."
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The dfx abbreviation isn't as standardized as eqx is; I think we should write out diffrax instead.

)
raise ValueError(msg)
else:
msg = (
"Argument `solver` must be an instance of (some subclass of) "
"`dfx.AbstractSolver`, but its type is not recognised."
)
raise ValueError(msg)

# Backward compatibility
if isinstance(solver, (EulerHeun, ItoMilstein, StratonovichMilstein)):
try:
Expand Down