|
1 | 1 | # -*- coding: utf-8 -*-
|
2 | 2 | r"""
|
3 |
| -=============================================================== |
4 |
| -Learning sample marginal distribution with CO-Optimal Transport |
5 |
| -=============================================================== |
| 3 | +====================================================================================================================================== |
| 4 | +Detecting outliers by learning sample marginal distribution with CO-Optimal Transport and by using unbalanced Co-Optimal Transport |
| 5 | +====================================================================================================================================== |
6 | 6 |
|
7 |
| -In this example, we illustrate how to estimate the sample marginal distribution which minimizes |
8 |
| -the CO-Optimal Transport distance [47]_ between two matrices. More precisely, given a source data |
| 7 | +In this example, we consider two point clouds living in different Euclidean spaces, where the outliers |
| 8 | +are artifically injected into the target data. We illustrate two methods which allow to filter out |
| 9 | +these outliers. |
| 10 | +
|
| 11 | +The first method requires learning the sample marginal distribution which minimizes |
| 12 | +the CO-Optimal Transport distance [49] between two input spaces. |
| 13 | +More precisely, given a source data |
9 | 14 | :math:`(X, \mu_x^{(s)}, \mu_x^{(f)})` and a target matrix :math:`Y` associated with a fixed
|
10 | 15 | histogram on features :math:`\mu_y^{(f)}`, we want to solve the following problem
|
11 | 16 |
|
|
17 | 22 | allows us to compute the CO-Optimal Transport distance with :func:`ot.coot.co_optimal_transport2`
|
18 | 23 | with differentiable losses.
|
19 | 24 |
|
| 25 | +The second method simply requires direct application of unbalanced Co-Optimal Transport [71]. |
| 26 | +More precisely, it is enough to use the sample and feature coupling from solving |
| 27 | +
|
| 28 | +.. math:: |
| 29 | + \text{UCOOT}\left( (X, \mu_x^{(s)}, \mu_x^{(f)}), (Y, \mu_y^{(s)}, \mu_y^{(f)}) \right) |
| 30 | +
|
| 31 | +where all the marginal distributions are uniform. |
| 32 | +
|
20 | 33 | .. [49] Redko, I., Vayer, T., Flamary, R., and Courty, N. (2020).
|
21 | 34 | `CO-Optimal Transport <https://proceedings.neurips.cc/paper/2020/file/cc384c68ad503482fb24e6d1e3b512ae-Paper.pdf>`_.
|
22 | 35 | Advances in Neural Information Processing Systems, 33.
|
| 36 | +.. [71] H. Tran, H. Janati, N. Courty, R. Flamary, I. Redko, P. Demetci & R. Singh (2023). [Unbalanced Co-Optimal Transport](https://dl.acm.org/doi/10.1609/aaai.v37i8.26193). |
| 37 | + AAAI Conference on Artificial Intelligence. |
23 | 38 | """
|
24 | 39 |
|
25 | 40 | # Author: Remi Flamary <remi.flamary@unice.fr>
|
|
35 | 50 |
|
36 | 51 | from ot.coot import co_optimal_transport as coot
|
37 | 52 | from ot.coot import co_optimal_transport2 as coot2
|
| 53 | +from ot.gromov._unbalanced import unbalanced_co_optimal_transport |
38 | 54 |
|
39 | 55 |
|
40 | 56 | # %%
|
|
148 | 164 | con = ConnectionPatch(
|
149 | 165 | xyA=xyA, xyB=xyB, coordsA=ax1.transData, coordsB=ax2.transData, color="blue")
|
150 | 166 | fig.add_artist(con)
|
| 167 | + |
| 168 | +# %% |
| 169 | +# Now, let see if we can use unbalanced Co-Optimal Transport to recover the clean OT plans, |
| 170 | +# without the need of learning the marginal distribution as in Co-Optimal Transport. |
| 171 | +# ----------------------------------------------------------------------------------------- |
| 172 | + |
| 173 | +pi_sample, pi_feature = unbalanced_co_optimal_transport( |
| 174 | + X=X, Y=Y_noisy, reg_marginals=(10, 10), epsilon=0, divergence="kl", |
| 175 | + unbalanced_solver="mm", max_iter=1000, tol=1e-6, |
| 176 | + max_iter_ot=1000, tol_ot=1e-6, log=False, verbose=False |
| 177 | +) |
| 178 | + |
| 179 | +# %% |
| 180 | +# Visualizing the row and column alignments learned by unbalanced Co-Optimal Transport. |
| 181 | +# ----------------------------------------------------------------------------------------- |
| 182 | +# |
| 183 | +# Similar to Co-Optimal Transport, we are also be able to fully recover the clean OT plans. |
| 184 | + |
| 185 | +fig = pl.figure(4, (9, 7)) |
| 186 | +pl.clf() |
| 187 | + |
| 188 | +ax1 = pl.subplot(2, 2, 3) |
| 189 | +pl.imshow(X, vmin=-2, vmax=2) |
| 190 | +pl.xlabel('$X$') |
| 191 | + |
| 192 | +ax2 = pl.subplot(2, 2, 2) |
| 193 | +ax2.yaxis.tick_right() |
| 194 | +pl.imshow(np.transpose(Y_noisy), vmin=-2, vmax=2) |
| 195 | +pl.title("Transpose(Noisy $Y$)") |
| 196 | +ax2.xaxis.tick_top() |
| 197 | + |
| 198 | +for i in range(n1): |
| 199 | + j = np.argmax(pi_sample[i, :]) |
| 200 | + xyA = (d1 - .5, i) |
| 201 | + xyB = (j, d2 - .5) |
| 202 | + con = ConnectionPatch(xyA=xyA, xyB=xyB, coordsA=ax1.transData, |
| 203 | + coordsB=ax2.transData, color="black") |
| 204 | + fig.add_artist(con) |
| 205 | + |
| 206 | +for i in range(d1): |
| 207 | + j = np.argmax(pi_feature[i, :]) |
| 208 | + xyA = (i, -.5) |
| 209 | + xyB = (-.5, j) |
| 210 | + con = ConnectionPatch( |
| 211 | + xyA=xyA, xyB=xyB, coordsA=ax1.transData, coordsB=ax2.transData, color="blue") |
| 212 | + fig.add_artist(con) |
0 commit comments