Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Persistence diagram for image data / embedding data #58

Open
dgm2 opened this issue Jun 2, 2022 · 5 comments
Open

Persistence diagram for image data / embedding data #58

dgm2 opened this issue Jun 2, 2022 · 5 comments

Comments

@dgm2
Copy link

dgm2 commented Jun 2, 2022

Hello,
Thanks for maintaining this repo.
Two questions on processing image datasets (e.g. torchvision MNIST).

  1. is there an example on getting the persistence diagram of an image?
  2. is it possible to get persistence diagram after the image has been convolved (e.g. linear layer) ? namely, instead of tensor with dimension (xxx, 2) we only have an an embedding of single dimension.

Example

input (X1) : 1x28x28 
emb1 = conv (X1) : 1x512
diag = RipsComplex( emb1)

I put RipsComplex but any object for persistence would be ok.

Many thanks!

@VincentRouvreau
Copy link
Contributor

@dgm2 Why not using the cubical complex ?

from sklearn.datasets import fetch_openml
import matplotlib.pyplot as plt
import gudhi as gd

X, y = fetch_openml("mnist_784", version=1, return_X_y=True, as_frame=False)

# X[17] is an '8'
cc = gd.CubicalComplex(top_dimensional_cells=X[17], dimensions=[28, 28])
diag = cc.persistence()
gd.plot_persistence_diagram(diag, legend=True)
plt.show()

Figure_1

# X[1] is an '0'
cc = gd.CubicalComplex(top_dimensional_cells=X[1], dimensions=[28, 28])
diag = cc.persistence()
gd.plot_persistence_diagram(diag, legend=True)
plt.show()

Figure_2

Another example is available on this tuto about cubical complex.

@dgm2
Copy link
Author

dgm2 commented Jun 4, 2022

sounds good, thanks!

What would be the best way to replicate something like the following Dionysus code with GUDHI ?

  1. get persistence diagram e.g. d1:
from dionysus import *
filtered = fill_freudenthal(image)
persistence = homology_persistence(filtered)
diag = init_diagrams(persistence, filtered)
  1. compute distance between 2 diagrams.
    e.g. di.wasserstein_distance(d1, d2)
    I have tried with cubicalComplex, but I seem to get all zeros on wasserstein distance.

references
fill_freudenthal

homology_persistence

many thanks

import numpy as np
from torchvision import datasets

from gudhi.cubical_complex import CubicalComplex
from gudhi.wasserstein import wasserstein_distance


def pers_diag(pts):
    pers = CubicalComplex(top_dimensional_cells=pts, dimensions=[28, 28]).persistence()
    res = np.array([list(b) for (_, b) in pers])
    return res


dataset2 = datasets.MNIST('../data', train=False, download=True)

diagrams = []
labels = []
n = 10
for dat, lab in zip(dataset2.data[:n], dataset2.train_labels[:n]):
    pts = dat.data.numpy().reshape(-1)
    diagrams.append(pers_diag(pts))
    labels.append(lab.item())


def print_wd(i, j):
    print(labels[i], labels[j], wasserstein_distance(diagrams[i], diagrams[j]))


for i, j in itertools.combinations(range(n), 2):
    print_wd(i, j)

output

labels 7 2 | wass 0.0 | bott 24.5
labels 7 1 | wass 0.0 | bott 52.0
labels 7 0 | wass 0.0 | bott 125.5
labels 7 4 | wass 0.0 | bott 18.0
labels 7 1 | wass 0.0 | bott 24.5
labels 7 4 | wass 0.0 | bott 32.0
labels 7 9 | wass 0.0 | bott 92.5
labels 7 5 | wass 0.0 | bott 126.5
labels 7 9 | wass 0.0 | bott 114.5
labels 2 1 | wass 0.0 | bott 52.0
labels 2 0 | wass 0.0 | bott 125.5
labels 2 4 | wass 0.0 | bott 26.5
labels 2 1 | wass 0.0 | bott 3.5
labels 2 4 | wass 0.0 | bott 40.0
labels 2 9 | wass 0.0 | bott 92.5
labels 2 5 | wass 0.0 | bott 126.5
labels 2 9 | wass 0.0 | bott 114.5
labels 1 0 | wass 0.0 | bott 125.5
labels 1 4 | wass 0.0 | bott 45.0
labels 1 1 | wass 0.0 | bott 52.0
labels 1 4 | wass 0.0 | bott 25.0
labels 1 9 | wass 0.0 | bott 92.5
labels 1 5 | wass 0.0 | bott 126.5
labels 1 9 | wass 0.0 | bott 114.5
labels 0 4 | wass 0.0 | bott 125.5
labels 0 1 | wass 0.0 | bott 125.5
labels 0 4 | wass 0.0 | bott 125.5
labels 0 9 | wass 0.0 | bott 66.0
labels 0 5 | wass 0.0 | bott 100.5
labels 0 9 | wass 0.0 | bott 38.0
labels 4 1 | wass 0.0 | bott 26.5
labels 4 4 | wass 0.0 | bott 25.0
labels 4 9 | wass 0.0 | bott 92.5
labels 4 5 | wass 0.0 | bott 126.5
labels 4 9 | wass 0.0 | bott 114.5
labels 1 4 | wass nan | bott 40.0
labels 1 9 | wass 0.0 | bott 92.5
labels 1 5 | wass 0.0 | bott 126.5
labels 1 9 | wass 0.0 | bott 114.5
labels 4 9 | wass 0.0 | bott 92.5
labels 4 5 | wass 0.0 | bott 126.5
labels 4 9 | wass 0.0 | bott 114.5
labels 9 5 | wass 0.0 | bott 100.5
labels 9 9 | wass 0.0 | bott 44.0
labels 5 9 | wass 0.0 | bott 100.5

