add empirical sinkhorn and sinkhorn divergence functions#80
Conversation
|
Hello @kilianFatras and thank you for the PR, Could you please build the html documentation (folder /docs, execute make html) and check that all the doc for the new function compile OK? |
|
Hello, thank you for your answer. I updated ot.bregman file with a new doc and I also updated ot.stochastic with a new doc. |
| .. [23] Aude Genevay, Gabriel Peyré, Marco Cuturi, Learning Generative Models with Sinkhorn Divergences, Proceedings of the Twenty-First International Conference on Artficial Intelligence and Statistics, (AISTATS) 21, 2018 | ||
| ''' | ||
|
|
||
| sinkhorn_div = (2 * empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs) - |
There was a problem hiding this comment.
This function fails if log=True. as a matter of fact if log=true you should return a log containing all 3 logs + the loss for each of the 3 OT compitation.
We also need a test for log=True
| M_t = ot.dist(X_t, X_t) | ||
|
|
||
| emp_sinkhorn_div = ot.bregman.empirical_sinkhorn_divergence(X_s, X_t, 1) | ||
| sinkhorn_div = (2 * ot.sinkhorn2(a, b, M, 1) - ot.sinkhorn2(a, a, M_s, 1) - |
There was a problem hiding this comment.
Definitely add a test for log=True here
There was a problem hiding this comment.
I added 2 log tests, I also changed the sinkhorn div = ot.sinkhorn2(a, b, M, 1) - 1/2 * ot.sinkhorn2(a, a, M_s, 1) - 1/2 * ot.sinkhorn2(b, b, M_t, 1)
|
It seems that what |
Hello,
I am sending a PR. The PR has for purpose to add empirical functions. Mainly, the added functions just need the original source data, target data and the regularization parameter for entropic OT. In the PR, you will find:
Those functions will be in the bregman.py file. Their test functions will be in test_bregman.py. Finally, their examples have been put in plot_OT_2D_samples.py.