-
-
Notifications
You must be signed in to change notification settings - Fork 167
Added JumpStepWrapper #484
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
78b122a to
0eac356
Compare
|
I now also added the functionality to revisit rejected steps. In addition, I also imporved the runtime of Also I think there was a bug in the PID controller, where it would sometimes reject a step, but have diffrax/diffrax/_step_size_controller/adaptive.py Lines 569 to 574 in 501bed5
I think possibly something smaller than just self.safety would make even more sense, I feel like if a step is rejected the next step should be at least 0.5x smaller. But I'm not an expert.
I added a test for revisiting steps and it all seems to work. I also sprinkled in a bunch of I think I commented the code quite well, so hopefully you can easily notice if I made a mistake somewhere. P.S.: Sorry for bombarding you with PRs. As far as I'm concerned this one is very low priority, I can use the code even if it isn't merged into diffrax proper. |
d022ac1 to
4702380
Compare
|
Hi @patrick-kidger, diffrax/benchmarks/jump_step_timing.py Lines 126 to 128 in 345e23a
|
patrick-kidger
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, quick first pass at a review!
|
Thanks for the review! I made all the edits I could and I left some comments where I need guidance (no hurry though, this is not high priority for me). Also, should I get rid of |
0050fa2 to
c3c4dcf
Compare
If it's easy to do that in a separate commit afterwards then I would say yes. A separate commit just so it's easy to revert if it turns out we were wrong about something here :D |
patrick-kidger
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, I'm really sorry for taking so long to get around to this one! Some other work projects got in the way for a bit. (But on the plus side I have a few more open source projects in the pipe, keep an eye out for those ;) ) This is a really useful PR that I very much want to see in.
I've just done another revivew, LMK what you think!
|
Thanks for the review, Patrick! I'll probably make the fixes sometime in the coming week. I am also making progress on the ML examples for the Single-seed paper, but it is slower now, due to my internship. |
|
I am very confused about what the correct value of Suppose there is a jump at t=2. I will present 2 possible scenarios, in both of which I think something goes wrong (although maybe diffeqsolve might correct for the issue in scenario B). I wrote them as if JSW and the controller are separate, but the same holds for just the old PID controller. ====== scenario A =======
====== scenario B =======
Another way of seeing this all is through this:
Hence setting |
|
So I think the Line 390 in daec89c
I made the decision to handle some of the step-rejection logic in the main So I think this fine? Do double-check my logic though! :p Other than that, one thing I am noticing is that this |
|
Great, that's exactly the line I was looking for (I must admit I looked in Thinking about it now, the Edit: I already implemented what I mentioned above and wrote the proof in a comment. If you're curious and have extra time (yes I know that's a very far tail event :)) you can find it on my |
e203b53 to
7325e74
Compare
|
Hi Patrick! I just pushed a new version of this PR, rebased on top of the most current main. I think I addressed everything you asked me to fix. As it stands this contains 3 commits, contatining:
I left some conversations unresloved. I did try to fix the things mentioned in those, but I am not sure whether what I did was the best way to tackle that so I wanted to hear your opinion. Also the test are failing because pyright doesn't know how to import PS: The linear search I added slows it down compared to the way I wrote it before, but it is still faster than the old implementation with binary search. In particular the times (as obtained by
Additionally, changing the length of |
patrick-kidger
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay! I think I really like this.
First of all, I think I'm basically happy with pretty much everything outside of jump_step_wrapper.py. The changes here are pleasingly simple ^^
For jump_step_wrapper.py, I think my main question is around whether the rejected-step-buffer should actually be part of this wrapper at all -- since that handles SDEs with any kind of step rejection, which I think is completely orthogonal to clipping steps? (Not sure how I didn't notice this before!) I've also commented on a few other more minor points.
By the way, what did you think of the idea of moving next_made_jump into _integrate.py? It doesn't have to be now -- happy for that to be a separate PR -- just checking your thoughts on whether it is a generalisable thing.
Finally: merry Christmas, and a happy new year! :D
diffrax/_step_size_controller/pid.py
Outdated
| at_dtmin = at_dtmin | (prev_dt <= self.dtmin) | ||
| keep_step = keep_step | at_dtmin |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, does at_dtmin need to be state? (I'm not sure it ever did.) I think we might just be able to have keep_step = keep_step | (prev_dt <= self.dtmin)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, good point. Looking at the code, I don't really see a reason to keep it, but I'll get rid of it in a separate commit so we can roll it back easily.
| The `step_ts` and `jump_ts` are used to force the solver to step to certain times. | ||
| They mostly act in the same way, except that when we hit an element of `jump_ts`, | ||
| the controller must return `made_jump = True`, so that the diffeqsolve function | ||
| knows that the vector field has a discontinuity at that point, in which case it | ||
| re-evaluates it right after the jump point. In addition, the | ||
| exact time of the jump will be skipped using eqxi.prevbefore and eqxi.nextafter. | ||
| So now to the explanation of the two (we will use `step_ts` as an example, but the | ||
| same applies to `jump_ts`): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I often rewrite parts of docs after merging anyway, so feel free to ignore this for now -- but just a heads-up that this part is discussing a lot of implementation details: made_jump = True and eqxi.{prevbefore,nextafter} are not details familiar to most users.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I intended to keep that as a comment, not part of the docs, but you suggested I put it in the docstring and I wasn't sure what exactly you wanted. I don't have strong opinions here, so feel free to rewrite it however you wish.
| i = jax.lax.while_loop(cond_up, lambda _i: _i + 1, i) | ||
| i = jax.lax.while_loop(cond_down, lambda _i: _i - 1, i) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we have both of these loops? I think we only need a linear search in one direction: to find the next element of ts to clip to?
(And if we do need a bidirectional search, then given a hint n it's probably more efficient to search e.g. n / n+1 / n-1 / n+2 / n-2 / ... etc back and forth?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
At most one of the two loops will trigger, so I am pretty sure doing it this way is faster and cleaner than your second suggestion (unless I'm getting it completely wrong??). And this is probably the safest option, but yeah I think we can easily just have the upwards loop, everything should still work if my logic is correct. Well in fact everything worked perfectly without any loops at all, the reason we added this is to be extra sure there aren't any edge cases. So I'd say if we want to be safe, let's be completely safe and have both loops. But up to you.
|
|
||
| # This is just a logging utility for testing purposes | ||
| if self.callback_on_reject is not None: | ||
| jax.debug.callback(self.callback_on_reject, keep_step, t1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I might suggest making this a pure_callback or io_callback, so that it will definitely be called in the right order across steps. JAX doesn't actually offer guarantees about the order in which multiple debug callbacks are called.
See for example how eqx.error_if works, which does the same thing by requiring a token.
(There is actually jax.debug.callback(..., ordered=True), but this works by having JAX sneakily rewriting the jaxpr to thread a dummy argument through as a token so as to order things... and I think that edge cases, so I try to avoid it.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good! I don't think the order matters in the test I use this for, but I suppose might as well do it properly.
| # Let's prove that the line below is correct. Say the inner controller is | ||
| # itself a JumpStepWrapper (JSW) with some inner_jump_ts. Then, given that | ||
| # it propsed (next_t0, original_next_t1), there cannot be any jumps in | ||
| # inner_jump_ts between next_t0 and original_next_t1. So if the next_t1 | ||
| # proposed by the outer JSW is different from the original_next_t1 then | ||
| # next_t1 \in (next_t0, original_next_t1) and hence there cannot be a jump | ||
| # in inner_jump_ts at next_t1. So the jump_at_next_t1 only depends on | ||
| # jump_at_next_t1. | ||
| # On the other hand if original_next_t1 == next_t1, then we just take an | ||
| # OR of the two. | ||
| jump_at_next_t1 = jnp.where( | ||
| next_t1 == original_next_t1, | ||
| jump_at_original_next_t1, | ||
| jump_at_next_t1 | jump_at_original_next_t1, | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, I don't think I completely believe this. Can we have the following:
- the PID controller proposes
t1. - the inner JSW wants to clip to a jump
b < t1. - the outer JSW wants to clip to a step (not a jump!)
a < b
?
In this case then we will have next_t1 != original_next_t1, an jump_at_original_next_t1 == True... but overall we want made_jump == False?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+can we have a test for two tested JSW, including the above scenario? It doesn't need to be a full diffeqsolve, just directly calling adapt_step_size and checking that we get the right output.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I bungled the code completely, because it doesn't do what's written in the comment at all. But I believe the comment (if implemented correctly) is correct.
I think the code should be:
jump_at_next_t1 = jnp.where(
next_t1 == original_next_t1,
jump_at_next_t1 | jump_at_original_next_t1,
jump_at_next_t1,
)
And this indeed works in the case you brought up as well.
Given this, I don't think moving next_made_jump into integrate is strictly necessary at this point.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, yes I'll add a test.
05bbcb2 to
8d4212c
Compare
|
Hi Patrick! Sorry for the long silence, the last few weeks have been very busy. I made the corrections you suggested. Regarding splitting off revisiting steps from So unless you feel very strongly about separating these two features, I would prefer not to add extra hurdles to this PR and conclude it in the near future. I think Owen has been itching to get this done as fast as possible as well. |
Leaving it here on that basis sounds totally reasonable to me. Let's just tweak the name from |
patrick-kidger
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, I started reviewing this... and now it's way too late in the evening for me to parse the math on this! :D
The one big thing I want to check for is that the numerics in JumpStepWrapper are doing the correct thing. I'll do that tomorrow and then let's aim to merge this. :)
| sol_no_jump_ts = run() | ||
| sol_with_jump_ts = run(jump_ts=[7.5]) | ||
| assert sol_no_jump_ts.stats["num_steps"] > sol_with_jump_ts.stats["num_steps"] | ||
| print(sol_no_jump_ts.stats["num_steps"], sol_with_jump_ts.stats["num_steps"]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
!
effab83 to
8c5e87e
Compare
|
Okay @andyElking, I've finished going over this PR and I've pushed changes back to this branch in one commit. Changes are:
Lmk what you think! |
|
I read your changes and it all LGTM. Indeed the trick of just checking equality over all Oh, I never really thought of using the I completely agree with the renaming as well. Thanks for adding the citation! Also I find the Thanks for finishing this up! |
8c5e87e to
d2ed71a
Compare
|
Awesome. In that case: now merged. Thank you for your efforts driving this one! I hope it will be useful for some of the interesting adaptive SDE work you have lined up. :) |
Hi Patrick,
I factored the
jump_tsandstep_tsout of thePIDControllerintoJumpStepWrapper(I'm not very set on this name, lmk if you have ideas). I also made it behave as we discussed in #483. In particular, the following three rules are maintained:t1-t0 <= prev_dt(this is checked viaeqx.error_if), with inequality only if the step was clipped or if we hit the end of the integration interval (we do not explicitly check for that).next_dtmust be>=prev_dt.next_dtmust be< t1-t0.We achieve this in a very simple way here:
diffrax/diffrax/_step_size_controller/jump_step_wrapper.py
Lines 119 to 123 in 78b122a
The next step is to add a parameter
JumpStepWrapper.revisit_rejected_stepswhich does what you expect. That will appear in a future commit in this same PR.