Skip to content

Commit

Permalink
Fix GUDHI#461 and review all error cases (no more prints, warnings an…
Browse files Browse the repository at this point in the history
…d exceptions instead)
  • Loading branch information
VincentRouvreau committed Jun 17, 2021
1 parent 486b281 commit cb10892
Show file tree
Hide file tree
Showing 2 changed files with 141 additions and 108 deletions.
4 changes: 4 additions & 0 deletions src/python/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,10 @@ if(PYTHONINTERP_FOUND)
add_gudhi_py_test(test_dtm_rips_complex)
endif()

# persistence graphical tools
if(MATPLOTLIB_FOUND)
add_gudhi_py_test(test_persistence_graphical_tools)
endif()

# Set missing or not modules
set(GUDHI_MODULES ${GUDHI_MODULES} "python" CACHE INTERNAL "GUDHI_MODULES")
Expand Down
245 changes: 137 additions & 108 deletions src/python/gudhi/persistence_graphical_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
import numpy as np
from functools import lru_cache
import warnings
import errno
import os

from gudhi.reader_utils import read_persistence_intervals_in_dimension
from gudhi.reader_utils import read_persistence_intervals_grouped_by_dimension
Expand Down Expand Up @@ -45,6 +47,9 @@ def __min_birth_max_death(persistence, band=0.0):
min_birth = float(interval[1][0])
if band > 0.0:
max_death += band
# can happen if only points at inf death
if min_birth == max_death:
max_death = max_death + 1.
return (min_birth, max_death)


