Skip to content

[MRG] Projection Robust Wasserstein#267

Merged
rflamary merged 9 commits into
PythonOT:masterfrom
mhhuang95:dr-prw
Sep 6, 2021
Merged

[MRG] Projection Robust Wasserstein#267
rflamary merged 9 commits into
PythonOT:masterfrom
mhhuang95:dr-prw

Conversation

@mhhuang95

Copy link
Copy Markdown
Contributor

Types of changes

  • Docs change / refactoring / dependency upgrade
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to change)

Motivation and context / Related issue

Code for
A Riemannian Block Coordinate Descent Method for Computing the PRW Distance, ICML 2021
Source: https://github.com/mhhuang95/PRW_RBCD

How has this been tested (if it applies)

tested on Fragmented Hypercube problem

Checklist

  • The documentation is up-to-date with the changes I made.
  • I have read the CONTRIBUTING document.
  • All tests passed, and additional code has been covered with new tests.

@codecov

codecov Bot commented Aug 5, 2021

Copy link
Copy Markdown

Codecov Report

Merging #267 (21ce5b8) into master (c105dcb) will increase coverage by 0.07%.
The diff coverage is 100.00%.

@@            Coverage Diff             @@
##           master     #267      +/-   ##
==========================================
+ Coverage   93.26%   93.34%   +0.07%     
==========================================
  Files          17       17              
  Lines        3506     3547      +41     
==========================================
+ Hits         3270     3311      +41     
  Misses        236      236              

@agramfort

Copy link
Copy Markdown
Collaborator

You have failing tests @mhhuang95

Comment thread ot/dr.py Outdated
return Popt, proj


def prw(X, Y, a, b, tau, U0=None, reg=0.1, k=2, stopThr=1e-3, maxiter=100, verbose=0):

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the function name is not explicit enough. Can we call it: projection_robust_wasserstein? Maybe @rflamary has a better suggestion

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree projection_robust_wasserstein is longer but explicit, we will also need to change the wda and pca functions and add a deprecation on the old names.

Comment thread ot/dr.py Outdated

def prw(X, Y, a, b, tau, U0=None, reg=0.1, k=2, stopThr=1e-3, maxiter=100, verbose=0):
r"""
Projection Robust Wasserstein Distance _[12],[13]

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I doubt this is formatted properly to have links to references below

Comment thread ot/dr.py Outdated
Samples from measure \mu
Y : ndarray, shape (n, d)
Samples from measure \nu
a : ndarray, shape (n, 1)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it necessary to have ndim == 2 with shape[1] == 1? I would favor a flat vector for a and b

Comment thread ot/dr.py Outdated
k : int
Subspace dimension
stopThr : float, optional
Accuracy

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Accuracy is to me misleading. Maybe tolerance ? But on what criteria?

Comment thread ot/dr.py
U = np.random.randn(d, k)
U, _ = np.linalg.qr(U)
else:
U = U0

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line need to be covered by a test

@rflamary rflamary changed the title ot.dr: PRW code; text.text_dr: PRW test code. [WIP] Projection Robust Wasserstein Aug 9, 2021

@agramfort agramfort left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment thread ot/dr.py Outdated
The function solves the following optimization problem:

.. math::
max_{U \in St(d, k)} min_{\pi \in \Pi(\mu,\nu)} \sum_{i,j} \pi_{i,j}*||U^T(x_i - y_j)||^2 - reg * H(\pi)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
max_{U \in St(d, k)} min_{\pi \in \Pi(\mu,\nu)} \sum_{i,j} \pi_{i,j}*||U^T(x_i - y_j)||^2 - reg * H(\pi)
\max_{U \in St(d, k)} \min_{\pi \in \Pi(\mu,\nu)} \sum_{i,j} \pi_{i,j} \|U^T(x_i - y_j)\|^2 - reg * H(\pi)

Comment thread ot/dr.py Outdated

- :math:`U` is a linear projection operator in the Stiefel(d, k) manifold
- :math:`H(\pi)` is entropy regularizer
- :math:`x_i`, `y_j` are samples of measures \mu and \nu respectively

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
- :math:`x_i`, `y_j` are samples of measures \mu and \nu respectively
- :math:`x_i`, :math:`y_j` are samples of measures \mu and \nu respectively

@rflamary rflamary changed the title [WIP] Projection Robust Wasserstein [MRG] Projection Robust Wasserstein Sep 6, 2021
@rflamary

rflamary commented Sep 6, 2021

Copy link
Copy Markdown
Collaborator

Thank you again @mhhuang95 for this contribution and welcome to the POT contributors.

@rflamary rflamary merged commit 96bf1a4 into PythonOT:master Sep 6, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants