Skip to content

Commit

Permalink
implicit graph construction using graphelement
Browse files Browse the repository at this point in the history
fixes #148
  • Loading branch information
hexaeder committed Sep 27, 2024
1 parent 6dd9cf0 commit 326dcea
Show file tree
Hide file tree
Showing 6 changed files with 177 additions and 4 deletions.
5 changes: 4 additions & 1 deletion docs/src/API.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,12 @@ coupling
### Component Metadata API
```@docs
metadata
get_metadata(::NetworkDynamics.ComponentFunction, ::Symbol)
has_metadata(::NetworkDynamics.ComponentFunction, ::Symbol)
get_metadata(::NetworkDynamics.ComponentFunction, ::Symbol)
set_metadata!(::NetworkDynamics.ComponentFunction, ::Symbol, ::Any)
has_graphelement
get_graphelement
set_graphelement!
```
### Per-Symbol Metadata API
```@docs
Expand Down
8 changes: 7 additions & 1 deletion docs/src/metadata.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,13 @@ Component metadata is a `Dict{Symbol,Any}` attached to each component to store v

To access the data, you can use the methods `has_metadata`, `get_metadata` and `set_metadata!` (see [Component Metadata API](@ref)).

Special uses: after [component wise initialization](@ref), the field `:init_residual` stores the residual vector of the nonlinear problem.
Special metadata:

- `:init_residual`: after [component wise initialization](@ref), this field stores the residual vector of the nonlinear problem.
- `:graphelement`: optional field to specialize the graphelement for each
component (`vidx`) for vertices, `(;src,dst)` named tuple of either vertex
names or vertex indices for edges. Has special accessors `has_/get_/set_graphelement`.


## Symbol Metadata
Each component stores symbol metadata. The symbol metadata is a `Dict{Symbol, Dict{Symbol, Any}}` which stores a metadate dict per symbol. Symbols are everything that appears in [`sym`](@ref), [`psym`](@ref), [`obssym`](@ref) and [`inputsym`](@ref).
Expand Down
4 changes: 3 additions & 1 deletion src/NetworkDynamics.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
module NetworkDynamics
using Graphs: Graphs, AbstractGraph, SimpleEdge, edges, vertices, ne, nv
using Graphs: Graphs, AbstractGraph, SimpleEdge, edges, vertices, ne, nv,
SimpleGraph, SimpleDiGraph, add_edge!, has_edge
using TimerOutputs: @timeit_debug, reset_timer!

using ArgCheck: @argcheck
Expand Down Expand Up @@ -43,6 +44,7 @@ export has_default, get_default, set_default!
export has_guess, get_guess, set_guess!
export has_init, get_init, set_init!
export has_bounds, get_bounds, set_bounds!
export has_graphelement, get_graphelement, set_graphelement!
include("component_functions.jl")

export Network
Expand Down
41 changes: 40 additions & 1 deletion src/component_functions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,17 @@ function _fill_defaults(T, kwargs)
_maybewrap!(dict, :obssym, Symbol)

symmetadata = get!(dict, :symmetadata, Dict{Symbol,Dict{Symbol,Any}}())
metadata = get!(dict, :metadata, Dict{Symbol,Any}())

metadata = try
convert(Dict{Symbol,Any}, get!(dict, :metadata, Dict{Symbol,Any}()))
catch e
throw(ArgumentError("Provided metadata keyword musst be a Dict{Symbol,Any}. Got $(repr(dict[:metadata]))."))
end

if haskey(dict, :graphelement)
ge = pop!(dict, :graphelement)
metadata[:graphelement] = ge
end

# sym & dim
haskey(dict, :dim) || haskey(dict, :sym) || throw(ArgumentError("Either `dim` or `sym` must be provided to construct $T."))
Expand Down Expand Up @@ -743,3 +753,32 @@ get_metadata(c::ComponentFunction, key::Symbol) = metadata(c)[key]
Sets the metadata `key` for the component to `value`.
"""
set_metadata!(c::ComponentFunction, key::Symbol, val) = setindex!(metadata(c), val, key)

#### graphelement field for edges and vertices
"""
has_graphelement(c)
Checks if the edge or vetex function function has the `graphelement` metadata.
"""
has_graphelement(c::EdgeFunction) = has_metadata(c, :graphelement)
has_graphelement(c::VertexFunction) = has_metadata(c, :graphelement)
"""
get_graphelement(c::EdgeFunction)::@NamedTuple{src::T, dst::T}
get_graphelement(c::VertexFunction)::Int
Retrieves the `graphelement` metadata for the component function. For edges this
returns a named tupe `(;src, dst)` where both are either integers (vertex index)
or symbols (vertex name).
"""
get_graphelement(c::EdgeFunction) = get_metadata(c, :graphelement)::@NamedTuple{src::T, dst::T} where {T<:Union{Int,Symbol}}
get_graphelement(c::VertexFunction) = get_metadata(c, :graphelement)::Int
"""
set_graphelement!(c::EdgeFunction, src, dst)
set_graphelement!(c::VertexFunction, vidx)
Sets the `graphelement` metadata for the edge function. For edges this takes two
arguments `src` and `dst` which are either integer (vertex index) or symbol
(vertex name). For vertices it takes a single integer `vidx`.
"""
set_graphelement!(c::EdgeFunction, nt::@NamedTuple{src::T, dst::T}) where {T<:Union{Int,Symbol}} = set_metadata!(c, :graphelement, nt)
set_graphelement!(c::VertexFunction, vidx::Int) = set_metadata!(c, :graphelement, vidx)
90 changes: 90 additions & 0 deletions src/construction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,28 @@ function Network(g::AbstractGraph,
@argcheck length(_vertexf) == nv(g)
@argcheck length(_edgef) == ne(g)

# check if graphelement is set correctly, warn otherwise
for (i, v) in pairs(_vertexf)
if has_graphelement(v)
if get_graphelement(v) != i
@warn "Vertex function $v has wrong `:graphelement` $(get_graphelement(v)) != $i. \
Using this constructor the provided `:graphelement` is ignored!"
end
end
end
if any(has_graphelement, _edgef)
vnamedict = _unique_name_dict(vertexfs)

for (iteredge, comf) in zip(edges(g), _edgef)
if has_graphelement(compf)
if iteredge != _resolve_ge_to_edge(ge, vnamedict)
@warn "Edge function $comf has wrong `:graphelement` $(get_graphelement(comf)) != $iteredge. \
Using this constructor the provided `:graphelement` is ignored!"
end
end
end
end

verbose &&
println("Create dynamic network with $(nv(g)) vertices and $(ne(g)) edges:")
@argcheck execution isa ExecutionStyle "Execution type $execution not supported (choose from $(subtypes(ExecutionStyle)))"
Expand Down Expand Up @@ -96,6 +118,74 @@ function Network(g::AbstractGraph,
return nw
end

function Network(vertexfs, edgefs; kwargs...)
@argcheck all(has_graphelement, vertexfs) "All vertex functions must have assigned `graphelement` to implicitly construct graph!"
@argcheck all(has_graphelement, edgefs) "All edge functions must have assigned `graphelement` to implicitly construct graph!"

vidxs = get_graphelement.(vertexfs)
allunique(vidxs) || throw(ArgumentError("All vertex functions must have unique `graphelement`!"))
sort(vidxs) == 1:length(vidxs) || throw(ArgumentError("Vertex functions must have `graphelement` in range 1:length(vertexfs)!"))

vdict = Dict(vidxs .=> vertexfs)

vnamedict = _unique_name_dict(vertexfs)

simpleedges = map(edgefs) do e
ge = get_graphelement(e)
_resolve_ge_to_edge(ge, vnamedict)
end
allunique(simpleedges) || throw(ArgumentError("Not all assigned edges are unique!"))
edict = Dict(simpleedges .=> edgefs)

# if all src < dst then we can use SimpleGraph, else digraph
g = if all(e -> e.src < e.dst, simpleedges)
SimpleGraph(length(vertexfs))
else
SimpleDiGraph(length(vertexfs))
end
for edge in simpleedges
if g isa SimpleDiGraph && has_edge(g, edge.dst, edge.src)
@warn "Edges $(edge.src) -> $(edge.dst) and $(edge.dst) -> $(edge.src) are both present in the graph!"
end
r = add_edge!(g, edge)
r || error("Could not add edge $(edge) to graph $(g)!")
end

vfs_ordered = [vdict[k] for k in vertices(g)]
efs_ordered = [edict[k] for k in edges(g)]

Network(g, vfs_ordered, efs_ordered; kwargs...)
end

function _unique_name_dict(cfs::AbstractVector{<:ComponentFunction})
# find all names to resolve
names = getproperty.(cfs, :name)
dict = Dict(names .=> vidxs)
# delete all names which occure multiple times
for i in eachindex(names)
if names[i] @views names[i+1:end]
delete!(dict, names[i])
end
end
dict
end
# resolve the graphelement ge (named tuple) to simple edge with potential lookup in vertex name dict dict
function _resolve_ge_to_edge(ge, vnamedict)
src = if ge.src isa Symbol
haskey(vnamedict, ge.src) || throw(ArgumentError("Edge function has unknown or non-unique source vertex name $(ge.src)"))
vnamedict[ge.src]
else
ge.src
end
dst = if ge.dst isa Symbol
haskey(vnamedict, ge.dst) || throw(ArgumentError("Edge function has unknown or non-unique source vertex name $(ge.dst)"))
vnamedict[ge.dst]
else
ge.dst
end
SimpleEdge(src, dst)
end

function VertexBatch(im::IndexManager, idxs::Vector{Int}; verbose)
components = @view im.vertexf[idxs]

Expand Down
33 changes: 33 additions & 0 deletions test/construction_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,39 @@ using Graphs
@test_throws ArgumentError nd(_du, _u, rand(pdim(nd)+1), 0.0)
end

@testset "graphless constructor" begin
@test_throws ArgumentError ODEVertex(x->x^1, 2, 0; metadata="foba")
v1 = ODEVertex(x->x^1, 2, 0; metadata=Dict(:graphelement=>1), name=:v1)
@test has_graphelement(v1) && get_graphelement(v1) == 1
v2 = ODEVertex(x->x^2, 2, 0; name=:v2)
set_graphelement!(v2, 3)
v3 = ODEVertex(x->x^3, 2, 0; name=:v3)
set_graphelement!(v3, 2)

e1 = StaticEdge(nothing, 0, Symmetric(); graphelement=(;src=1,dst=2))
@test get_graphelement(e1) == (;src=1,dst=2)
e2 = StaticEdge(nothing, 0, Symmetric())
set_graphelement!(e2, (;src=:v3,dst=:v2))
e3 = StaticEdge(nothing, 0, Symmetric())

@test_throws ArgumentError Network([v1,v2,v3], [e1,e2,e3])
set_graphelement!(e3, (;src=3,dst=1))

nw = Network([v1,v2,v3], [e1,e2,e3])
@test nw.im.vertexf == [v1, v3, v2]
g = SimpleDiGraph(3)
add_edge!(g, 1, 2)
add_edge!(g, 2, 3)
add_edge!(g, 3, 1)
@test nw.im.g == g

set_graphelement!(e3, (;src=1,dst=2))
@test_throws ArgumentError Network([v1,v2,v3], [e1,e2,e3])

set_graphelement!(e3, (;src=2,dst=1))
Network([v1,v2,v3], [e1,e2,e3]) # throws waring about 1->2 and 2->1 beeing present
end

@testset "Vertex batch" begin
using NetworkDynamics: BatchStride, VertexBatch, parameter_range
vb = VertexBatch{ODEVertex, typeof(sum), Vector{Int}}([1, 2, 3, 4], # vertices
Expand Down

0 comments on commit 326dcea

Please sign in to comment.