PollMixin: add timeout= argument, rewrite to avoid tail-recursion problems
authorBrian Warner <warner@allmydata.com>
Tue, 5 Feb 2008 03:35:07 +0000 (20:35 -0700)
committerBrian Warner <warner@allmydata.com>
Tue, 5 Feb 2008 03:35:07 +0000 (20:35 -0700)
src/allmydata/test/test_util.py
src/allmydata/util/testutil.py

index 284bf5e5f82384743d8fc3258678d45424786a7f..b59e47a26ac9988acb718a0cbfc3ddcde02aefc4 100644 (file)
@@ -374,19 +374,25 @@ class PollMixinTests(unittest.TestCase):
     def setUp(self):
         self.pm = testutil.PollMixin()
 
-    def _check(self, d):
-        def fail_unless_arg_is_true(arg):
-            self.failUnless(arg is True, repr(arg))
-        d.addCallback(fail_unless_arg_is_true)
-        return d
-
     def test_PollMixin_True(self):
         d = self.pm.poll(check_f=lambda : True,
                          pollinterval=0.1)
-        return self._check(d)
+        return d
 
     def test_PollMixin_False_then_True(self):
         i = iter([False, True])
         d = self.pm.poll(check_f=i.next,
                          pollinterval=0.1)
-        return self._check(d)
+        return d
+
+    def test_timeout(self):
+        d = self.pm.poll(check_f=lambda: False,
+                         pollinterval=0.01,
+                         timeout=1)
+        def _suc(res):
+            self.fail("poll should have failed, not returned %s" % (res,))
+        def _err(f):
+            f.trap(testutil.TimeoutError)
+            return None # success
+        d.addCallbacks(_suc, _err)
+        return d
index 867cf00c17ed0916b2e78d5f1d2b0741c0e0a5be..13e72911a6cafc564c6cdd715f3af8eb087ed27c 100644 (file)
@@ -1,6 +1,6 @@
 import os, signal, time
 
-from twisted.internet import reactor, defer
+from twisted.internet import reactor, defer, task
 from twisted.python import failure
 
 
@@ -30,22 +30,32 @@ class SignalMixin:
         if self.sigchldHandler:
             signal.signal(signal.SIGCHLD, self.sigchldHandler)
 
+class TimeoutError(Exception):
+    pass
+
 class PollMixin:
 
-    def poll(self, check_f, pollinterval=0.01):
+    def poll(self, check_f, pollinterval=0.01, timeout=None):
         # Return a Deferred, then call check_f periodically until it returns
         # True, at which point the Deferred will fire.. If check_f raises an
-        # exception, the Deferred will errback.
-        d = defer.maybeDeferred(self._poll, None, check_f, pollinterval)
+        # exception, the Deferred will errback. If the check_f does not
+        # indicate success within timeout= seconds, the Deferred will
+        # errback. If timeout=None, no timeout will be enforced.
+        cutoff = None
+        if timeout is not None:
+            cutoff = time.time() + timeout
+        stash = [] # ick. We have to pass the LoopingCall into itself
+        lc = task.LoopingCall(self._poll, check_f, stash, cutoff)
+        stash.append(lc)
+        d = lc.start(pollinterval)
         return d
 
-    def _poll(self, res, check_f, pollinterval):
+    def _poll(self, check_f, stash, cutoff):
+        if cutoff is not None and time.time() > cutoff:
+            raise TimeoutError()
+        lc = stash[0]
         if check_f():
-            return True
-        d = defer.Deferred()
-        d.addCallback(self._poll, check_f, pollinterval)
-        reactor.callLater(pollinterval, d.callback, None)
-        return d
+            lc.stop()
 
 class ShouldFailMixin: