Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ot/lp/EMD.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ enum ProblemType {
MAX_ITER_REACHED
};

int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter);
int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter, double* alpha_init, double* beta_init);
int EMD_wrap_omp(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter, int numThreads);

int EMD_wrap_sparse(
Expand Down
16 changes: 15 additions & 1 deletion ot/lp/EMD_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@


int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G,
double* alpha, double* beta, double *cost, uint64_t maxIter) {
double* alpha, double* beta, double *cost, uint64_t maxIter,
double* alpha_init, double* beta_init) {
// beware M and C are stored in row major C style!!!

using namespace lemon;
Expand Down Expand Up @@ -93,6 +94,19 @@ int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G,
}
}

// Set warmstart potentials if provided
if (alpha_init != nullptr && beta_init != nullptr) {
// Compress warmstart potentials to only non-zero entries
std::vector<double> alpha_compressed(n);
std::vector<double> beta_compressed(m);
for (uint64_t i = 0; i < n; i++) {
alpha_compressed[i] = alpha_init[indI[i]];
}
for (uint64_t j = 0; j < m; j++) {
beta_compressed[j] = beta_init[indJ[j]];
}
net.setWarmstartPotentials(&alpha_compressed[0], &beta_compressed[0], (int)n, (int)m);
}

// Solve the problem with the network simplex algorithm

Expand Down
23 changes: 22 additions & 1 deletion ot/lp/_network_simplex.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ def emd(
center_dual=True,
numThreads=1,
check_marginals=True,
warmstart_dual=None,
):
r"""Solves the Earth Movers distance problem and returns the OT matrix

Expand Down Expand Up @@ -237,6 +238,11 @@ def emd(
check_marginals: bool, optional (default=True)
If True, checks that the marginals mass are equal. If False, skips the
check.
warmstart_dual: tuple of two arrays (alpha, beta), optional (default=None)
Warmstart dual potentials to accelerate convergence. Should be a tuple
(alpha, beta) where alpha is shape (ns,) and beta is shape (nt,).
These potentials are used to guide initial pivots in the network simplex.
Typically obtained from a previous EMD solve or Sinkhorn approximation.

.. note:: The solver automatically detects sparse format using the backend's
:py:meth:`issparse` method. For sparse inputs:
Expand Down Expand Up @@ -373,8 +379,18 @@ def emd(
a, b, edge_sources, edge_targets, edge_costs, numItermax
)
else:
# Prepare warmstart if provided
alpha_init = None
beta_init = None
if warmstart_dual is not None:
alpha_init, beta_init = warmstart_dual
alpha_init = np.asarray(alpha_init, dtype=np.float64)
beta_init = np.asarray(beta_init, dtype=np.float64)

# Dense solver
G, cost, u, v, result_code = emd_c(a, b, M, numItermax, numThreads)
G, cost, u, v, result_code = emd_c(
a, b, M, numItermax, numThreads, alpha_init, beta_init
)

# ============================================================================
# POST-PROCESS DUAL VARIABLES AND CREATE TRANSPORT PLAN
Expand Down Expand Up @@ -513,6 +529,11 @@ def emd2(
check_marginals: bool, optional (default=True)
If True, checks that the marginals mass are equal. If False, skips the
check.
warmstart_dual: tuple of two arrays (alpha, beta), optional (default=None)
Warmstart dual potentials to accelerate convergence. Should be a tuple
(alpha, beta) where alpha is shape (ns,) and beta is shape (nt,).
These potentials are used to guide initial pivots in the network simplex.
Typically obtained from a previous EMD solve or Sinkhorn approximation.

.. note:: The solver automatically detects sparse format using the backend's
:py:meth:`issparse` method. For sparse inputs:
Expand Down
23 changes: 20 additions & 3 deletions ot/lp/emd_wrap.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import warnings


cdef extern from "EMD.h":
int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter) nogil
int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter, double* alpha_init, double* beta_init) nogil
int EMD_wrap_omp(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter, int numThreads) nogil
int EMD_wrap_sparse(int n1, int n2, double *X, double *Y, uint64_t n_edges, uint64_t *edge_sources, uint64_t *edge_targets, double *edge_costs, uint64_t *flow_sources_out, uint64_t *flow_targets_out, double *flow_values_out, uint64_t *n_flows_out, double *alpha, double *beta, double *cost, uint64_t maxIter) nogil
int EMD_wrap_lazy(int n1, int n2, double *X, double *Y, double *coords_a, double *coords_b, int dim, int metric, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter) nogil
Expand All @@ -42,7 +42,7 @@ def check_result(result_code):

@cython.boundscheck(False)
@cython.wraparound(False)
def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mode="c"] b, np.ndarray[double, ndim=2, mode="c"] M, uint64_t max_iter, int numThreads):
def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mode="c"] b, np.ndarray[double, ndim=2, mode="c"] M, uint64_t max_iter, int numThreads, alpha_init=None, beta_init=None):
"""
Solves the Earth Movers distance problem and returns the optimal transport matrix

Expand Down Expand Up @@ -81,6 +81,10 @@ def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mod
max_iter : uint64_t
The maximum number of iterations before stopping the optimization
algorithm if it has not converged.
alpha_init : (ns,) numpy.ndarray, float64, optional
Initial dual potentials for sources (warmstart)
beta_init : (nt,) numpy.ndarray, float64, optional
Initial dual potentials for targets (warmstart)

Returns
-------
Expand All @@ -101,6 +105,12 @@ def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mod
cdef np.ndarray[double, ndim=2, mode="c"] G=np.zeros([0, 0])

cdef np.ndarray[double, ndim=1, mode="c"] Gv=np.zeros(0)

# Warmstart potentials
cdef np.ndarray[double, ndim=1, mode="c"] alpha_init_c
cdef np.ndarray[double, ndim=1, mode="c"] beta_init_c
cdef double* alpha_init_ptr = NULL
cdef double* beta_init_ptr = NULL

if not len(a):
a=np.ones((n1,))/n1
Expand All @@ -110,11 +120,18 @@ def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mod

# init OT matrix
G=np.zeros([n1, n2])

# Setup warmstart pointers if provided
if alpha_init is not None and beta_init is not None:
alpha_init_c = np.ascontiguousarray(alpha_init, dtype=np.float64)
beta_init_c = np.ascontiguousarray(beta_init, dtype=np.float64)
alpha_init_ptr = <double*> alpha_init_c.data
beta_init_ptr = <double*> beta_init_c.data

# calling the function
with nogil:
if numThreads == 1:
result_code = EMD_wrap(n1, n2, <double*> a.data, <double*> b.data, <double*> M.data, <double*> G.data, <double*> alpha.data, <double*> beta.data, <double*> &cost, max_iter)
result_code = EMD_wrap(n1, n2, <double*> a.data, <double*> b.data, <double*> M.data, <double*> G.data, <double*> alpha.data, <double*> beta.data, <double*> &cost, max_iter, alpha_init_ptr, beta_init_ptr)
else:
result_code = EMD_wrap_omp(n1, n2, <double*> a.data, <double*> b.data, <double*> M.data, <double*> G.data, <double*> alpha.data, <double*> beta.data, <double*> &cost, max_iter, numThreads)
return G, cost, alpha, beta, result_code
Expand Down
Loading
Loading