-
-
Notifications
You must be signed in to change notification settings - Fork 214
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
WIP: change to new @muladd #89
Conversation
Tests are canceled/will fail since the new |
@@ -131,6 +131,6 @@ end | |||
|
|||
copy!(u,nlres) | |||
integrator.f[2](t+dt,nlres,rtmp1) | |||
integrator.fsallast .= A*u .+ rtmp1 | |||
integrator.fsallast = A*u .+ rtmp1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I somehow lost a dot, it should be added again
@@ -35,14 +35,14 @@ function ode_determine_initdt{tType,uType}(u0,t::tType,tdir,dtmax,abstol,reltol, | |||
|
|||
#@. u₁ = @muladd u0 + tdir*dt₀*f₀ | |||
@tight_loop_macros for i in uidx | |||
@inbounds u₁[i] = u0[i] + tdir*dt₀*f₀[i] | |||
@inbounds u₁[i] = u0[i] + (tdir*dt₀)*f₀[i] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
probably best to just hoist that out of the loop. Does @muladd
work on this now? I think it had a problem with it before.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes that's probably the best. @muladd
works on this now, it handles also products with more than two factors (should write a test for it...) - but since it is not a dot call all factors should be scalars. It does not fail for vectors of vectors since tdir*dt₀
is a scalar, but it also does not lead to a call of @llvm.fmuladd
; see SciML/DiffEqBase.jl#57 (comment) and:
julia> @code_lowered muladd(1., 1., 1.)
CodeInfo(:(begin
nothing
return (Base.muladd_float)(x, y, z)
end))
julia> @code_lowered muladd(1., [1.], [1.])
CodeInfo(:(begin
nothing
return x * y + z
end))
julia> @code_lowered muladd([1.], [1.], [1.])
CodeInfo(:(begin
nothing
return x * y + z
end))
In the current implementation @muladd a*b*c+d
is transformed to muladd(a, b*c, d)
- so the first factor always ends up as the first argument to muladd
and the product of all other factors builds the second argument. I don't know if this is better/more natural than muladd(a*b, c, d)
.
In this case @muladd
produces
julia> macroexpand(:(@muladd @tight_loop_macros for i in uidx
@inbounds u₁[i] = u0[i] + (tdir*dt₀)*f₀[i]
end))
:(for i = uidx # REPL[6], line 2:
begin
$(Expr(:inbounds, true))
u₁[i] = (muladd)(tdir * dt₀, f₀[i], u0[i])
$(Expr(:inbounds, :pop))
end
end)
and without brackets:
julia> macroexpand(:(@muladd @tight_loop_macros for i in uidx
@inbounds u₁[i] = u0[i] + tdir*dt₀*f₀[i]
end))
:(for i = uidx # REPL[7], line 2:
begin
$(Expr(:inbounds, true))
u₁[i] = (muladd)(tdir, dt₀ * f₀[i], u0[i])
$(Expr(:inbounds, :pop))
end
end)
But of course the best is to move the multiplication of the first two factors completely out of the loop.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice, good to see that works though!
k3 = f(t+c2*dt, @. uprev+dt*(a31*k1+a32*k2)) | ||
k4 = f(t+c3*dt, @. uprev+dt*(a41*k1+a42*k2+a43*k3)) | ||
k5 = f(t+c4*dt, @. uprev+dt*(a51*k1+a52*k2+a53*k3+a54*k4)) | ||
k6 = f(t+dt, @. uprev+dt*(a61*k1+a62*k2+a63*k3+a64*k4+a65*k5)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
are the @.
s in these spots okay for performance? It should be fine in the long run since I'm hoping to get that fixed so we should stick with it for now (and usage with actual arrays should just be inplace anyways), but I'd like to know if you've checked.
Merging, but still open question: are the @.s in these spots okay for performance? It should be fine in the long run since I'm hoping to get that fixed so we should stick with it for now (and usage with actual arrays should just be inplace anyways), but I'd like to know if you've checked. |
At https://gist.github.com/devmotion/a3096d1c051449178f200bcbd202f66e#file-benchmarks_bs3-jl you can find a first benchmark of the new implementation of the step calculations with |
I did the same for |
Nice observation! That may be the only one that's small enough though, since I know |
I haven't benchmarked However, since the old, undotted implementation prevents the application of @inline @muladd function perform_step!(integrator,cache::BS5ConstantCache,f=integrator.f)
@unpack t,dt,uprev,u,k = integrator
@unpack c1,c2,c3,c4,c5,a21,a31,a32,a41,a42,a43,a51,a52,a53,a54,a61,a62,a63,a64,a65,a71,a72,a73,a74,a75,a76,a81,a83,a84,a85,a86,a87,bhat1,bhat3,bhat4,bhat5,bhat6,btilde1,btilde2,btilde3,btilde4,btilde5,btilde6,btilde7,btilde8 = cache
k1 = integrator.fsalfirst
a = dt*a21
if typeof(u) <: AbstractArray
uidx = eachindex(uprev)
tmp = similar(uprev)
@tight_loop_macros for i in uidx
@inbounds tmp[i] = uprev[i]+a*k1[i]
end
k2 = f(t+c1*dt,tmp)
@tight_loop_macros for i in uidx
@inbounds tmp[i] = uprev[i]+dt*(a31*k1[i]+a32*k2[i])
end
k3 = f(t+c2*dt,tmp)
@tight_loop_macros for i in uidx
@inbounds tmp[i] = uprev[i]+dt*(a41*k1[i]+a42*k2[i]+a43*k3[i])
end
k4 = f(t+c3*dt,tmp)
@tight_loop_macros for i in uidx
@inbounds tmp[i] = uprev[i]+dt*(a51*k1[i]+a52*k2[i]+a53*k3[i]+a54*k4[i])
end
k5 = f(t+c4*dt,tmp)
@tight_loop_macros for i in uidx
@inbounds tmp[i] = uprev[i]+dt*(a61*k1[i]+a62*k2[i]+a63*k3[i]+a64*k4[i]+a65*k5[i])
end
k6 = f(t+c5*dt,tmp)
@tight_loop_macros for i in uidx
@inbounds tmp[i] = uprev[i]+dt*(a71*k1[i]+a72*k2[i]+a73*k3[i]+a74*k4[i]+a75*k5[i]+a76*k6[i])
end
k7 = f(t+dt,tmp)
utmp = similar(u)
@tight_loop_macros for i in uidx
@inbounds utmp[i] = uprev[i]+dt*(a81*k1[i]+a83*k3[i]+a84*k4[i]+a85*k5[i]+a86*k6[i]+a87*k7[i])
end
u = convert(typeof(u),utmp) # fixes problem with StaticArrays where typeof(u) != typeof(utmp)
integrator.fsallast = f(t+dt,u); k8 = integrator.fsallast
if integrator.opts.adaptive
tmptilde = similar(uprev)
@tight_loop_macros for (i,atol,rtol) in zip(uidx,Iterators.cycle(integrator.opts.abstol),Iterators.cycle(integrator.opts.reltol))
@inbounds uhat = dt*(bhat1*k1[i] + bhat3*k3[i] + bhat4*k4[i] + bhat5*k5[i] + bhat6*k6[i])
@inbounds utilde = uprev[i] + dt*(btilde1*k1[i] + btilde2*k2[i] + btilde3*k3[i] + btilde4*k4[i] + btilde5*k5[i] + btilde6*k6[i] + btilde7*k7[i] + btilde8*k8[i])
@inbounds tmp[i] = uhat[i]/(atol+max(abs(uprev[i]),abs(u[i]))*rtol)
@inbounds tmptilde[i] = (utilde[i]-u[i])/(atol+max(abs(uprev[i]),abs(u[i]))*rtol)
end
EEst1 = integrator.opts.internalnorm(tmp)
EEst2 = integrator.opts.internalnorm(tmptilde)
integrator.EEst = max(EEst1,EEst2)
end
else
k2 = f(t+c1*dt, uprev+a*k1)
k3 = f(t+c2*dt, uprev+dt*(a31*k1+a32*k2))
k4 = f(t+c3*dt, uprev+dt*(a41*k1+a42*k2+a43*k3))
k5 = f(t+c4*dt, uprev+dt*(a51*k1+a52*k2+a53*k3+a54*k4))
k6 = f(t+c5*dt, uprev+dt*(a61*k1+a62*k2+a63*k3+a64*k4+a65*k5))
k7 = f(t+dt, uprev+dt*(a71*k1+a72*k2+a73*k3+a74*k4+a75*k5+a76*k6))
u = uprev+dt*(a81*k1+a83*k3+a84*k4+a85*k5+a86*k6+a87*k7)
integrator.fsallast = f(t+dt,u); k8 = integrator.fsallast
if integrator.opts.adaptive
uhat = dt*(bhat1*k1 + bhat3*k3 + bhat4*k4 + bhat5*k5 + bhat6*k6)
utilde = uprev + dt*(btilde1*k1 + btilde2*k2 + btilde3*k3 + btilde4*k4 + btilde5*k5 + btilde6*k6 + btilde7*k7 + btilde8*k8)
EEst1 = integrator.opts.internalnorm(uhat/(integrator.opts.abstol+max(abs(uprev),abs(u))*integrator.opts.reltol))
EEst2 = integrator.opts.internalnorm((utilde-u)/(integrator.opts.abstol+max(abs(uprev),abs(u))*integrator.opts.reltol))
integrator.EEst = max(EEst1,EEst2)
end
end
integrator.k[1]=k1; integrator.k[2]=k2; integrator.k[3]=k3;integrator.k[4]=k4;integrator.k[5]=k5;integrator.k[6]=k6;integrator.k[7]=k7;integrator.k[8]=k8
@pack integrator = t,dt,u,k
end The advantage of this implementation is that it does not use any broadcasts but still correctly applies So I would strongly suggest commenting out dot calls as soon as we hit the limit (as the examples for DelayDiffEq and BS3 show, it should be fine for smaller expressions) and instead replacing them with the so-called mixed implementation. |
That would break StaticArrays though, which is arguably the most clear use for the not-in-place versions. I don't think we should put too much effort into workarounds here (except for maybe the most popular methods), and instead just focus on getting broadcast fusion fixed in Base and other algorithmic details. It's puzzling that broadcasts are so much faster though when the loops are small. Have you done any tests to try and find out why? Is SIMD not being applied to the loops or something? |
The tests with StaticArrays pass successfully, also with this mixed implementation.
No, I haven't done any tests, and I'm also surprised. I just copied the loops from OrdinaryDiffEq, how could I check if SIMD is not applied or something else is missing? |
These are some first changes to the integrators in order to correctly apply the new
@muladd
macro SciML/DiffEqBase.jl#57 (which might still end in a separate package?). There are still some integrators missing and one should also update the calculations in thedense
subdirectory.A general problem is that
muladd
can only act as dot call on vectors, resulting in many broadcasts especially for not in-place methods without for loops. However, according to JuliaLang/julia#22255 this leads to performance issues. Nevertheless for the beginning I added dots to many not in-place methods, similar to the commented code, since it guarantees the correct application ofmuladd
. Probably a workaround for these problems (as long as there is no fix for the performance issues) is to replace expressions with 12+ broadcasts with the not dotted expression if all variables are scalars and otherwise create an output vector which can be filled in a for loop (as it is done for in-place methods); depending on the type ofu
one of these options would be executed. This should guarantee a correct application ofmuladd
but needs code duplication.