-
-
Notifications
You must be signed in to change notification settings - Fork 27
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Pass full integrator instead of parameters #116
Comments
Yes, this makes sense. I am a little worried about compile times, but maybe it all just quickly compiles away. |
Yes, hopefully the compiler is smart enough. However, there's another issue: in the same way we have to pass around Alternatively, one could define functions such as function DiffEqBase.analytic(f::ODEFunction{iip,unpack}, u, integrator, t) where {iip,unpack}
has_analytic(f) || error("analytical solution is not defined")
unpack ? f.analytic(u, get_p(integrator), t) : f.analytic(u, integrator, t)
end for all such overloads, but I don't know if this makes any difference. I still like the idea of attacking this problem on the lowest level, but of course an alternative would be to explicitly define function perform_step!(integrator, cache::BS3ConstantCache)
p = unpack_params(integrator, integrator.f)
.....
end
unpack_params(integrator::ODEIntegrator, ::ODEFunction{iip,false}) where iip = integrator
unpack_params(integrator::ODEIntegrator, ::ODEFunction{iip,true}) where iip = get_p(integrator) |
We can also hack it with getproperty overloading |
I'm working on a prototype for However, I'm not sure how to deal with the fact that Can we get around this problem somehow by not caching |
Since passing around the integrator in OrdinaryDiffEq is not completely straightforward (at least it seems to me), I started playing around with something that's more centered around the use case in DelayDiffEq. One idea was to use getproperty overloading such that all calls of using DelayDiffEq, DiffEqBase, Test
struct ODEFunctionWrapper{iip,F,H} <: DiffEqBase.AbstractODEFunction{iip}
f::F
h::H
end
function wrap(prob::DDEProblem)
ODEFunctionWrapper{isinplace(prob.f),typeof(prob.f),typeof(prob.h)}(prob.f, prob.h)
end
(f::ODEFunctionWrapper{false})(u, p, t) = f.f(u, f.h, p, t)
(f::ODEFunctionWrapper{true})(du, u, p, t) = f.f(du, u, f.h, p, t)
struct TestStruct{F,A}
f::F
a::A
end
function buildTestStruct(prob::DDEProblem, u, p, t)
f = wrap(prob)
a = f(u, p, t)
TestStruct(f, a)
end
function buildTestStruct(prob::DDEProblem, du, u, p, t)
f = wrap(prob)
f(du, u, p, t)
TestStruct(f, first(du))
end
function Base.getproperty(test::TestStruct, x::Symbol)
if x === :f
f = getfield(test, :f)
if isinplace(f)
(du, u, p, t) -> f.f(du, u, (p, t) -> [t * test.a], p, t)
else
(u, p, t) -> f.f(u, (p, t) -> t * test.a, p, t)
end
else
getfield(test, x)
end
end
function calc(test::TestStruct, u, p, t)
f = test.f
f(u, p, t)
end
function calc!(test::TestStruct, du, u, p, t)
f = test.f
f(du, u, p, t)
nothing
end
function f_ip(du, u, h, p, t)
du[1] = h(p, t)[1] - u[1]
nothing
end
f_scalar(u, h, p, t) = h(p, t) - u
function test()
prob_ip = DDEProblem(f_ip, [1.0], (p, t) -> [0.0], (0.0, 10.0))
prob_scalar = DDEProblem(f_scalar, 1.0, (p, t) -> 0.0, (0.0, 10.0))
wrap_ip = wrap(prob_ip)
wrap_scalar = wrap(prob_scalar)
a = [0.0]
wrap_ip(a, [5.0], nothing, 0.0)
@test a[1] == - 5.0
wrap_ip(a, [5.0], nothing, 5.0)
@test a[1] == - 5.0
wrap_ip(a, [5.0], nothing, 10.0)
@test a[1] == - 5.0
@test wrap_scalar(5.0, nothing, 0.0) == - 5.0
@test wrap_scalar(5.0, nothing, 5.0) == - 5.0
@test wrap_scalar(5.0, nothing, 10.0) == - 5.0
struct_ip = buildTestStruct(prob_ip, [0.0], [5.0], nothing, 4.0)
@test struct_ip.a == -5.0
struct_scalar = buildTestStruct(prob_scalar, 5.0, nothing, 4.0)
@test struct_scalar.a == -5.0
b = [0.0]
calc!(struct_ip, b, [5.0], nothing, 1.0)
@test b[1] == -10.0
calc!(struct_ip, b, [5.0], nothing, 4.0)
@test b[1] == -25.0
@test calc(struct_scalar, 5.0, nothing, 2.0) == -15.0
@test calc(struct_scalar, 5.0, nothing, 6.0) == -35.0
end However, I'm not sure, how this will affect performance if it is possible at all. |
As discussed in SciML/DiffEqProblemLibrary.jl#39, especially for the history function it seems reasonable to pass the full integrator as argument instead of only the parameters, i.e., having
h(integrator, t)
instead ofh(p, t)
and alsof(u, h, integrator, t)
instead off(u, h, p, t)
. This would enable the user to write generic history functions with correct output types (see the discussion in the PR) and hopefully allow to simplify the implementation in DelayDiffEq.According to @ChrisRackauckas
I think we should approach this issue slightly differently. A user has to decide whether to pass around the integrator or only the parameters already when implementing
f
(orh
), i.e., it is a property that does not depend on the numerical algorithm but rather of the differential equation function. Hence I guess it would make sense to handle this issue by modifyingDiffEqFunction
s instead of different algorithms. We could replacewith
and then define, e.g.,
In that way, we just have to implement
get_p
for everyintegrator
(which would beintegrator.p
by default) and could always passintegrator
in every package.The text was updated successfully, but these errors were encountered: