Fix complex-valued t_span
for complex-valued ODEs in odeint
#179
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Following up on my PR #178 and on this issue #177.
In the current state of
odeint
, the time valuest
are complex-valued ifx
is initially complex-valued when callingf(t, x)
at every time step. Instead, they should be float like. Otherwise, this could be problematic iff(t, x)
requires real-valued times (e.g. iff(t, x) = torch.erf(t) * x
).This PR fixes this by making changes to the
solver.tableau
definitions. Thec
constants in the tableaus are initialized with a float dtype of the same precision as the (complex or float) dtype ofx
. This is done by callingt_dtype = getattr(torch, torch.finfo(x.dtype).dtype)
. Ifx
is float-valued, thent_dtype = x.dtype
and things work as usual. Changes insync_device_dtype
were also required to differentiate between thet.dtype
and thex.dtype
.Sorry for having missed this in the initial PR !