]> git.rkrishnan.org Git - tahoe-lafs/tahoe-lafs.git/commitdiff
add a basic concurrency limiter utility
authorBrian Warner <warner@allmydata.com>
Wed, 7 May 2008 23:53:30 +0000 (16:53 -0700)
committerBrian Warner <warner@allmydata.com>
Wed, 7 May 2008 23:53:30 +0000 (16:53 -0700)
src/allmydata/test/test_util.py
src/allmydata/util/limiter.py [new file with mode: 0644]

index 98d6da71b9f72ad8f2e22d5e67ee1b2ac4c36460..c990ca9c9fa5e6c08a95b994dcd3467781640c4a 100644 (file)
@@ -3,12 +3,12 @@ def foo(): pass # keep the line number constant
 
 import os
 from twisted.trial import unittest
-from twisted.internet import defer
+from twisted.internet import defer, reactor
 from twisted.python import failure
 
 from allmydata.util import base32, idlib, humanreadable, mathutil, hashutil
 from allmydata.util import assertutil, fileutil, testutil, deferredutil
-
+from allmydata.util import limiter
 
 class Base32(unittest.TestCase):
     def test_b2a_matches_Pythons(self):
@@ -444,3 +444,68 @@ class HashUtilTests(unittest.TestCase):
         h2 = hashutil.plaintext_segment_hasher()
         h2.update("foo")
         self.failUnlessEqual(h1, h2.digest())
+
+class Limiter(unittest.TestCase):
+    def job(self, i, foo):
+        self.calls.append( (i, foo) )
+        self.simultaneous += 1
+        self.peak_simultaneous = max(self.simultaneous, self.peak_simultaneous)
+        d = defer.Deferred()
+        def _done():
+            self.simultaneous -= 1
+            d.callback("done")
+        reactor.callLater(1.0, _done)
+        return d
+
+    def bad_job(self, i, foo):
+        raise RuntimeError("bad_job %d" % i)
+
+    def test_limiter(self):
+        self.calls = []
+        self.simultaneous = 0
+        self.peak_simultaneous = 0
+        l = limiter.ConcurrencyLimiter()
+        dl = []
+        for i in range(20):
+            dl.append(l.add(self.job, i, foo=str(i)))
+        d = defer.DeferredList(dl, fireOnOneErrback=True)
+        def _done(res):
+            self.failUnlessEqual(self.simultaneous, 0)
+            self.failUnless(self.peak_simultaneous <= 10)
+            self.failUnlessEqual(len(self.calls), 20)
+            for i in range(20):
+                self.failUnless( (i, str(i)) in self.calls)
+        d.addCallback(_done)
+        return d
+
+    def test_errors(self):
+        self.calls = []
+        self.simultaneous = 0
+        self.peak_simultaneous = 0
+        l = limiter.ConcurrencyLimiter()
+        dl = []
+        for i in range(20):
+            dl.append(l.add(self.job, i, foo=str(i)))
+        d2 = l.add(self.bad_job, 21, "21")
+        d = defer.DeferredList(dl, fireOnOneErrback=True)
+        def _most_done(res):
+            self.failUnless(self.peak_simultaneous <= 10)
+            self.failUnlessEqual(len(self.calls), 20)
+            for i in range(20):
+                self.failUnless( (i, str(i)) in self.calls)
+            def _good(res):
+                self.fail("should have failed, not got %s" % (res,))
+            def _err(f):
+                f.trap(RuntimeError)
+                self.failUnless("bad_job 21" in str(f))
+            d2.addCallbacks(_good, _err)
+            return d2
+        d.addCallback(_most_done)
+        def _all_done(res):
+            self.failUnlessEqual(self.simultaneous, 0)
+            self.failUnless(self.peak_simultaneous <= 10)
+            self.failUnlessEqual(len(self.calls), 20)
+            for i in range(20):
+                self.failUnless( (i, str(i)) in self.calls)
+        d.addCallback(_all_done)
+        return d
diff --git a/src/allmydata/util/limiter.py b/src/allmydata/util/limiter.py
new file mode 100644 (file)
index 0000000..62850bc
--- /dev/null
@@ -0,0 +1,36 @@
+
+from twisted.internet import defer
+from foolscap.eventual import eventually
+
+class ConcurrencyLimiter:
+    """I implement a basic concurrency limiter. Add work to it in the form of
+    (callable, args, kwargs) tuples. No more than LIMIT callables will be
+    outstanding at any one time.
+    """
+
+    def __init__(self, limit=10):
+        self.limit = 10
+        self.pending = []
+        self.active = 0
+
+    def add(self, cb, *args, **kwargs):
+        d = defer.Deferred()
+        task = (cb, args, kwargs, d)
+        self.pending.append(task)
+        self.maybe_start_task()
+        return d
+
+    def maybe_start_task(self):
+        if self.active >= self.limit:
+            return
+        if not self.pending:
+            return
+        (cb, args, kwargs, done_d) = self.pending.pop(0)
+        self.active += 1
+        d = defer.maybeDeferred(cb, *args, **kwargs)
+        d.addBoth(self._done, done_d)
+
+    def _done(self, res, done_d):
+        self.active -= 1
+        eventually(done_d.callback, res)
+        self.maybe_start_task()