Skip to content

[MRG] EMD and Wasserstein 1D#89

Merged
rflamary merged 14 commits into
PythonOT:masterfrom
rtavenar:master
Jun 27, 2019
Merged

[MRG] EMD and Wasserstein 1D#89
rflamary merged 14 commits into
PythonOT:masterfrom
rtavenar:master

Conversation

@rtavenar

Copy link
Copy Markdown
Contributor

Hi there,

I started coding a specific EMD for mono-dimensional case (i.e. when sorting both arrays is enough).
Doc is missing for the moment (will do that asap), but a basic implementation that covers the non uniform weight case and tests that checks if the results are coherent with EMD are already there.

On my machine, I ran the following timing test:

>>> n = 20000
>>> m = 3000
>>> u = np.random.randn(n, 1)
>>> v = np.random.randn(m, 1)
>>> ot.tic(); ot.emd_1d([], [], u, v, metric='sqeuclidean'); ot.toc()
array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]])
Elapsed time : 2.3728668689727783 s
2.3728668689727783
>>> ot.tic(); M = ot.dist(u, v, metric='sqeuclidean'); ot.emd([], [], M); ot.toc()
RESULT MIGHT BE INACURATE
Max number of iteration reached, currently 100000. Sometimes iterations go on in cycle even though the solution has been reached, to check if it's the case here have a look at the minimal reduced cost. If it is very close to machine precision, you might actually have the correct solution, if not try setting the maximum number of iterations a bit higher
/Users/tavenard_r/Documents/costel/src/POT/ot/lp/__init__.py:104: UserWarning: numItermax reached before optimality. Try to increase numItermax.
  result_code_string = check_result(result_code)
array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]])
Elapsed time : 8.67806887626648 s

Romain

Comment thread ot/lp/emd_wrap.pyx Outdated
np.ndarray[double, ndim=2, mode="c"] v,
str metric='sqeuclidean'):
r"""
Roro's stuff

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice documentation indeed ;)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

:-P

@rflamary

Copy link
Copy Markdown
Collaborator

Thank you romain, this is nice.

Is it me of Cython is not particularly fast (2 sec for n=20000?)? it is probably due to the use of the dist function, you should probably implement it in cython for squared and absolute value and use the dist only for weird stuff ;)

Rémi.

@rtavenar

rtavenar commented Jun 20, 2019

Copy link
Copy Markdown
Contributor Author

If I change to the following:

        if metric == 'sqeuclidean':
            m_ij = (u[i, 0] - v[j, 0]) ** 2
        else:
            m_ij = dist(u[i].reshape((1, 1)), v[j].reshape((1, 1)),
                        metric=metric)[0, 0]

I get the same timings (in the order of 2secs)...

The slow part seems to be when we deal with G. If I remove all the stuff related to G I get:

Elapsed time : 0.0061719417572021484 s

I will check if using a sparse representation for G helps.

EDIT: OK, when I remove the overhead for G, I can see a 100x improvement in timings with this if...else, so will keep it for L1 and L2 norms and resort to dist for other distances.

Comment thread test/test_ot.py Outdated
Comment thread ot/lp/emd_wrap.pyx Outdated
dtype=np.float64)
while i < n and j < m:
m_ij = dist(u[i].reshape((1, 1)), v[j].reshape((1, 1)),
metric=metric)[0, 0]

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since you have a pure python function call in the loop I doubt that cython brings you any speed gain.

my 2c

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've tried something for basic metrics (euclidean and sqeuclidean), not sure how to do otherwise

@rtavenar

rtavenar commented Jun 21, 2019

Copy link
Copy Markdown
Contributor Author

Also, new timings (for larger problem than above) are:

>>> import ot
>>> import numpy as np
>>> from scipy.stats import wasserstein_distance
>>> 
>>> n = 20000
>>> m = 30000
>>> u = np.random.randn(n)
>>> v = np.random.randn(m)
>>> 
>>> ot.tic(); _ = wasserstein_distance(u, v); _ = ot.toc()
Elapsed time : 0.012831926345825195 s
>>> ot.tic(); _ = ot.emd_1d([], [], u, v, metric='euclidean', dense=False); _ = ot.toc()
Elapsed time : 0.04144096374511719 s
>>> ot.tic(); M = ot.dist(u.reshape((-1, 1)), v.reshape((-1, 1)),
...                       metric='euclidean'); _ = ot.emd([], [], M); _ = ot.toc()

RESULT MIGHT BE INACURATE
Max number of iteration reached, currently 100000. Sometimes iterations go on in cycle even though the solution has been reached, to check if it's the case here have a look at the minimal reduced cost. If it is very close to machine precision, you might actually have the correct solution, if not try setting the maximum number of iterations a bit higher
/Users/tavenard_r/Documents/costel/src/POT/ot/lp/__init__.py:106: UserWarning: numItermax reached before optimality. Try to increase numItermax.
  result_code_string = check_result(result_code)
Elapsed time : 312.5033311843872 s

We are a bit slower than scipy's implementation, not sure whether this is due to Cython or to the fact that scipy does not deal with G :/

@agramfort

Copy link
Copy Markdown
Collaborator

@rtavenar did you run "cython -a" on the pyx file to see if it's white (no yellow slow python lines)?

@rtavenar

Copy link
Copy Markdown
Contributor Author

@agramfort

  1. Did not know this one, thanks for the tip
  2. I've changed the np.abs, now the only yellow lines I get are the return, the np.zeros lines and the cdist, but I do not know how to remove these ones.

@agramfort

Copy link
Copy Markdown
Collaborator

you cannot remove yellow lines of np.zeros or return.

For cdist either you can directly call blas functions from scipy or you need
to code the metrics in cython

@rtavenar

Copy link
Copy Markdown
Contributor Author

I've had a look there, could not find obvious matches for distances, but maybe it's not the right place :/

Regarding coding the metrics in Cython, this is what I have done for Euclidean distance and Squared Euclidean distance up to now. The question is: should I code all of them even if they are unlikely to be used, or only a subset?

@rflamary

Copy link
Copy Markdown
Collaborator

hello, I think those two are OK just be clear in the documentation that the others are slower and use cdist (such a slow function btw ;) )

Rémi

@rtavenar

Copy link
Copy Markdown
Contributor Author

OK, and I'll also have to be clear that only strings are accepted as metrics for emd_1d

@rtavenar

Copy link
Copy Markdown
Contributor Author

OK, so now I added proper docstrings. Let me know if something is missing or should be changed.

@rflamary

Copy link
Copy Markdown
Collaborator

This is great, thank you @rtavenar for the code and optimization.

I will merge it now.

@rflamary rflamary changed the title [WIP] EMD 1d [MRG] EMD and Wassersyein 1D Jun 27, 2019
@rflamary rflamary changed the title [MRG] EMD and Wassersyein 1D [MRG] EMD and Wasserstein 1D Jun 27, 2019
@rflamary rflamary merged commit a9b8af1 into PythonOT:master Jun 27, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants