Skip to content

Commit

Permalink
Merge pull request #26 from rafaqz/loose_typed_nodes
Browse files Browse the repository at this point in the history
More types for nodes
  • Loading branch information
ChrisRackauckas authored May 8, 2018
2 parents 8f2a994 + b14833b commit c7bb8fe
Show file tree
Hide file tree
Showing 5 changed files with 117 additions and 18 deletions.
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,15 @@ tissue2 = construct(Tissue, deepcopy([population2, population]))
embryo = construct(Embryo, deepcopy([tissue1, tissue2])) # Make an embryo from Tissues
```

Note that tuples can be used as well. This allows for type-stable indexing with
heterogeneous nodes. For example:

```julia
tissue1 = construct(Tissue, deepcopy((population, cell3)))
```

(of course at the cost of mutability).

The head node then acts as the king. It is designed to have functionality which
mimics a vector in order for usage in DifferentialEquations or Optim. So for example

Expand Down
7 changes: 5 additions & 2 deletions src/indexing.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
nodeselect(ns, i, I...) = ns[i][I...]
nodechild(ns, i, j) = ns[i].nodes[j]

function bisect_search(a, i)
first(searchsorted(a,i))
end
Expand All @@ -7,7 +10,7 @@ Base.IndexStyle(::Type{<:AbstractMultiScaleArray}) = IndexLinear()
@inline function getindex(m::AbstractMultiScaleArray, i::Int)
idx = bisect_search(m.end_idxs, i)
idx > 1 && (i -= m.end_idxs[idx-1]) # also works with values
(isempty(m.values) || idx < length(m.end_idxs)) ? m.nodes[idx][i] : m.values[i]
(isempty(m.values) || idx < length(m.end_idxs)) ? nodeselect(m.nodes, idx, i) : m.values[i]
end

@inline function setindex!(m::AbstractMultiScaleArray, nodes, i::Int)
Expand All @@ -25,7 +28,7 @@ end

@inline function getindex(m::AbstractMultiScaleArray, i, I...)
if isempty(m.values) || i < length(m.end_idxs)
length(I) == 1 ? m.nodes[i].nodes[I[1]] : m.nodes[i][I...]
length(I) == 1 ? nodechild(m.nodes, i, I[1]) : nodeselect(m.nodes, i, I...)
else
m.values[I...]
end
Expand Down
30 changes: 14 additions & 16 deletions src/shape_construction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,25 +24,11 @@ end
end

recursive_similar(x,T) = [similar(y, T) for y in x]
recursive_similar(x::Tuple,T) = tuple((similar(y, T) for y in x)...)

construct(::Type{T}, args...) where {T<:AbstractMultiScaleArrayLeaf} = T(args...)

function __construct(nodes::Vector{<:AbstractMultiScaleArray})
end_idxs = Vector{Int}(length(nodes))
off = 0
@inbounds for i in 1:length(nodes)
end_idxs[i] = (off += length(nodes[i]))
end
end_idxs
end

function (construct(::Type{T}, nodes::Vector{<:AbstractMultiScaleArray},args...)
where {T<:AbstractMultiScaleArray})
T(nodes, eltype(T)[], __construct(nodes),args...)
end

function (construct(::Type{T}, nodes::Vector{<:AbstractMultiScaleArray}, values, args...)
where {T<:AbstractMultiScaleArray})
function __construct(T, nodes, values, args...)
vallen = length(values)
end_idxs = Vector{Int}(length(nodes) + ifelse(vallen == 0, 0, 1))
off = 0
Expand All @@ -53,6 +39,18 @@ function (construct(::Type{T}, nodes::Vector{<:AbstractMultiScaleArray}, values,
T(nodes, values, end_idxs, args...)
end

(construct(::Type{T}, nodes::AbstractVector{<:AbstractMultiScaleArray}, args...)
where {T<:AbstractMultiScaleArray}) = __construct(T, nodes, eltype(T)[], args...)

(construct(::Type{T}, nodes::AbstractVector{<:AbstractMultiScaleArray}, values, args...)
where {T<:AbstractMultiScaleArray}) = __construct(T, nodes, values, args...)

(construct(::Type{T}, nodes::Tuple{Vararg{<:AbstractMultiScaleArray}}, args...)
where {T<:AbstractMultiScaleArray}) = __construct(T, nodes, eltype(T)[], args...)

(construct(::Type{T}, nodes::Tuple{Vararg{<:AbstractMultiScaleArray}}, values, args...)
where {T<:AbstractMultiScaleArray}) = __construct(T, nodes, values, args...)

vcat(m1::AbstractMultiScaleArray, m2::AbstractMultiScaleArray) =
error("AbstractMultiScaleArrays cannot be concatenated")

Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using MultiScaleArrays, OrdinaryDiffEq, DiffEqBase, StochasticDiffEq
using Base.Test

@time @testset "Tuple Nodes" begin include("tuple_nodes.jl") end
@time @testset "Bisect Search Tests" begin include("bisect_search_tests.jl") end
@time @testset "Indexing and Creation Tests" begin include("indexing_and_creation_tests.jl") end
@time @testset "Values Indexing" begin include("values_indexing.jl") end
Expand Down
88 changes: 88 additions & 0 deletions test/tuple_nodes.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
using MultiScaleArrays, OrdinaryDiffEq, DiffEqBase, StochasticDiffEq
using Base.Test

struct PlantSettings{T} x::T end
struct OrganParams{T} y::T end

struct Organ{B<:Number,P} <: AbstractMultiScaleArrayLeaf{B}
values::Vector{B}
name::Symbol
params::P
end

struct Plant{B,S,N<:Tuple{Vararg{<:Organ{<:Number}}}} <: AbstractMultiScaleArray{B}
nodes::N
values::Vector{B}
end_idxs::Vector{Int}
settings::S
end

struct Community{B,N<:Tuple{Vararg{<:Plant{<:Number}}}} <: AbstractMultiScaleArray{B}
nodes::N
values::Vector{B}
end_idxs::Vector{Int}
end

mutable struct Scenario{B,N<:Tuple{Vararg{<:Community{<:Number}}}} <: AbstractMultiScaleArrayHead{B}
nodes::N
values::Vector{B}
end_idxs::Vector{Int}
end

organ1 = Organ([1.1,2.1,3.1], :Shoot, OrganParams(:grows_up))
organ2 = Organ([4.1,5.1,6.1], :Root, OrganParams("grows down"))
organ3 = Organ([1.2,2.2,3.2], :Shoot, OrganParams(true))
organ4 = Organ([4.2,5.2,6.2], :Root, OrganParams(1//3))
plant1 = construct(Plant, (deepcopy(organ1), deepcopy(organ2)), Float64[], PlantSettings(1))
plant2 = construct(Plant, (deepcopy(organ3), deepcopy(organ4)), Float64[], PlantSettings(1.0))
community = construct(Community, (deepcopy(plant1), deepcopy(plant2), ))
scenario = construct(Scenario, (deepcopy(community),))

@inferred getindex(organ1, 1)
@inferred getindex(plant1, 3)
@inferred getindex(community, 4)
@inferred getindex(scenario, 8)

@test scenario[1] == 1.1
@test scenario[2] == 2.1
@test scenario[3] == 3.1
@test scenario[4] == 4.1
@test scenario[5] == 5.1
@test scenario[6] == 6.1
@test scenario[7] == 1.2
@test scenario[8] == 2.2
@test scenario[9] == 3.2
@test scenario[10] == 4.2
@test scenario[11] == 5.2
@test scenario[12] == 6.2

@test getindices(scenario, 1) == 1:12
@test getindices(scenario, 1, 1) == 1:6
@test getindices(scenario, 1, 2) == 7:12
@test getindices(scenario, 1, 1, 1) == 1:3
@test getindices(scenario, 1, 1, 2) == 4:6
@test getindices(scenario, 1, 2, 1) == 7:9
@test getindices(scenario, 1, 2, 2) == 10:12

organ_ode = function (dorgan,organ,p,t)
m = mean(organ)
for i in eachindex(organ)
dorgan[i] = -m*organ[i]
end
end
f = function (dscenario,scenario,p,t)
for (organ, y, z) in LevelIterIdx(scenario, 2)
organ_ode(@view(dscenario[y:z]),organ,p,t)
end
end
affect! = function (integrator)
add_node!(integrator, integrator.u[1, 1, 1], 1, 1)
end

println("ODE with tuple nodes")

prob = ODEProblem(f, scenario, (0.0, 1.0))

sol = solve(prob, Tsit5())

@test length(sol[end]) == 12

0 comments on commit c7bb8fe

Please sign in to comment.