Skip to content

Commit 791137b

Browse files
authored
[WIP] Implementation of FUGW and UCOOT (#677)
* implementation of FUGW and UCOOT * fix pep8 error * fix test_utils error * remove print * add documentation and fix bug * first code review * fix documentation
1 parent 2aa8338 commit 791137b

15 files changed

+3180
-118
lines changed

README.md

+14-3
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,9 @@ POT provides the following generic OT solvers (links to examples):
5252
* [Several backends](https://pythonot.github.io/quickstart.html#solving-ot-with-multiple-backends) for easy use of POT with [Pytorch](https://pytorch.org/)/[jax](https://github.com/google/jax)/[Numpy](https://numpy.org/)/[Cupy](https://cupy.dev/)/[Tensorflow](https://www.tensorflow.org/) arrays.
5353
* [Smooth Strongly Convex Nearest Brenier Potentials](https://pythonot.github.io/auto_examples/others/plot_SSNB.html#sphx-glr-auto-examples-others-plot-ssnb-py) [58], with an extension to bounding potentials using [59].
5454
* Gaussian Mixture Model OT [69]
55+
* [Co-Optimal Transport](https://pythonot.github.io/auto_examples/others/plot_COOT.html) [49] and
56+
[unbalanced Co-Optimal Transport](https://pythonot.github.io/auto_examples/others/plot_learning_weights_with_COOT.html) [71].
57+
* Fused unbalanced Gromov-Wasserstein [70].
5558

5659
POT provides the following Machine Learning related solvers:
5760

@@ -62,7 +65,7 @@ POT provides the following Machine Learning related solvers:
6265
* [Linear OT mapping](https://pythonot.github.io/auto_examples/domain-adaptation/plot_otda_linear_mapping.html) [14] and [Joint OT mapping estimation](https://pythonot.github.io/auto_examples/domain-adaptation/plot_otda_mapping.html) [8].
6366
* [Wasserstein Discriminant Analysis](https://pythonot.github.io/auto_examples/others/plot_WDA.html) [11] (requires autograd + pymanopt).
6467
* [JCPOT algorithm for multi-source domain adaptation with target shift](https://pythonot.github.io/auto_examples/domain-adaptation/plot_otda_jcpot.html) [27].
65-
* [Graph Neural Network OT layers TFGW](https://pythonot.github.io/auto_examples/gromov/plot_gnn_TFGW.html) [52] and TW (OT-GNN) [53]
68+
* [Graph Neural Network OT layers TFGW](https://pythonot.github.io/auto_examples/gromov/plot_gnn_TFGW.html) [52] and TW (OT-GNN) [53]
6669

6770
Some other examples are available in the [documentation](https://pythonot.github.io/auto_examples/index.html).
6871

@@ -198,7 +201,7 @@ This toolbox has been created by
198201
* [Rémi Flamary](https://remi.flamary.com/)
199202
* [Nicolas Courty](http://people.irisa.fr/Nicolas.Courty/)
200203

201-
It is currently maintained by
204+
It is currently maintained by
202205

203206
* [Rémi Flamary](https://remi.flamary.com/)
204207
* [Cédric Vincent-Cuaz](https://cedricvincentcuaz.github.io/)
@@ -370,4 +373,12 @@ distances between Gaussian distributions](https://hal.science/hal-03197398v2/fil
370373

371374
[68] Chowdhury, S., Miller, D., & Needham, T. (2021). [Quantized gromov-wasserstein](https://link.springer.com/chapter/10.1007/978-3-030-86523-8_49). ECML PKDD 2021. Springer International Publishing.
372375

373-
[69] Delon, J., & Desolneux, A. (2020). [A Wasserstein-type distance in the space of Gaussian mixture models](https://epubs.siam.org/doi/abs/10.1137/19M1301047). SIAM Journal on Imaging Sciences, 13(2), 936-970.
376+
[69] Delon, J., & Desolneux, A. (2020). [A Wasserstein-type distance in the space of Gaussian mixture models](https://epubs.siam.org/doi/abs/10.1137/19M1301047). SIAM Journal on Imaging Sciences, 13(2), 936-970.
377+
378+
[70] A. Thual, H. Tran, T. Zemskova, N. Courty, R. Flamary, S. Dehaene
379+
& B. Thirion (2022). [Aligning individual brains with Fused Unbalanced Gromov-Wasserstein.](https://proceedings.neurips.cc/paper_files/paper/2022/file/8906cac4ca58dcaf17e97a0486ad57ca-Paper-Conference.pdf). Neural Information Processing Systems (NeurIPS).
380+
381+
[71] H. Tran, H. Janati, N. Courty, R. Flamary, I. Redko, P. Demetci & R. Singh (2023). [Unbalanced Co-Optimal Transport](https://dl.acm.org/doi/10.1609/aaai.v37i8.26193). AAAI Conference on
382+
Artificial Intelligence.
383+
384+
[72] Thibault Séjourné, François-Xavier Vialard, and Gabriel Peyré (2021). [The Unbalanced Gromov Wasserstein Distance: Conic Formulation and Relaxation](https://proceedings.neurips.cc/paper/2021/file/4990974d150d0de5e6e15a1454fe6b0f-Paper.pdf). Neural Information Processing Systems (NeurIPS).

RELEASES.md

+13-12
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,16 @@
33
## 0.9.5dev
44

55
#### New features
6-
- Add feature `mass=True` for `nx.kl_div` (PR #654)
7-
- Gaussian Mixture Model OT `ot.gmm` (PR #649)
8-
- Add feature `semirelaxed_fgw_barycenters` and generic FGW-related barycenter updates `update_barycenter_structure` and `update_barycenter_feature` (PR #659)
9-
- Add initialization heuristics for sr(F)GW problems via `semirelaxed_init_plan`, integrated in all sr(F)GW solvers (PR #659)
6+
- Added feature `mass=True` for `nx.kl_div` (PR #654)
7+
- Implemented Gaussian Mixture Model OT `ot.gmm` (PR #649)
8+
- Added feature `semirelaxed_fgw_barycenters` and generic FGW-related barycenter updates `update_barycenter_structure` and `update_barycenter_feature` (PR #659)
9+
- Added initialization heuristics for sr(F)GW problems via `semirelaxed_init_plan`, integrated in all sr(F)GW solvers (PR #659)
1010
- Improved `ot.plot.plot1D_mat` (PR #649)
1111
- Added `nx.det` (PR #649)
1212
- `nx.sqrtm` is now broadcastable (takes ..., d, d) inputs (PR #649)
13-
- restructure `ot.unbalanced` module (PR #658)
14-
- add `ot.unbalanced.lbfgsb_unbalanced2` and add flexible reference measure `c` in all unbalanced solvers (PR #658)
13+
- Restructured `ot.unbalanced` module (PR #658)
14+
- Added `ot.unbalanced.lbfgsb_unbalanced2` and add flexible reference measure `c` in all unbalanced solvers (PR #658)
15+
- Implemented Fused unbalanced Gromov-Wasserstein and unbalanced Co-Optimal Transport (PR #677)
1516

1617
#### Closed issues
1718
- Fixed `ot.gaussian` ignoring weights when computing means (PR #649, Issue #648)
@@ -72,7 +73,7 @@ xs, xt = np.random.randn(100, 2), np.random.randn(50, 2)
7273

7374
# Solve OT problem with empirical samples
7475
sol = ot.solve_sample(xs, xt) # Exact OT betwen smaples with uniform weights
75-
sol = ot.solve_sample(xs, xt, wa, wb) # Exact OT with weights given by user
76+
sol = ot.solve_sample(xs, xt, wa, wb) # Exact OT with weights given by user
7677

7778
sol = ot.solve_sample(xs, xt, reg= 1, metric='euclidean') # sinkhorn with euclidean metric
7879

@@ -84,15 +85,15 @@ sol = ot.solve_sample(x,x2, method='lowrank', rank=10) # compute lowrank sinkhor
8485

8586
value_bw = ot.solve_sample(xs, xt, method='gaussian').value # Bures-Wasserstein distance
8687

87-
# Solve GW problem
88+
# Solve GW problem
8889
Cs, Ct = ot.dist(xs, xs), ot.dist(xt, xt) # compute cost matrices
8990
sol = ot.solve_gromov(Cs,Ct) # Exact GW between samples with uniform weights
9091

9192
# Solve FGW problem
9293
M = ot.dist(xs, xt) # compute cost matrix
9394

9495
# Exact FGW between samples with uniform weights
95-
sol = ot.solve_gromov(Cs, Ct, M, loss='KL', alpha=0.7) # FGW with KL data fitting
96+
sol = ot.solve_gromov(Cs, Ct, M, loss='KL', alpha=0.7) # FGW with KL data fitting
9697

9798

9899
# recover solutions objects
@@ -102,14 +103,14 @@ value = sol.value # OT value
102103

103104
# for GW and FGW
104105
value_linear = sol.value_linear # linear part of the loss
105-
value_quad = sol.value_quad # quadratic part of the loss
106+
value_quad = sol.value_quad # quadratic part of the loss
106107

107108
```
108109

109110
Users are encouraged to use the new API (it is much simpler) but it might still be subjects to small changes before the release of POT 1.0 .
110111

111112

112-
We also fixed a number of issues, the most pressing being a problem of GPU memory allocation when pytorch is installed that will not happen now thanks to Lazy initialization of the backends. We now also have the possibility to deactivate some backends using environment which prevents POT from importing them and can lead to large import speedup.
113+
We also fixed a number of issues, the most pressing being a problem of GPU memory allocation when pytorch is installed that will not happen now thanks to Lazy initialization of the backends. We now also have the possibility to deactivate some backends using environment which prevents POT from importing them and can lead to large import speedup.
113114

114115

115116
#### New features
@@ -143,7 +144,7 @@ We also fixed a number of issues, the most pressing being a problem of GPU memor
143144
- Correct independence of `fgw_barycenters` to `init_C` and `init_X` (Issue #547, PR #566)
144145
- Avoid precision change when computing norm using PyTorch backend (Discussion #570, PR #572)
145146
- Create `ot/bregman/`repository (Issue #567, PR #569)
146-
- Fix matrix feature shape in `entropic_fused_gromov_barycenters`(Issue #574, PR #573)
147+
- Fix matrix feature shape in `entropic_fused_gromov_barycenters`(Issue #574, PR #573)
147148
- Fix (fused) gromov-wasserstein barycenter solvers to support `kl_loss`(PR #576)
148149

149150

examples/others/plot_learning_weights_with_COOT.py renamed to examples/others/plot_outlier_detection_with_COOT_and_unbalanced_COOT.py

+67-5
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,16 @@
11
# -*- coding: utf-8 -*-
22
r"""
3-
===============================================================
4-
Learning sample marginal distribution with CO-Optimal Transport
5-
===============================================================
3+
======================================================================================================================================
4+
Detecting outliers by learning sample marginal distribution with CO-Optimal Transport and by using unbalanced Co-Optimal Transport
5+
======================================================================================================================================
66
7-
In this example, we illustrate how to estimate the sample marginal distribution which minimizes
8-
the CO-Optimal Transport distance [47]_ between two matrices. More precisely, given a source data
7+
In this example, we consider two point clouds living in different Euclidean spaces, where the outliers
8+
are artifically injected into the target data. We illustrate two methods which allow to filter out
9+
these outliers.
10+
11+
The first method requires learning the sample marginal distribution which minimizes
12+
the CO-Optimal Transport distance [49] between two input spaces.
13+
More precisely, given a source data
914
:math:`(X, \mu_x^{(s)}, \mu_x^{(f)})` and a target matrix :math:`Y` associated with a fixed
1015
histogram on features :math:`\mu_y^{(f)}`, we want to solve the following problem
1116
@@ -17,9 +22,19 @@
1722
allows us to compute the CO-Optimal Transport distance with :func:`ot.coot.co_optimal_transport2`
1823
with differentiable losses.
1924
25+
The second method simply requires direct application of unbalanced Co-Optimal Transport [71].
26+
More precisely, it is enough to use the sample and feature coupling from solving
27+
28+
.. math::
29+
\text{UCOOT}\left( (X, \mu_x^{(s)}, \mu_x^{(f)}), (Y, \mu_y^{(s)}, \mu_y^{(f)}) \right)
30+
31+
where all the marginal distributions are uniform.
32+
2033
.. [49] Redko, I., Vayer, T., Flamary, R., and Courty, N. (2020).
2134
`CO-Optimal Transport <https://proceedings.neurips.cc/paper/2020/file/cc384c68ad503482fb24e6d1e3b512ae-Paper.pdf>`_.
2235
Advances in Neural Information Processing Systems, 33.
36+
.. [71] H. Tran, H. Janati, N. Courty, R. Flamary, I. Redko, P. Demetci & R. Singh (2023). [Unbalanced Co-Optimal Transport](https://dl.acm.org/doi/10.1609/aaai.v37i8.26193).
37+
AAAI Conference on Artificial Intelligence.
2338
"""
2439

2540
# Author: Remi Flamary <remi.flamary@unice.fr>
@@ -35,6 +50,7 @@
3550

3651
from ot.coot import co_optimal_transport as coot
3752
from ot.coot import co_optimal_transport2 as coot2
53+
from ot.gromov._unbalanced import unbalanced_co_optimal_transport
3854

3955

4056
# %%
@@ -148,3 +164,49 @@
148164
con = ConnectionPatch(
149165
xyA=xyA, xyB=xyB, coordsA=ax1.transData, coordsB=ax2.transData, color="blue")
150166
fig.add_artist(con)
167+
168+
# %%
169+
# Now, let see if we can use unbalanced Co-Optimal Transport to recover the clean OT plans,
170+
# without the need of learning the marginal distribution as in Co-Optimal Transport.
171+
# -----------------------------------------------------------------------------------------
172+
173+
pi_sample, pi_feature = unbalanced_co_optimal_transport(
174+
X=X, Y=Y_noisy, reg_marginals=(10, 10), epsilon=0, divergence="kl",
175+
unbalanced_solver="mm", max_iter=1000, tol=1e-6,
176+
max_iter_ot=1000, tol_ot=1e-6, log=False, verbose=False
177+
)
178+
179+
# %%
180+
# Visualizing the row and column alignments learned by unbalanced Co-Optimal Transport.
181+
# -----------------------------------------------------------------------------------------
182+
#
183+
# Similar to Co-Optimal Transport, we are also be able to fully recover the clean OT plans.
184+
185+
fig = pl.figure(4, (9, 7))
186+
pl.clf()
187+
188+
ax1 = pl.subplot(2, 2, 3)
189+
pl.imshow(X, vmin=-2, vmax=2)
190+
pl.xlabel('$X$')
191+
192+
ax2 = pl.subplot(2, 2, 2)
193+
ax2.yaxis.tick_right()
194+
pl.imshow(np.transpose(Y_noisy), vmin=-2, vmax=2)
195+
pl.title("Transpose(Noisy $Y$)")
196+
ax2.xaxis.tick_top()
197+
198+
for i in range(n1):
199+
j = np.argmax(pi_sample[i, :])
200+
xyA = (d1 - .5, i)
201+
xyB = (j, d2 - .5)
202+
con = ConnectionPatch(xyA=xyA, xyB=xyB, coordsA=ax1.transData,
203+
coordsB=ax2.transData, color="black")
204+
fig.add_artist(con)
205+
206+
for i in range(d1):
207+
j = np.argmax(pi_feature[i, :])
208+
xyA = (i, -.5)
209+
xyB = (-.5, j)
210+
con = ConnectionPatch(
211+
xyA=xyA, xyB=xyB, coordsA=ax1.transData, coordsB=ax2.transData, color="blue")
212+
fig.add_artist(con)

ot/gromov/__init__.py

+15-1
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,16 @@
66

77
# Author: Remi Flamary <remi.flamary@unice.fr>
88
# Cedric Vincent-Cuaz <cedvincentcuaz@gmail.com>
9+
# Quang Huy Tran <quang-huy.tran@univ-ubs.fr>
910
#
1011
# License: MIT License
1112

1213
# All submodules and packages
1314
from ._utils import (init_matrix, tensor_product, gwloss, gwggrad,
1415
init_matrix_semirelaxed, semirelaxed_init_plan,
15-
update_barycenter_structure, update_barycenter_feature)
16+
update_barycenter_structure, update_barycenter_feature,
17+
div_between_product, div_to_product, fused_unbalanced_across_spaces_cost,
18+
uot_cost_matrix, uot_parameters_and_measures)
1619

1720
from ._gw import (gromov_wasserstein, gromov_wasserstein2,
1821
fused_gromov_wasserstein, fused_gromov_wasserstein2,
@@ -63,9 +66,17 @@
6366
quantized_fused_gromov_wasserstein_samples
6467
)
6568

69+
from ._unbalanced import (fused_unbalanced_gromov_wasserstein,
70+
fused_unbalanced_gromov_wasserstein2,
71+
unbalanced_co_optimal_transport,
72+
unbalanced_co_optimal_transport2,
73+
fused_unbalanced_across_spaces_divergence)
74+
6675
__all__ = ['init_matrix', 'tensor_product', 'gwloss', 'gwggrad',
6776
'init_matrix_semirelaxed', 'semirelaxed_init_plan',
6877
'update_barycenter_structure', 'update_barycenter_feature',
78+
'div_between_product', 'div_to_product', 'fused_unbalanced_across_spaces_cost',
79+
'uot_cost_matrix', 'uot_parameters_and_measures',
6980
'gromov_wasserstein', 'gromov_wasserstein2', 'fused_gromov_wasserstein',
7081
'fused_gromov_wasserstein2', 'solve_gromov_linesearch', 'gromov_barycenters',
7182
'fgw_barycenters', 'entropic_gromov_wasserstein', 'entropic_gromov_wasserstein2',
@@ -87,4 +98,7 @@
8798
'get_graph_representants', 'format_partitioned_graph',
8899
'quantized_fused_gromov_wasserstein', 'get_partition_and_representants_samples',
89100
'format_partitioned_samples', 'quantized_fused_gromov_wasserstein_samples',
101+
'fused_unbalanced_gromov_wasserstein', 'fused_unbalanced_gromov_wasserstein2',
102+
'unbalanced_co_optimal_transport', 'unbalanced_co_optimal_transport2',
103+
'fused_unbalanced_across_spaces_divergence'
90104
]

0 commit comments

Comments
 (0)