Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

path compression variants for union-find IntDisjointSet #913

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/DataStructures.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ module DataStructures
include("queue.jl")
include("accumulator.jl")
include("disjoint_set.jl")
export PCRecursive, PCIterative, PCHalving, PCSplitting
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

after making it an enum, export the enum

include("heaps.jl")

include("default_dict.jl")
Expand Down
57 changes: 56 additions & 1 deletion src/disjoint_set.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,64 @@
return p
end

# iterative path compression: makes every node on the path point directly to the root
@inline function find_root_iterative!(parents::Vector{T}, x::Integer) where {T<:Integer}
current = x
# find the root of the tree
@inbounds while parents[current] != current
current = parents[current]
end
root = current
# compress the path: make every node point directly to the root
current = x
@inbounds while parents[current] != root
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

address the test coverage warning.

p = parents[current] # temporarily store the parent
parents[current] = root # point directly to the root
current = p # move to the next node in the original path
end

Check warning on line 77 in src/disjoint_set.jl

View check run for this annotation

Codecov / codecov/patch

src/disjoint_set.jl#L74-L77

Added lines #L74 - L77 were not covered by tests
return root
end

# path-halving and path-splitting are a one-pass forms of path compression with inverse-ackerman complexity
# e.g., see p.19 of https://www.cs.princeton.edu/courses/archive/spr11/cos423/Lectures/PathCompressionAnalysisII.pdf

# path-halving: every node on the path points to its grandparent
@inline function find_root_halving!(parents::Vector{T}, x::Integer) where {T<:Integer}
current = x # use a separate variable 'current' to track traversal
@inbounds while parents[current] != current
@inbounds parents[current] = parents[parents[current]] # point to grandparent
@inbounds current = parents[current] # move to grandparent
end
return current
end

# path-splitting: every node on the path points to its grandparent
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the exact difference between path compression using path halving and path splitting? it's not very clear. can you illustrate with an example?

@inline function find_root_splitting!(parents::Vector{T}, x::Integer) where {T<:Integer}
@inbounds while parents[x] != x
p = parents[x] # store the current parent
parents[x] = parents[p] # point to grandparent
x = p # move to parent
end
return x
end


struct PCRecursive end # path compression types
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make this an enum.

struct PCIterative end # path compression types
struct PCHalving end # path compression types
struct PCSplitting end # path compression types

"""
find_root!(s::IntDisjointSet{T}, x::T)

Find the root element of the subset that contains an member `x`.
Path compression happens here.
"""
find_root!(s::IntDisjointSet{T}, x::T) where {T<:Integer} = find_root_impl!(s.parents, x)
@inline find_root!(s::IntDisjointSet{T}, x::T) where {T<:Integer} = find_root_impl!(s.parents, x) # default
@inline find_root!(s::IntDisjointSet{T}, x::T, ::PCRecursive) where {T<:Integer} = find_root_impl!(s.parents, x)

Check warning on line 117 in src/disjoint_set.jl

View check run for this annotation

Codecov / codecov/patch

src/disjoint_set.jl#L117

Added line #L117 was not covered by tests
@inline find_root!(s::IntDisjointSet{T}, x::T, ::PCIterative) where {T<:Integer} = find_root_iterative!(s.parents, x)
@inline find_root!(s::IntDisjointSet{T}, x::T, ::PCHalving) where {T<:Integer} = find_root_halving!(s.parents, x)
@inline find_root!(s::IntDisjointSet{T}, x::T, ::PCSplitting) where {T<:Integer} = find_root_splitting!(s.parents, x)

"""
in_same_set(s::IntDisjointSet{T}, x::T, y::T)
Expand Down Expand Up @@ -191,6 +242,10 @@
Find the root element of the subset in `s` which has the element `x` as a member.
"""
find_root!(s::DisjointSet{T}, x::T) where {T} = s.revmap[find_root!(s.internal, s.intmap[x])]
find_root!(s::DisjointSet{T}, x::T, ::PCIterative) where {T} = s.revmap[find_root!(s.internal, s.intmap[x], PCIterative())]
find_root!(s::DisjointSet{T}, x::T, ::PCRecursive) where {T} = s.revmap[find_root!(s.internal, s.intmap[x], PCRecursive())]

