Skip to content

Commit

Permalink
Add in place integration
Browse files Browse the repository at this point in the history
  • Loading branch information
Michele Zaffalon committed Mar 21, 2024
1 parent 0e8470d commit 28f4d7a
Show file tree
Hide file tree
Showing 3 changed files with 154 additions and 6 deletions.
41 changes: 36 additions & 5 deletions src/HCubature.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ module HCubature
using StaticArrays, LinearAlgebra
import Combinatorics, DataStructures, QuadGK

export hcubature, hquadrature, hcubature_buffer
export hcubature, hcubature!, hquadrature, hcubature_buffer

include("genz-malik.jl")
include("gauss-kronrod.jl")
Expand All @@ -35,6 +35,7 @@ end
Base.isless(i::Box, j::Box) = isless(i.E, j.E)

cubrule(v::Val{n}, ::Type{T}) where {n,T} = GenzMalik(v, T)
cubrule(v::Val{n}, ::Type{T}, r::Q) where {n,T,Q} = GenzMalik_InPlace(v, T, r)
cubrule(::Val{1}, ::Type{T}) where {T} = GaussKronrod(T)

# trivial rule for 0-dimensional integrals
Expand Down Expand Up @@ -96,13 +97,13 @@ function hcubature_buffer_(f, a::Tuple{Vararg{Real,n}}, b::Tuple{Vararg{Real,n}}
hcubature_buffer_(f, SVector{n}(float.(a)), SVector{n}(float.(b)), norm)
end

function hcubature_(f::F, a::SVector{n,T}, b::SVector{n,T}, norm, rtol_, atol, maxevals, initdiv, buf) where {F, n, T<:Real}
function hcubature_(f::F, a::SVector{n,T}, b::SVector{n,T}, norm, rtol_, atol, maxevals, initdiv, buf, rule) where {F, n, T<:Real}
rtol = rtol_ == 0 == atol ? sqrt(eps(T)) : rtol_
(rtol < 0 || atol < 0) && throw(ArgumentError("invalid negative tolerance"))
maxevals < 0 && throw(ArgumentError("invalid negative maxevals"))
initdiv < 1 && throw(ArgumentError("initdiv must be positive"))

rule = cubrule(Val{n}(), T)
#rule = cubrule(Val{n}(), T)
numevals = evals_per_box = countevals(rule)

Δ = (b-a) / initdiv
Expand Down Expand Up @@ -176,12 +177,34 @@ function hcubature_(f, a::AbstractVector{T}, b::AbstractVector{S},
norm, rtol, atol, maxevals, initdiv, buf) where {T<:Real, S<:Real}
length(a) == length(b) || throw(DimensionMismatch("endpoints $a and $b must have the same length"))
F = float(promote_type(T, S))
return hcubature_(f, SVector{length(a),F}(a), SVector{length(a),F}(b), norm, rtol, atol, maxevals, initdiv, buf)
rule = cubrule(Val{length(a)}(), F)
return hcubature_(f, SVector{length(a),F}(a), SVector{length(a),F}(b), norm, rtol, atol, maxevals, initdiv, buf, rule)
end
function hcubature_(f, a::Tuple{Vararg{Real,n}}, b::Tuple{Vararg{Real,n}}, norm, rtol, atol, maxevals, initdiv, buf) where {n}
hcubature_(f, SVector{n}(float.(a)), SVector{n}(float.(b)), norm, rtol, atol, maxevals, initdiv, buf)
sa = SVector{n}(float.(a))
sb = SVector{n}(float.(b))
rule = cubrule(Val{length(sa)}(), eltype(sa))
hcubature_(f, sa, sb, norm, rtol, atol, maxevals, initdiv, buf, rule)
end


function hcubature_(result::Q, f!, a::AbstractVector{T}, b::AbstractVector{S},

Check warning on line 191 in src/HCubature.jl

View check run for this annotation

Codecov / codecov/patch

src/HCubature.jl#L191

Added line #L191 was not covered by tests
norm, rtol, atol, maxevals, initdiv, buf) where {Q, T<:Real, S<:Real}
length(a) == length(b) || throw(DimensionMismatch("endpoints $a and $b must have the same length"))
F = float(promote_type(T, S))
rule = cubrule(Val(length(a)), F, result)
return hcubature_(f!, SVector{length(a),F}(a), SVector{length(a),F}(b), norm, rtol, atol, maxevals, initdiv, buf, rule)

Check warning on line 196 in src/HCubature.jl

View check run for this annotation

Codecov / codecov/patch

src/HCubature.jl#L193-L196

Added lines #L193 - L196 were not covered by tests
end
function hcubature_(result::Q, f!, a::Tuple{Vararg{Real,n}}, b::Tuple{Vararg{Real,n}}, norm, rtol, atol, maxevals, initdiv, buf) where {n, Q}
sa = SVector{n}(float.(a))
sb = SVector{n}(float.(b))
rule = cubrule(Val(length(sa)), eltype(sa), result)
hcubature_(f!, sa, sb, norm, rtol, atol, maxevals, initdiv, buf, rule)
end




"""
hcubature(f, a, b; norm=norm, rtol=sqrt(eps), atol=0, maxevals=typemax(Int),
initdiv=1, buffer=nothing)
Expand Down Expand Up @@ -236,6 +259,14 @@ hcubature(f, a, b; norm=norm, rtol::Real=0, atol::Real=0,
hcubature_(f, a, b, norm, rtol, atol, maxevals, initdiv, buffer)


function hcubature!(result, f!, a, b; norm=norm, rtol::Real=0, atol::Real=0,
maxevals::Integer=typemax(Int), initdiv::Integer=1, buffer=nothing)
(I, E) = hcubature_(result, f!, a, b, norm, rtol, atol, maxevals, initdiv, buffer)
result .= I
return I, E
end


"""
hquadrature(f, a, b; norm=norm, rtol=sqrt(eps), atol=0, maxevals=typemax(Int), initdiv=1)
Expand Down
110 changes: 109 additions & 1 deletion src/genz-malik.jl
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,8 @@ function GenzMalik(v::Val{n}, ::Type{T}=Float64) where {n, T<:Real}
return g
end

countevals(g::GenzMalik{n}) where {n} = 1 + 4n + 2*n*(n-1) + (1<<n)

#countevals(g::GenzMalik{n}) where {n} = 1 + 4n + 2*n*(n-1) + (1<<n)

"""
genzmalik(f, a, b, norm=norm)
Expand Down Expand Up @@ -165,3 +166,110 @@ function (g::GenzMalik{n,T})(f::F, a::SVector{n}, b::SVector{n}, norm=norm) wher

return I, E, kdivide
end


struct GenzMalik_InPlace{n,T<:Real,Q}
p::NTuple{4,Vector{SVector{n,T}}} # points for the last 4 G-M weights
w::NTuple{5,T} # weights for the 5 terms in the G-M rule
w′::NTuple{4,T} # weights for the embedded lower-degree rule

# internal variables
f₁::Q
f₂::Q
f₃::Q
twelvef₁::Q
f₂ᵢ::Q
f₃ᵢ::Q
f₄::Q
f₅::Q
t::Q
end

countevals(g::Union{GenzMalik_InPlace{n},GenzMalik{n}}) where {n} = 1 + 4n + 2*n*(n-1) + (1<<n)

function GenzMalik_InPlace(v::Val{n}, ::Type{T}, r::Q) where {n, T<:Real, Q}
#haskey(gmcache, (n,T)) && return gmcache[n,T]::GenzMalik{n,T}

n < 2 && throw(ArgumentError("invalid dimension $n: GenzMalik rule requires dimension ≠ 2"))

λ₄ = sqrt(9/T(10))
λ₂ = sqrt(9/T(70))
λ₃ = λ₄
λ₅ = sqrt(9/T(19))

twoⁿ = 1 << n
w₁ = twoⁿ * ((12824 - 9120n + 400n^2) / T(19683))
w₂ = twoⁿ * (980 / T(6561))
w₃ = twoⁿ * ((1820 - 400n) / T(19683))
w₄ = twoⁿ * (200 / T(19683))
w₅ = 6859/T(19683)
w₄′ = twoⁿ * (25/T(729))
w₃′ = twoⁿ * ((265 - 100n)/T(1458))
w₂′ = twoⁿ * (245/T(486))
w₁′ = twoⁿ * ((729 - 950n + 50n^2)/T(729))

p₂ = combos(1, λ₂, v)
p₃ = combos(1, λ₃, v)
p₄ = signcombos(2, λ₄, v)
p₅ = signcombos(n, λ₅, v)

g = GenzMalik_InPlace{n,T,Q}((p₂,p₃,p₄,p₅), (w₁,w₂,w₃,w₄,w₅), (w₁′,w₂′,w₃′,w₄′),
similar(r), similar(r), similar(r), similar(r),
similar(r), similar(r), similar(r), similar(r),
similar(r))
#gmcache[n,T] = g
return g
end


function (g::GenzMalik_InPlace{n,T,Q})(f!::F, a::SVector{n}, b::SVector{n}, norm=norm) where {F, n,T,Q}
c = T(0.5).*(a.+b)
Δ = T(0.5).*(b.-a)
V = prod(Δ)

f!(g.f₁, c)
fill!(g.f₂, 0)
fill!(g.f₃, 0)
g.twelvef₁ .= 12g.f₁
maxdivdiff = zero(norm(g.f₁))
divdiff = similar(SVector{n,typeof(maxdivdiff)})
for i = 1:n
p₂ = Δ .* g.p[1][i]
f!(g.f₂ᵢ, c + p₂); f!(g.t, c - p₂); g.f₂ᵢ .+= g.t
p₃ = Δ .* g.p[2][i]
f!(g.f₃ᵢ, c + p₃); f!(g.t, c - p₃); g.f₃ᵢ .+= g.t
g.f₂ .+= g.f₂ᵢ
g.f₃ .+= g.f₃ᵢ
# fourth divided difference: f₃ᵢ-2f₁ - 7*(f₂ᵢ-2f₁),
# where 7 = (λ₃/λ₂)^2 [see van Dooren and de Ridder]
divdiff[i] = norm(g.f₃ᵢ + g.twelvef₁ - 7*g.f₂ᵢ)
end

fill!(g.f₄, 0)
for p in g.p[3]
f!(g.t, c .+ Δ .* p); g.f₄ .+= g.t
end

fill!(g.f₅, 0)
for p in g.p[4]
f!(g.t, c .+ Δ .* p); g.f₅ .+= g.t
end

I = V * (g.w[1]*g.f₁ + g.w[2]*g.f₂ + g.w[3]*g.f₃ + g.w[4]*g.f₄ + g.w[5]*g.f₅)
I′ = V * (g.w′[1]*g.f₁ + g.w′[2]*g.f₂ + g.w′[3]*g.f₃ + g.w′[4]*g.f₄)
E = norm(I - I′)

# choose axis
kdivide = 1
δf = E / (10^n * V)
for i = 1:n
if= divdiff[i] - maxdivdiff) > δf
kdivide = i
maxdivdiff = divdiff[i]
elseif abs(δ) <= δf && abs(Δ[i]) > abs(Δ[kdivide])
kdivide = i
end
end

return I, E, kdivide
end
9 changes: 9 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -110,3 +110,12 @@ end
@test hquadrature(x -> exp(-x^2), T(0), T(1); rtol = 1e-20)[1] 0.7468241328124270254
@test hcubature(x -> exp(-x[1]^2), T.((0,0)), T.((1,1)); rtol = 1e-20)[1] 0.7468241328124270254
end

@testset "in place" begin
f(x) = [cos(x[1]) sin(x[2]); cos(x[1])^2 cos(x[2])^2]
f!(r, x) = r .= f(x)
result = zeros(2, 2)
I₁ = hcubature(f, (-1,0), (1,1))[1]
I₂ = hcubature!(result, f!, (-1,0), (1,1))[1]
@test I₁ I₂
end

0 comments on commit 28f4d7a

Please sign in to comment.