[MRG] Projection Robust Wasserstein#267
Conversation
Codecov Report
@@ Coverage Diff @@
## master #267 +/- ##
==========================================
+ Coverage 93.26% 93.34% +0.07%
==========================================
Files 17 17
Lines 3506 3547 +41
==========================================
+ Hits 3270 3311 +41
Misses 236 236 |
|
You have failing tests @mhhuang95 |
| return Popt, proj | ||
|
|
||
|
|
||
| def prw(X, Y, a, b, tau, U0=None, reg=0.1, k=2, stopThr=1e-3, maxiter=100, verbose=0): |
There was a problem hiding this comment.
I think the function name is not explicit enough. Can we call it: projection_robust_wasserstein? Maybe @rflamary has a better suggestion
There was a problem hiding this comment.
I agree projection_robust_wasserstein is longer but explicit, we will also need to change the wda and pca functions and add a deprecation on the old names.
|
|
||
| def prw(X, Y, a, b, tau, U0=None, reg=0.1, k=2, stopThr=1e-3, maxiter=100, verbose=0): | ||
| r""" | ||
| Projection Robust Wasserstein Distance _[12],[13] |
There was a problem hiding this comment.
I doubt this is formatted properly to have links to references below
| Samples from measure \mu | ||
| Y : ndarray, shape (n, d) | ||
| Samples from measure \nu | ||
| a : ndarray, shape (n, 1) |
There was a problem hiding this comment.
Is it necessary to have ndim == 2 with shape[1] == 1? I would favor a flat vector for a and b
| k : int | ||
| Subspace dimension | ||
| stopThr : float, optional | ||
| Accuracy |
There was a problem hiding this comment.
Accuracy is to me misleading. Maybe tolerance ? But on what criteria?
| U = np.random.randn(d, k) | ||
| U, _ = np.linalg.qr(U) | ||
| else: | ||
| U = U0 |
There was a problem hiding this comment.
This line need to be covered by a test
agramfort
left a comment
There was a problem hiding this comment.
problems spotted by looking at https://724-71472695-gh.circle-artifacts.com/0/dev/gen_modules/ot.dr.html#ot.dr.projection_robust_wasserstein
| The function solves the following optimization problem: | ||
|
|
||
| .. math:: | ||
| max_{U \in St(d, k)} min_{\pi \in \Pi(\mu,\nu)} \sum_{i,j} \pi_{i,j}*||U^T(x_i - y_j)||^2 - reg * H(\pi) |
There was a problem hiding this comment.
| max_{U \in St(d, k)} min_{\pi \in \Pi(\mu,\nu)} \sum_{i,j} \pi_{i,j}*||U^T(x_i - y_j)||^2 - reg * H(\pi) | |
| \max_{U \in St(d, k)} \min_{\pi \in \Pi(\mu,\nu)} \sum_{i,j} \pi_{i,j} \|U^T(x_i - y_j)\|^2 - reg * H(\pi) |
|
|
||
| - :math:`U` is a linear projection operator in the Stiefel(d, k) manifold | ||
| - :math:`H(\pi)` is entropy regularizer | ||
| - :math:`x_i`, `y_j` are samples of measures \mu and \nu respectively |
There was a problem hiding this comment.
| - :math:`x_i`, `y_j` are samples of measures \mu and \nu respectively | |
| - :math:`x_i`, :math:`y_j` are samples of measures \mu and \nu respectively |
|
Thank you again @mhhuang95 for this contribution and welcome to the POT contributors. |
Types of changes
Motivation and context / Related issue
Code for
A Riemannian Block Coordinate Descent Method for Computing the PRW Distance, ICML 2021
Source: https://github.com/mhhuang95/PRW_RBCD
How has this been tested (if it applies)
tested on Fragmented Hypercube problem
Checklist