Check warning on line 246 in src/disjoint_set.jl

View check run for this annotation

Codecov / codecov/patch

src/disjoint_set.jl#L246

Added line #L246 was not covered by tests
find_root!(s::DisjointSet{T}, x::T, ::PCHalving) where {T} = s.revmap[find_root!(s.internal, s.intmap[x], PCHalving())]
find_root!(s::DisjointSet{T}, x::T, ::PCSplitting) where {T} = s.revmap[find_root!(s.internal, s.intmap[x], PCSplitting())]

"""
in_same_set(s::DisjointSet{T}, x::T, y::T)
Expand Down
42 changes: 41 additions & 1 deletion test/bench_disjoint_set.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Benchmark on disjoint set forests

using DataStructures
using DataStructures, BenchmarkTools

# do 10^6 random unions over 10^6 element set

Expand Down Expand Up @@ -29,3 +29,43 @@ x = rand(1:n, T)
y = rand(1:n, T)

@time batch_union!(s, x, y)

#=
benchmark `find` operation
=#

function create_disjoint_set_struct(n::Int)
parents = [1; collect(1:n-1)] # each element's parent is its predecessor
ranks = zeros(Int, n) # ranks are all zero
IntDisjointSet(parents, ranks, n)
end

# benchmarking function
function benchmark_find_root(n::Int)
println("Benchmarking recursive path compression implementation (find_root_impl!):")
if n >= 10^5
println("Recursive may path compression may encounter stack-overflow; skipping")
else
s = create_disjoint_set_struct(n)
@btime find_root!($s, $n, PCRecursive())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

increase the number of evals to let's say 100. post the median and max time. do it for all of the methods

end

println("Benchmarking iterative path compression implementation (find_root_iterative!):")
s = create_disjoint_set_struct(n) # reset parents
@btime find_root!($s, $n, PCIterative())

println("Benchmarking path-halving implementation (find_root_halving!):")
s = create_disjoint_set_struct(n) # reset parents
@btime find_root!($s, $n, PCHalving())

println("Benchmarking path-splitting implementation (find_root_path_splitting!):")
s = create_disjoint_set_struct(n) # reset parents
@btime find_root!($s, $n, PCSplitting())
end

# run benchmark tests
benchmark_find_root(1_000)
benchmark_find_root(10_000)
benchmark_find_root(100_000)
benchmark_find_root(1_000_000)
benchmark_find_root(10_000_000)
111 changes: 111 additions & 0 deletions test/test_disjoint_set.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,16 @@
@test num_groups(s) == T(9)
@test in_same_set(s, T(2), T(3))
@test find_root!(s, T(3)) == T(2)
@test find_root!(s, T(3), PCIterative()) == T(2)
@test find_root!(s, T(3), PCHalving()) == T(2)
@test find_root!(s, T(3), PCSplitting()) == T(2)
union!(s, T(3), T(2))
@test num_groups(s) == T(9)
@test in_same_set(s, T(2), T(3))
@test find_root!(s, T(3)) == T(2)
@test find_root!(s, T(3), PCIterative()) == T(2)
@test find_root!(s, T(3), PCHalving()) == T(2)
@test find_root!(s, T(3), PCSplitting()) == T(2)
end

@testset "more tests" begin
Expand All @@ -48,10 +54,19 @@
@test union!(s, T(8), T(5)) == T(8)
@test num_groups(s) == T(7)
@test find_root!(s, T(6)) == T(8)
@test find_root!(s, T(6), PCIterative()) == T(8)
@test find_root!(s, T(6), PCHalving()) == T(8)
@test find_root!(s, T(6), PCSplitting()) == T(8)
union!(s, T(2), T(6))
@test find_root!(s, T(2)) == T(8)
root1 = find_root!(s, T(6))
root1 = find_root!(s, T(6), PCIterative())
root1 = find_root!(s, T(6), PCHalving())
root1 = find_root!(s, T(6), PCSplitting())
root2 = find_root!(s, T(2))
root2 = find_root!(s, T(2), PCIterative())
root2 = find_root!(s, T(2), PCHalving())
root2 = find_root!(s, T(2), PCSplitting())
@test root_union!(s, T(root1), T(root2)) == T(8)
@test union!(s, T(5), T(6)) == T(8)
end
Expand Down Expand Up @@ -98,6 +113,12 @@

r = [find_root!(s, i) for i in 1 : 10]
@test isequal(r, collect(1:10))
r = [find_root!(s, i, PCIterative()) for i in 1 : 10]
@test isequal(r, collect(1:10))
r = [find_root!(s, i, PCHalving()) for i in 1 : 10]
@test isequal(r, collect(1:10))
r = [find_root!(s, i, PCSplitting()) for i in 1 : 10]
@test isequal(r, collect(1:10))
end

@testset "union!" begin
Expand All @@ -117,6 +138,57 @@
@test num_groups(s) == 2
end

@testset "union! PCIterative" begin
for i = 1 : 5
x = 2 * i - 1
y = 2 * i
union!(s, x, y)
@test find_root!(s, x, PCIterative()) == find_root!(s, y, PCIterative())
end


@test union!(s, 1, 4) == find_root!(s, 1, PCIterative())
@test union!(s, 3, 5) == find_root!(s, 1, PCIterative())
@test union!(s, 7, 9) == find_root!(s, 7, PCIterative())

@test length(s) == 10
@test num_groups(s) == 2
end

@testset "union! PCHalving" begin
for i = 1 : 5
x = 2 * i - 1
y = 2 * i
union!(s, x, y)
@test find_root!(s, x, PCHalving()) == find_root!(s, y, PCHalving())
end


@test union!(s, 1, 4) == find_root!(s, 1, PCHalving())
@test union!(s, 3, 5) == find_root!(s, 1, PCHalving())
@test union!(s, 7, 9) == find_root!(s, 7, PCHalving())

@test length(s) == 10
@test num_groups(s) == 2
end

@testset "union! PCSplitting" begin
for i = 1 : 5
x = 2 * i - 1
y = 2 * i
union!(s, x, y)
@test find_root!(s, x, PCSplitting()) == find_root!(s, y, PCSplitting())
end


@test union!(s, 1, 4) == find_root!(s, 1, PCSplitting())
@test union!(s, 3, 5) == find_root!(s, 1, PCSplitting())
@test union!(s, 7, 9) == find_root!(s, 7, PCSplitting())

@test length(s) == 10
@test num_groups(s) == 2
end

@testset "r0" begin
r0 = [ find_root!(s,i) for i in 1:10 ]
# Since this is a DisjointSet (not IntDisjointSet), the root for 17 will be 17, not 11
Expand All @@ -130,6 +202,45 @@
@test isequal(r, r0)
end

@testset "r0 Iterative" begin
r0 = [ find_root!(s,i) for i in 1:10 ]
# Since this is a DisjointSet (not IntDisjointSet), the root for 17 will be 17, not 11
push!(s, 17)

@test length(s) == 11
@test num_groups(s) == 3

r0 = [ r0 ; 17]
r = [find_root!(s, i, PCIterative()) for i in [1 : 10; 17] ]
@test isequal(r, r0)
end

@testset "r0 Splitting" begin
r0 = [ find_root!(s,i) for i in 1:10 ]
# Since this is a DisjointSet (not IntDisjointSet), the root for 17 will be 17, not 11
push!(s, 17)

@test length(s) == 11
@test num_groups(s) == 3

r0 = [ r0 ; 17]
r = [find_root!(s, i, PCSplitting()) for i in [1 : 10; 17] ]
@test isequal(r, r0)
end

@testset "r0 Halving" begin
r0 = [ find_root!(s,i) for i in 1:10 ]
# Since this is a DisjointSet (not IntDisjointSet), the root for 17 will be 17, not 11
push!(s, 17)

@test length(s) == 11
@test num_groups(s) == 3

r0 = [ r0 ; 17]
r = [find_root!(s, i, PCHalving()) for i in [1 : 10; 17] ]
@test isequal(r, r0)
end

@testset "root_union!" begin
root1 = find_root!(s, 7)
root2 = find_root!(s, 3)
Expand Down
Loading