@VincentRouvreau
Copy link
Contributor

🤔 strange to me your second point...
I updated 2-3 things to your code, but yours was (almost) working:

import itertools
import numpy as np
from torchvision import datasets

from gudhi.cubical_complex import CubicalComplex
from gudhi.wasserstein import wasserstein_distance
from gudhi import bottleneck_distance


def pers_diag(pts):
    pers = CubicalComplex(top_dimensional_cells=pts, dimensions=[28, 28]).persistence()
    res = np.array([list(b) for (_, b) in pers])
    return res


dataset2 = datasets.MNIST('data', train=False, download=True)

diagrams = []
labels = []
n = 10
for dat, lab in zip(dataset2.data[:n], dataset2.train_labels[:n]):
    pts = dat.data.numpy().reshape(-1)
    diagrams.append(pers_diag(pts))
    labels.append(lab.item())


def print_wd(i, j):
    print("labels ", labels[i], labels[j], " | was ", wasserstein_distance(diagrams[i], diagrams[j]), " | bot ", bottleneck_distance(diagrams[i], diagrams[j]))


for i, j in itertools.combinations(range(n), 2):
    print_wd(i, j)

outputs:

labels  7 2  | was  107.0  | bot  24.5
labels  7 1  | was  119.0  | bot  52.0
labels  7 0  | was  241.0  | bot  125.5
labels  7 4  | was  90.0  | bot  18.0
labels  7 1  | was  109.0  | bot  24.5
labels  7 4  | was  104.5  | bot  32.0
labels  7 9  | was  132.0  | bot  92.5
labels  7 5  | was  288.5  | bot  126.5
labels  7 9  | was  248.5  | bot  114.5
labels  2 1  | was  123.0  | bot  52.0
labels  2 0  | was  141.0  | bot  125.5
labels  2 4  | was  146.0  | bot  26.5
labels  2 1  | was  8.0  | bot  3.5
labels  2 4  | was  147.5  | bot  40.0
labels  2 9  | was  201.0  | bot  92.5
labels  2 5  | was  282.0  | bot  126.5
labels  2 9  | was  197.0  | bot  114.5
labels  1 0  | was  222.0  | bot  125.5
labels  1 4  | was  115.5  | bot  45.0
labels  1 1  | was  117.5  | bot  52.0
labels  1 4  | was  114.0  | bot  25.0
labels  1 9  | was  196.5  | bot  92.5
labels  1 5  | was  253.0  | bot  126.5
labels  1 9  | was  274.5  | bot  114.5
labels  0 4  | was  274.0  | bot  125.5
labels  0 1  | was  136.5  | bot  125.5
labels  0 4  | was  281.5  | bot  125.5
labels  0 9  | was  187.0  | bot  66.0
labels  0 5  | was  161.0  | bot  100.5
labels  0 9  | was  110.0  | bot  38.0
labels  4 1  | was  140.0  | bot  26.5
labels  4 4  | was  118.0  | bot  25.0
labels  4 9  | was  192.0  | bot  92.5
labels  4 5  | was  323.0  | bot  126.5
labels  4 9  | was  290.0  | bot  114.5
labels  1 4  | was  148.0  | bot  40.0
labels  1 9  | was  205.0  | bot  92.5
labels  1 5  | was  276.5  | bot  126.5
labels  1 9  | was  200.5  | bot  114.5
labels  4 9  | was  175.0  | bot  92.5
labels  4 5  | was  315.5  | bot  126.5
labels  4 9  | was  280.0  | bot  114.5
labels  9 5  | was  250.0  | bot  100.5
labels  9 9  | was  179.5  | bot  44.0
labels  5 9  | was  222.5  | bot  100.5

@VincentRouvreau
Copy link
Contributor

@dgm2 what is your gudhi version ? python -c "import gudhi; print(gudhi.__version__)"

@VincentRouvreau
Copy link
Contributor

Here is an example on how to do the same code with dionysus and gudhi:

import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import fetch_openml
import dionysus as d
import gudhi as gd

X, y = fetch_openml("mnist_784", version=1, return_X_y=True, as_frame=False)

# a zero
a = X[1].reshape((28,28))
#a = np.random.random((10,10))
plt.matshow(a)
plt.colorbar()
plt.show()

f_lower_star = d.fill_freudenthal(a)
p = d.homology_persistence(f_lower_star)
dgms = d.init_diagrams(p, f_lower_star)

for i,dgm in enumerate(dgms):
    print(i)
    for pt in dgm:
        print(pt)

# 0
# (0,inf)
# (0,165)
# (84,96)
# 1
# (0,255)
# (173,252)
# (223,253)
# (225,252)
# (225,253)
# (238,253)
# (240,253)
# (246,253)
# (252,253)
# (252,253)
# (252,253)
# (252,253)
# (253,255)

cc = gd.CubicalComplex(top_dimensional_cells=a)
cc.compute_persistence()
cc.persistence_intervals_in_dimension(0)
# array([[ 84.,  96.],
#        [  0., 165.],
#        [  0.,  inf]])
cc.persistence_intervals_in_dimension(1)
#array([[173., 252.],
#       [225., 252.],
#       [252., 253.],
#       [246., 253.],
#       [240., 253.],
#       [238., 253.],
#       [252., 253.],
#       [252., 253.],
#       [225., 253.],
#       [252., 253.],
#       [237., 253.],
#       [223., 253.],
#       [253., 255.],
#       [  0., 255.]])

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants