From 6ebdd9052d1dddd0a9ea3cef6a71c18b557ae474 Mon Sep 17 00:00:00 2001
From: Matthieu Boileau <matthieu.boileau@math.unistra.fr>
Date: Sun, 19 Jan 2020 13:37:45 +0100
Subject: [PATCH] Control nstepmax in srw

---
 mc2020.ipynb |  6 +++---
 srw.py       | 28 ++++++++++++++--------------
 2 files changed, 17 insertions(+), 17 deletions(-)

diff --git a/mc2020.ipynb b/mc2020.ipynb
index 7aee5ab..5c34546 100644
--- a/mc2020.ipynb
+++ b/mc2020.ipynb
@@ -124,7 +124,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "srw.FinalDistance(1000).plot()"
+    "srw.FinalDistance(nwalk=1000).plot()"
    ]
   },
   {
@@ -140,7 +140,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "srw.MaxDistance(1000).plot()"
+    "srw.MaxDistance(nwalk=1000).plot()"
    ]
   },
   {
@@ -156,7 +156,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "srw.BackToStart(10000).plot()"
+    "srw.BackToStart(nwalk=10000, nstepmax=5000).plot()"
    ]
   },
   {
diff --git a/srw.py b/srw.py
index 042677f..9aecb80 100644
--- a/srw.py
+++ b/srw.py
@@ -112,8 +112,10 @@ class NWalk:
     suptitle = ''
     title = ''
 
-    def __init__(self, nwalk):
+    def __init__(self, nwalk, nstepmax=1000):
         self.nwalk = nwalk
+        self.nsteps = np.linspace(1, nstepmax, num=10, endpoint=True,
+                                  dtype=int)
 
     @staticmethod
     def compute_step(nstep: int):
@@ -124,12 +126,12 @@ class NWalk:
         vfunc = np.vectorize(self.compute_step)
         return np.sum(vfunc(np.full(self.nwalk, nstep))) / self.nwalk
 
-    def compute_steps(self, nsteps: np.ndarray) -> np.ndarray:
+    def compute_steps(self) -> np.ndarray:
         """
         return a mean over nwalk samples of the distance for various nsteps
         """
-        result = np.empty_like(nsteps, dtype=float)
-        for i, nstep in np.ndenumerate(nsteps):
+        result = np.empty_like(self.nsteps, dtype=float)
+        for i, nstep in np.ndenumerate(self.nsteps):
             result[i] = self.compute_average(nstep)
         return result
 
@@ -159,14 +161,13 @@ class Distance(NWalk):
         Plot mean distance from starting point as function of number of steps
         """
 
-        nsteps = np.arange(1, 1000, 100)
-        distances = self.compute_steps(nsteps)
+        distances = self.compute_steps()
 
         fig, ax = self._init_figure()
 
-        ax.plot(nsteps, np.sqrt(2 * nsteps / pi),
+        ax.plot(self.nsteps, np.sqrt(2 * self.nsteps / pi),
                 label=r'$\sqrt{\frac{2n}{\pi}}$')
-        ax.plot(nsteps, distances, 'o',
+        ax.plot(self.nsteps, distances, 'o',
                 label=f'Average over {self.nwalk} samples')
         ax.legend()
 
@@ -253,12 +254,11 @@ class BackToStart(NWalk):
         Plot mean distance from starting point as function of number of steps
         """
 
-        nsteps = np.arange(1, 1000, 100)
-        ntimes = self.compute_steps(nsteps)
+        ntimes = self.compute_steps()
 
         fig, ax = self._init_figure(ylabel='Number of times')
 
-        ax.plot(nsteps, ntimes, 'o',
+        ax.plot(self.nsteps, ntimes, 'o',
                 label=f'Average over {self.nwalk} samples')
         ax.legend()
 
@@ -330,7 +330,7 @@ if __name__ == '__main__':
     anim = walk.generate_animation()
     walk.plot()
 
-    FinalDistance(1000).plot()
-    MaxDistance(1000).plot()
-    BackToStart(10000).plot()
+    FinalDistance(nwalk=1000).plot()
+    MaxDistance(nwalk=1000).plot()
+    BackToStart(nwalk=10000, nstepmax=5000).plot()
     plt.show()
-- 
GitLab