Commit 375c44d6 authored by Matthieu Boileau's avatar Matthieu Boileau
Browse files

Fix bug in earth_movers and add 1DEM

parent 69dac777
......@@ -7,14 +7,12 @@ import matplotlib.pyplot as plt
import ot
class EarthMovers:
class EarthMovers2D:
def __init__(self, n):
def __init__(self, n: int):
self.n = n # number of samples
# source position
self.xs = np.random.random_sample((self.n, self.n))
# target position
self.xt = np.random.random_sample((self.n, self.n))
self.p = None
self._set_positions()
# Cost matrix
self.M = None
# OT matrix
......@@ -23,24 +21,24 @@ class EarthMovers:
self.a = []
self.b = []
def _set_positions(self):
# source position
self.xs = np.random.random_sample((self.n, 2))
# target position
self.xt = np.random.random_sample((self.n, 2))
def _compute_loss_matrix(self):
"""Return loss matrix"""
M = ot.dist(self.xs, self.xt)
M /= M.max()
self.M = M
self.M = (ot.dist(self.xs, self.xt,
metric='sqeuclidean'))**(self.p / 2)
def get_ot_matrix(self):
"""Return optimal transport matrix"""
if self.M is None:
self._compute_loss_matrix()
self._compute_loss_matrix()
return ot.emd(self.a, self.b, self.M)
def get_wasserstein_distance(self) -> float:
"""Return Wasserstein_distance"""
if self.M is None:
self._compute_loss_matrix()
if self.T is None:
self.compute_ot()
return np.sum(self.T * self.M)
def compute_ot(self):
......@@ -49,8 +47,6 @@ class EarthMovers:
def get_distances(self):
"""Return a 1D-array of the distances"""
if self.T is None:
self.compute_ot()
return np.extract(self.T / self.T.max() > 1e-8, self.M)
def create_figure(self, suptitle: str):
......@@ -62,45 +58,71 @@ class EarthMovers:
fig.suptitle(suptitle)
return ax
def plot_ot(self):
def plot_ot(self, p=1., plot_points=True):
"""A 2D plot of the OT problem"""
self.p = p
xs = self.xs
xt = self.xt
ax = self.create_figure(suptitle='Source and target distributions')
if self.T is None:
self.compute_ot()
self.compute_ot()
max_distance = self.get_distances().max()
# inspired by plot2D_samples_mat()
mx = self.T.max()
for i in range(self.n):
for j in range(self.n):
if self.T[i, j] / mx > 1e-8:
ax.plot([xs[i, 0], xt[j, 0]], [xs[i, 1], xt[j, 1]],
alpha=self.T[i, j] / mx, c=[.5, .5, 1])
ax.plot(xs[:, 0], xs[:, 1], 'ob', label='Source samples')
ax.plot(xt[:, 0], xt[:, 1], 'or', label='Target samples')
ax.legend(loc=0)
color_scale = 1 - self.M[i, j] / max_distance
c = [color_scale, color_scale, color_scale]
ax.plot([xs[i, 0], xt[j, 0]], [xs[i, 1], xt[j, 1]], c=c)
if plot_points:
ax.plot(xs[:, 0], xs[:, 1], 'ob', label='Source samples')
ax.plot(xt[:, 0], xt[:, 1], 'or', label='Target samples')
ax.legend(loc=0)
wd = self.get_wasserstein_distance()
ax.set_title(f"Wasserstein distance: {wd:f}", fontsize=10)
ax.set_title(f"$p = {self.p}$ - Wasserstein distance: {wd:f}",
fontsize=10)
def plot_distance_histogram(self, bins=10):
def plot_distance_histogram(self, p=1., bins=10):
"""Plot an histogram of distance"""
self.p = p
self.compute_ot()
distances = self.get_distances()
fig, ax = plt.subplots()
plt.hist(distances, bins=bins)
ax.set_xlabel("Distance")
ax.set_ylabel("Number of matchings")
ax.set_title("Histogram of distance")
ax.set_title(f"Histogram of distance ($p = {self.p}$)")
class EarthMovers1D(EarthMovers2D):
def _set_positions(self):
# source and target positions
self.xs = np.empty((self.n, 2))
self.xt = np.empty((self.n, 2))
# source
self.xs[:, 0] = np.random.random_sample((self.n, ))
self.xs[:, 1] = 0.
# target
self.xt[:, 0] = np.random.random_sample((self.n, ))
self.xt[:, 1] = 1.
def _compute_loss_matrix(self):
"""Return loss matrix"""
self.M = (ot.dist(self.xs, self.xt, metric='sqeuclidean')
- 1)**(self.p / 2)
if __name__ == '__main__':
em = EarthMovers(50)
em.plot_ot()
em = EarthMovers2D(500)
em.plot_ot(p=1., plot_points=False)
em1D = EarthMovers1D(50)
em1D.plot_ot(p=1.00001)
em1000 = EarthMovers(2000)
print(f"Wasserstein distance: {em1000.get_wasserstein_distance():f}")
em1000 = EarthMovers2D(1000)
em1000.plot_distance_histogram(bins=20)
plt.show()
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment