Commit 375c44d6 by Matthieu Boileau

### Fix bug in earth_movers and add 1DEM

parent 69dac777
 ... ... @@ -7,14 +7,12 @@ import matplotlib.pyplot as plt import ot class EarthMovers: class EarthMovers2D: def __init__(self, n): def __init__(self, n: int): self.n = n # number of samples # source position self.xs = np.random.random_sample((self.n, self.n)) # target position self.xt = np.random.random_sample((self.n, self.n)) self.p = None self._set_positions() # Cost matrix self.M = None # OT matrix ... ... @@ -23,24 +21,24 @@ class EarthMovers: self.a = [] self.b = [] def _set_positions(self): # source position self.xs = np.random.random_sample((self.n, 2)) # target position self.xt = np.random.random_sample((self.n, 2)) def _compute_loss_matrix(self): """Return loss matrix""" M = ot.dist(self.xs, self.xt) M /= M.max() self.M = M self.M = (ot.dist(self.xs, self.xt, metric='sqeuclidean'))**(self.p / 2) def get_ot_matrix(self): """Return optimal transport matrix""" if self.M is None: self._compute_loss_matrix() self._compute_loss_matrix() return ot.emd(self.a, self.b, self.M) def get_wasserstein_distance(self) -> float: """Return Wasserstein_distance""" if self.M is None: self._compute_loss_matrix() if self.T is None: self.compute_ot() return np.sum(self.T * self.M) def compute_ot(self): ... ... @@ -49,8 +47,6 @@ class EarthMovers: def get_distances(self): """Return a 1D-array of the distances""" if self.T is None: self.compute_ot() return np.extract(self.T / self.T.max() > 1e-8, self.M) def create_figure(self, suptitle: str): ... ... @@ -62,45 +58,71 @@ class EarthMovers: fig.suptitle(suptitle) return ax def plot_ot(self): def plot_ot(self, p=1., plot_points=True): """A 2D plot of the OT problem""" self.p = p xs = self.xs xt = self.xt ax = self.create_figure(suptitle='Source and target distributions') if self.T is None: self.compute_ot() self.compute_ot() max_distance = self.get_distances().max() # inspired by plot2D_samples_mat() mx = self.T.max() for i in range(self.n): for j in range(self.n): if self.T[i, j] / mx > 1e-8: ax.plot([xs[i, 0], xt[j, 0]], [xs[i, 1], xt[j, 1]], alpha=self.T[i, j] / mx, c=[.5, .5, 1]) ax.plot(xs[:, 0], xs[:, 1], 'ob', label='Source samples') ax.plot(xt[:, 0], xt[:, 1], 'or', label='Target samples') ax.legend(loc=0) color_scale = 1 - self.M[i, j] / max_distance c = [color_scale, color_scale, color_scale] ax.plot([xs[i, 0], xt[j, 0]], [xs[i, 1], xt[j, 1]], c=c) if plot_points: ax.plot(xs[:, 0], xs[:, 1], 'ob', label='Source samples') ax.plot(xt[:, 0], xt[:, 1], 'or', label='Target samples') ax.legend(loc=0) wd = self.get_wasserstein_distance() ax.set_title(f"Wasserstein distance: {wd:f}", fontsize=10) ax.set_title(f"\$p = {self.p}\$ - Wasserstein distance: {wd:f}", fontsize=10) def plot_distance_histogram(self, bins=10): def plot_distance_histogram(self, p=1., bins=10): """Plot an histogram of distance""" self.p = p self.compute_ot() distances = self.get_distances() fig, ax = plt.subplots() plt.hist(distances, bins=bins) ax.set_xlabel("Distance") ax.set_ylabel("Number of matchings") ax.set_title("Histogram of distance") ax.set_title(f"Histogram of distance (\$p = {self.p}\$)") class EarthMovers1D(EarthMovers2D): def _set_positions(self): # source and target positions self.xs = np.empty((self.n, 2)) self.xt = np.empty((self.n, 2)) # source self.xs[:, 0] = np.random.random_sample((self.n, )) self.xs[:, 1] = 0. # target self.xt[:, 0] = np.random.random_sample((self.n, )) self.xt[:, 1] = 1. def _compute_loss_matrix(self): """Return loss matrix""" self.M = (ot.dist(self.xs, self.xt, metric='sqeuclidean') - 1)**(self.p / 2) if __name__ == '__main__': em = EarthMovers(50) em.plot_ot() em = EarthMovers2D(500) em.plot_ot(p=1., plot_points=False) em1D = EarthMovers1D(50) em1D.plot_ot(p=1.00001) em1000 = EarthMovers(2000) print(f"Wasserstein distance: {em1000.get_wasserstein_distance():f}") em1000 = EarthMovers2D(1000) em1000.plot_distance_histogram(bins=20) plt.show()
This diff is collapsed.
This diff is collapsed.
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!