From: Brian Warner Date: Wed, 7 May 2008 23:53:30 +0000 (-0700) Subject: add a basic concurrency limiter utility X-Git-Tag: allmydata-tahoe-1.1.0~152 X-Git-Url: https://git.rkrishnan.org/?a=commitdiff_plain;h=c7e441309d247c144594467aa971dfe6f5a63858;p=tahoe-lafs%2Ftahoe-lafs.git add a basic concurrency limiter utility --- diff --git a/src/allmydata/test/test_util.py b/src/allmydata/test/test_util.py index 98d6da71..c990ca9c 100644 --- a/src/allmydata/test/test_util.py +++ b/src/allmydata/test/test_util.py @@ -3,12 +3,12 @@ def foo(): pass # keep the line number constant import os from twisted.trial import unittest -from twisted.internet import defer +from twisted.internet import defer, reactor from twisted.python import failure from allmydata.util import base32, idlib, humanreadable, mathutil, hashutil from allmydata.util import assertutil, fileutil, testutil, deferredutil - +from allmydata.util import limiter class Base32(unittest.TestCase): def test_b2a_matches_Pythons(self): @@ -444,3 +444,68 @@ class HashUtilTests(unittest.TestCase): h2 = hashutil.plaintext_segment_hasher() h2.update("foo") self.failUnlessEqual(h1, h2.digest()) + +class Limiter(unittest.TestCase): + def job(self, i, foo): + self.calls.append( (i, foo) ) + self.simultaneous += 1 + self.peak_simultaneous = max(self.simultaneous, self.peak_simultaneous) + d = defer.Deferred() + def _done(): + self.simultaneous -= 1 + d.callback("done") + reactor.callLater(1.0, _done) + return d + + def bad_job(self, i, foo): + raise RuntimeError("bad_job %d" % i) + + def test_limiter(self): + self.calls = [] + self.simultaneous = 0 + self.peak_simultaneous = 0 + l = limiter.ConcurrencyLimiter() + dl = [] + for i in range(20): + dl.append(l.add(self.job, i, foo=str(i))) + d = defer.DeferredList(dl, fireOnOneErrback=True) + def _done(res): + self.failUnlessEqual(self.simultaneous, 0) + self.failUnless(self.peak_simultaneous <= 10) + self.failUnlessEqual(len(self.calls), 20) + for i in range(20): + self.failUnless( (i, str(i)) in self.calls) + d.addCallback(_done) + return d + + def test_errors(self): + self.calls = [] + self.simultaneous = 0 + self.peak_simultaneous = 0 + l = limiter.ConcurrencyLimiter() + dl = [] + for i in range(20): + dl.append(l.add(self.job, i, foo=str(i))) + d2 = l.add(self.bad_job, 21, "21") + d = defer.DeferredList(dl, fireOnOneErrback=True) + def _most_done(res): + self.failUnless(self.peak_simultaneous <= 10) + self.failUnlessEqual(len(self.calls), 20) + for i in range(20): + self.failUnless( (i, str(i)) in self.calls) + def _good(res): + self.fail("should have failed, not got %s" % (res,)) + def _err(f): + f.trap(RuntimeError) + self.failUnless("bad_job 21" in str(f)) + d2.addCallbacks(_good, _err) + return d2 + d.addCallback(_most_done) + def _all_done(res): + self.failUnlessEqual(self.simultaneous, 0) + self.failUnless(self.peak_simultaneous <= 10) + self.failUnlessEqual(len(self.calls), 20) + for i in range(20): + self.failUnless( (i, str(i)) in self.calls) + d.addCallback(_all_done) + return d diff --git a/src/allmydata/util/limiter.py b/src/allmydata/util/limiter.py new file mode 100644 index 00000000..62850bc4 --- /dev/null +++ b/src/allmydata/util/limiter.py @@ -0,0 +1,36 @@ + +from twisted.internet import defer +from foolscap.eventual import eventually + +class ConcurrencyLimiter: + """I implement a basic concurrency limiter. Add work to it in the form of + (callable, args, kwargs) tuples. No more than LIMIT callables will be + outstanding at any one time. + """ + + def __init__(self, limit=10): + self.limit = 10 + self.pending = [] + self.active = 0 + + def add(self, cb, *args, **kwargs): + d = defer.Deferred() + task = (cb, args, kwargs, d) + self.pending.append(task) + self.maybe_start_task() + return d + + def maybe_start_task(self): + if self.active >= self.limit: + return + if not self.pending: + return + (cb, args, kwargs, done_d) = self.pending.pop(0) + self.active += 1 + d = defer.maybeDeferred(cb, *args, **kwargs) + d.addBoth(self._done, done_d) + + def _done(self, res, done_d): + self.active -= 1 + eventually(done_d.callback, res) + self.maybe_start_task()