Fix assert error for type of keep_step#310
Fix assert error for type of keep_step#310mstoelzle wants to merge 1 commit intopatrick-kidger:mainfrom
keep_step#310Conversation
|
What version of JAX and what version of NumPy are you using? |
I am using python 3.11, numpy 1.23.5, jax 0.4.14, equinox 0.10.11, diffrax 0.4.1 |
|
Hmm. I'm not able to easily reproduce this with those versions. It should always be the case that |
|
Hi @mstoelzle and @patrick-kidger, I had the same error. However, when I initialized the model beforehand as I did during training, loading the checkpoint and subsequently using I hope this helps to resolve the issues you are having, @mstoelzle. Best regards, |
When I am running a normal integration such as
I will get an error similar to
When I print
jnp.result_type(keep_step), I getboolinstead ofjnp.dtype(bool).I would like to stress that this issue only appears for certain
ode_fn. I haven't quite figured out yet which change/property of theode_fncauses this error to occur.Still, this backwards-compatible change should work for any case.