Expand All @@ -54,7 +59,7 @@ def _array_handler(a):
persistence-compatible list (padding with 0), so that the
plot can be performed seamlessly.
'''
if isinstance(a[0][1], np.float64) or isinstance(a[0][1], float):
if isinstance(a[0][1], np.floating) or isinstance(a[0][1], float):
return [[0, x] for x in a]
else:
return a
Expand Down Expand Up @@ -88,7 +93,7 @@ def _matplotlib_can_use_tex():
from matplotlib import checkdep_usetex
return checkdep_usetex(True)
except ImportError:
print("This function is not available, you may be missing matplotlib.")
warnings.warn("This function is not available, you may be missing matplotlib.")


def plot_persistence_barcode(
Expand Down Expand Up @@ -157,53 +162,58 @@ def plot_persistence_barcode(
for persistence_interval in diag[key]:
persistence.append((key, persistence_interval))
else:
print("file " + persistence_file + " not found.")
return None
raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), persistence_file)

persistence = _array_handler(persistence)

persistence = _limit_to_max_intervals(persistence, max_intervals,
key = lambda life_time: life_time[1][1] - life_time[1][0])

if colormap == None:
colormap = plt.cm.Set1.colors
if axes == None:
_, axes = plt.subplots(1, 1)

persistence = sorted(persistence, key=lambda birth: birth[1][0])
try:
persistence = _array_handler(persistence)
persistence = _limit_to_max_intervals(persistence, max_intervals,
key = lambda life_time: life_time[1][1] - life_time[1][0])
(min_birth, max_death) = __min_birth_max_death(persistence)
persistence = sorted(persistence, key=lambda birth: birth[1][0])
except IndexError:
min_birth, max_death = 0., 1.
pass

(min_birth, max_death) = __min_birth_max_death(persistence)
ind = 0
delta = (max_death - min_birth) * inf_delta
# Replace infinity values with max_death + delta for bar code to be more
# readable
infinity = max_death + delta
axis_start = min_birth - delta

if axes == None:
_, axes = plt.subplots(1, 1)
if colormap == None:
colormap = plt.cm.Set1.colors
ind = 0

# Draw horizontal bars in loop
for interval in reversed(persistence):
if float(interval[1][1]) != float("inf"):
# Finite death case
axes.barh(
ind,
(interval[1][1] - interval[1][0]),
height=0.8,
left=interval[1][0],
alpha=alpha,
color=colormap[interval[0]],
linewidth=0,
)
else:
# Infinite death case for diagram to be nicer
axes.barh(
ind,
(infinity - interval[1][0]),
height=0.8,
left=interval[1][0],
alpha=alpha,
color=colormap[interval[0]],
linewidth=0,
)
ind = ind + 1
try:
if float(interval[1][1]) != float("inf"):
# Finite death case
axes.barh(
ind,
(interval[1][1] - interval[1][0]),
height=0.8,
left=interval[1][0],
alpha=alpha,
color=colormap[interval[0]],
linewidth=0,
)
else:
# Infinite death case for diagram to be nicer
axes.barh(
ind,
(infinity - interval[1][0]),
height=0.8,
left=interval[1][0],
alpha=alpha,
color=colormap[interval[0]],
linewidth=0,
)
ind = ind + 1
except IndexError:
pass

if legend:
dimensions = list(set(item[0] for item in persistence))
Expand All @@ -218,11 +228,12 @@ def plot_persistence_barcode(
axes.set_title("Persistence barcode", fontsize=fontsize)

# Ends plot on infinity value and starts a little bit before min_birth
axes.axis([axis_start, infinity, 0, ind])
if ind != 0:
axes.axis([axis_start, infinity, 0, ind])
return axes

except ImportError:
print("This function is not available, you may be missing matplotlib.")
warnings.warn("This function is not available, you may be missing matplotlib.")


def plot_persistence_diagram(
Expand Down Expand Up @@ -296,54 +307,61 @@ def plot_persistence_diagram(
for persistence_interval in diag[key]:
persistence.append((key, persistence_interval))
else:
print("file " + persistence_file + " not found.")
return None

persistence = _array_handler(persistence)

persistence = _limit_to_max_intervals(persistence, max_intervals,
key = lambda life_time: life_time[1][1] - life_time[1][0])
raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), persistence_file)

if colormap == None:
colormap = plt.cm.Set1.colors
if axes == None:
_, axes = plt.subplots(1, 1)
try:
persistence = _array_handler(persistence)
persistence = _limit_to_max_intervals(persistence, max_intervals,
key = lambda life_time: life_time[1][1] - life_time[1][0])
min_birth, max_death = __min_birth_max_death(persistence, band)
except IndexError:
min_birth, max_death = 0., 1.
pass

(min_birth, max_death) = __min_birth_max_death(persistence, band)
delta = (max_death - min_birth) * inf_delta
# Replace infinity values with max_death + delta for diagram to be more
# readable
infinity = max_death + delta
axis_end = max_death + delta / 2
axis_start = min_birth - delta

if axes == None:
_, axes = plt.subplots(1, 1)
if colormap == None:
colormap = plt.cm.Set1.colors
# bootstrap band
if band > 0.0:
x = np.linspace(axis_start, infinity, 1000)
axes.fill_between(x, x, x + band, alpha=alpha, facecolor="red")
# lower diag patch
if greyblock:
axes.add_patch(mpatches.Polygon([[axis_start, axis_start], [axis_end, axis_start], [axis_end, axis_end]], fill=True, color='lightgrey'))
axes.add_patch(mpatches.Polygon([[axis_start, axis_start], [axis_end, axis_start], [axis_end, axis_end]],
fill=True, color='lightgrey'))
# line display of equation : birth = death
axes.plot([axis_start, axis_end], [axis_start, axis_end], linewidth=1.0, color="k")

# Draw points in loop
pts_at_infty = False # Records presence of pts at infty
for interval in reversed(persistence):
if float(interval[1][1]) != float("inf"):
# Finite death case
axes.scatter(
interval[1][0],
interval[1][1],
alpha=alpha,
color=colormap[interval[0]],
)
else:
pts_at_infty = True
# Infinite death case for diagram to be nicer
axes.scatter(
interval[1][0], infinity, alpha=alpha, color=colormap[interval[0]]
)
try:
if float(interval[1][1]) != float("inf"):
# Finite death case
axes.scatter(
interval[1][0],
interval[1][1],
alpha=alpha,
color=colormap[interval[0]],
)
else:
pts_at_infty = True
# Infinite death case for diagram to be nicer
axes.scatter(
interval[1][0], infinity, alpha=alpha, color=colormap[interval[0]]
)
except IndexError:
pass
if pts_at_infty:
# infinity line and text
axes.plot([axis_start, axis_end], [axis_start, axis_end], linewidth=1.0, color="k")
axes.plot([axis_start, axis_end], [infinity, infinity], linewidth=1.0, color="k", alpha=alpha)
# Infinity label
yt = axes.get_yticks()
Expand Down Expand Up @@ -371,7 +389,7 @@ def plot_persistence_diagram(
return axes

except ImportError:
print("This function is not available, you may be missing matplotlib.")
warnings.warn("This function is not available, you may be missing matplotlib.")


def plot_persistence_density(
Expand Down Expand Up @@ -461,51 +479,64 @@ def plot_persistence_density(
persistence_file=persistence_file, only_this_dim=dimension
)
else:
print("file " + persistence_file + " not found.")
return None

if len(persistence) > 0:
persistence = _array_handler(persistence)
persistence_dim = np.array(
[
(dim_interval[1][0], dim_interval[1][1])
for dim_interval in persistence
if (dim_interval[0] == dimension) or (dimension is None)
]
)

persistence_dim = persistence_dim[np.isfinite(persistence_dim[:, 1])]

persistence_dim = np.array(_limit_to_max_intervals(persistence_dim, max_intervals,
key = lambda life_time: life_time[1] - life_time[0]))

# Set as numpy array birth and death (remove undefined values - inf and NaN)
birth = persistence_dim[:, 0]
death = persistence_dim[:, 1]
raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), persistence_file)

# default cmap value cannot be done at argument definition level as matplotlib is not yet defined.
if cmap is None:
cmap = plt.cm.hot_r
if axes == None:
_, axes = plt.subplots(1, 1)

try:
if len(persistence) > 0:
# if not read from file but given by an argument
persistence = _array_handler(persistence)
persistence_dim = np.array(
[
(dim_interval[1][0], dim_interval[1][1])
for dim_interval in persistence
if (dim_interval[0] == dimension) or (dimension is None)
]
)
persistence_dim = persistence_dim[np.isfinite(persistence_dim[:, 1])]
persistence_dim = np.array(_limit_to_max_intervals(persistence_dim, max_intervals,
key = lambda life_time: life_time[1] - life_time[0]))

# Set as numpy array birth and death (remove undefined values - inf and NaN)
birth = persistence_dim[:, 0]
death = persistence_dim[:, 1]
birth_min = birth.min()
birth_max = birth.max()
death_min = death.min()
death_max = death.max()

# Evaluate a gaussian kde on a regular grid of nbins x nbins over data extents
k = kde.gaussian_kde([birth, death], bw_method=bw_method)
xi, yi = np.mgrid[
birth_min : birth_max : nbins * 1j,
death_min : death_max : nbins * 1j,
]
zi = k(np.vstack([xi.flatten(), yi.flatten()]))
# Make the plot
img = axes.pcolormesh(xi, yi, zi.reshape(xi.shape), cmap=cmap, shading='auto')

# IndexError on empty diagrams, ValueError on only inf death values
except (IndexError, ValueError):
birth_min = 0.
birth_max = 1.
death_min = 0.
death_max = 1.
pass

# line display of equation : birth = death
x = np.linspace(death.min(), birth.max(), 1000)
x = np.linspace(death_min, birth_max, 1000)
axes.plot(x, x, color="k", linewidth=1.0)

# Evaluate a gaussian kde on a regular grid of nbins x nbins over data extents
k = kde.gaussian_kde([birth, death], bw_method=bw_method)
xi, yi = np.mgrid[
birth.min() : birth.max() : nbins * 1j,
death.min() : death.max() : nbins * 1j,
]
zi = k(np.vstack([xi.flatten(), yi.flatten()]))

# Make the plot
img = axes.pcolormesh(xi, yi, zi.reshape(xi.shape), cmap=cmap, shading='auto')

if greyblock:
axes.add_patch(mpatches.Polygon([[birth.min(), birth.min()], [death.max(), birth.min()], [death.max(), death.max()]], fill=True, color='lightgrey'))
axes.add_patch(mpatches.Polygon([[birth_min, birth_min],
[death_max, birth_min],
[death_max, death_max]],
fill=True, color='lightgrey'))

if legend:
plt.colorbar(img, ax=axes)
Expand All @@ -517,6 +548,4 @@ def plot_persistence_density(
return axes

except ImportError:
print(
"This function is not available, you may be missing matplotlib and/or scipy."
)
warnings.warn("This function is not available, you may be missing matplotlib and/or scipy.")

0 comments on commit cb10892

Please sign in to comment.