import os
from twisted.trial import unittest
-from twisted.internet import defer
+from twisted.internet import defer, reactor
from twisted.python import failure
from allmydata.util import base32, idlib, humanreadable, mathutil, hashutil
from allmydata.util import assertutil, fileutil, testutil, deferredutil
-
+from allmydata.util import limiter
class Base32(unittest.TestCase):
def test_b2a_matches_Pythons(self):
h2 = hashutil.plaintext_segment_hasher()
h2.update("foo")
self.failUnlessEqual(h1, h2.digest())
+
+class Limiter(unittest.TestCase):
+ def job(self, i, foo):
+ self.calls.append( (i, foo) )
+ self.simultaneous += 1
+ self.peak_simultaneous = max(self.simultaneous, self.peak_simultaneous)
+ d = defer.Deferred()
+ def _done():
+ self.simultaneous -= 1
+ d.callback("done")
+ reactor.callLater(1.0, _done)
+ return d
+
+ def bad_job(self, i, foo):
+ raise RuntimeError("bad_job %d" % i)
+
+ def test_limiter(self):
+ self.calls = []
+ self.simultaneous = 0
+ self.peak_simultaneous = 0
+ l = limiter.ConcurrencyLimiter()
+ dl = []
+ for i in range(20):
+ dl.append(l.add(self.job, i, foo=str(i)))
+ d = defer.DeferredList(dl, fireOnOneErrback=True)
+ def _done(res):
+ self.failUnlessEqual(self.simultaneous, 0)
+ self.failUnless(self.peak_simultaneous <= 10)
+ self.failUnlessEqual(len(self.calls), 20)
+ for i in range(20):
+ self.failUnless( (i, str(i)) in self.calls)
+ d.addCallback(_done)
+ return d
+
+ def test_errors(self):
+ self.calls = []
+ self.simultaneous = 0
+ self.peak_simultaneous = 0
+ l = limiter.ConcurrencyLimiter()
+ dl = []
+ for i in range(20):
+ dl.append(l.add(self.job, i, foo=str(i)))
+ d2 = l.add(self.bad_job, 21, "21")
+ d = defer.DeferredList(dl, fireOnOneErrback=True)
+ def _most_done(res):
+ self.failUnless(self.peak_simultaneous <= 10)
+ self.failUnlessEqual(len(self.calls), 20)
+ for i in range(20):
+ self.failUnless( (i, str(i)) in self.calls)
+ def _good(res):
+ self.fail("should have failed, not got %s" % (res,))
+ def _err(f):
+ f.trap(RuntimeError)
+ self.failUnless("bad_job 21" in str(f))
+ d2.addCallbacks(_good, _err)
+ return d2
+ d.addCallback(_most_done)
+ def _all_done(res):
+ self.failUnlessEqual(self.simultaneous, 0)
+ self.failUnless(self.peak_simultaneous <= 10)
+ self.failUnlessEqual(len(self.calls), 20)
+ for i in range(20):
+ self.failUnless( (i, str(i)) in self.calls)
+ d.addCallback(_all_done)
+ return d
--- /dev/null
+
+from twisted.internet import defer
+from foolscap.eventual import eventually
+
+class ConcurrencyLimiter:
+ """I implement a basic concurrency limiter. Add work to it in the form of
+ (callable, args, kwargs) tuples. No more than LIMIT callables will be
+ outstanding at any one time.
+ """
+
+ def __init__(self, limit=10):
+ self.limit = 10
+ self.pending = []
+ self.active = 0
+
+ def add(self, cb, *args, **kwargs):
+ d = defer.Deferred()
+ task = (cb, args, kwargs, d)
+ self.pending.append(task)
+ self.maybe_start_task()
+ return d
+
+ def maybe_start_task(self):
+ if self.active >= self.limit:
+ return
+ if not self.pending:
+ return
+ (cb, args, kwargs, done_d) = self.pending.pop(0)
+ self.active += 1
+ d = defer.maybeDeferred(cb, *args, **kwargs)
+ d.addBoth(self._done, done_d)
+
+ def _done(self, res, done_d):
+ self.active -= 1
+ eventually(done_d.callback, res)
+ self.maybe_start_task()