-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtda_iris.py
339 lines (36 loc) · 1.4 KB
/
tda_iris.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import numpy as np;
import matplotlib.pyplot as plt;
import ripser;
import persim;
import sklearn;
from sklearn import datasets;
from sklearn.metrics import pairwise_distances;
#
# Load iris dataset from sklearn
#
iris = datasets.load_iris();
d = iris.data;
target = iris.target;
#
# Construct distance matrix
#
# distance_matrix = np.sqrt(np.sum((data[:,np.newaxis]-data[np.newaxis,:])**2, axis=1));
# distance_matrix = np.sqrt(
# np.sum((data[:, np.newaxis] - data[np.newaxis, :]) ** 2, axis=-1)
# )
# distance_matrix = np.sqrt(np.sum((d[np.newaxis,:] - d[:, np.newaxis])**2, axis=1));
distance_matrix = pairwise_distances(d, metric='euclidean');
#
# Compute persistent homology
#
rips = ripser.ripser(X=distance_matrix, maxdim=2, distance_matrix=True, metric='euclidean');
#
# Construct Persistence Diagram
#
persistence_diagrams = rips['dgms'][1];
births = [pt[0] for pt in persistence_diagrams if np.isfinite(pt[1])];
deaths = [pt[1] for pt in persistence_diagrams if np.isfinite(pt[1])];
plt.scatter(births, deaths);
plt.xlabel('Birth Time'); plt.ylabel('Death Time');
plt.title('Iris Data Persistence Diagrams');
# plt.show();